Original U-Net in PyTorch

Explanation with Code Implementation of U-Net Research Paper

Ankan Sharma
Towards Dev

--

Image of Semantic Segmentation
Semantic Segmentation Credit: Rockers technology

Before jumping into U-net architecture, let me introduce you to the term “Semantic Segmentation” (sounds fancy😎). Semantic Segmentation is classifying each pixel of an image according to a category of objects.

Image of Semantic Segmentation
Semantic Segmentation of Image

Let’s say for the above image, there are 2 categories of Objects(classes) Person, bicycle and the additional category background include all those categories for which we don’t care.
The image on the right is the prediction of each pixel according to the category.
U-net is the architecture that maps the image to pixel-level classification.

The research paper, I will be discussing here is “U-Net: Convolutional Networks for Biomedical Image Segmentation” — Ronneberger et al.
You can download the paper from here.

Diagram of U-Net Architecture
U-Net Architecture

Overview

This architecture takes input an Image (572x572 acc. to the paper) with channel dimension “1” for the grayscale image and outputs a segmentation map of size (388x388) with channel dimension equivalent to the numbers of class (“K” in Fig 1) categories to be identified.

Simplified Unet Image
Fig: 1 U-NET

Doubts
Why is the Input image dimension different from the Output segmentation map?
The author used unpadded convolution, so the output is reduced by a constant border width. It was assumed that the segmentation object lies in the center of the image.
This can be avoided by using a padded convolution with a mirror strategy.
**mirror strategy — Instead of using zero-padding, mirror image pixel in the border.

Unet contains mainly three parts:

a. The Contracting Path — This is the descending side (left side) of “U”, it helps in getting the context of objects in the image. Feature map of contracting side learns “which class object belongs to”.

b. The Expanding Path — This is the ascending side (right side) of “U”, it helps in the localization of the objects in the image. Feature map of this side learns “which position is the object present”.

c. Skip Connection — The long gray arrows from the contracting side to the expanding side are the skip connections.
Skip Connections are used to “retain the spatial information” lost during downsampling the image. So that, the feature map from the expanding path can get a better context of the position of the original pixels.

Architecture Implementation

The architecture implementation is in PyTorch

Simple Double Convolution Layer
An input image is passed through a 3x3 kernel-sized Convolution and a ReLU activation over it and then drop-out. Channel dimension increases from “1” to “64”.

Image of Simple Convolution Layer
Simple Double Convolution Layer
Image of Code of Simple Double Convolution Code
Code of Simple Double Convolution Code

Down-Convolution Layer
After passing through a simple double Conv layer, the feature map enters a max-pool of 2x2 kernel-size, which reduces the feature map by half. So 568x568 becomes 284x284 in dimension. Then it is passed through the 3x3 kernel-size Convolution Layer two times. The channel dimension is doubled to 128.
This down Convolution is repeated three times with different dimensions.

Image of Down Convolution Layer
Down Convolution Layer
Image of Code of Down Convolution
Code of Down Convolution

Up-Convolution Layer
The feature map, after passing through the “bridge”, is sent into two 3x3 kernel-sized Conv with ReLU and dropout. Then it is passed through Transpose Convolution to upsample the feature map.
Transpose Convolution is used with 2x2-kernel size and with stride equals 2.
The output channel dimension in the Transpose Convolution layer is halved as we will be concatenating the feature map from the contracting path.

Image of Bridge
Bridge in Unet

Bridge — Bridge contains a max-pooling of kernel-size 2x2 and an Up-Convolution Layer. Up Conv layer for the bridge has the input as 512 and the output as 1024 channels.

Image of Up Convolution Layer
Up-Convolution Layer
Image of Code for Up-Convolution
Code for Up-Convolution

Last Convolution Layer
This layer is similar to the previous layers with a double 3x3 Conv layer, then ReLU, and then drop-out.
The distinction is on the last Conv layer, after getting the feature map from the double Conv layer, it is passed into a 1x1 Conv layer to map 64 channels to the desired number of classes (categories of Object).

