In this post, we'll demonstrate how to save the best model during model training.

Saving the best model is a good technique when we're not sure about the optimal number of epochs we should use for training.

We will use the same pipeline in this post to fine-tune a BERT model on a text classification task.

If you're familiar with the training progress, you can just read the subsections:


Load data

We'll use the emotion dataset from the Hugging Face Hub.

The emotion dataset consists of three sets: train, validation, and test set, and has six kinds of emotion: sadness, joy, love, anger, fear, and surprise.

emotion
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})
label_names = emotion["train"].features['label'].names
label_names
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

Let's take a look at what the text is like:

emotion.set_format(type="pandas")
train_df = emotion['train'][:]
valid_df = emotion['validation'][:]
test_df = emotion['test'][:]
train_df.head()
text label
0 i didnt feel humiliated 0
1 i can go from feeling so hopeless to so damned... 0
2 im grabbing a minute to post i feel greedy wrong 3
3 i am ever feeling nostalgic about the fireplac... 2
4 i am feeling grouchy 3

In this post, we'll just use 350 samples from each class for training, and 70 samples for validation and 50 for testing.

train_df = train_df.groupby('label').apply(lambda x: x.sample(350)).reset_index(drop=True)
valid_df = valid_df.groupby('label').apply(lambda x: x.sample(70)).reset_index(drop=True)
test_df = test_df.groupby('label').apply(lambda x: x.sample(50)).reset_index(drop=True)
train_df['label'].value_counts()
0    350
1    350
2    350
3    350
4    350
5    350
Name: label, dtype: int64
valid_df['label'].value_counts()
0    70
1    70
2    70
3    70
4    70
5    70
Name: label, dtype: int64
test_df['label'].value_counts()
0    50
1    50
2    50
3    50
4    50
5    50
Name: label, dtype: int64


Tokenization

Tokenization is a process for spliting raw texts into tokens, and encoding the tokens into numeric data.

To do this, we first initialize a BertTokenizer:

from transformers import BertTokenizer
PRETRAINED_LM = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True)
tokenizer
PreTrainedTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

define a function for encoding:

def encode(docs):
    '''
    This function takes list of texts and returns input_ids and attention_mask of texts
    '''
    encoded_dict = tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=128, padding='max_length',
                            return_attention_mask=True, truncation=True, return_tensors='pt')
    input_ids = encoded_dict['input_ids']
    attention_masks = encoded_dict['attention_mask']
    return input_ids, attention_masks

Use the ecode function to get input ids and attention masks of the datasets:

train_input_ids, train_att_masks = encode(train_df['text'].values.tolist())
valid_input_ids, valid_att_masks = encode(valid_df['text'].values.tolist())
test_input_ids, test_att_masks = encode(test_df['text'].values.tolist())


Creating Datasets and DataLoaders

We'll use pytorch Dataset and DataLoader to split data into batches. For more detatils, you can check out another post on DataLoader.

Turn the labels into tensors:

import torch
train_y = torch.LongTensor(train_df['label'].values.tolist())
valid_y = torch.LongTensor(valid_df['label'].values.tolist())
test_y = torch.LongTensor(test_df['label'].values.tolist())
train_y.size(),valid_y.size(),test_y.size()
(torch.Size([2100]), torch.Size([420]), torch.Size([300]))

Create dataloaders for training

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

BATCH_SIZE = 16
train_dataset = TensorDataset(train_input_ids, train_att_masks, train_y)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)

valid_dataset = TensorDataset(valid_input_ids, valid_att_masks, valid_y)
valid_sampler = SequentialSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=BATCH_SIZE)

test_dataset = TensorDataset(test_input_ids, test_att_masks, test_y)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=BATCH_SIZE)


Bert For Sequence Classification Model

We will initiate the BertForSequenceClassification model from Huggingface, which allows easily fine-tuning the pretrained BERT mode for classification task.

You will see a warning that some parts of the model are randomly initialized. This is normal since the classification head has not yet been trained.

from transformers import BertForSequenceClassification
N_labels = len(train_df.label.unique())
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM,
                                                      num_labels=N_labels,
                                                      output_attentions=False,
                                                      output_hidden_states=False)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
model = model.cuda()


Fine-tuning

Optimizer and Scheduler

An optimizer is for tuning parameters in the model.

The learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model’s parameters. We can access them with model.parameters().

Hence, we initialize an AdamW optimizer with the model parameters and a learning rate using the following code:

from torch.optim import AdamW

LEARNING_RATE = 2e-6
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

