lexical note hlt · adv-nlp

SRL with Logistic Regression

SRL with Logistic Regression

This is a simple implementation of SRL with Logistic Regression, using a sparse feature vector. The dataset implemented for this experiment is the English CoNLL-U (Universal Dependencies Consortium) from the 1.0 version of the Universal Proposition Banks (Universal Propositions Consortium).

Three features are used:

  • A feature combining the directed dependency path between the token and the predicate, along with the predicate’s lemma. This feature is a string that concatenates the dependency path and the predicate lemma.
  • The token’s lemma as a string.
  • The token’s position relative to the predicate as a string. Encoded as Before or After.

The sci-kit learn library is used for the feature vectorization and logistic regression model.

The learned model is available here: https://drive.google.com/file/d/1Zgq8W1yB6OfIdaMDpS8Xf9ZMcq5V1xkd/view?usp=share_link

import json
import sys
import pickle
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction import DictVectorizer
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

sys.path.append('feature_extraction')
from extract_position_rel2pred import extract_word_position_and_voice
from extract_dependency_path import extract_dependency_paths
from extract_predicate import extract_predicate_lemma

Step 1: Preprocess the data

Step 1.1: Load the datasets

The data is loaded from files wihin this repository. The files are in CoNNL-U Plus format.

dev_file_path = 'data/en_ewt-up-dev.conllu'
train_file_path = 'data/en_ewt-up-train.conllu'
test_file_path = 'data/en_ewt-up-test.conllu'

Step 1.2: Create a helper function to preprocess the data

The preprocessing of the data involves parsing the CoNNL-U Plus files and creating a list of lists of objects. Each list of objects represents one one proposition in a sentence. Each object represents a token in the sentence.

This process involves an expansion of the data, since one sentence can contain multiple predicate-argument structures.

Furthermore, argument labels V and C-V are both replaced with _ since they are not to be predicted.

def preprocess_data(file_path):
	"""
	Parses a CoNNL-U Plus file and returns a list of objects. 
 
 	Each object represents one semantic 'frame' in a sentence. 
  	Each frame has a predicate and a list of arguments.

	data (str): The file path to the data to be preprocessed.
	"""

	sentences = [] # Initialize an empty list for all sentences.
	sentence = []  # Initialize an empty list for the current sentence.
	with open(file_path, 'r', encoding="utf8") as file:
		for line in file:
			line = line.strip().split('\t')
			if line[0].startswith('#'):
				# If the line starts with '#', it's a comment, ignore it.
				continue
			elif line[0].strip() != '':	

				# Create a token if its ID does not contain a period.
				if '.' not in line[0] and len(line) > 10:
					token = {
						'form': line[1],
						'predicate': line[10],
						'argument': line[11:]  # Store all remaining columns as arguments.
					}
					# Append the token to the sentence.
					sentence.append(token)

			# A new line indicates the end of a sentence.
			elif line[0].strip() == '':
				# Append the completed sentence to the sentences list.
				sentences.append(sentence)
				# Reset sentence for the next sentence.
				sentence = []

	# Iterate over all sentences. Create copies of sentences for each predicate.
	expanded_sentences = []
	for sentence in sentences:
		# Find all predicates in the sentence.
		predicates = [token['predicate'] for token in sentence if token['predicate'] != '_']

		# for every predicate, create a copy of the sentence.
		for index, predicate in enumerate(predicates):
			sentence_copy = [token.copy() for token in sentence]
			for token in sentence_copy:
				# Keep only this predicate.
				if token['predicate'] != predicate:
					token['predicate'] = '_'

				# Keep only the relevant argument for this predicate. Overwrite 'V' and 'C-V' with '_'.
				token['argument'] = '_' if token['argument'][index] in ['V', 'C-V'] else token['argument'][index]

			# Append only sentences with arguments.
			if any(token['argument'] != '_' for token in sentence_copy):	
				expanded_sentences.append(sentence_copy)

	return expanded_sentences
def count_before_preprocessing(file_path):
    """
    Counts the original number of sentences and tokens in a CoNLL-U Plus file
    before any duplication or token filtering.
    
    Args:
        file_path (str): Path to the CoNLL-U Plus file.
    
    Returns:
        tuple: (number_of_sentences, number_of_tokens)
    """
    num_sentences = 0
    num_tokens = 0
    
    with open(file_path, 'r', encoding="utf8") as file:
        for line in file:
            line = line.strip()
            if line.startswith('#'):
                continue  # Ignore comment lines
            elif line == '':
                num_sentences += 1  # Sentence boundary
            else:
                num_tokens += 1  # Count tokens
    
    return num_sentences, num_tokens

