Preprocessing First
from datasets import load_dataset, Dataset, DatasetDict
import torch
import numpy as np
from transformers import AutoModelForMaskedLM, BertTokenizer, DataCollatorForLanguageModeling
from tokenizers import BertWordPieceTokenizer
import os
Preprocessing First
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
print(dataset)
DatasetDict({
test: Dataset({
features: ['text'],
num_rows: 4358
})
train: Dataset({
features: ['text'],
num_rows: 36718
})
validation: Dataset({
features: ['text'],
num_rows: 3760
})
})
tokenizer = BertWordPieceTokenizer(
clean_text=True,
handle_chinese_chars=False,
strip_accents=True,
lowercase=True
)
train_lines = [line for line in dataset["train"]["text"] if len(line.strip()) > 0]
with open("wikitext_train.txt", "w", encoding="utf-8") as f:
for line in train_lines:
f.write(line + "\n")
tokenizer.train(
files=["wikitext_train.txt"],
vocab_size=16384,
min_frequency=2,
special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
)
output_dir = "custom-ltg-tokenizer"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
tokenizer.save_model(output_dir)
['custom-ltg-tokenizer/vocab.txt']
tokenizer = BertTokenizer.from_pretrained("custom-ltg-tokenizer")
print(tokenizer.vocab_size)
16384
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512
)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
Map: 0%| | 0/4358 [00:00<?, ? examples/s]
Map: 0%| | 0/36718 [00:00<?, ? examples/s]
Map: 0%| | 0/3760 [00:00<?, ? examples/s]
print(tokenized_dataset)
DatasetDict({
test: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 4358
})
train: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 36718
})
validation: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 3760
})
})
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=0.15
)
from transformers import BertConfig, BertForMaskedLM
config = BertConfig(
attention_probs_dropout_prob=0.1,
hidden_dropout_prob=0.1,
hidden_size=192,
intermediate_size=512,
max_position_embeddings=512,
position_bucket_size=32,
num_attention_heads=3,
num_hidden_layers=12,
vocab_size=tokenizer.vocab_size,
layer_norm_eps=1e-7,
pad_token_id=tokenizer.pad_token_id
)
model = BertForMaskedLM(config)
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="ltgbert-wikitext2-checkpoints",
overwrite_output_dir=True,
num_train_epochs=10,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
learning_rate=5e-5,
warmup_steps=1000,
weight_decay=0.01,
save_strategy="epoch",
# save_steps=5000, #idk if needed
save_total_limit=10, #idk if needed
# prediction_loss_only=True, #idk if needed
# logging_steps=100, #idk if needed
fp16=True, #helped run time
report_to="none" #hugging face id needed this disabled it
)
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=data_collator
)
trainer.train()
<div>
<progress value='11480' max='11480' style='width:300px; height:20px; vertical-align: middle;'></progress>
[11480/11480 43:15, Epoch 10/10]
</div>
<table border="1" class="dataframe">
Step
Training Loss
500
9.227700
1000
7.757700
1500
7.108000
2000
6.992500
2500
6.922600
3000
6.883600
3500
6.847000
4000
6.817000
4500
6.794800
5000
6.759200
5500
6.753700
6000
6.728600
6500
6.716300
7000
6.700100
7500
6.676900
8000
6.664200
8500
6.655600
9000
6.651900
9500
6.639900
10000
6.624600
10500
6.644500
11000
6.642200
TrainOutput(global_step=11480, training_loss=6.9061379249918335, metrics={'train_runtime': 2597.3206, 'train_samples_per_second': 141.369, 'train_steps_per_second': 4.42, 'total_flos': 4748620572917760.0, 'train_loss': 6.9061379249918335, 'epoch': 10.0})
trainer.save_model("XS-ltgbert-wikitext2")
Extracting the static embeddings + specific token embedding
static_embs = model.get_input_embeddings().weight.detach().cpu().numpy()
print(static_embs.shape)
(16384, 192)
tokens = tokenizer.tokenize("clouds")
print(tokens)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)
vectors = [static_embs[i] for i in token_ids]
for v in vectors:
print(v.shape)
['clouds']
[12228]
(192,)
Extracting epoch by epoch
model_epoch1 = BertForMaskedLM.from_pretrained("ltgbert-wikitext2-checkpoints/checkpoint-1148")
model_epoch2 = BertForMaskedLM.from_pretrained("ltgbert-wikitext2-checkpoints/checkpoint-2296")
model_epoch3 = BertForMaskedLM.from_pretrained("ltgbert-wikitext2-checkpoints/checkpoint-3444")
Now to try probing on each epoch model!!