Semantic Segmentation using FCN and DeepLabV3

Semantic Segmentation is an image analysis task in which we classify each pixel in the image into a class.

In this post, we will perform semantic segmentation using pre-trained models built in Pytorch. They are FCN and DeepLabV3.

Understanding model inputs and outputs

Now before we get started, we need to know about the inputs and outputs of these semantic segmentation models. So, let's start!

These models expect a 3-channled image which is normalized with the Imagenet mean and standard deviation, i.e.,
mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]

So, the input is [Ni x Ci x Hi x Wi]
where,

  • Ni -> the batch size
  • Ci -> the number of channels (which is 3)
  • Hi -> the height of the image
  • Wi -> the width of the image

And the output of the model is [No x Co x Ho x Wo]
where,

  • No -> is the batch size (same as Ni)
  • Co -> is the number of classes that the dataset have! So each class will have a map in the output.
  • Ho -> the height of the image (which is the same as Hi in almost all cases)
  • Wo -> the width of the image (which is the same as Wi in almost all cases)

Note that torchvision models output is an OrderedDict and not a torch.Tensor
And in .eval() mode it just has one key out and thus to get the output we need to get the value stored in that key.

The out key of this OrderedDict is the key that holds the output.

FCN with Resnet-101 backbone

FCN - Fully Convolutional Netowrks, are among the most early invented Neural Networks for the task of Semantic Segmentation.

Let's load one up!

In [0]:
from torchvision import models
fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /root/.cache/torch/checkpoints/resnet101-5d3b4d8f.pth

Downloading: "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth" to /root/.cache/torch/checkpoints/fcn_resnet101_coco-7ecb50ca.pth

And that's it we have a pretrained model of FCN (which stands for Fully Convolutional Neural Networks) with a Resnet101 backbone :)

Now, let's get an image!

In [0]:
from PIL import Image
import matplotlib.pyplot as plt
import torch

!wget -nv https://www.goodfreephotos.com/cache/other-photos/car-and-traffic-on-the-road-coming-towards-me.jpg -O car.png
img = Image.open('./car.png')
plt.imshow(img)
plt.axis('off')
plt.show()
2020-06-14 05:56:51 URL:https://www.goodfreephotos.com/cache/other-photos/car-and-traffic-on-the-road-coming-towards-me_800.jpg?cached=1522560655 [409997/409997] -> "car.png" [1]

Now, that we have the image we need to preprocess it and normalize it! These preprocess is following the ImageNet training data.

In [0]:
# Apply the transformations needed
import torchvision.transforms as T
trf = T.Compose([T.Resize(256),
                 T.CenterCrop(224),
                 T.ToTensor(), 
                 T.Normalize(mean = [0.485, 0.456, 0.406], 
                             std = [0.229, 0.224, 0.225])])
inp = trf(img)
print(inp.shape)
inp = inp.unsqueeze(0)
print(inp.shape)
torch.Size([3, 224, 224])
torch.Size([1, 3, 224, 224])

Forward pass the input tensor. The output of the model is a OrderedDict so, we need to take the out key from that to get the output of the model.

In [0]:
# Pass the input through the net
out = fcn(inp)['out']
print (out.shape)
torch.Size([1, 21, 224, 224])

Alright! So, out is the final output of the model. And as we can see, its shape is [1 x 21 x H x W] as discussed earlier. So, the model was trained on 21 classes and thus our output have 21 channels!

Since the pretrained FCN segmentation model is trained on PASCAL VOC dataset, which have 20 class labels + 1 background class. Therefore, there are 21 classes in total. Now we need to compress the 21 channels into a single channel which contains all the class labels for each pixel.

In [0]:
import numpy as np
om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
print (om.shape)
print (np.unique(om))
(224, 224)
[0 7]

So, we as we can see now have a 2D image. Where each pixel corresponds to a class! The last thing is to take this 2D image where each pixel corresponds to a class label and convert this
into a segmentation map where each class label is converted into a RGB color and thus helping in easy visualization.

We will use the following function to convert this 2D image to an RGB image wheree each label is mapped to its corresponding color.