Step 1.3: Preprocess the data

Use the helper function to preprocess the data. The data is stored in a list of lists of objects:

  • Each list of objects represents one proposition.
  • Each object represents one token.
dev_data = preprocess_data(dev_file_path)
train_data = preprocess_data(train_file_path)
test_data = preprocess_data(test_file_path)

Step 1.4: Inspect the preprocessed data

First, let’s inspect a sample of the preprocessed data.

for sentences in dev_data[:3]:
    print(json.dumps(sentences, indent=4))
[
    {
        "form": "From",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "the",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "AP",
        "predicate": "_",
        "argument": "ARG2"
    },
    {
        "form": "comes",
        "predicate": "come.03",
        "argument": "_"
    },
    {
        "form": "this",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "story",
        "predicate": "_",
        "argument": "ARG1"
    },
    {
        "form": ":",
        "predicate": "_",
        "argument": "_"
    }
]
[
    {
        "form": "President",
        "predicate": "_",
        "argument": "ARG0"
    },
    {
        "form": "Bush",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "on",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "Tuesday",
        "predicate": "_",
        "argument": "ARGM-TMP"
    },
    {
        "form": "nominated",
        "predicate": "nominate.01",
        "argument": "_"
    },
    {
        "form": "two",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "individuals",
        "predicate": "_",
        "argument": "ARG1"
    },
    {
        "form": "to",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "replace",
        "predicate": "_",
        "argument": "ARG2"
    },
    {
        "form": "retiring",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "jurists",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "on",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "federal",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "courts",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "in",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "the",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "Washington",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "area",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": ".",
        "predicate": "_",
        "argument": "_"
    }
]
[
    {
        "form": "President",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "Bush",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "on",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "Tuesday",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "nominated",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "two",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "individuals",
        "predicate": "_",
        "argument": "ARG0"
    },
    {
        "form": "to",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "replace",
        "predicate": "replace.01",
        "argument": "_"
    },
    {
        "form": "retiring",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "jurists",
        "predicate": "_",
        "argument": "ARG1"
    },
    {
        "form": "on",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "federal",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "courts",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "in",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "the",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "Washington",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": "area",
        "predicate": "_",
        "argument": "_"
    },
    {
        "form": ".",
        "predicate": "_",
        "argument": "_"
    }
]

For each dataset, print the number of tokens and number of sentences before and after preprocessing.

The final test set tokens that are classified is a subset (80027) of the 84328 tokens in the test data after preprocessing. This is because predicate tokens are excluded from argument identification and classsification, because the predicate is never an argument to itself.

print(f"Before preprocessing:\n")
test_sentences, test_tokens = count_before_preprocessing(test_file_path)
train_sentences, train_tokens = count_before_preprocessing(train_file_path)
dev_sentences, dev_tokens = count_before_preprocessing(dev_file_path)

print(f"{train_sentences} sentences and {train_tokens} tokens in train data")
print(f"{dev_sentences} sentences and {dev_tokens} tokens in dev data")
print(f"{test_sentences} sentences and {test_tokens} tokens in test data")

print(f"\nAfter preprocessing:\n")

train_tokens_after = sum(len(sentence) for sentence in train_data)
dev_tokens_after = sum(len(sentence) for sentence in dev_data)
test_tokens_after = sum(len(sentence) for sentence in test_data)


print(f"{len(train_data)} sentences and {train_tokens_after} tokens in train data")
print(f"{len(dev_data)} sentences and {dev_tokens_after} tokens in dev data")
print(f"{len(test_data)} sentences and {test_tokens_after} tokens in test data")
Before preprocessing:

12543 sentences and 204609 tokens in train data
2002 sentences and 25150 tokens in dev data
2077 sentences and 25097 tokens in test data

After preprocessing:

33521 sentences and 852388 tokens in train data
4143 sentences and 87697 tokens in dev data
3971 sentences and 84328 tokens in test data

Step 2: Extract features

Step 2.1: Create a helper function to extract features

The extracted features are:

