Original U-Net in PyTorch
Explanation with Code Implementation of U-Net Research Paper
Before jumping into U-net architecture, let me introduce you to the term “Semantic Segmentation” (sounds fancy😎). Semantic Segmentation is classifying each pixel of an image according to a category of objects.
Let’s say for the above image, there are 2 categories of Objects(classes) Person, bicycle and the additional category background include all those categories for which we don’t care.
The image on the right is the prediction of each pixel according to the category.
U-net is the architecture that maps the image to pixel-level classification.
The research paper, I will be discussing here is “U-Net: Convolutional Networks for Biomedical Image Segmentation” — Ronneberger et al.
You can download the paper from here.
Overview
This architecture takes input an Image (572x572 acc. to the paper) with channel dimension “1” for the grayscale image and outputs a segmentation map of size (388x388) with channel dimension equivalent to the numbers of class (“K” in Fig 1) categories to be identified.
Doubts
Why is the Input image dimension different from the Output segmentation map?
The author used unpadded convolution, so the output is reduced by a constant border width. It was assumed that the segmentation object lies in the center of the image.
This can be avoided by using a padded convolution with a mirror strategy.
**mirror strategy — Instead of using zero-padding, mirror image pixel in the border.
Unet contains mainly three parts:
a. The Contracting Path — This is the descending side (left side) of “U”, it helps in getting the context of objects in the image. Feature map of contracting side learns “which class object belongs to”.
b. The Expanding Path — This is the ascending side (right side) of “U”, it helps in the localization of the objects in the image. Feature map of this side learns “which position is the object present”.
c. Skip Connection — The long gray arrows from the contracting side to the expanding side are the skip connections.
Skip Connections are used to “retain the spatial information” lost during downsampling the image. So that, the feature map from the expanding path can get a better context of the position of the original pixels.
Architecture Implementation
The architecture implementation is in PyTorch
Simple Double Convolution Layer
An input image is passed through a 3x3 kernel-sized Convolution and a ReLU activation over it and then drop-out. Channel dimension increases from “1” to “64”.
Down-Convolution Layer
After passing through a simple double Conv layer, the feature map enters a max-pool of 2x2 kernel-size, which reduces the feature map by half. So 568x568 becomes 284x284 in dimension. Then it is passed through the 3x3 kernel-size Convolution Layer two times. The channel dimension is doubled to 128.
This down Convolution is repeated three times with different dimensions.
Up-Convolution Layer
The feature map, after passing through the “bridge”, is sent into two 3x3 kernel-sized Conv with ReLU and dropout. Then it is passed through Transpose Convolution to upsample the feature map.
Transpose Convolution is used with 2x2-kernel size and with stride equals 2.
The output channel dimension in the Transpose Convolution layer is halved as we will be concatenating the feature map from the contracting path.
Bridge — Bridge contains a max-pooling of kernel-size 2x2 and an Up-Convolution Layer. Up Conv layer for the bridge has the input as 512 and the output as 1024 channels.
Last Convolution Layer
This layer is similar to the previous layers with a double 3x3 Conv layer, then ReLU, and then drop-out.
The distinction is on the last Conv layer, after getting the feature map from the double Conv layer, it is passed into a 1x1 Conv layer to map 64 channels to the desired number of classes (categories of Object).
Image Cropper
While concatenating feature map from contracting path to expanding path. Feature map from contracting path need to be cropped to match the dimension of feature map from expanding path.
Constructing U-Net — Assembling all Modules together
Most of the code is straightforward. I will explain some parts, where doubt may arise.
— “midMaxpool” and “bridge” formed the Bridge as discussed earlier.
— Skip connections are implemented by “torch.cat” to concatenate feature maps along channel dimensions. So, concatenation is done along the “1” axis, i.e. channel dimension in PyTorch.
In PyTorch, tensors are represented a bit differently.
Normally, tensors are (batch_size,height,width,channels).
But in torch, tensors are (batch_size, channels,height,width).
I chose some liberty to construct this architecture to maximize code-reusability.
Optimizer and Loss Function
As discussed in the paper, the stochastic gradient is used with momentum as an optimizer. Nothing fancy, just a moving average of the calculated gradient to reduce noise and accelerate faster towards an optimal solution.
optim = torch.optim.SGD(model.parameters(),lr=0.01,momentum=beta)
# beta=0.99 as per paper
For the loss function, it uses a cross-entropy loss function over pixel-wise softmax.
Explaining Softmax and Cross-entropy will make this blog very lengthy, but I will briefly go through it.
Softmax — It is a function that maps values in a vector between 0 to 1 and the total sum after applying it to every value in that vector would be 1. It looks like probability, but the values are not fixed because they come from the activation function, which can change with a tweak in initial weight and biases in the neural network.
Pixel-wise softmax as mentioned in the paper
x = single pixel position on the output segmentation map
aₖ (x) = activation function at “x” pixel position of kth class
K = number of classes acc. to paper its 2
k = Individual class
Numerator — exponentiation of the activation function at pixel position X.
Denominator — the sum of all exponentiation of activation function at pixel position “X” along channel dimension of the output segmentation map.
Cross-Entropy Loss Function— It measures the performance of the model with the ground-truth value and heavily penalizes it with increasing deviation.
w (x) = ground-truth value of pixel from weighted segmentation map
log(pₗ₍ₓ₎(x)) = log of pixel value after passing through softmax
Why not just use something simpler, like just the difference between ground truth and predicted or square of residual?
Cross-Entropy Loss shoots up when deviation from ground truth starts increasing. And so, the derivative of the loss function gives a better judgment of how much weights need to be tweaked.
Whereas, if we use something like a square of residual, the loss function slope is very little and it makes backpropagation (the method to get derivatives in Neural Network) harder.
Weight Initialization
Initial weight can be derived from Normal distribution with standard deviation square_root(2/N).
In torch, it can be done by xavier_normal initialization.
torch.nn.init.xavier_normal_(tensor, gain=1.0)
Miscellaneous
Overlap Tile Strategy
This strategy is used in the paper so that the Resolution of the Image for prediction should not be dependable on GPU memory.
To predict a part of an image for segmentation neighbor area plays an important role. So if u want to predict the image inside the yellow box, then pixels inside the blue box are also required. That is done by mirroring the image.
Weight Map for Ground Truth Segmentation
This is done to small separation between objects which are close together.
This w(x) is used earlier in the cross-entropy function as ground truth to calculate the deviation.
In the future blog, I will write on the advancement over Unet and variations of Unet like Recurrent-Unet, Attention-Unet, etc.
More Reading References
1. Unet Research Paper — U-Net: Convolutional Networks for Biomedical Image Segmentation
2. Transposed Convolution
3. Softmax
4. Cross-Entropy Loss Function
Thanks for reading and Happy Coding!!