In [0]:
# Define the helper function
def decode_segmap(image, nc=21):
  
  label_colors = np.array([(0, 0, 0),  # 0=background
               # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
               (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
               # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
               (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
               # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
               (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
               # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
               (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])

  r = np.zeros_like(image).astype(np.uint8)
  g = np.zeros_like(image).astype(np.uint8)
  b = np.zeros_like(image).astype(np.uint8)
  
  for l in range(0, nc):
    idx = image == l
    r[idx] = label_colors[l, 0]
    g[idx] = label_colors[l, 1]
    b[idx] = label_colors[l, 2]
    
  rgb = np.stack([r, g, b], axis=2)
  return rgb
In [0]:
rgb = decode_segmap(om)
plt.imshow(rgb); plt.show()

Because we are doing semantic segmentation. So the two cars in the middle are assigned to the same class. And that class is car

Also, Do note that the image after segmentation is smaller than the original image as in the preprocessing step the image is resized and cropped.

Next, let's move all this under one function and play with a few more images!

In [0]:
def segment(net, path, show_orig=True, dev='cuda'):
  img = Image.open(path)
  if show_orig: plt.imshow(img); plt.axis('off'); plt.show()
  # Comment the Resize and CenterCrop for better inference results
  trf = T.Compose([T.Resize(640), 
                   #T.CenterCrop(224), 
                   T.ToTensor(), 
                   T.Normalize(mean = [0.485, 0.456, 0.406], 
                               std = [0.229, 0.224, 0.225])])
  inp = trf(img).unsqueeze(0).to(dev)
  out = net.to(dev)(inp)['out']
  om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
  rgb = decode_segmap(om)
  plt.imshow(rgb); plt.axis('off'); plt.show()

And let's get a new image!

In [0]:
!wget -nv https://upload.wikimedia.org/wikipedia/commons/thumb/f/f4/Sussex_cow_4.JPG/1200px-Sussex_cow_4.JPG -O cow.jpg
segment(fcn, './cow.jpg')
2020-06-14 06:01:04 URL:https://upload.wikimedia.org/wikipedia/commons/thumb/f/f4/Sussex_cow_4.JPG/1200px-Sussex_cow_4.JPG [223575/223575] -> "cow.jpg" [1]
In [0]:
!wget -nv https://storage.needpix.com/rsynced_images/pedestrian-zone-456909_1280.jpg -O pedestrian.jpg
segment(fcn, './pedestrian.jpg')
2020-06-14 06:03:03 URL:https://storage.needpix.com/rsynced_images/pedestrian-zone-456909_1280.jpg [409534/409534] -> "pedestrian.jpg" [1]

DeepLabv3 with Resnet-101 backbone

In [0]:
dlab = models.segmentation.deeplabv3_resnet101(pretrained=1).eval()
Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /root/.cache/torch/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth

Alright! Now we have god-level segmentation model!
Let's see how we perform with the same image on this model!

In [0]:
segment(dlab, './car.png')

Wow. It looks like DeepLabV3 is able to find the third car which is very dark in the original picture.

Let's see how it performs on other pictures.

In [0]:
segment(dlab, './cow.jpg')
In [0]:
segment(dlab, './pedestrian.jpg')

Comparision

For, now we will see how these two models compare with each other in 3 metrics

  • Inference time
  • Size of the model
  • GPU memory used by the model

Inference Time

In [0]:
import time

def infer_time(net, path='./pedestrian.jpg', dev='cuda'):
  img = Image.open(path)
  trf = T.Compose([T.Resize(256), 
                   T.CenterCrop(224), 
                   T.ToTensor(), 
                   T.Normalize(mean = [0.485, 0.456, 0.406], 
                               std = [0.229, 0.224, 0.225])])
  
  inp = trf(img).unsqueeze(0).to(dev)
  
  st = time.time()
  out1 = net.to(dev)(inp)
  et = time.time()
  
  return et - st

On CPU

In [0]:
avg_over = 100

fcn_infer_time_list_cpu = [infer_time(fcn, dev='cpu') for _ in range(avg_over)]
fcn_infer_time_avg_cpu = sum(fcn_infer_time_list_cpu) / avg_over

dlab_infer_time_list_cpu = [infer_time(dlab, dev='cpu') for _ in range(avg_over)]
dlab_infer_time_avg_cpu = sum(dlab_infer_time_list_cpu) / avg_over


print ('Inference time for first few calls for FCN      : {}'.format(fcn_infer_time_list_cpu[:10]))
print ('Inference time for first few calls for DeepLabv3: {}'.format(dlab_infer_time_list_cpu[:10]))

print ('The Average Inference time on FCN is:     {:.2f}s'.format(fcn_infer_time_avg_cpu))
print ('The Average Inference time on DeepLab is: {:.2f}s'.format(dlab_infer_time_avg_cpu))
Inference time for first few calls for FCN      : [1.7313017845153809, 1.57572603225708, 1.5629489421844482, 1.5728049278259277, 1.5671417713165283, 1.5791828632354736, 1.5742602348327637, 1.5786330699920654, 1.5627250671386719, 1.578148365020752]
Inference time for first few calls for DeepLabv3: [2.0515308380126953, 1.8824033737182617, 1.8790647983551025, 1.8520596027374268, 1.860480546951294, 1.8530280590057373, 1.837153434753418, 1.8442959785461426, 1.8274402618408203, 1.8295633792877197]
The Average Inference time on FCN is:     1.57s
The Average Inference time on DeepLab is: 1.85s

On GPU

In [0]:
avg_over = 100

fcn_infer_time_list_gpu = [infer_time(fcn) for _ in range(avg_over)]
fcn_infer_time_avg_gpu = sum(fcn_infer_time_list_gpu) / avg_over

dlab_infer_time_list_gpu = [infer_time(dlab) for _ in range(avg_over)]
dlab_infer_time_avg_gpu = sum(dlab_infer_time_list_gpu) / avg_over

print ('Inference time for first few calls for FCN      : {}'.format(fcn_infer_time_list_gpu[:10]))
print ('Inference time for first few calls for DeepLabv3: {}'.format(dlab_infer_time_list_gpu[:10]))

print ('The Average Inference time on FCN is:     {:.3f}s'.format(fcn_infer_time_avg_gpu))
print ('The Average Inference time on DeepLab is: {:.3f}s'.format(dlab_infer_time_avg_gpu))
Inference time for first few calls for FCN      : [0.1053011417388916, 0.017090559005737305, 0.01671576499938965, 0.01747894287109375, 0.018079280853271484, 0.01729726791381836, 0.01781606674194336, 0.024182796478271484, 0.017487287521362305, 0.017383575439453125]
Inference time for first few calls for DeepLabv3: [0.12458348274230957, 0.019745349884033203, 0.01925039291381836, 0.018856048583984375, 0.020597457885742188, 0.01930522918701172, 0.01948070526123047, 0.018863677978515625, 0.018346786499023438, 0.019052743911743164]
The Average Inference time on FCN is:     0.019s
The Average Inference time on DeepLab is: 0.020s

We can see that in both cases (for GPU and CPU) its taking longer for the DeepLabv3 model, as its a much deeper model as compared to FCN.

In [0]:
plt.bar([0.1, 0.2], [fcn_infer_time_avg_cpu, dlab_infer_time_avg_cpu], width=0.08)
plt.ylabel('Time/s')
plt.xticks([0.1, 0.2], ['FCN', 'DeepLabv3'])
plt.title('Inference time of FCN and DeepLabv3 with Resnet-101 backbone on CPU')
plt.show()
In [0]:
plt.bar([0.1, 0.2], [fcn_infer_time_avg_gpu, dlab_infer_time_avg_gpu], width=0.08)
plt.ylabel('Time/s')
plt.xticks([0.1, 0.2], ['FCN', 'DeepLabv3'])
plt.title('Inference time of FCN and DeepLabv3 with Resnet-101 backbone on GPU')
plt.show()

Okay! Now, let's move on to the next comparison, where we will compare the model sizes for both the models.

Model Size

In [0]:
import os

resnet101_size = os.path.getsize('/root/.cache/torch/checkpoints/resnet101-5d3b4d8f.pth')
fcn_size = os.path.getsize('/root/.cache/torch/checkpoints/fcn_resnet101_coco-7ecb50ca.pth')
dlab_size = os.path.getsize('/root/.cache/torch/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth')

fcn_total = fcn_size + resnet101_size
dlab_total = dlab_size + resnet101_size
    
print ('Size of the FCN model with Resnet101 backbone is:       {:.2f} MB'.format(fcn_total /  (1024 * 1024)))
print ('Size of the DeepLabv3 model with Resnet101 backbone is: {:.2f} MB'.format(dlab_total / (1024 * 1024)))
Size of the FCN model with Resnet101 backbone is:       378.16 MB
Size of the DeepLabv3 model with Resnet101 backbone is: 403.67 MB
In [0]:
plt.bar([0, 1], [fcn_total / (1024 * 1024), dlab_total / (1024 * 1024)])
plt.ylabel('Size of the model in MegaBytes')
plt.xticks([0, 1], ['FCN', 'DeepLabv3'])
plt.title('Comparison of the model size of FCN and DeepLabv3 with Resnet-101 backbone')
plt.show()


Comments

comments powered by Disqus