1. A feature combining the directed dependency path between the token and the predicate along with the predicate’s lemma. This feature is a string that concatenates the dependency path and the predicate lemma. This feature is designed to capture the syntactic relation between the target word and the predicate of the sentence, and leverages syntactic typicality (Palmer, Gildea & Xue, 2010). The path in the parse tree provides a compact representation of various kinds of grammatical relationships between the target word and predixate. For instance, the path can indicate if a target word is the subject, direct object or adjunct related to the predicate. And the semantic role of a subject is likely to be an ARG-0, whereas the direct object is likely to be an ARG-1 variant. In essence, the dependency path acts as a crucial link between the syntactic structure of a sentence and its underlying semantic organization

2. The token’s form as a string. A word’s form carries intrinsic semantic meaning. Different words represent different concepts, and are thus more likely to fulfill a particular semantic role. The word form provides direct access to this lexical semantic quality of a target token. Furthermore, predicates often introduce selectional restrictions (Palmer, Gildea & Xue, 2010). For example the predicate eat.01 has an ARG1-PPT (meal) that’s likely to be something edible. Word forms such as “bread” or “spaghetti” are thus much more likely to fulfill the ARG1-PPT role to this predicate than words referrring to abstract concepts as “liberty”, or non-edible concepts as “painting”. Furthermore, certain words frequently appear in specific semantic roles with particular predicates. The target token’s form can capture these common co-occurences. For example, a “chef” often cooks (cook.01), and a “shooter” often shoots (shoot.01).

3. The token’s position relative to the predicate as a string, combined with the voice.

Encoded as (one of):

  • Before+Active
  • After+Active
  • Before+Passive
  • After+Passive

The poition relative to the predicate is highly correlated with grammatical function: In the English language, the position of a constituent relative to the prediate is highly correlated with its grammatical function. Subjects generally appear before the verb, and objects typically appear after the verb. The combination with voice is made because active and passive voice constructions often present th same semantic roles in different syntactic positions (Jurafsky & Martin, 2024). In an active sentence, the agent is typically the subject and the patient is the object. However, in a passive sentence, the patient often becomes the subject and the agent may be expressed in a prepositional phrase, or omitted entirely. In an active voice sentence, a noun phrase before the verb is more likely to be the agent, whereas in a passive voice sentence, a noun phrase before the verb is likely to be the patient.

All three strings will later be vectorized using a DictVectorizer, resulting in one-hot encoded features.

def extract_features(data):
	"""
	Extract features from the data.

	Returns a list of samples.
	"""

	samples = []

	for sentence in data:

		# Extract features
		positions_rel2pred, verb_voice = extract_word_position_and_voice(sentence)
		d_paths = extract_dependency_paths(sentence)
		predicate_lemma = extract_predicate_lemma(sentence)
  
		# Create a sample for each token in the sentence.
		for i, token in enumerate(sentence):
			# Skip predicate tokens.	
			if token['predicate'] == '_':
				sample = {
					'token': token['form'],
					'position_rel2pred': positions_rel2pred[i] + verb_voice,
					'dep_path+lemma': d_paths[i] + predicate_lemma
				}
				samples.append(sample)

	return samples

Use the helper function to extract features from the data.

The result is a list of samples, where each sample represents a token.Every sample is a dictionary of the three extracted features.

samples_train = extract_features(train_data)
samples_dev = extract_features(dev_data)
samples_test = extract_features(test_data)

Step 2.2 Extract the gold labels

Create a helper function to extract the gold labels from the data.

def extract_gold_labels(data):
    '''
    Extract gold labels.
    Return a list of gold labels.
    
    :param data_file: a list of objects, where each object represents one 'frame' in a sentence.
    '''
    labels = []
    
    for sentence in data:
        for word in sentence:
            # Skip predicate tokens.
            if word['predicate'] == '_':
                # Append the gold label to the list.
                gold_label = word['argument']
                labels.append(gold_label)

    return labels

Use the helper function to extract the gold labels from the three datasets.

gold_labels_train = extract_gold_labels(train_data)
gold_labels_dev = extract_gold_labels(dev_data)
gold_labels_test = extract_gold_labels(test_data)

Inspect the first ten samples and their gold labels.

for gold, sample in zip(gold_labels_train[:15], samples_train[:30]):
    print(f"gold: {gold:<6} sample: {sample}")