Selection of the learning rate is important. In practice, it's common to use a scheduler to decrease the learning rate during training.

from transformers import get_linear_schedule_with_warmup

EPOCHS = 30
scheduler = get_linear_schedule_with_warmup(optimizer, 
             num_warmup_steps=0,
            num_training_steps=len(train_dataloader)*EPOCHS )


Define a training and validation loop

The training loop is where the magic of deep learning happens. The model will be fine-tuned on the emotion dataset for classification task.

from tqdm.notebook import tqdm
from torch.nn.utils import clip_grad_norm_
def train(model, train_dataloader):
    model.train()
    train_loss = 0
    for step_num, batch_data in enumerate(tqdm(train_dataloader,desc='Training')):
          # get the inputs
          input_ids, att_mask, labels = [data.to(device) for data in batch_data]
          
          # zero the parameter gradients
          model.zero_grad()

          # forward + backward + optimize
          output = model(input_ids = input_ids, attention_mask=att_mask, labels= labels)
          loss = output.loss
          loss.backward()
          clip_grad_norm_(parameters=model.parameters(), max_norm=1.0)
          optimizer.step()
          scheduler.step()


          train_loss += loss.item()
         
    return train_loss / len(train_dataloader)

Define a validation loop to see how the model performs:

from tqdm.notebook import tqdm
import numpy as np

def evaluate(model, dataloader, desc= 'Validation'):
    model.eval()
    valid_loss = 0
    valid_pred = []
    with torch.no_grad():
        for step_num_e, batch_data in enumerate(tqdm(dataloader,desc= desc)):
            input_ids, att_mask, labels = [data.to(device) for data in batch_data]
            
            output = model(input_ids = input_ids, attention_mask=att_mask, labels= labels)
            loss = output.loss
            valid_loss += loss.item()
            valid_pred.append(np.argmax(output.logits.cpu().detach().numpy(),axis=-1))

    valid_pred = np.concatenate(valid_pred)

    return valid_loss/ len(dataloader), valid_pred


Train and save the best model

The best model is the one with the optimal set of parameters that yields the least validation loss.

import copy

train_loss_per_epoch = []
val_loss_per_epoch = []

best_val_loss = float('inf')
best_model = None

for epoch_num in range(EPOCHS):
    print('Epoch: ', epoch_num + 1)

    # Training
    train_loss = train(model, train_dataloader)
    train_loss_per_epoch.append(train_loss)              

    # Validation
    valid_loss, valid_pred = evaluate(model, valid_dataloader)
    val_loss_per_epoch.append(valid_loss)

    # Loss message
    print(f"train loss: {train_loss} |  val loss: {valid_loss}" )

    # save best model
    if valid_loss < best_val_loss:
        best_val_loss = valid_loss
        best_model = copy.deepcopy(model)

Epoch:  1
train loss: 1.80937030098655 |  val loss: 1.7834540826302987
Epoch:  2
train loss: 1.7831896841526031 |  val loss: 1.770483926490501
Epoch:  3
train loss: 1.7660167388843768 |  val loss: 1.7524053388171725
Epoch:  4
train loss: 1.7304552576758645 |  val loss: 1.7150162590874567
Epoch:  5
train loss: 1.6872752543651697 |  val loss: 1.6650207219300446
Epoch:  6
train loss: 1.6259240074591204 |  val loss: 1.6075689174510814
Epoch:  7
train loss: 1.5426361163457234 |  val loss: 1.5362377255051225
Epoch:  8
train loss: 1.4585662565448068 |  val loss: 1.4692661011660542
Epoch:  9
train loss: 1.3859010347814271 |  val loss: 1.4036017038204052
Epoch:  10
train loss: 1.2940114294037675 |  val loss: 1.3318283491664462
Epoch:  11
train loss: 1.2193167494101957 |  val loss: 1.2636116919694123
Epoch:  12
train loss: 1.1350717223954923 |  val loss: 1.1939950143849407
Epoch:  13
train loss: 1.052193590185859 |  val loss: 1.1205124810889915
Epoch:  14
train loss: 0.9720014762697797 |  val loss: 1.0528035053500422
Epoch:  15
train loss: 0.8954907055153991 |  val loss: 0.9932596705577992
Epoch:  16
train loss: 0.8330346842606863 |  val loss: 0.9356682808310898
Epoch:  17
train loss: 0.7613115613207673 |  val loss: 0.8837055652229874
Epoch:  18
train loss: 0.7175321741537615 |  val loss: 0.8406207395924462
Epoch:  19
train loss: 0.656489051426902 |  val loss: 0.8039783747107895
Epoch:  20
train loss: 0.6229531875613964 |  val loss: 0.7734741369883219
Epoch:  21
train loss: 0.596022024073384 |  val loss: 0.7514391420064149
Epoch:  22
train loss: 0.5664069943807342 |  val loss: 0.7316006300625978
Epoch:  23
train loss: 0.5331556061000535 |  val loss: 0.7149963958395852
Epoch:  24
train loss: 0.5119351231013284 |  val loss: 0.701943866080708
Epoch:  25
train loss: 0.49288606417901587 |  val loss: 0.69107067419423
Epoch:  26
train loss: 0.48085053265094757 |  val loss: 0.6813045464180134
Epoch:  27
train loss: 0.47528503067565686 |  val loss: 0.674878223074807
Epoch:  28
train loss: 0.4632551791993054 |  val loss: 0.6706337260979193
Epoch:  29
train loss: 0.4563892047965165 |  val loss: 0.6679090079334047
Epoch:  30
train loss: 0.45349600132216106 |  val loss: 0.6671353473707482

