Batch Normalization Dynamically Normalizes Each Feature to Have Zero Mean and Unit Variance
Basic idea: Normalize input batch of each layer during the forward pass
Input is minibatch of data \(X^t \in \mathbb{R}^{m \times d}\) at iteration \(t\)
Compute mean and standard deviation for every feature \[ \mu_j^t = \mathbb{E}[x_j^t], \quad \sigma_j^t = \sqrt{\mathbb{E}[(x_j^t - \mu_j^t)^2]}, \quad \forall j \in \{1, \dots, d\} \]
Normalize each feature (note different for every batch) \[ \tilde{x}_{i,j}^t = \frac{x_{i,j}^t - \mu_j^t}{\sigma_j^t} \]
Output \(\tilde{X}^t\)
Santurkar, S., Tsipras, D., Ilyas, A., & Madry, A. (2018). How does batch normalization help optimization?. In Advances in Neural Information Processing Systems (pp. 2483-2493).
Because BatchNorm Removes Linear Effects, Extra Linear Parameters Are Also Learned
The form of this final update is: \[ \tilde{x}_{i,j}^t = \frac{x_{i,j}^t - \mu_j^t}{\sigma_j^t} \cdot \gamma_j + \beta_j \]
Where \(\gamma_j\) and \(\beta_j\) are learnable parameters
While \(\mu_j^t\) and \(\sigma_j^t\) are computed from the minibatch
But how do we compute \(\mu_j^t\) and \(\sigma_j^t\) during test time (i.e., no minibatch)?
Use running average of mean and variance: \[ \mu_{run}^t = \lambda \mu_{run}^{t-1} + (1-\lambda)\mu_{batch}^t \]\[ {\sigma^2}_{run}^t = \lambda {\sigma^2}_{run}^{t-1} + (1-\lambda){\sigma^2}_{batch}^t \]
For CNNs, the Channel Dimension Is Treated as a “Feature”
If the input minibatch tensor is \(X^t \in \mathbb{R}^{m \times c \times h \times w}\), then the channel dimension \(c\) is treated as a feature: \[ \mu_j^t = \mathbb{E}[x_j^t], \quad \sigma_j^t = \sqrt{\mathbb{E}[(x_j^t - \mu_j^t)^2]}, \quad \forall j \in \{1, \dots, c\} \]
Where the mean is taken over both the batch dimension \(m\)and the spatial dimensions \(h\) and \(w\).
BatchNorm Can Stabilize and Accelerate Training of Deep Models
To use in practice:
Only normalize batches during training (model.train())
Turn off after training (model.eval())
Uses running average of mean and variance
Surprisingly effective at stabilizing training, reducing training time, and producing better models
Not fully understood why it works
Santurkar, S., Tsipras, D., Ilyas, A., & Madry, A. (2018). How does batch normalization help optimization?. In Advances in Neural Information Processing Systems (pp. 2483-2493).
BatchNorm Demo: Let’s create and inspect a batchnorm 2D (i.e., for images) layer
# Demo of batchnorm import torchimport torch.nn as nnclass BatchNormModel(nn.Module):def__init__(self, n_channels):super().__init__()self.bn = nn.BatchNorm2d(n_channels)def forward(self, x): x =self.bn(x)return xn_channels =3# Each channel is treated as a "feature" for imagesbn_model = nn.BatchNorm2d(n_channels)list(bn_model.named_parameters())
Notice that there are weight and bias parameters for each channel.
BatchNorm’s behavior during training
def print_mean_std(A, label='unlabeled'):print(f'{label}: Mean and standard deviation across channels')print(torch.mean(A, dim=(0,2,3))) # Sum print(torch.std(A, dim=(0,2,3), unbiased=False))print()torch.manual_seed(0)bn_model.train()batch1 =2*torch.randn((100, n_channels, 2, 2)) + torch.arange(n_channels).reshape(1,n_channels,1,1) # (N, C, H, W)batch2 =3*torch.randn((100, n_channels, 2, 2)) +-5# (N, C, H, W)out1 = bn_model(batch1)out2 = bn_model(batch2)print_mean_std(batch1, 'batch1')print_mean_std(out1, 'out1')print_mean_std(batch2, 'batch2')print_mean_std(out2, 'out2')
batch1: Mean and standard deviation across channels
tensor([0.0107, 1.0870, 2.0128])
tensor([2.0200, 1.9704, 2.1094])
out1: Mean and standard deviation across channels
tensor([ 1.4901e-08, 6.8545e-09, -3.9041e-08], grad_fn=<MeanBackward1>)
tensor([1.0000, 1.0000, 1.0000], grad_fn=<StdBackward0>)
batch2: Mean and standard deviation across channels
tensor([-4.9791, -5.2417, -4.8956])
tensor([3.0027, 3.0281, 2.9813])
out2: Mean and standard deviation across channels
tensor([ 3.3081e-08, 7.1824e-08, -8.6427e-08], grad_fn=<MeanBackward1>)
tensor([1.0000, 1.0000, 1.0000], grad_fn=<StdBackward0>)
Notice that even though distributions of the batches are quite different and different across channels, the output has been renormalized across the channel to always have zero mean and unit variance.
What about BatchNorm’s behavior during test time?
Let’s set simulate two simple batches and then apply at test time
Running mean and standard devaiation
tensor([0.0987, 0.2405, 0.4342])
tensor([1.3707, 1.3690, 1.3793])
batch1: Mean and standard deviation across channels
tensor([0.0054, 1.0435, 2.0064])
tensor([1.0100, 0.9852, 1.0547])
out1: Mean and standard deviation across channels
tensor([-0.0681, 0.5865, 1.1398], grad_fn=<MeanBackward1>)
tensor([0.7368, 0.7197, 0.7647], grad_fn=<StdBackward0>)
batch2: Mean and standard deviation across channels
tensor([5.0070, 4.9194, 5.0348])
tensor([1.0009, 1.0094, 0.9938])
out2: Mean and standard deviation across channels
tensor([3.5808, 3.4178, 3.3355], grad_fn=<MeanBackward1>)
tensor([0.7302, 0.7373, 0.7205], grad_fn=<StdBackward0>)
Notice that the running mean and running standard deviation are used for normalization during test time rather than the batch.
Thus, it is important to set model.eval() or model.train() when running models with BatchNorm or other specialized layers.
Generally, it is just good practice to do this no matter what during training and testing.
Residual Networks Add the Input to the Output of the CNN
Most deep model layers have the form: \[ y = f(x) \]
Where \(f\) could be any function including a convolutional layer like \(f(x) = \sigma(\text{Conv}(\sigma(\text{Conv}(x))))\)
Residual layers add back in the input: \[ y = f(x) + x \]
Notice that \(f(x)\) models the difference between \(x\) and \(y\) (hence the name residual).
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).
A Residual Network Enables Deeper Networks Because Gradient Information Can Flow Between Layers
A data flow diagram shows the “shortcut” connections.
Consider composing 2 residual layers:
\(z^{(1)} = f_1(x) + x\)
\(z^{(2)} = f_2(z^{(1)}) + z^{(1)}\)
Or, equivalently:
\(z^{(2)} = f_2(f_1(x)+x) + f_1(x) + x\)
If the residuals \(= 0\), then this is merely the identity function.
Images from: He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).
Detail: If the Dimensionality Is Not the Same, Then Use Either Fully Connected Layer or Convolution Layer to Match
In the 1D case, suppose \(f(x): \mathbb{R}^d \to \mathbb{R}^m\), then we need to multiply \(x\) by linear operator to match the dimension: \[ y = f(x) + Wx, \quad \text{where } W \in \mathbb{R}^{m \times d} \]
Similarly, for images, if \(f(x): \mathbb{R}^{c \times h \times w} \to \mathbb{R}^{c' \times h' \times w'}\), we can apply a convolution layer to match the dimensions: \[ y = f(x) + \text{conv}(x), \quad \text{where conv}(\cdot): \mathbb{R}^{c \times h \times w} \to \mathbb{R}^{c' \times h' \times w'} \]
Residual Network Demo: Very simple residual network in PyTorch
(See https://towardsdatascience.com/residual-network-implementing-resnet-a7da63c7b278 for a tutorial on the real ResNet architectures from https://arxiv.org/abs/1512.03385)
Code below simply loads CIFAR10 dataset like before.
correct =0total =0with torch.no_grad():for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print('Accuracy of the non-residual CNN on the 10000 test images: %d%%'% (53))print('Accuracy of the network on the 10000 test images: %d%%'% (100* correct / total))
Accuracy of the non-residual CNN on the 10000 test images: 53 %
Accuracy of the network on the 10000 test images: 61 %
U-Nets Have an Autoencoder Structure With Skip Connections for Semantic Segmentation Task
Concatenation + convolution rather than residual skip connections
Any (pretrained) classification backbone can be used for encoder
State-of-the-art semantic segmentation are based on this idea
Figure from: Ronneberger, O., Fischer, P., & Brox, T. (2015, October). U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham.