lexical note hlt · adv-nlp

CheckList test of two SRL sytems

CheckList test of two SRL sytems

In this notebook, we’ll apply CheckList tests to two Semantic Role Labeling (SRL) models:

  1. A logistic regression model, trained on three features.
  2. A DistillBERT model, fine-tuned on a CoNLL SRL dataset.

Importing dependencies

# For the BERT model

import time
import pandas as pd
import transformers
import numpy as np
import torch
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
from datasets import Dataset
from utils import read_data_as_sentence,map_labels_in_dataframe,tokenize_and_align_labels,get_label_mapping,get_labels_from_map,load_srl_model,load_dataset,compute_metrics,write_predictions_to_csv,compute_evaluation_metrics_from_csv, print_sentences
from bert_srl import main, define_args

# For the logistic regression model
import json
import sys
import pickle
from datetime import datetime
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction import DictVectorizer
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

sys.path.append('feature_extraction')
from extract_position_rel2pred import extract_word_position_and_voice
from extract_dependency_path import extract_dependency_paths
from extract_predicate import extract_predicate_lemma

1. Declare standalone functions

1.1 Logistic regression

def extract_features(data):
	"""
	Extract features from the data.

	Returns a list of samples.
	"""

	samples = []

	for sentence in data:

		# Extract features
		positions_rel2pred, verb_voice = extract_word_position_and_voice(sentence)
		d_paths = extract_dependency_paths(sentence)
		predicate_lemma = extract_predicate_lemma(sentence)
  
		# Create a sample for each token in the sentence.
		for i, token in enumerate(sentence):
			# Skip predicate tokens.	
			if token['predicate'] == '_':
				sample = {
					'token': token['form'],
					'position_rel2pred': positions_rel2pred[i] + verb_voice,
					'dep_path+lemma': d_paths[i] + predicate_lemma
				}
				samples.append(sample)

	return samples

def format_sentence(sentence, predicate_location, argument_labels, predicate_form):
    """
    Formats a sentence into a list of dictionaries, to match the input format of the feature extraction functions.
    
    Args:
        sentence (list): A list of words.
        predicate_location (list): A one-hot vector, indicating the location of the predicate.
        argument_labels (list): A list of argument labels.
        predicate_form (str): The sense label of the predicate.
    """
    output = []
    
    for i, word in enumerate(sentence):
        word_dict = {
            "form": word,
            "predicate": predicate_form if predicate_location[i] == 1 else "_",
            "argument": argument_labels[i]
        }
        output.append(word_dict)
    
    return output

def classify_sentence_logreg(sentence, predicate_location, argument_labels, predicate_sense, model, vectorizer):
    """
    The standalone function that takes a sentence and predicts the argument labels, using logistic regression
    
    Args:
        sentence (list): A list of words.
        predicate_location (list): A one-hot vector, indicating the location of the predicate.
        argument_labels (list): A list of argument labels.
        predicate_sense (str): The sense label of the predicate.
    """
    formatted_output = format_sentence(sentence, predicate_location, argument_labels, predicate_sense)
    sample = extract_features([formatted_output])
    feature_vectors = vectorizer.transform(sample)
    predictions = model.predict(feature_vectors)
    predictions = np.insert(predictions, predicate_location.index(1), '_')
    return predictions

sentence = ["The", "dog", "ran", "and", "the", "man", "fell", "."]

predicate_location = [0, 0, 1, 0, 0, 0, 0, 0] 
argument_labels = ['_', 'ARG0', '_', '_', '_', '_', '_', '_']
predicate_sense = "run.01"

with open("learned-models/vectorizer.pkl", "rb") as f:
    vectorizer = pickle.load(f)

with open("learned-models/model.pkl", "rb") as f:
    model = pickle.load(f)
/Users/krisstallenberg/anaconda3/envs/adv-nlp-final-exam/lib/python3.12/site-packages/sklearn/base.py:380: InconsistentVersionWarning: Trying to unpickle estimator DictVectorizer from version 1.6.0 when using version 1.6.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  warnings.warn(
/Users/krisstallenberg/anaconda3/envs/adv-nlp-final-exam/lib/python3.12/site-packages/sklearn/base.py:380: InconsistentVersionWarning: Trying to unpickle estimator LogisticRegression from version 1.6.0 when using version 1.6.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  warnings.warn(

1.2 DistillBERT

def create_input_sequence(sentence, predicate_position, argument_labels):
    """
    Creates a DataFrame with columns 'input_form' and 'argument' for a single sentence.

    Parameters:
    - sentence (list of str): The words in the sentence.
    - predicate_position (list of int): One-hot encoding indicating the predicate position.
    - argument_labels (list of str): The argument labels for each token in the sentence.

    Returns:
    - DataFrame with two columns: 'input_form' and 'argument'.
    """
    # Ensure input lengths match
    assert len(sentence) == len(predicate_position) == len(argument_labels), "Input lists must have the same length."
    
    # Determine the predicate form based on the one-hot encoding
    predicate_index = predicate_position.index(1)
    predicate_form = sentence[predicate_index]
    
    # Append special tokens to input_form and argument lists
    input_form = sentence + ['[SEP]', predicate_form]
    argument = argument_labels + [None, None]
    
    # Create a DataFrame
    df = pd.DataFrame([{"input_form": input_form, "argument": argument}])
    return df

def map_labels_to_words(predicted_labels, gold_labels, dataset):
    tokens = []
    for i, (predictions, gold_labels) in enumerate(zip(predicted_labels, gold_labels)):
        subword_tokens = tokenizer.convert_ids_to_tokens(dataset[i]["input_ids"], skip_special_tokens=True)
        
        word_tokens = []
        word_labels_gold = []
        word_labels_pred = []
    
        current_word = ""
        current_gold_label = None
        current_pred_label = None
    
        for idx, (subword, gold, pred) in enumerate(zip(subword_tokens, gold_labels, predictions)):
            if subword.startswith("##"):  # Continuation of a word
                current_word += subword[2:]
            else:  # New word starts
                if current_word:  # Save the previous word and its label
                    word_tokens.append(current_word)
                    word_labels_gold.append(current_gold_label)
                    word_labels_pred.append(current_pred_label)
                
                current_word = subword  # Start new word
                current_gold_label = gold  # Take the first subword's label
                current_pred_label = pred  # Take the first subword's label
    
        if current_word:
            word_tokens.append(current_word)
            word_labels_gold.append(current_gold_label)
            word_labels_pred.append(current_pred_label)
    
        tokens.extend(zip(word_tokens, word_labels_gold, word_labels_pred))
    
    # Create a dataframe and write to CVS
    df = pd.DataFrame(tokens, columns=["word", "gold_label", "predicted_label"])
    return df

def classify_sentence_bert(sentence, predicate_location, argument_labels, predicate_sense, trainer, tokenizer):
    """
    The standalone function that takes a sentence and predicts the argument labels, using DistilBERT.
    
    Args:
        sentence (list): A list of words.
        predicate_location (list): A one-hot vector, indicating the location of the predicate.
        argument_labels (list): A list of argument labels.
        predicate_sense (str): The sense label of the predicate.
        trainer: The HuggingFace Trainer instance to predict labels with.
        tokenizer: The HuggingFace Tokenizer to tokenize input sequences with. 
    """
    inference_input = create_input_sequence(sentence, predicate_location, argument_labels)
    label_map = {'_': 0, 'ARG0': 1, 'ARG1': 2, 'ARG1-DSP': 3, 'ARG2': 4, 'ARG3': 5, 'ARG4': 6, 'ARG5': 7, 'ARGA': 8, 'ARGM-ADJ': 9, 'ARGM-ADV': 10, 'ARGM-CAU': 11, 'ARGM-COM': 12, 'ARGM-CXN': 13, 'ARGM-DIR': 14, 'ARGM-DIS': 15, 'ARGM-EXT': 16, 'ARGM-GOL': 17, 'ARGM-LOC': 18, 'ARGM-LVB': 19, 'ARGM-MNR': 20, 'ARGM-MOD': 21, 'ARGM-NEG': 22, 'ARGM-PRD': 23, 'ARGM-PRP': 24, 'ARGM-PRR': 25, 'ARGM-REC': 26, 'ARGM-TMP': 27, 'C-ARG0': 28, 'C-ARG1': 29, 'C-ARG1-DSP': 30, 'C-ARG2': 31, 'C-ARG3': 32, 'C-ARG4': 33, 'C-ARGM-ADV': 34, 'C-ARGM-COM': 35, 'C-ARGM-CXN': 36, 'C-ARGM-DIR': 37, 'C-ARGM-EXT': 38, 'C-ARGM-GOL': 39, 'C-ARGM-LOC': 40, 'C-ARGM-MNR': 41, 'C-ARGM-PRP': 42, 'C-ARGM-PRR': 43, 'C-ARGM-TMP': 44, 'R-ARG0': 45, 'R-ARG1': 46, 'R-ARG2': 47, 'R-ARG3': 48, 'R-ARG4': 49, 'R-ARGM-ADJ': 50, 'R-ARGM-ADV': 51, 'R-ARGM-CAU': 52, 'R-ARGM-COM': 53, 'R-ARGM-DIR': 54, 'R-ARGM-GOL': 55, 'R-ARGM-LOC': 56, 'R-ARGM-MNR': 57, 'R-ARGM-TMP': 58, None: None}
    inference_data = map_labels_in_dataframe(inference_input, label_map)
    tokenized_input = tokenize_and_align_labels(tokenizer, inference_data, label_all_tokens=True)
    dataset_inference_sample = load_dataset(tokenized_input)
    label_list = get_labels_from_map(label_map)
    
    predictions, labels, _ = trainer.predict(dataset_inference_sample)
    argmax_predictions = np.argmax(predictions, axis=2)
    
    # Extract predicted labels for each token, filtering out special tokens
    predicted_labels = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(argmax_predictions, labels)
    ]
    
    return predicted_labels[0]

tokenizer = AutoTokenizer.from_pretrained("learned-models/tokenizer.save_pretrained.distillbert-base-uncased-finetuned-srl")
bert_model = AutoModelForTokenClassification.from_pretrained("learned-models/model.save_pretrained.distillbert-base-uncased-finetuned-srl")
training_args = TrainingArguments(output_dir="learned-models/trainer.save_model.distillbert-base-uncased-finetuned-srl")

trainer = Trainer(
    model=bert_model,
    args=training_args,
    tokenizer=tokenizer
)
/var/folders/d9/p0hwqj9x1sx30sdq622dyn1r0000gn/T/ipykernel_1437/3635224896.py:99: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
  trainer = Trainer(

2. Perform CheckList evaluation

Load the testing-dataset.json from the data directory as a Python dictionary.

with open("data/testing-dataset.json", "r", encoding="utf-8") as file:
    checklist_dataset = json.load(file)

Iterate over the capabilities, tests and samples from the CheckList challenge dataset. The dataset contains MFT and INV tests. Each is evaluated in its own way:

  • INV: For each model, the predicted labels for two items in a test are compared. If they’re the same, the test passes.
  • MFT: For each model, the predicted label is compared to a gold label. If they’re the same, the test passes.

Write the results to checklist-results.json and print().

def get_argument_index(argument_labels):
    """
    Get the argument label's index.

    Args:

    argument_labels:
        Array of arguments strings representing the gold labels.
    """
    relevant_indices = [i for i, label in enumerate(argument_labels) if label != "_"]
    assert len(relevant_indices) == 1, f"Expected exactly one relevant label, found: {relevant_indices}"
    return relevant_indices[0]

# Initialize a dictionary to collect all results
aggregated_results = {"capabilities": []}

# Iterate over capabilities in the dataset
for capability in checklist_dataset:
    cap_name = capability['capability_name']
    cap_description = capability['capability_description']
    cap_category = capability['capability_category']
    print(f"Capability: {cap_name} ({cap_category})")
    
    cap_result = {
        "capability_name": cap_name,
        "capability_description": cap_description,
        "capability_category": cap_category,
        "tests": []
    }
    
    # Initialize variables to track capability performance
    cap_total_gold = 0
    cap_failures_logreg = 0
    cap_failures_bert = 0
    
    # Iterate over tests for current capability
    for test in capability['tests']:
        test_type = test['test_type']
        test_result = {
            "test_name": test['test_name'],
            "test_description": test['test_description'],
            "test_type": test_type,
            "samples": []
        }

        print(f"\n  * Test: {test_result['test_name']} ({test_result['test_type']})")
        
        # Initialize variables to track test performance
        test_total_gold = 0
        test_failures_logreg = 0
        test_failures_bert = 0
        
        # Iterate over the samples in this test
        for sample in test['samples']:
            sample_results = [] 
            
            if test_type == "MFT":
                for item in sample:
                    tokens = item['tokens']
                    argument_labels = item['argument_labels']
                    predicate_sense = item['predicate_name']
                    num_classes = len(tokens)
                    
                    # Convert predicate position to one-hot vector
                    predicate_position = torch.tensor([item['predicate_position']])
                    predicate_position = torch.nn.functional.one_hot(predicate_position, num_classes).squeeze(0).tolist()
                    
                    # Get inferences from models
                    logreg_inference = classify_sentence_logreg(
                        tokens, predicate_position, argument_labels, predicate_sense, model, vectorizer
                    )
                    bert_inference = classify_sentence_bert(
                        tokens, predicate_position, argument_labels, predicate_sense, trainer, tokenizer
                    )
                    
                    # Compare predicted labels to gold label
                    for gold_label, logreg_label, bert_label in zip(argument_labels, logreg_inference, bert_inference):
                        if gold_label != "_":
                            test_total_gold += 1
                            cap_total_gold += 1
                            
                            logreg_correct = (gold_label == logreg_label)
                            bert_correct = (gold_label == bert_label)
                            
                            if not logreg_correct:
                                test_failures_logreg += 1
                                cap_failures_logreg += 1
                            if not bert_correct:
                                test_failures_bert += 1
                                cap_failures_bert += 1
                            
                            sample_results.append({
                                "gold_label": gold_label,
                                "logreg_label": logreg_label,
                                "bert_label": bert_label,
                                "logreg_correct": logreg_correct,
                                "bert_correct": bert_correct
                            })
                            
            elif test_type == "INV":
                # INV test samples have two items (a 'before' and 'after')
                if len(sample) != 2:
                    print("Warning: INV test sample does not contain exactly 2 items.")
                    continue

                item1, item2 = sample[0], sample[1]
                
                # Process item 1
                tokens1 = item1['tokens']
                num_classes1 = len(tokens1)
                argument_labels_1 = item1['argument_labels']
                predicate_sense_1 = item1['predicate_name']
                
                # Find the index of the argument gold label
                argument_index_1 = get_argument_index(argument_labels_1)
                
                # Convert predicate position to one-hot vector                
                predicate_position_tensor1 = torch.tensor([item1['predicate_position']])
                predicate_position1 = torch.nn.functional.one_hot(predicate_position_tensor1, num_classes1).squeeze(0).tolist()
                
                # Infer both SRL models for item 1
                logreg_inference_1 = classify_sentence_logreg(
                    tokens1, predicate_position1, argument_labels_1, predicate_sense_1, model, vectorizer
                )
                bert_inference_1 = classify_sentence_bert(
                    tokens1, predicate_position1, argument_labels_1, predicate_sense_1, trainer, tokenizer
                )
                
                # Process item 2
                tokens2 = item2['tokens']
                num_classes2 = len(tokens2)
                argument_labels_2 = item2['argument_labels']
                predicate_sense_2 = item2['predicate_name']
                argument_index_2 = get_argument_index(argument_labels_2)
                
                # Convert predicate position to one-hot vector                
                predicate_position_tensor2 = torch.tensor([item2['predicate_position']])
                predicate_position2 = torch.nn.functional.one_hot(predicate_position_tensor2, num_classes2).squeeze(0).tolist()
                
                # Infer both SRL models for item 2
                logreg_inference_2 = classify_sentence_logreg(
                    tokens2, predicate_position2, argument_labels_2, predicate_sense_2, model, vectorizer
                )
                bert_inference_2 = classify_sentence_bert(
                    tokens2, predicate_position2, argument_labels_2, predicate_sense_2, trainer, tokenizer
                )
                
                arg_logreg_1 = logreg_inference_1[argument_index_1]
                arg_bert_1 = bert_inference_1[argument_index_1]
                arg_logreg_2 = logreg_inference_2[argument_index_2]
                arg_bert_2 = bert_inference_2[argument_index_2]
                
                # Increment counters
                test_total_gold += 1
                cap_total_gold += 1
                
                # Test passes if the predicted labels are the same for both items
                logreg_correct = (arg_logreg_1 == arg_logreg_2)
                bert_correct = (arg_bert_1 == arg_bert_2)
                
                if not logreg_correct:
                    test_failures_logreg += 1
                    cap_failures_logreg += 1
                if not bert_correct:
                    test_failures_bert += 1
                    cap_failures_bert += 1
                
                sample_results.append({
                    "relevant_index_item1": argument_index_1,
                    "relevant_index_item2": argument_index_2,
                    "logreg_prediction_item1": pred_logreg_1,
                    "logreg_prediction_item2": pred_logreg_2,
                    "bert_prediction_item1": pred_bert_1,
                    "bert_prediction_item2": pred_bert_2,
                    "logreg_correct": logreg_correct,
                    "bert_correct": bert_correct
                })
            
            # Save the detailed results for this sample
            test_result["samples"].append(sample_results)
        
        # Calculate failure rates for this test (as percentages)
        test_result["failure_rate_logreg"] = (
            100.0 * test_failures_logreg / test_total_gold if test_total_gold else None
        )
        test_result["failure_rate_bert"] = (
            100.0 * test_failures_bert / test_total_gold if test_total_gold else None
        )

        print(f"""
      * Failure rates:
            Logistic regression: {test_result["failure_rate_logreg"]}%
            DistillBERT: {test_result["failure_rate_bert"]}%""")
        
        # Add the test result to the capability's results
        cap_result["tests"].append(test_result)
    
    # Calculate overall failure rates for the capability
    cap_result["failure_rate_logreg"] = (
        100.0 * cap_failures_logreg / cap_total_gold if cap_total_gold else None
    )
    cap_result["failure_rate_bert"] = (
        100.0 * cap_failures_bert / cap_total_gold if cap_total_gold else None
    )

    print(f"""
  > Failure rates (total for capability):
        Logistic regression: {cap_result["failure_rate_logreg"]}%
        DistillBERT: {cap_result["failure_rate_bert"]}%

=====================================================================================
""")
    
    # Add the capability result to the aggregated results
    aggregated_results["capabilities"].append(cap_result)

# Write the CheckList test results to a JSON file
with open("checklist-results.json", "w", encoding="utf-8") as outfile:
    json.dump(aggregated_results, outfile, indent=2)
Capability: Long-distance dependencies between predicate and ARG0 (syntactic)

  * Test: Effect of injecting relative clause between predicate and ARG0 (INV)



















































































      * Failure rates:
            Logistic regression: 20.0%
            DistillBERT: 30.0%

  * Test: Sentences without relative clause between predicate and ARG0 (MFT)











































      * Failure rates:
            Logistic regression: 70.0%
            DistillBERT: 20.0%

  * Test: Sentences with relative clause between predicate and ARG0 (MFT)











































      * Failure rates:
            Logistic regression: 90.0%
            DistillBERT: 50.0%

  > Failure rates (total for capability):
        Logistic regression: 60.0%
        DistillBERT: 33.333333333333336%

=====================================================================================

Capability: Long-distance dependencies between predicate and ARG1 (syntactic)

  * Test: Effect of injecting adverbial or participial phrase between predicate and ARG1 (INV)



















































































      * Failure rates:
            Logistic regression: 60.0%
            DistillBERT: 60.0%

  * Test: Sentence without adverbial or participial phrase between predicate and ARG1 (MFT)











































      * Failure rates:
            Logistic regression: 10.0%
            DistillBERT: 0.0%

  * Test: Sentence with adverbial or participial phrase between predicate and ARG1 (MFT)











































      * Failure rates:
            Logistic regression: 70.0%
            DistillBERT: 60.0%

  > Failure rates (total for capability):
        Logistic regression: 46.666666666666664%
        DistillBERT: 40.0%

=====================================================================================

Capability: Robustness to noise in the form of typos in ARG0 on ARG0 labeling (lexical)

  * Test: Effect of typos in proper nouns as ARG0 on ARG0 labeling (INV)



















































































      * Failure rates:
            Logistic regression: 30.0%
            DistillBERT: 20.0%

  * Test: Effect of typos in ARG0 as common noun (INV)



















































































      * Failure rates:
            Logistic regression: 20.0%
            DistillBERT: 0.0%

  * Test: Sentences without typos in ARG0 as proper noun (MFT)











































      * Failure rates:
            Logistic regression: 60.0%
            DistillBERT: 0.0%

  * Test: Sentences without typos in ARG0 as proper noun (MFT)











































      * Failure rates:
            Logistic regression: 90.0%
            DistillBERT: 20.0%

  * Test: Sentences without typos in ARG0 as common noun (MFT)











































      * Failure rates:
            Logistic regression: 60.0%
            DistillBERT: 0.0%

  * Test: Sentences with typos in ARG0 as common noun (MFT)











































      * Failure rates:
            Logistic regression: 80.0%
            DistillBERT: 0.0%

  > Failure rates (total for capability):
        Logistic regression: 56.666666666666664%
        DistillBERT: 6.666666666666667%

=====================================================================================

Capability: Effect of semantic atypicality in active voice syntactically simple SVO sentences on ARG0 labeling (lexical)

  * Test: Animate objects as ARG0 (MFT)











































      * Failure rates:
            Logistic regression: 90.0%
            DistillBERT: 0.0%

  * Test: Inanimate objects as ARG0 (MFT)











































      * Failure rates:
            Logistic regression: 100.0%
            DistillBERT: 40.0%

  * Test: Effect of animate versus inanimate concepts as ARG0 (INV)



















































































      * Failure rates:
            Logistic regression: 10.0%
            DistillBERT: 40.0%

  * Test: Non-abstract concepts as ARG0 (MFT)











































      * Failure rates:
            Logistic regression: 100.0%
            DistillBERT: 70.0%

  * Test: Abstract concepts as ARG0 (MFT)











































      * Failure rates:
            Logistic regression: 100.0%
            DistillBERT: 90.0%

  * Test: Effect of abstract versus non-abstract concepts as ARG0 (INV)



















































































      * Failure rates:
            Logistic regression: 0.0%
            DistillBERT: 30.0%

  > Failure rates (total for capability):
        Logistic regression: 66.66666666666667%
        DistillBERT: 45.0%

=====================================================================================

Capability: Dealing with dative verb alternations (syntactic)

  * Test: Effect of dative verb alternations on ARG1 (INV)



















































































      * Failure rates:
            Logistic regression: 0.0%
            DistillBERT: 0.0%

  * Test: Dative verb alternations with prepositional dative construction for ARG1 (MFT)











































      * Failure rates:
            Logistic regression: 80.0%
            DistillBERT: 30.0%

  * Test: Dative verb alternations with double object construction for ARG1 (MFT)











































      * Failure rates:
            Logistic regression: 80.0%
            DistillBERT: 30.0%

  > Failure rates (total for capability):
        Logistic regression: 53.333333333333336%
        DistillBERT: 20.0%

=====================================================================================