The benefit of saving the best model is not evident in this post since, with 30 epochs, the validation loss is still steadily decreasing in every epoch.

But, if we train the model for 100 epochs, we may see the benefit.


Save the best model to file

We use the following code to save the best model to file:

torch.save(best_model.state_dict(), 'best_model.pt')

You can see that we're saving a model's state_dict.

A state_dict is a python dictionary object that maps each layer to its parameter tensor.

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model's state_dict:
bert.embeddings.position_ids 	 torch.Size([1, 512])
bert.embeddings.word_embeddings.weight 	 torch.Size([30522, 768])
bert.embeddings.position_embeddings.weight 	 torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight 	 torch.Size([2, 768])
bert.embeddings.LayerNorm.weight 	 torch.Size([768])
bert.embeddings.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.0.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.0.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.0.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.0.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.0.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.1.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.1.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.1.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.1.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.1.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.1.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.1.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.1.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.1.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.1.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.1.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.1.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.2.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.2.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.2.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.2.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.2.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.2.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.2.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.2.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.2.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.2.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.2.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.2.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.2.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.2.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.2.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.2.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.3.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.3.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.3.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.3.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.3.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.3.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.3.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.3.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.3.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.3.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.3.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.3.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.3.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.3.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.3.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.3.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.4.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.4.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.4.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.4.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.4.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.4.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.4.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.4.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.4.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.4.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.4.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.4.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.4.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.4.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.4.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.4.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.5.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.5.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.5.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.5.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.5.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.5.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.5.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.5.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.5.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.5.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.5.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.5.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.5.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.5.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.5.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.5.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.6.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.6.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.6.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.6.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.6.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.6.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.6.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.6.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.6.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.6.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.6.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.6.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.6.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.6.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.6.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.6.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.7.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.7.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.7.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.7.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.7.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.7.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.7.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.7.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.7.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.7.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.7.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.7.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.7.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.7.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.7.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.7.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.8.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.8.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.8.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.8.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.8.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.8.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.8.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.8.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.8.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.8.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.8.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.8.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.8.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.8.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.8.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.8.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.9.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.9.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.9.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.9.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.9.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.9.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.9.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.9.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.9.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.9.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.9.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.9.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.9.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.9.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.9.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.9.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.10.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.10.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.10.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.10.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.10.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.10.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.10.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.10.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.10.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.10.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.10.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.10.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.10.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.10.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.10.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.10.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.11.attention.self.query.weight 	 torch.Size([768, 768])
bert.encoder.layer.11.attention.self.query.bias 	 torch.Size([768])
bert.encoder.layer.11.attention.self.key.weight 	 torch.Size([768, 768])
bert.encoder.layer.11.attention.self.key.bias 	 torch.Size([768])
bert.encoder.layer.11.attention.self.value.weight 	 torch.Size([768, 768])
bert.encoder.layer.11.attention.self.value.bias 	 torch.Size([768])
bert.encoder.layer.11.attention.output.dense.weight 	 torch.Size([768, 768])
bert.encoder.layer.11.attention.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.11.attention.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.11.attention.output.LayerNorm.bias 	 torch.Size([768])
bert.encoder.layer.11.intermediate.dense.weight 	 torch.Size([3072, 768])
bert.encoder.layer.11.intermediate.dense.bias 	 torch.Size([3072])
bert.encoder.layer.11.output.dense.weight 	 torch.Size([768, 3072])
bert.encoder.layer.11.output.dense.bias 	 torch.Size([768])
bert.encoder.layer.11.output.LayerNorm.weight 	 torch.Size([768])
bert.encoder.layer.11.output.LayerNorm.bias 	 torch.Size([768])
bert.pooler.dense.weight 	 torch.Size([768, 768])
bert.pooler.dense.bias 	 torch.Size([768])
classifier.weight 	 torch.Size([6, 768])
classifier.bias 	 torch.Size([6])