gold: _      sample: {'token': 'Al', 'position_rel2pred': 'BeforeActive', 'dep_path+lemma': '↑compound↓acl:kill'}
gold: _      sample: {'token': '-', 'position_rel2pred': 'BeforeActive', 'dep_path+lemma': '↑punct↓acl:kill'}
gold: _      sample: {'token': 'Zaman', 'position_rel2pred': 'BeforeActive', 'dep_path+lemma': '↓acl:kill'}
gold: _      sample: {'token': ':', 'position_rel2pred': 'BeforeActive', 'dep_path+lemma': '↑punct↓acl:kill'}
gold: _      sample: {'token': 'American', 'position_rel2pred': 'BeforeActive', 'dep_path+lemma': '↑amod↑nsubj:kill'}
gold: ARG0   sample: {'token': 'forces', 'position_rel2pred': 'BeforeActive', 'dep_path+lemma': '↑nsubj:kill'}
gold: ARG1   sample: {'token': 'Shaikh', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑compound↑dobj:kill'}
gold: _      sample: {'token': 'Abdullah', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑compound↑dobj:kill'}
gold: _      sample: {'token': 'al', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑compound↑dobj:kill'}
gold: _      sample: {'token': '-', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑punct↑dobj:kill'}
gold: _      sample: {'token': 'Ani', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑dobj:kill'}
gold: _      sample: {'token': ',', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑punct↑dobj:kill'}
gold: _      sample: {'token': 'the', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑det↑appos↑dobj:kill'}
gold: _      sample: {'token': 'preacher', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑appos↑dobj:kill'}
gold: _      sample: {'token': 'at', 'position_rel2pred': 'AfterActive', 'dep_path+lemma': '↑prep↑appos↑dobj:kill'}

Step 3: Train a model

Use the sci-kit learn library to instantiate a DictVectorizer and transform all samples into sparse feature vectors.

# Create a vectorizer and fit it to the samples
vectorizer = DictVectorizer()
vectorizer.fit(samples_train)

# Transform the train, dev and test samples using the vectorizer
feature_vectors_train = vectorizer.transform(samples_train)
feature_vectors_dev = vectorizer.transform(samples_dev)
feature_vectors_test = vectorizer.transform(samples_test)

Train a logistic regression model on the feature vectors and gold labels of the training data.

# Train a logistic regression model using the Sci-kit learn library.
model = LogisticRegression(solver='saga')
model.fit(feature_vectors_train, gold_labels_train)
/Users/krisstallenberg/anaconda3/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
LogisticRegression(solver='saga')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
# Store the model to file
with open('model.pkl', 'wb') as file:
    pickle.dump(model, file)

# Store the vectorizer to file
with open('vectorizer.pkl', 'wb') as file:
    pickle.dump(vectorizer, file)

Step 4: Infer and evaluate the model

Infer the gold labels for the development data. Store the predictions in a list.

predictions_dev = model.predict(feature_vectors_dev)

Step 4.1: Evaluate the model

Create a helper function to create the classification report using the scikit-learn library.

def create_classification_report(gold_labels, predictions, label_set):
    """
    Create a classification report.
    
    :param gold_labels: The gold labels.
    :param predictions: The predictions.
    :param label_set: The set of labels.
    """

    # Create a classification report and confusion matrix
    report_dict = classification_report(gold_labels, predictions, digits=3, target_names=label_set, output_dict=True)
    report = classification_report(gold_labels, predictions, digits=3, target_names=label_set, zero_division=0.0)

    # Print the classification report.
    print(report)

    return report_dict, report, label_set

Create a helper function to create a confusion matrix using the scikit-learn library.

def plot_confusion_matrix(gold_labels, predictions, label_set):
    """
    Plot the confusion matrix.
    
    :param gold_labels: The gold labels.
    :param predictions: The predictions.
    :param label_set: The set of labels.
    """
    # Create a confusion matrix.
    cf_matrix = confusion_matrix(gold_labels, predictions) 

    # Create a display for the confusion matrix.
    display = ConfusionMatrixDisplay(confusion_matrix=cf_matrix, display_labels=label_set)

    # Create a plot for the confusion matrix.
    fig, ax = plt.subplots(figsize=(15, 15)) 

    # Display the confusion matrix.
    display.plot(ax=ax) 
    plt.xticks(rotation=90)
    plt.show() 
    
    return cf_matrix

Use the helper functions to create the classification report and confusion matrix.

# Create a set of labels in alphabetical order.
label_set_dev = sorted(set(gold_labels_dev))

# Create a classification report and confusion matrix
report_dict_dev, report_dev, label_set_dev = create_classification_report(gold_labels_dev, predictions_dev, label_set_dev)
cf_matrix_dev = plot_confusion_matrix(gold_labels_dev, predictions_dev, label_set_dev)
/Users/krisstallenberg/anaconda3/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/krisstallenberg/anaconda3/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/krisstallenberg/anaconda3/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

        ARG0      0.858     0.525     0.652      1733
        ARG1      0.798     0.529     0.636      3322
        ARG2      0.706     0.517     0.597      1212
        ARG3      0.579     0.143     0.229        77
        ARG4      0.409     0.383     0.396        47
        ARG5      0.000     0.000     0.000         1
    ARGM-ADJ      0.771     0.257     0.386       249
    ARGM-ADV      0.684     0.194     0.302       479
    ARGM-CAU      0.833     0.071     0.132        70
    ARGM-COM      0.000     0.000     0.000        14
    ARGM-CXN      0.000     0.000     0.000        14
    ARGM-DIR      0.500     0.167     0.250        48
    ARGM-DIS      0.870     0.323     0.471       186
    ARGM-EXT      0.903     0.248     0.389       113
    ARGM-GOL      0.000     0.000     0.000        26
    ARGM-LOC      0.425     0.070     0.121       242
    ARGM-LVB      1.000     0.293     0.454        75
    ARGM-MNR      0.638     0.171     0.270       175
    ARGM-MOD      0.909     0.660     0.765       377
    ARGM-NEG      0.903     0.563     0.693       215
    ARGM-PRD      1.000     0.020     0.039        50
    ARGM-PRP      0.600     0.049     0.091        61
    ARGM-PRR      1.000     0.467     0.636        75
    ARGM-REC      0.000     0.000     0.000         4
    ARGM-TMP      0.786     0.335     0.469       550
      C-ARG0      0.000     0.000     0.000         4
      C-ARG1      0.900     0.170     0.286        53
      C-ARG2      0.000     0.000     0.000         7
      C-ARG3      0.000     0.000     0.000         7
  C-ARGM-CXN      0.000     0.000     0.000         7
  C-ARGM-EXT      0.000     0.000     0.000         2
  C-ARGM-LOC      0.000     0.000     0.000         3
  C-ARGM-MNR      0.000     0.000     0.000         1
      R-ARG0      0.641     0.417     0.505        60
      R-ARG1      0.394     0.197     0.263        66
      R-ARG2      0.000     0.000     0.000         4
      R-ARG3      0.000     0.000     0.000         1
  R-ARGM-ADV      0.000     0.000     0.000         1
  R-ARGM-CAU      0.000     0.000     0.000         1
  R-ARGM-COM      0.000     0.000     0.000         1
  R-ARGM-LOC      0.400     0.200     0.267        10
  R-ARGM-MNR      0.000     0.000     0.000         2
  R-ARGM-TMP      0.000     0.000     0.000         8
           _      0.937     0.991     0.963     73478

    accuracy                          0.927     83131
   macro avg      0.419     0.181     0.233     83131
weighted avg      0.918     0.927     0.915     83131



png

Save the evaluation report of the development data to a JSON file.

Step 5: Evaluate the model on the test data

Infer the gold labels for the test data.

# Predict the gold labels
predictions_test = model.predict(feature_vectors_test)

Step 5.1: Store predictions to file

Create a TSV file with rows representing tokens. Every row conists of:

  • Token form
  • Gold label
  • Predicted label

The results are saved in a file named predictions.tsv.

# Extract the word forms
word_forms = []
for sentence in test_data:
    for word in sentence:
        if word['predicate'] == '_':
            word_forms.append(word['form'])

# Save the predictions to a file
with open('predictions.tsv', 'w', encoding='utf-8') as file:
    file.write("Word\tGold\tPrediction\n")
    for word, prediction, gold in zip(word_forms, predictions_test, gold_labels_test):
        file.write(f"{word}\t{gold}\t{prediction}\n")

Evaluate the model on the test data.

# Create a set of labels in alphabetical order.
label_set_test = sorted(set(gold_labels_test))

# Create a classification report and confusion matrix
report_dict_test, report_test, label_set_test = create_classification_report(gold_labels_test, predictions_test, label_set_test)
cf_matrix_test = plot_confusion_matrix(gold_labels_test, predictions_test, label_set_test)
/Users/krisstallenberg/anaconda3/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/krisstallenberg/anaconda3/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/krisstallenberg/anaconda3/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

        ARG0      0.871     0.525     0.655      1733
        ARG1      0.799     0.533     0.639      3241
    ARG1-DSP      0.000     0.000     0.000         4
        ARG2      0.708     0.516     0.597      1129
        ARG3      0.733     0.149     0.247        74
        ARG4      0.481     0.446     0.463        56
        ARG5      0.000     0.000     0.000         1
        ARGA      0.000     0.000     0.000         2
    ARGM-ADJ      0.774     0.317     0.450       227
    ARGM-ADV      0.596     0.200     0.300       495
    ARGM-CAU      0.667     0.087     0.154        46
    ARGM-COM      0.000     0.000     0.000        13
    ARGM-CXN      1.000     0.083     0.154        12
    ARGM-DIR      0.400     0.043     0.077        47
    ARGM-DIS      0.701     0.258     0.378       182
    ARGM-EXT      0.929     0.371     0.531       105
    ARGM-GOL      0.000     0.000     0.000        24
    ARGM-LOC      0.730     0.130     0.221       207
    ARGM-LVB      0.944     0.246     0.391        69
    ARGM-MNR      0.786     0.074     0.136       148
    ARGM-MOD      0.919     0.697     0.793       442
    ARGM-NEG      0.887     0.620     0.730       216
    ARGM-PRD      0.000     0.000     0.000        44
    ARGM-PRP      0.600     0.080     0.141        75
    ARGM-PRR      0.769     0.435     0.556        69
    ARGM-TMP      0.805     0.326     0.464       543
      C-ARG0      0.000     0.000     0.000         3
      C-ARG1      1.000     0.212     0.349        52
  C-ARG1-DSP      0.000     0.000     0.000         1
      C-ARG2      0.000     0.000     0.000         7
      C-ARG3      0.000     0.000     0.000         2
  C-ARGM-CXN      0.000     0.000     0.000         5
  C-ARGM-LOC      0.000     0.000     0.000         1
      R-ARG0      0.742     0.343     0.469        67
      R-ARG1      0.400     0.231     0.293        52
      R-ARG2      0.000     0.000     0.000         1
  R-ARGM-ADJ      0.000     0.000     0.000         1
  R-ARGM-ADV      0.000     0.000     0.000         1
  R-ARGM-DIR      0.000     0.000     0.000         1
  R-ARGM-LOC      1.000     0.111     0.200         9
  R-ARGM-MNR      0.000     0.000     0.000         8
  R-ARGM-TMP      0.000     0.000     0.000         2
           _      0.937     0.991     0.963     70610

    accuracy                          0.927     80027
   macro avg      0.446     0.187     0.241     80027
weighted avg      0.918     0.927     0.916     80027



png

Create a helper function to save the evaluation report to a JSON file.

def save_evaluation_report(report_dict, cf_matrix):
    """
    Save the evaluation report to a JSON file.
    
    :param report_dict: The classification report dictionary.
    :param cf_matrix: The confusion matrix.
    """
    evaluation_report = {
        "classification_report": report_dict,
        "confusion_matrix": cf_matrix.tolist()
    }

    # Create a report file with the current date and time
    filename = f"evaluation_report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"  # create file name with current date and time

    with open(filename, 'w') as f:  # write report to file
        json.dump(evaluation_report, f, indent=4)

    print(f"Evaluation report saved to: {filename}")
# Save the evaluation report of the test data to a JSON file
save_evaluation_report(report_dict_test, cf_matrix_test)
Evaluation report saved to: evaluation_report_2025-03-28_22-19-23.json

The evaluation report of the test data is saved to a JSON file.

The logistic regression model achieves a macro-average F1 score of 24% on the test dataset with an overall accuracy of 0.927 (92.7% of all—including label ’_‘—predictions are correct).

Note that macro-average F1 scores are unweighted and impacted by the presence of many rare labels with low precision and recall scores.

Standalone function

The classify_sentence() function takes as input:

  • A list (of length n) of strings, representing the tokens of the sentence.
  • A one-hot encoding (of length n) representing the location of the predicate.
  • A list (of length n) of argument labels (these are only used to create a data structure that matches the output of the standard preprocessing, letting me reuse the original feature extraction function without adjustments)
  • The predicate sense (this is only used to create a data structure that matches the output of the standard preprocessing, letting me reuse the original feature extraction function without adjustments)
def extract_features(data):
	"""
	Extract features from the data.

	Returns a list of samples.
	"""

	samples = []

	for sentence in data:

		# Extract features
		positions_rel2pred, verb_voice = extract_word_position_and_voice(sentence)
		d_paths = extract_dependency_paths(sentence)
		predicate_lemma = extract_predicate_lemma(sentence)
  
		# Create a sample for each token in the sentence.
		for i, token in enumerate(sentence):
			# Skip predicate tokens.	
			if token['predicate'] == '_':
				sample = {
					'token': token['form'],
					'position_rel2pred': positions_rel2pred[i] + verb_voice,
					'dep_path+lemma': d_paths[i] + predicate_lemma
				}
				samples.append(sample)

	return samples

def format_sentence(sentence, predicate_location, argument_labels, predicate_form):
    """
    Formats a sentence into a list of dictionaries, to match the input format of the feature extraction functions.
    
    Args:
        sentence (list): A list of words.
        predicate_location (list): A one-hot vector, indicating the location of the predicate.
        argument_labels (list): A list of argument labels.
        predicate_form (str): The sense label of the predicate.
    """
    output = []
    
    for i, word in enumerate(sentence):
        word_dict = {
            "form": word,
            "predicate": predicate_form if predicate_location[i] == 1 else "_",
            "argument": argument_labels[i]
        }
        output.append(word_dict)
    
    return output

def classify_sentence_logreg(sentence, predicate_location, argument_labels, predicate_sense, model, vectorizer):
    """
    The standalone function that takes a sentence and predicts the argument labels, using logistic regression
    
    Args:
        sentence (list): A list of words.
        predicate_location (list): A one-hot vector, indicating the location of the predicate.
        argument_labels (list): A list of argument labels.
        predicate_sense (str): The sense label of the predicate.
    """
    formatted_output = format_sentence(sentence, predicate_location, argument_labels, predicate_sense)
    sample = extract_features([formatted_output])
    feature_vectors = vectorizer.transform(sample)
    predictions = model.predict(feature_vectors)
    predictions = np.insert(predictions, predicate_location.index(1), '_')
    return predictions

sentence = ["The", "dog", "ran", "and", "the", "man", "fell", "."]

predicate_location = [0, 0, 1, 0, 0, 0, 0, 0] 
argument_labels = ['_', 'ARG0', '_', '_', '_', '_', '_', '_']
predicate_sense = "run.01"

with open("vectorizer.pkl", "rb") as f:
    vectorizer = pickle.load(f)

with open("model.pkl", "rb") as f:
    model = pickle.load(f)

predictions = classify_sentence_logreg(sentence, predicate_location, argument_labels, predicate_sense, model, vectorizer)
print(f"Token\t\tPrediction\tGold\n=====================================")
for word, prediction, gold, pred_location in zip(sentence, predictions, argument_labels, predicate_location):
    if pred_location == 1:
        print(f"{word} ({predicate_sense})\t\t{prediction}\t\t{gold}")
    else:   
        print(f"{word}\t\t{prediction}\t\t{gold}")

predicate_location = [0, 0, 0, 0, 0, 0, 1, 0] 
argument_labels = ['_', '_', '_', '_', '_', 'ARG1', '_', '_']
predicate_sense = "fall.01"

predictions = classify_sentence_logreg(sentence, predicate_location, argument_labels, predicate_sense, model, vectorizer)
print(f"\n\nToken\t\tPrediction\tGold\n=====================================")
for word, prediction, gold, pred_location in zip(sentence, predictions, argument_labels, predicate_location):
    if pred_location == 1:
        print(f"{word} ({predicate_sense})\t\t{prediction}\t\t{gold}")
    else:   
        print(f"{word}\t\t{prediction}\t\t{gold}")
Token		Prediction	Gold
=====================================
The		_		_
dog		ARG0		ARG0
ran (run.01)		_		_
and		_		_
the		_		_
man		_		_
fell		_		_
.		_		_


Token		Prediction	Gold
=====================================
The		_		_
dog		_		_
ran		_		_
and		_		_
the		_		_
man		_		ARG1
fell (fall.01)		_		_
.		_		_