pytorch学习笔记-高阶篇(batchnorm)
本篇主要记录一下一些卷积神经网络中batch norm相关知识
一、batch Norm引入
对于sigmoid函数,在有效范围外就会梯度接近0,出现梯度离散的情况,数据长时间得不到更新,这不是我们所希望的,因此就需要把输入值控制在有效范围内,对此,引入batch norm把输入映射到希望的范围内。
二、feature Scaling
对于一个普通的RGB三通道的图片:经过适当的各通道的标准化,可以使得三个通道接下来对卷积层等的作用效果几乎等价,不至于R通道要改变很大时而G通道改变很小时的效果差不多。可以上梯度下降的过程更加平滑
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
```
### 三、batch Norm

  对于batch norm,比如[6, 3, 28, 28]这样shape的数据,batch norm会把batch抽出来,抽出一个shape为[3]的tensor,里面的数据是batch和 28*28所得的一个均值和方差,然后得到些运行时(running)的均值和方差。
  对于该图,在运行时,均值和方差是实时出来的,然后会生成一个运行时的均值,方差来记录。对于γ和β则是在反向传播的时候更新。在test时,由于不需要更新,是没有这两个参数的,对于均值和方差也是直接用的running_mean和running_var。因此在**test**时注意切换模式:
``` python
layer.eval()
# 2.batch norm
x = torch.rand(100, 16, 28, 28)
layer = nn.BatchNorm2d(num_features=16)
out = layer(x)
layer.running_mean
'''
Out[28]:
tensor([0.0500, 0.0500, 0.0500, 0.0500, 0.0499, 0.0500, 0.0498, 0.0500, 0.0501,
0.0500, 0.0501, 0.0500, 0.0502, 0.0501, 0.0502, 0.0498])
'''
layer.running_var
'''
Out[29]:
tensor([0.9083, 0.9083, 0.9084, 0.9083, 0.9083, 0.9083, 0.9082, 0.9083, 0.9084,
0.9083, 0.9083, 0.9083, 0.9084, 0.9083, 0.9083, 0.9083])
'''
# 3.batch norm2d
x = torch.rand(1, 16, 7, 7)
layer = nn.BatchNorm2d(16)
out = layer(x)
# Out[31]: torch.Size([1, 16, 7, 7])
layer.weight
'''
Out[32]:
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
requires_grad=True)
'''
# 这里weight bias对于上面某个图中的γ和β
layer.weight.shape
# Out[33]: torch.Size([16])
layer.bias.shape
# Out[34]: torch.Size([16])
vars(layer)
'''
Out[35]:
{'training': True,
'_parameters': OrderedDict([('weight', Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
requires_grad=True)),
('bias',
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
requires_grad=True))]),
'_buffers': OrderedDict([('running_mean',
tensor([0.0505, 0.0518, 0.0501, 0.0503, 0.0536, 0.0503, 0.0455, 0.0451, 0.0524,
0.0529, 0.0487, 0.0470, 0.0385, 0.0565, 0.0534, 0.0458])),
('running_var',
tensor([0.9066, 0.9077, 0.9085, 0.9084, 0.9076, 0.9083, 0.9079, 0.9071, 0.9084,
0.9089, 0.9094, 0.9068, 0.9084, 0.9087, 0.9086, 0.9070])),
('num_batches_tracked', tensor(1))]),
'_non_persistent_buffers_set': set(),
'_backward_hooks': OrderedDict(),
'_is_full_backward_hook': None,
'_forward_hooks': OrderedDict(),
'_forward_pre_hooks': OrderedDict(),
'_state_dict_hooks': OrderedDict(),
'_load_state_dict_pre_hooks': OrderedDict(),
'_modules': OrderedDict(),
'num_features': 16,
'eps': 1e-05,
'momentum': 0.1,
'affine': True,
'track_running_stats': True}
'''
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 不听话的兔子君!