Transfer Learning with Pytorch

In this post, we'll explore how to perform transfer learning using Pytorch.

We will use a subset of Food-11k that contains 11 different kinds of food categories. We will go over the dataset preparation, data augmentation and then steps to build the classifier. We use transfer learning to use the low level image features like edges, textures etc. learnt by a pretrained model, ResNet50, and then train our classifier to learn the higher level details in our dataset images. ResNet50 has already been trained on ImageNet with millions of images.

The original Food-11k dataset constains about 11k images of 11 categories of foods. Training on the whole dataset will take hours. Hence, we are going to use a subset of this dataset. The 10 food categories include Bread, Dairy product, Dessert, Egg, Fried food, Meat, Noodles, Rice, Seafood, Soup, Vegetable. I've prepared the sub-dataset into train, valid, test set. In the train set, there are 10 folders for the 11 kinds of food, and each folder contains 100 images for a particular kind of food. valid and test set follow the same structure, but with 20 and 40 images per category respectively.

So finally, we have 1100 training images, 220 validation images, and 440 test images in 10 classes of foods.

In [0]:
# import libraries
import torch
from torchvision import models, datasets
from torchvision import transforms
from torch import nn, optim
from torch.utils.data.dataloader import DataLoader

import time
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image

Data Augmentations

The images in the available training set can be modified in a number of ways to incorporate more variations in the training process, so that the trained model gets more generalized and performs well on different kinds of test data. Also the input data can come in a variety of sizes. They need to be normalized to a fixed size and format before batches of data are used together for training.

Let us go over the transformations we used for our data augmentation.

The transform RandomResizedCrop crops the input image by a random size(within a scale range of 0.8 to 1.0 of the original size and a random aspect ratio in the default range of 0.75 to 1.33 ). The crop is then resized to 256×256.

RandomRotation rotates the image by an angle randomly chosen between -15 to 15 degrees.

RandomHorizontalFlip randomly flips the image horizontally with a default probability of 50%.

CenterCrop crops an 224×224 image from the center.

ToTensor converts the PIL Image which has values in the range of 0-255 to a floating point Tensor and normalizes them to a range of 0-1, by dividing it by 255.

Normalize takes in a 3 channel Tensor and normalizes each channel by the input mean and standard deviation for the channel. Mean and standard deviation vectors are input as 3 element vectors. Each channel in the tensor is normalized as T = (T – mean)/(standard deviation)

In [0]:
# applying transforms to the data
image_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(size=256, scale=(0.8,1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
}

Note that for the validation and test data, we do not do the RandomResizedCrop, RandomRotation and RandomHorizontalFlip transformations. Because they are used for testing model performance.

In [53]:
# Load data
# Set train, valid, and test directory
train_directory = 'food-11k-sub/train'
valid_directory = 'food-11k-sub/valid'
test_directory = 'food-11k-sub/test'

# batch size
bs = 32

# number of epochs
epochs = 20

# number of classes
num_classes = 11

# device 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load data from directory
data = {
    'train': datasets.ImageFolder(root=train_directory,
                                  transform=image_transforms['train']),
    'valid': datasets.ImageFolder(root=valid_directory,
                                  transform=image_transforms['valid']),
    'test': datasets.ImageFolder(root=test_directory,
                                 transform=image_transforms['test'])

}

# Get a mapping of the indices to the class names, in order to see the output classes of the test images.
idx_to_class = {v: k for k, v in data['train'].class_to_idx.items()}
print(idx_to_class)

# size of data, to be used for calculating Averge Loss and Accuracy
train_data_size = len(data['train'])
valid_data_size = len(data['valid'])
test_data_size = len(data['test'])

# Create iterators for the Data loaded using DataLoader module
train_data = DataLoader(data['train'], batch_size=bs, shuffle=True)
valid_data = DataLoader(data['valid'], batch_size=bs, shuffle=True)
test_data = DataLoader(data['test'], batch_size=bs, shuffle=True)

train_data_size, valid_data_size, test_data_size
{0: 'Bread', 1: 'Dairy product', 2: 'Dessert', 3: 'Egg', 4: 'Fried food', 5: 'Meat', 6: 'Noodles', 7: 'Rice', 8: 'Seafood', 9: 'Soup', 10: 'Vegetable'}
Out[53]:
(1100, 220, 440)

The torchvision.transforms package and the DataLoader are very important PyTorch features that makes the data augmentation and loading process very easy.

Transfer Learning

We are going to use the Resnet50 as the base model. It is one of the best performant models in terms of model size, inference speed, and prediction accuracy.

First we load the pretrained Resnet50. Then we freeze the model parameters of the convolutional layers (as a feature extractor). Because we are doing transfer learning.

In [0]:
# load pretrained resnet50
resnet_50 = models.resnet50(pretrained=True)

# Freeze model parameters, coz we are fine-tuning
for param in resnet_50.parameters():
  param.requires_grad = False

Then we replace the final layer of the ResNet50 model by a small set of Sequential layers. The inputs to the last fully connected layer of ResNet50 is fed to a Linear layer (Dense) which has 256 outputs, which are then fed into ReLU and Dropout layers. It is then followed by a 256×11 Linear Layer which has 11 outputs corresponding to the 11 classes.

In [0]:
# change the final layer of Resnet50 Model for fine-tuning
fc_inputs = resnet_50.fc.in_features

resnet_50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4), 
    nn.Linear(256, 11),
    nn.LogSoftmax(dim=1) # for using NLLLoss()
)