Image of Last Convolution Layer
Last Convolution Layer
Image of Code for Last Convolution Layer
Code for Last Convolution Layer

Image Cropper
While concatenating feature map from contracting path to expanding path. Feature map from contracting path need to be cropped to match the dimension of feature map from expanding path.

Image of Image Cropping
Image Cropping
Code for Image Cropping

Constructing U-Net — Assembling all Modules together
Most of the code is straightforward. I will explain some parts, where doubt may arise.
— “midMaxpool” and “bridge” formed the Bridge as discussed earlier.
— Skip connections are implemented by “torch.cat” to concatenate feature maps along channel dimensions. So, concatenation is done along the “1” axis, i.e. channel dimension in PyTorch.

In PyTorch, tensors are represented a bit differently.
Normally, tensors are (batch_size,height,width,channels).
But in torch, tensors are (batch_size, channels,height,width).

I chose some liberty to construct this architecture to maximize code-reusability.

Image of Code for complete Unet
Code for complete Unet

Optimizer and Loss Function

As discussed in the paper, the stochastic gradient is used with momentum as an optimizer. Nothing fancy, just a moving average of the calculated gradient to reduce noise and accelerate faster towards an optimal solution.

optim = torch.optim.SGD(model.parameters(),lr=0.01,momentum=beta)
# beta=0.99 as per paper

For the loss function, it uses a cross-entropy loss function over pixel-wise softmax.
Explaining Softmax and Cross-entropy will make this blog very lengthy, but I will briefly go through it.
Softmax — It is a function that maps values in a vector between 0 to 1 and the total sum after applying it to every value in that vector would be 1. It looks like probability, but the values are not fixed because they come from the activation function, which can change with a tweak in initial weight and biases in the neural network.

Pixel-wise softmax as mentioned in the paper

softmax formula from the Unet paper

x = single pixel position on the output segmentation map
aₖ (x) = activation function at “x” pixel position of kth class
K = number of classes acc. to paper its 2
k = Individual class

Numerator exponentiation of the activation function at pixel position X.
Denominator — the sum of all exponentiation of activation function at pixel position “X” along channel dimension of the output segmentation map.

Cross-Entropy Loss Function

Cross-Entropy Loss Function— It measures the performance of the model with the ground-truth value and heavily penalizes it with increasing deviation.
w (x) = ground-truth value of pixel from weighted segmentation map
log(pₗ₍ₓ₎(x)) = log of pixel value after passing through softmax

Why not just use something simpler, like just the difference between ground truth and predicted or square of residual?
Cross-Entropy Loss shoots up when deviation from ground truth starts increasing. And so, the derivative of the loss function gives a better judgment of how much weights need to be tweaked.
Whereas, if we use something like a square of residual, the loss function slope is very little and it makes backpropagation (the method to get derivatives in Neural Network) harder.

Weight Initialization
Initial weight can be derived from Normal distribution with standard deviation square_root(2/N).
In torch, it can be done by xavier_normal initialization.

torch.nn.init.xavier_normal_(tensor, gain=1.0)

Miscellaneous

Overlap Tile Strategy
This strategy is used in the paper so that the Resolution of the Image for prediction should not be dependable on GPU memory.

Overlap Tile Strategy

To predict a part of an image for segmentation neighbor area plays an important role. So if u want to predict the image inside the yellow box, then pixels inside the blue box are also required. That is done by mirroring the image.

Weight Map for Ground Truth Segmentation
This is done to small separation between objects which are close together.

Weight map

This w(x) is used earlier in the cross-entropy function as ground truth to calculate the deviation.

In the future blog, I will write on the advancement over Unet and variations of Unet like Recurrent-Unet, Attention-Unet, etc.

More Reading References
1. Unet Research Paper — U-Net: Convolutional Networks for Biomedical Image Segmentation
2. Transposed Convolution
3. Softmax
4. Cross-Entropy Loss Function

Thanks for reading and Happy Coding!!

--

--

Avid Learner | Obsessed with ML and DL | Areas of Interest are CV, NLP, and Graph DL | Mobile Developer in Flutter| FullStack Developer | Gamer at Heart