Plot training and validation losses

from matplotlib import pyplot as plt
epochs = range(1, EPOCHS +1 )
fig, ax = plt.subplots()
ax.plot(epochs,train_loss_per_epoch,label ='training loss')
ax.plot(epochs, val_loss_per_epoch, label = 'validation loss' )
ax.set_title('Training and Validation loss')
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.legend()
plt.show()

Performance Metrics

It's common to use precision, recall, and F1-score as the performance metrics.

from sklearn.metrics import classification_report
print('classifiation report')
print(classification_report(valid_pred, valid_df['label'].to_numpy(), target_names=label_names))
classifiation report
              precision    recall  f1-score   support

     sadness       0.77      0.79      0.78        68
         joy       0.64      0.71      0.68        63
        love       0.79      0.81      0.80        68
       anger       0.80      0.78      0.79        72
        fear       0.81      0.79      0.80        72
    surprise       0.93      0.84      0.88        77

    accuracy                           0.79       420
   macro avg       0.79      0.79      0.79       420
weighted avg       0.80      0.79      0.79       420

Error Analysis

With the predictions, we can plot the confusion matrix:

from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
def plot_confusion_matrix(y_preds, y_true, labels=None):
  cm = confusion_matrix(y_true, y_preds, normalize="true")
  fig, ax = plt.subplots(figsize=(6, 6))
  disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels) 
  disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False) 
  plt.title("Normalized confusion matrix")
  plt.show()
plot_confusion_matrix(valid_pred,valid_df['label'].to_numpy(),labels=label_names)

You can see that sadness has a higher likelihood to be classified as anger or fear, leading to a lower f1 score.


Inference

Now let's use the best model to predict the testing set.

Load the best model

First, we have to initiate the model again.

from transformers import BertForSequenceClassification
N_labels = len(train_df.label.unique())
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM,
                                                      num_labels=N_labels,
                                                      output_attentions=False,
                                                      output_hidden_states=False)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

And then, we use:

  • torch.load() to load the best_model.pt file
  • load_state_dict() to save the learned parameters to the newly initiated model
model.load_state_dict(torch.load('best_model.pt'))
<All keys matched successfully>

Make prediction

Evaluate the model on testing data:

test_loss, test_pred = evaluate(best_model, test_dataloader, desc= 'Testing' )

Output the classification report:

print('classifiation report')
print(classification_report(test_pred, test_df['label'].to_numpy(),target_names=label_names))
classifiation report
              precision    recall  f1-score   support

     sadness       0.88      0.75      0.81        59
         joy       0.86      0.80      0.83        54
        love       0.72      0.82      0.77        44
       anger       0.78      0.95      0.86        41
        fear       0.74      0.79      0.76        47
    surprise       0.94      0.85      0.90        55

    accuracy                           0.82       300
   macro avg       0.82      0.83      0.82       300
weighted avg       0.83      0.82      0.82       300

With the predictions, we can plot the confusion matrix again:

plot_confusion_matrix(test_pred,test_df['label'].to_numpy(),labels=label_names)

Output the misclassified text:

test_df['pred'] = test_pred
test_df.reset_index(level=0)
print(test_df[test_df['label']!=test_df['pred']].shape)
test_df[test_df['label']!=test_df['pred']][['text','label','pred']].head(10)
(54, 3)
text label pred
5 i feel like an ugly monster where i cannot sho... 0 4
12 i am not surprised cause its like ok when you ... 0 4
24 i feel defeated but others i feel refreshed 0 1
38 i wasnt very interested in it but it evoked th... 0 4
44 i feel anger i feel sad i feel joy and i feel ... 0 2
45 i don t know how i feel about my submissive le... 0 2
63 im feeling all jolly and warm inside but i jus... 1 0
70 i feel very cheated since i am supporting the ... 1 0
75 i don t know about you but it makes me feel ge... 1 2
78 i suppose if one was feeling generous one coul... 1 0