# convert model to GPU
resnet_50 = resnet_50.to(device)

# define optimizer and loss function
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet_50.parameters())

Start training

Now the entire model has been set up, let's start training.

First, take a look at the summary of the model. Since we freezed the layers of the base model, Trainable parameters are only 527,371 for the added layers.

In [56]:
from torchsummary import summary
summary(resnet_50, input_size=(3,224,224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256, 56, 56]             512
             ReLU-15          [-1, 256, 56, 56]               0
       Bottleneck-16          [-1, 256, 56, 56]               0
           Conv2d-17           [-1, 64, 56, 56]          16,384
      BatchNorm2d-18           [-1, 64, 56, 56]             128
             ReLU-19           [-1, 64, 56, 56]               0
           Conv2d-20           [-1, 64, 56, 56]          36,864
      BatchNorm2d-21           [-1, 64, 56, 56]             128
             ReLU-22           [-1, 64, 56, 56]               0
           Conv2d-23          [-1, 256, 56, 56]          16,384
      BatchNorm2d-24          [-1, 256, 56, 56]             512
             ReLU-25          [-1, 256, 56, 56]               0
       Bottleneck-26          [-1, 256, 56, 56]               0
           Conv2d-27           [-1, 64, 56, 56]          16,384
      BatchNorm2d-28           [-1, 64, 56, 56]             128
             ReLU-29           [-1, 64, 56, 56]               0
           Conv2d-30           [-1, 64, 56, 56]          36,864
      BatchNorm2d-31           [-1, 64, 56, 56]             128
             ReLU-32           [-1, 64, 56, 56]               0
           Conv2d-33          [-1, 256, 56, 56]          16,384
      BatchNorm2d-34          [-1, 256, 56, 56]             512
             ReLU-35          [-1, 256, 56, 56]               0
       Bottleneck-36          [-1, 256, 56, 56]               0
           Conv2d-37          [-1, 128, 56, 56]          32,768
      BatchNorm2d-38          [-1, 128, 56, 56]             256
             ReLU-39          [-1, 128, 56, 56]               0
           Conv2d-40          [-1, 128, 28, 28]         147,456
      BatchNorm2d-41          [-1, 128, 28, 28]             256
             ReLU-42          [-1, 128, 28, 28]               0
           Conv2d-43          [-1, 512, 28, 28]          65,536
      BatchNorm2d-44          [-1, 512, 28, 28]           1,024
           Conv2d-45          [-1, 512, 28, 28]         131,072
      BatchNorm2d-46          [-1, 512, 28, 28]           1,024
             ReLU-47          [-1, 512, 28, 28]               0
       Bottleneck-48          [-1, 512, 28, 28]               0
           Conv2d-49          [-1, 128, 28, 28]          65,536
      BatchNorm2d-50          [-1, 128, 28, 28]             256
             ReLU-51          [-1, 128, 28, 28]               0
           Conv2d-52          [-1, 128, 28, 28]         147,456
      BatchNorm2d-53          [-1, 128, 28, 28]             256
             ReLU-54          [-1, 128, 28, 28]               0
           Conv2d-55          [-1, 512, 28, 28]          65,536
      BatchNorm2d-56          [-1, 512, 28, 28]           1,024
             ReLU-57          [-1, 512, 28, 28]               0
       Bottleneck-58          [-1, 512, 28, 28]               0
           Conv2d-59          [-1, 128, 28, 28]          65,536
      BatchNorm2d-60          [-1, 128, 28, 28]             256
             ReLU-61          [-1, 128, 28, 28]               0
           Conv2d-62          [-1, 128, 28, 28]         147,456
      BatchNorm2d-63          [-1, 128, 28, 28]             256
             ReLU-64          [-1, 128, 28, 28]               0
           Conv2d-65          [-1, 512, 28, 28]          65,536
      BatchNorm2d-66          [-1, 512, 28, 28]           1,024
             ReLU-67          [-1, 512, 28, 28]               0
       Bottleneck-68          [-1, 512, 28, 28]               0
           Conv2d-69          [-1, 128, 28, 28]          65,536
      BatchNorm2d-70          [-1, 128, 28, 28]             256
             ReLU-71          [-1, 128, 28, 28]               0
           Conv2d-72          [-1, 128, 28, 28]         147,456
      BatchNorm2d-73          [-1, 128, 28, 28]             256
             ReLU-74          [-1, 128, 28, 28]               0
           Conv2d-75          [-1, 512, 28, 28]          65,536
      BatchNorm2d-76          [-1, 512, 28, 28]           1,024
             ReLU-77          [-1, 512, 28, 28]               0
       Bottleneck-78          [-1, 512, 28, 28]               0
           Conv2d-79          [-1, 256, 28, 28]         131,072
      BatchNorm2d-80          [-1, 256, 28, 28]             512
             ReLU-81          [-1, 256, 28, 28]               0
           Conv2d-82          [-1, 256, 14, 14]         589,824
      BatchNorm2d-83          [-1, 256, 14, 14]             512
             ReLU-84          [-1, 256, 14, 14]               0
           Conv2d-85         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-86         [-1, 1024, 14, 14]           2,048
           Conv2d-87         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-88         [-1, 1024, 14, 14]           2,048
             ReLU-89         [-1, 1024, 14, 14]               0
       Bottleneck-90         [-1, 1024, 14, 14]               0
           Conv2d-91          [-1, 256, 14, 14]         262,144
      BatchNorm2d-92          [-1, 256, 14, 14]             512
             ReLU-93          [-1, 256, 14, 14]               0
           Conv2d-94          [-1, 256, 14, 14]         589,824
      BatchNorm2d-95          [-1, 256, 14, 14]             512
             ReLU-96          [-1, 256, 14, 14]               0
           Conv2d-97         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-98         [-1, 1024, 14, 14]           2,048
             ReLU-99         [-1, 1024, 14, 14]               0
      Bottleneck-100         [-1, 1024, 14, 14]               0
          Conv2d-101          [-1, 256, 14, 14]         262,144
     BatchNorm2d-102          [-1, 256, 14, 14]             512
            ReLU-103          [-1, 256, 14, 14]               0
          Conv2d-104          [-1, 256, 14, 14]         589,824
     BatchNorm2d-105          [-1, 256, 14, 14]             512
            ReLU-106          [-1, 256, 14, 14]               0
          Conv2d-107         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-108         [-1, 1024, 14, 14]           2,048
            ReLU-109         [-1, 1024, 14, 14]               0
      Bottleneck-110         [-1, 1024, 14, 14]               0
          Conv2d-111          [-1, 256, 14, 14]         262,144
     BatchNorm2d-112          [-1, 256, 14, 14]             512
            ReLU-113          [-1, 256, 14, 14]               0
          Conv2d-114          [-1, 256, 14, 14]         589,824
     BatchNorm2d-115          [-1, 256, 14, 14]             512
            ReLU-116          [-1, 256, 14, 14]               0
          Conv2d-117         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-118         [-1, 1024, 14, 14]           2,048
            ReLU-119         [-1, 1024, 14, 14]               0
      Bottleneck-120         [-1, 1024, 14, 14]               0
          Conv2d-121          [-1, 256, 14, 14]         262,144
     BatchNorm2d-122          [-1, 256, 14, 14]             512
            ReLU-123          [-1, 256, 14, 14]               0
          Conv2d-124          [-1, 256, 14, 14]         589,824
     BatchNorm2d-125          [-1, 256, 14, 14]             512
            ReLU-126          [-1, 256, 14, 14]               0
          Conv2d-127         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-128         [-1, 1024, 14, 14]           2,048
            ReLU-129         [-1, 1024, 14, 14]               0
      Bottleneck-130         [-1, 1024, 14, 14]               0
          Conv2d-131          [-1, 256, 14, 14]         262,144
     BatchNorm2d-132          [-1, 256, 14, 14]             512
            ReLU-133          [-1, 256, 14, 14]               0
          Conv2d-134          [-1, 256, 14, 14]         589,824
     BatchNorm2d-135          [-1, 256, 14, 14]             512
            ReLU-136          [-1, 256, 14, 14]               0
          Conv2d-137         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-138         [-1, 1024, 14, 14]           2,048
            ReLU-139         [-1, 1024, 14, 14]               0
      Bottleneck-140         [-1, 1024, 14, 14]               0
          Conv2d-141          [-1, 512, 14, 14]         524,288
     BatchNorm2d-142          [-1, 512, 14, 14]           1,024
            ReLU-143          [-1, 512, 14, 14]               0
          Conv2d-144            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-145            [-1, 512, 7, 7]           1,024
            ReLU-146            [-1, 512, 7, 7]               0
          Conv2d-147           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-148           [-1, 2048, 7, 7]           4,096
          Conv2d-149           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-150           [-1, 2048, 7, 7]           4,096
            ReLU-151           [-1, 2048, 7, 7]               0
      Bottleneck-152           [-1, 2048, 7, 7]               0
          Conv2d-153            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-154            [-1, 512, 7, 7]           1,024
            ReLU-155            [-1, 512, 7, 7]               0
          Conv2d-156            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-157            [-1, 512, 7, 7]           1,024
            ReLU-158            [-1, 512, 7, 7]               0
          Conv2d-159           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-160           [-1, 2048, 7, 7]           4,096
            ReLU-161           [-1, 2048, 7, 7]               0
      Bottleneck-162           [-1, 2048, 7, 7]               0
          Conv2d-163            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-164            [-1, 512, 7, 7]           1,024
            ReLU-165            [-1, 512, 7, 7]               0
          Conv2d-166            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-167            [-1, 512, 7, 7]           1,024
            ReLU-168            [-1, 512, 7, 7]               0
          Conv2d-169           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-170           [-1, 2048, 7, 7]           4,096
            ReLU-171           [-1, 2048, 7, 7]               0
      Bottleneck-172           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-173           [-1, 2048, 1, 1]               0
          Linear-174                  [-1, 256]         524,544
            ReLU-175                  [-1, 256]               0
         Dropout-176                  [-1, 256]               0
          Linear-177                   [-1, 11]           2,827
      LogSoftmax-178                   [-1, 11]               0
================================================================
Total params: 24,035,403
Trainable params: 527,371
Non-trainable params: 23,508,032
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 286.56
Params size (MB): 91.69
Estimated Total Size (MB): 378.82
----------------------------------------------------------------
In [0]:
def train_and_validate(model, loss_criterion, optimizer, epochs=25):
    '''
    Function to train and validate
    Parameters
        :param model: Model to train and validate
        :param loss_criterion: Loss Criterion to minimize
        :param optimizer: Optimizer for computing gradients
        :param epochs: Number of epochs (default=25)
  
    Returns
        model: Trained Model with best validation accuracy
        history: (dict object): Having training loss, accuracy and validation loss, accuracy
    '''
    
    start = time.time()
    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1, epochs))
        
        # Set to training mode
        model.train()
        
        # Loss and Accuracy within the epoch
        train_loss = 0.0
        train_acc = 0.0
        
        valid_loss = 0.0
        valid_acc = 0.0
        
        for i, (inputs, labels) in enumerate(train_data):

            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Clean existing gradients
            optimizer.zero_grad()
            
            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)
            
            # Compute loss
            loss = loss_criterion(outputs, labels)
            
            # Backpropagate the gradients
            loss.backward()
            
            # Update the parameters
            optimizer.step()
            
            # Compute the total loss for the batch and add it to train_loss
            train_loss += loss.item() * inputs.size(0)
            
            # Compute the accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            
            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            # Compute total accuracy in the whole batch and add to train_acc
            train_acc += acc.item() * inputs.size(0)
            
            #print("Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}".format(i, loss.item(), acc.item()))

            
        # Validation - No gradient tracking needed
        with torch.no_grad():

            # Set to evaluation mode
            model.eval()

            # Validation loop
            for j, (inputs, labels) in enumerate(valid_data):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass - compute outputs on input data using the model
                outputs = model(inputs)

                # Compute loss
                loss = loss_criterion(outputs, labels)

                # Compute the total loss for the batch and add it to valid_loss
                valid_loss += loss.item() * inputs.size(0)

                # Calculate validation accuracy
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))

                # Convert correct_counts to float and then compute the mean
                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                # Compute total accuracy in the whole batch and add to valid_acc
                valid_acc += acc.item() * inputs.size(0)

                #print("Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}".format(j, loss.item(), acc.item()))
            
        # Find average training loss and training accuracy
        avg_train_loss = train_loss/train_data_size 
        avg_train_acc = train_acc/train_data_size

        # Find average training loss and training accuracy
        avg_valid_loss = valid_loss/valid_data_size 
        avg_valid_acc = valid_acc/valid_data_size

        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])
                
        epoch_end = time.time()
    
        print("Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch+1, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))
        
        # Save if the model has best accuracy till now
        # torch.save(model, 'model_'+str(epoch)+'.pt')
            
    return model, history
In [58]:
num_epochs = 25
trained_model, history = train_and_validate(resnet_50, loss_func, optimizer, num_epochs)
torch.save(history, 'history.pt')
Epoch: 1/25
Epoch : 001, Training: Loss: 2.0360, Accuracy: 32.7273%, 
		Validation : Loss : 1.3534, Accuracy: 65.0000%, Time: 15.9276s
Epoch: 2/25
Epoch : 002, Training: Loss: 1.3181, Accuracy: 59.7273%, 
		Validation : Loss : 0.9446, Accuracy: 73.6364%, Time: 15.8867s
Epoch: 3/25
Epoch : 003, Training: Loss: 1.0782, Accuracy: 65.2727%, 
		Validation : Loss : 0.8406, Accuracy: 76.8182%, Time: 15.8777s
Epoch: 4/25
Epoch : 004, Training: Loss: 0.9117, Accuracy: 72.0909%, 
		Validation : Loss : 0.7171, Accuracy: 79.0909%, Time: 16.2870s
Epoch: 5/25
Epoch : 005, Training: Loss: 0.8877, Accuracy: 72.6364%, 
		Validation : Loss : 0.7464, Accuracy: 77.7273%, Time: 15.8715s
Epoch: 6/25
Epoch : 006, Training: Loss: 0.8385, Accuracy: 72.9091%, 
		Validation : Loss : 0.7248, Accuracy: 75.9091%, Time: 15.8559s
Epoch: 7/25
Epoch : 007, Training: Loss: 0.7604, Accuracy: 74.7273%, 
		Validation : Loss : 0.6933, Accuracy: 79.5455%, Time: 15.8478s
Epoch: 8/25
Epoch : 008, Training: Loss: 0.7379, Accuracy: 76.8182%, 
		Validation : Loss : 0.5714, Accuracy: 85.0000%, Time: 15.9050s
Epoch: 9/25
Epoch : 009, Training: Loss: 0.6839, Accuracy: 78.4545%, 
		Validation : Loss : 0.5900, Accuracy: 83.1818%, Time: 15.9014s
Epoch: 10/25
Epoch : 010, Training: Loss: 0.6969, Accuracy: 76.5455%, 
		Validation : Loss : 0.6262, Accuracy: 80.0000%, Time: 15.9232s
Epoch: 11/25
Epoch : 011, Training: Loss: 0.5962, Accuracy: 80.0909%, 
		Validation : Loss : 0.5857, Accuracy: 79.5455%, Time: 15.8562s
Epoch: 12/25
Epoch : 012, Training: Loss: 0.5672, Accuracy: 81.7273%, 
		Validation : Loss : 0.5673, Accuracy: 80.4545%, Time: 15.8696s
Epoch: 13/25
Epoch : 013, Training: Loss: 0.6247, Accuracy: 80.5455%, 
		Validation : Loss : 0.5577, Accuracy: 81.3636%, Time: 15.8772s
Epoch: 14/25
Epoch : 014, Training: Loss: 0.5686, Accuracy: 81.8182%, 
		Validation : Loss : 0.5743, Accuracy: 80.9091%, Time: 15.8630s
Epoch: 15/25
Epoch : 015, Training: Loss: 0.5647, Accuracy: 80.7273%, 
		Validation : Loss : 0.5236, Accuracy: 84.0909%, Time: 15.9140s
Epoch: 16/25
Epoch : 016, Training: Loss: 0.5239, Accuracy: 82.0000%, 
		Validation : Loss : 0.6071, Accuracy: 79.5455%, Time: 15.9032s
Epoch: 17/25
Epoch : 017, Training: Loss: 0.5532, Accuracy: 81.3636%, 
		Validation : Loss : 0.5373, Accuracy: 84.0909%, Time: 15.8482s
Epoch: 18/25
Epoch : 018, Training: Loss: 0.5236, Accuracy: 82.9091%, 
		Validation : Loss : 0.5534, Accuracy: 82.7273%, Time: 15.8996s
Epoch: 19/25
Epoch : 019, Training: Loss: 0.4936, Accuracy: 83.0000%, 
		Validation : Loss : 0.5223, Accuracy: 82.2727%, Time: 16.0922s
Epoch: 20/25
Epoch : 020, Training: Loss: 0.5198, Accuracy: 82.2727%, 
		Validation : Loss : 0.4858, Accuracy: 84.5455%, Time: 15.8731s
Epoch: 21/25
Epoch : 021, Training: Loss: 0.4850, Accuracy: 82.6364%, 
		Validation : Loss : 0.6114, Accuracy: 78.6364%, Time: 15.8514s
Epoch: 22/25
Epoch : 022, Training: Loss: 0.5058, Accuracy: 83.3636%, 
		Validation : Loss : 0.5244, Accuracy: 82.7273%, Time: 15.9262s
Epoch: 23/25
Epoch : 023, Training: Loss: 0.4296, Accuracy: 86.1818%, 
		Validation : Loss : 0.6065, Accuracy: 78.1818%, Time: 15.9853s
Epoch: 24/25
Epoch : 024, Training: Loss: 0.4808, Accuracy: 83.4545%, 
		Validation : Loss : 0.5916, Accuracy: 81.8182%, Time: 16.1919s
Epoch: 25/25
Epoch : 025, Training: Loss: 0.5062, Accuracy: 82.8182%, 
		Validation : Loss : 0.5311, Accuracy: 80.0000%, Time: 15.8669s
In [0]:
torch.save(trained_model,'trained_model.pt')
In [61]:
history = np.array(history)
plt.plot(history[:,0:2])
plt.legend(['Tr Loss', 'Val Loss'])
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.ylim(0,1)
plt.savefig('loss_curve.png')
plt.show()
In [62]:
plt.plot(history[:,2:4])
plt.legend(['Tr Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.ylim(0,1)
plt.savefig('_accuracy_curve.png')
plt.show()

We achieved over 80% accuracy on the validation score. The result is acceptable for this small dataset. If we trained on the entire dataset, the accuracy would be much better.

In the next post, I will try fine-tuning the same model, on the same dataset with Keras. Let's see the difference in implementation. So stay tuned ~.

Test set accuracy

In [0]:
def computeTestSetAccuracy(model, loss_criterion):
    '''
    Function to compute the accuracy on the test set
    Parameters
        :param model: Model to test
        :param loss_criterion: Loss Criterion to minimize
    '''

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    test_acc = 0.0
    test_loss = 0.0

    # Validation - No gradient tracking needed
    with torch.no_grad():

        # Set to evaluation mode
        model.eval()

        # Validation loop
        for j, (inputs, labels) in enumerate(test_data):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)

            # Compute loss
            loss = loss_criterion(outputs, labels)

            # Compute the total loss for the batch and add it to valid_loss
            test_loss += loss.item() * inputs.size(0)

            # Calculate validation accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))

            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))

            # Compute total accuracy in the whole batch and add to valid_acc
            test_acc += acc.item() * inputs.size(0)

            print("Test Batch number: {:03d}, Test: Loss: {:.4f}, Accuracy: {:.4f}".format(j, loss.item(), acc.item()))

    # Find average test loss and test accuracy
    avg_test_loss = test_loss/test_data_size 
    avg_test_acc = test_acc/test_data_size

    print("Test accuracy : " + str(avg_test_acc))
In [65]:
computeTestSetAccuracy(trained_model, loss_func)
Test Batch number: 000, Test: Loss: 0.5162, Accuracy: 0.7812
Test Batch number: 001, Test: Loss: 0.5739, Accuracy: 0.7812
Test Batch number: 002, Test: Loss: 0.7250, Accuracy: 0.8125
Test Batch number: 003, Test: Loss: 0.7017, Accuracy: 0.7812
Test Batch number: 004, Test: Loss: 0.6244, Accuracy: 0.8125
Test Batch number: 005, Test: Loss: 0.7207, Accuracy: 0.7812
Test Batch number: 006, Test: Loss: 0.7719, Accuracy: 0.8125
Test Batch number: 007, Test: Loss: 0.8399, Accuracy: 0.6562
Test Batch number: 008, Test: Loss: 0.3801, Accuracy: 0.8750
Test Batch number: 009, Test: Loss: 0.9372, Accuracy: 0.7812
Test Batch number: 010, Test: Loss: 0.9332, Accuracy: 0.6875
Test Batch number: 011, Test: Loss: 0.4517, Accuracy: 0.8438
Test Batch number: 012, Test: Loss: 0.7051, Accuracy: 0.7812
Test Batch number: 013, Test: Loss: 0.6603, Accuracy: 0.8333
Test accuracy : 0.7863636352799156

Predict on test images

In [0]:
def predict(model, test_image_name):
    '''
    Function to predict the class of a single test image
    Parameters
        :param model: Model to test
        :param test_image_name: Test image

    '''
    
    transform = image_transforms['test']

    test_image = Image.open(test_image_name)
    plt.imshow(test_image)
    
    test_image_tensor = transform(test_image)

    if torch.cuda.is_available():
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()
    else:
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224)
    
    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        out = model(test_image_tensor)
        ps = torch.exp(out)
        topk, topclass = ps.topk(3, dim=1)
        for i in range(3):
            print("Predcition", i+1, ":", idx_to_class[topclass.cpu().numpy()[0][i]], ", Score: ", topk.cpu().numpy()[0][i])
In [67]:
model = torch.load('trained_model.pt')
predict(model, 'food-11k-sub/test/Egg/3_22.jpg')
Predcition 1 : Egg , Score:  0.947523
Predcition 2 : Bread , Score:  0.04919945
Predcition 3 : Meat , Score:  0.0022846945

Summary

The above results and test accuracy is pretty good I would say, considering we are using a fairly small dataset.

In this post, we've tried transfer learning on a small dataset using Pytorch with the pre-trained ResNet50 model. In a future post, I will try to do the same thing using Keras. Stay tuned.



Comments

comments powered by Disqus