欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

学习笔记|Pytorch使用教程25(Batch Normalization)

程序员文章站 2022-07-16 17:23:05
...

学习笔记|Pytorch使用教程25

本学习笔记主要摘自“深度之眼”,做一个总结,方便查阅。
使用Pytorch版本为1.2

  • Batch Normalization概念
  • PyTorch的Batch Normalization 1d/2d/3d实现

一.Batch Normalization概念

Batch Normalization :批标准化
:一批数据,通常为mini- batch
标准化: 0均值,1方差
优点:

  • 1.可以更大学习率,加速模型收敛
  • 2.可以不用精心设计权值初始化
  • 3.可以不用dropout或较小的dropout
  • 4.可以不用L2或者较小的weight decay
  • 5.可以不用LRN(local response normalization)
  • 《Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shit》

学习笔记|Pytorch使用教程25(Batch Normalization)
学习笔记|Pytorch使用教程25(Batch Normalization)
测试代码:

import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


class MLP(nn.Module):
    def __init__(self, neural_num, layers=100):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.bns = nn.ModuleList([nn.BatchNorm1d(neural_num) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):

        for (i, linear), bn in zip(enumerate(self.linears), self.bns):
            x = linear(x)
            # x = bn(x)
            x = torch.relu(x)

            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

            print("layers:{}, std:{}".format(i, x.std().item()))

        return x

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):

                # method 1
                nn.init.normal_(m.weight.data, std=1)    # normal: mean=0, std=1

                # method 2 kaiming
                # nn.init.kaiming_normal_(m.weight.data)


neural_nums = 256
layer_nums = 100
batch_size = 16

net = MLP(neural_nums, layer_nums)
# net.initialize()

inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

output = net(inputs)
print(output)

输出:

layers:0, std:0.3342404067516327
layers:1, std:0.13787388801574707
layers:2, std:0.05783054977655411
layers:3, std:0.02498556487262249
layers:4, std:0.009679116308689117
layers:5, std:0.0040797945111989975
layers:6, std:0.0016723505686968565
layers:7, std:0.000768698868341744
......
layers:93, std:7.51512610937515e-38
layers:94, std:2.6169094678434883e-38
layers:95, std:1.1516209894049713e-38
layers:96, std:4.344910860036386e-39
layers:97, std:1.5943525511579185e-39
layers:98, std:5.721221370145363e-40
layers:99, std:2.4877251637158477e-40
tensor([[0.0000e+00, 2.1158e-41, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 5.1800e-41, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 5.8066e-41, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]], grad_fn=<ReluBackward0>)

发现参数在100层的时候非常小了。
现在进行初始化,设置:net.initialize()
输出:

layers:0, std:9.35224723815918
layers:1, std:112.47123718261719
layers:2, std:1322.805419921875
layers:3, std:14569.419921875
layers:4, std:154672.703125
layers:5, std:1834037.125
layers:6, std:18807968.0
layers:7, std:209552880.0
......
layers:28, std:3.221297392084588e+30
layers:29, std:3.530939139138446e+31
layers:30, std:4.525336236359181e+32
layers:31, std:4.714992054712809e+33
layers:32, std:5.369568386632447e+34
layers:33, std:6.712290740934239e+35
layers:34, std:7.451081630611702e+36
output is nan in 35 layers
tensor([[3.2625e+36, 0.0000e+00, 7.2931e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.5465e+38],
        [3.9236e+36, 0.0000e+00, 7.5033e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.1274e+38],
        [0.0000e+00, 0.0000e+00, 4.4931e+37,  ..., 0.0000e+00, 0.0000e+00,
         1.7016e+38],
        ...,
        [0.0000e+00, 0.0000e+00, 2.4222e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.5295e+38],
        [4.7380e+37, 0.0000e+00, 2.1579e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.6028e+38],
        [0.0000e+00, 0.0000e+00, 6.0877e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.1695e+38]], grad_fn=<ReluBackward0>)

网络在35层的时候就出现了nan的情况。
使用凯明初始化:nn.init.kaiming_normal_(m.weight.data)
输出:

layers:0, std:0.826629638671875
layers:1, std:0.878681480884552
layers:2, std:0.9134420156478882
layers:3, std:0.8892467617988586
layers:4, std:0.8344276547431946
layers:5, std:0.87453693151474
layers:6, std:0.792696475982666
layers:7, std:0.7806451916694641
......
layers:92, std:0.6094536185264587
layers:93, std:0.6019036173820496
layers:94, std:0.595414936542511
layers:95, std:0.6624482870101929
layers:96, std:0.6377813220024109
layers:97, std:0.6079217195510864
layers:98, std:0.6579239368438721
layers:99, std:0.6668398976325989
tensor([[0.0000, 1.3437, 0.0000,  ..., 0.0000, 0.6444, 1.1867],
        [0.0000, 0.9757, 0.0000,  ..., 0.0000, 0.4645, 0.8594],
        [0.0000, 1.0023, 0.0000,  ..., 0.0000, 0.5147, 0.9196],
        ...,
        [0.0000, 1.2873, 0.0000,  ..., 0.0000, 0.6454, 1.1411],
        [0.0000, 1.3588, 0.0000,  ..., 0.0000, 0.6749, 1.2437],
        [0.0000, 1.1807, 0.0000,  ..., 0.0000, 0.5668, 1.0600]],
       grad_fn=<ReluBackward0>)

数据有一定的波动,现在加入bn层:x = bn(x)
输出:

layers:0, std:0.5872595906257629
layers:1, std:0.579325795173645
layers:2, std:0.5757012367248535
layers:3, std:0.5840616822242737
layers:4, std:0.5781518220901489
layers:5, std:0.5856173634529114
layers:6, std:0.5862171053886414
......
layers:95, std:0.5735476016998291
layers:96, std:0.5807774662971497
layers:97, std:0.5868753790855408
layers:98, std:0.5801646113395691
layers:99, std:0.5738694667816162
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.4841, 0.0000],
        [1.5034, 0.0000, 0.2277,  ..., 0.3768, 0.0000, 0.0000],
        [0.9003, 0.0000, 1.7231,  ..., 0.0000, 0.0000, 1.1034],
        ...,
        [0.0000, 0.0000, 0.6059,  ..., 0.0000, 0.0000, 0.0000],
        [0.7283, 0.6607, 0.4622,  ..., 0.0000, 0.0000, 0.0000],
        [0.0331, 0.0000, 1.0855,  ..., 1.2032, 0.0000, 0.3746]],
       grad_fn=<ReluBackward0>)

数据尺度保持的很好,如果放弃初始化:#net.initialize()
输出:

layers:0, std:0.5751240849494934
layers:1, std:0.5803307890892029
layers:2, std:0.5825020670890808
layers:3, std:0.5823132395744324
......
layers:97, std:0.5814812183380127
layers:98, std:0.5802980661392212
layers:99, std:0.5824452638626099
tensor([[2.4655, 0.3893, 0.0000,  ..., 1.9130, 0.7964, 0.7588],
        [0.3542, 0.1579, 2.3155,  ..., 0.0500, 0.2595, 0.0000],
        [0.0000, 0.0000, 0.2838,  ..., 0.0000, 0.9119, 0.2732],
        ...,
        [0.0000, 1.5330, 0.0000,  ..., 0.1120, 0.0000, 1.9477],
        [0.0000, 0.0000, 2.0451,  ..., 0.0000, 0.0000, 0.0000],
        [0.5085, 0.8023, 0.3493,  ..., 0.2117, 0.0000, 0.0000]],
       grad_fn=<ReluBackward0>)

数据尺度仍然保持的很好。
注意:bn层需要在**函数前使用。
搭建带BN层的LeNet网络:

class LeNet_bn(nn.Module):
    def __init__(self, classes):
        super(LeNet_bn, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(num_features=6)

        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(num_features=16)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(num_features=120)

        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = F.max_pool2d(out, 2)

        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)

        out = F.max_pool2d(out, 2)

        out = out.view(out.size(0), -1)

        out = self.fc1(out)
        out = self.bn3(out)
        out = F.relu(out)

        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, 0, 1)
                m.bias.data.zero_()

下面测试bn层在网络中的使用。
使用的完整训练代码在:学习笔记|Pytorch使用教程05(Dataloader与Dataset)
先不加bn层,且不初始化进性测试:

# ============================ step 2/5 模型 ============================

#net = LeNet_bn(classes=2)
net = LeNet(classes=2)
# net.initialize_weights()

输出:

Training:Epoch[000/010] Iteration[010/010] Loss: 0.6966 Acc:50.00%
Valid:   Epoch[000/010] Iteration[002/002] Loss: 1.3483 Acc:50.00%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.6888 Acc:53.75%
Valid:   Epoch[001/010] Iteration[002/002] Loss: 1.3469 Acc:53.75%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.6822 Acc:60.62%
Valid:   Epoch[002/010] Iteration[002/002] Loss: 1.3270 Acc:60.62%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.6739 Acc:81.25%
Valid:   Epoch[003/010] Iteration[002/002] Loss: 1.2961 Acc:81.25%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.6466 Acc:83.75%
Valid:   Epoch[004/010] Iteration[002/002] Loss: 1.1401 Acc:83.75%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.5422 Acc:95.62%
Valid:   Epoch[005/010] Iteration[002/002] Loss: 0.6329 Acc:95.62%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.2208 Acc:96.88%
Valid:   Epoch[006/010] Iteration[002/002] Loss: 0.0163 Acc:96.88%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.1321 Acc:95.62%
Valid:   Epoch[007/010] Iteration[002/002] Loss: 0.0006 Acc:95.62%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.2649 Acc:93.75%
Valid:   Epoch[008/010] Iteration[002/002] Loss: 1.2047 Acc:93.75%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.4774 Acc:87.50%
Valid:   Epoch[009/010] Iteration[002/002] Loss: 0.4023 Acc:87.50%

学习笔记|Pytorch使用教程25(Batch Normalization)
进行初始化:net.initialize_weights()
输出:

Training:Epoch[000/010] Iteration[010/010] Loss: 0.6846 Acc:53.75%
Valid:   Epoch[000/010] Iteration[002/002] Loss: 0.9805 Acc:53.75%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.4099 Acc:85.00%
Valid:   Epoch[001/010] Iteration[002/002] Loss: 0.0829 Acc:85.00%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.1470 Acc:94.38%
Valid:   Epoch[002/010] Iteration[002/002] Loss: 0.0035 Acc:94.38%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.4276 Acc:88.12%
Valid:   Epoch[003/010] Iteration[002/002] Loss: 0.2250 Acc:88.12%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.3169 Acc:87.50%
Valid:   Epoch[004/010] Iteration[002/002] Loss: 0.1232 Acc:87.50%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.2026 Acc:91.88%
Valid:   Epoch[005/010] Iteration[002/002] Loss: 0.0132 Acc:91.88%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.1064 Acc:95.62%
Valid:   Epoch[006/010] Iteration[002/002] Loss: 0.0002 Acc:95.62%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0482 Acc:99.38%
Valid:   Epoch[007/010] Iteration[002/002] Loss: 0.0006 Acc:99.38%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0069 Acc:100.00%
Valid:   Epoch[008/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0133 Acc:99.38%
Valid:   Epoch[009/010] Iteration[002/002] Loss: 0.0000 Acc:99.38%

学习笔记|Pytorch使用教程25(Batch Normalization)
接下来使用LeNet_bn网络进行训练:net = LeNet_bn(classes=2)

Training:Epoch[000/010] Iteration[010/010] Loss: 0.6666 Acc:60.00%
Valid:   Epoch[000/010] Iteration[002/002] Loss: 1.2814 Acc:60.00%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.4274 Acc:93.12%
Valid:   Epoch[001/010] Iteration[002/002] Loss: 0.4916 Acc:93.12%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.1601 Acc:98.75%
Valid:   Epoch[002/010] Iteration[002/002] Loss: 0.0878 Acc:98.75%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.0688 Acc:100.00%
Valid:   Epoch[003/010] Iteration[002/002] Loss: 0.0104 Acc:100.00%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.0406 Acc:98.75%
Valid:   Epoch[004/010] Iteration[002/002] Loss: 0.0109 Acc:98.75%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.0895 Acc:95.62%
Valid:   Epoch[005/010] Iteration[002/002] Loss: 0.0061 Acc:95.62%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.0765 Acc:95.62%
Valid:   Epoch[006/010] Iteration[002/002] Loss: 0.0675 Acc:95.62%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0370 Acc:98.75%
Valid:   Epoch[007/010] Iteration[002/002] Loss: 0.0069 Acc:98.75%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0144 Acc:100.00%
Valid:   Epoch[008/010] Iteration[002/002] Loss: 0.0028 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0365 Acc:98.75%
Valid:   Epoch[009/010] Iteration[002/002] Loss: 0.0015 Acc:98.75%

学习笔记|Pytorch使用教程25(Batch Normalization)

二.PyTorch的Batch Normalization 1d/2d/3d实现

1._BatchNorm

  • nn.BatchN orm1d
  • nn. BatchNorm2d
  • nn. BatchNorm2d
    学习笔记|Pytorch使用教程25(Batch Normalization)
    参数:
  • num_features :一个样本特征数量(最重要)
  • eps:分母修正项
  • momentum :指数加权平均估计当前mean/var
  • affine :是否需要affine transform
  • track_running_stats :是训练状态,还是测试状态

x^ixiμBσB2+ϵyiγx^i+βBNγ,β(xi)\begin{aligned} \widehat{x}_{i} & \leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}} \\ y_{i} & \leftarrow \gamma \widehat{x}_{i}+\beta \equiv \mathrm{B} \mathrm{N}_{\gamma, \beta}\left(x_{i}\right) \end{aligned}

  • 训练:均值和方差采用指数加权平均计算
  • 测试:当前统计值

主要属性:

  • running_mean :均值
  • running_var :方差
  • weight : affine transform中的gamma
  • bias : affine transform中的beta
    学习笔记|Pytorch使用教程25(Batch Normalization)
    学习笔记|Pytorch使用教程25(Batch Normalization)

1.nn.BatchNorm1d
测试代码:

import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed



set_seed(1)  # 设置随机种子

# ======================================== nn.BatchNorm1d
flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 5
    momentum = 0.3

    features_shape = (1)

    feature_map = torch.ones(features_shape)                                                    # 1D
    feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)         # 2D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)             # 3D
    print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))

    bn = nn.BatchNorm1d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps_bs)

        print("\niteration:{}, running mean: {} ".format(i, bn.running_mean))
        print("iteration:{}, running var:{} ".format(i, bn.running_var))

        mean_t, var_t = 2, 0

        running_mean = (1 - momentum) * running_mean + momentum * mean_t
        running_var = (1 - momentum) * running_var + momentum * var_t

        print("iteration:{}, 第二个特征的running mean: {} ".format(i, running_mean))
        print("iteration:{}, 第二个特征的running var:{}".format(i, running_var))

输出:

input data:
tensor([[[1.],
         [2.],
         [3.],
         [4.],
         [5.]],

        [[1.],
         [2.],
         [3.],
         [4.],
         [5.]],

        [[1.],
         [2.],
         [3.],
         [4.],
         [5.]]]) shape is torch.Size([3, 5, 1])

iteration:0, running mean: tensor([0.3000, 0.6000, 0.9000, 1.2000, 1.5000]) 
iteration:0, running var:tensor([0.7000, 0.7000, 0.7000, 0.7000, 0.7000]) 
iteration:0, 第二个特征的running mean: 0.6 
iteration:0, 第二个特征的running var:0.7

iteration:1, running mean: tensor([0.5100, 1.0200, 1.5300, 2.0400, 2.5500]) 
iteration:1, running var:tensor([0.4900, 0.4900, 0.4900, 0.4900, 0.4900]) 
iteration:1, 第二个特征的running mean: 1.02 
iteration:1, 第二个特征的running var:0.48999999999999994

2.nn.BatchNorm2d
测试代码:

# ======================================== nn.BatchNorm2d
flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 6
    momentum = 0.3
    
    features_shape = (2, 2)

    feature_map = torch.ones(features_shape)                                                    # 2D
    feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)         # 3D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)             # 4D

    print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))

    bn = nn.BatchNorm2d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps_bs)

        print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
        print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))

        print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
        print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))

输出:

input data:
tensor([[[[1., 1.],
          [1., 1.]],

         [[2., 2.],
          [2., 2.]],

         [[3., 3.],
          [3., 3.]],

         [[4., 4.],
          [4., 4.]],

         [[5., 5.],
          [5., 5.]],

         [[6., 6.],
          [6., 6.]]],


        [[[1., 1.],
          [1., 1.]],

         [[2., 2.],
          [2., 2.]],

         [[3., 3.],
          [3., 3.]],

         [[4., 4.],
          [4., 4.]],

         [[5., 5.],
          [5., 5.]],

         [[6., 6.],
          [6., 6.]]],


        [[[1., 1.],
          [1., 1.]],

         [[2., 2.],
          [2., 2.]],

         [[3., 3.],
          [3., 3.]],

         [[4., 4.],
          [4., 4.]],

         [[5., 5.],
          [5., 5.]],

         [[6., 6.],
          [6., 6.]]]]) shape is torch.Size([3, 6, 2, 2])

iter:0, running_mean.shape: torch.Size([6])
iter:0, running_var.shape: torch.Size([6])
iter:0, weight.shape: torch.Size([6])
iter:0, bias.shape: torch.Size([6])

iter:1, running_mean.shape: torch.Size([6])
iter:1, running_var.shape: torch.Size([6])
iter:1, weight.shape: torch.Size([6])
iter:1, bias.shape: torch.Size([6])

3.nn.BatchNorm3d
测试代码:

# ======================================== nn.BatchNorm3d
flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 4
    momentum = 0.3

    features_shape = (2, 2, 3)

    feature = torch.ones(features_shape)                                                # 3D
    feature_map = torch.stack([feature * (i + 1) for i in range(num_features)], dim=0)  # 4D
    feature_maps = torch.stack([feature_map for i in range(batch_size)], dim=0)         # 5D

    print("input data:\n{} shape is {}".format(feature_maps, feature_maps.shape))

    bn = nn.BatchNorm3d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps)

        print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
        print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))

        print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
        print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))

输出:

input data:
tensor([[[[[1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.]]],


         [[[2., 2., 2.],
           [2., 2., 2.]],

          [[2., 2., 2.],
           [2., 2., 2.]]],


         [[[3., 3., 3.],
           [3., 3., 3.]],

          [[3., 3., 3.],
           [3., 3., 3.]]],


         [[[4., 4., 4.],
           [4., 4., 4.]],

          [[4., 4., 4.],
           [4., 4., 4.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.]]],


         [[[2., 2., 2.],
           [2., 2., 2.]],

          [[2., 2., 2.],
           [2., 2., 2.]]],


         [[[3., 3., 3.],
           [3., 3., 3.]],

          [[3., 3., 3.],
           [3., 3., 3.]]],


         [[[4., 4., 4.],
           [4., 4., 4.]],

          [[4., 4., 4.],
           [4., 4., 4.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.]]],


         [[[2., 2., 2.],
           [2., 2., 2.]],

          [[2., 2., 2.],
           [2., 2., 2.]]],


         [[[3., 3., 3.],
           [3., 3., 3.]],

          [[3., 3., 3.],
           [3., 3., 3.]]],


         [[[4., 4., 4.],
           [4., 4., 4.]],

          [[4., 4., 4.],
           [4., 4., 4.]]]]]) shape is torch.Size([3, 4, 2, 2, 3])

iter:0, running_mean.shape: torch.Size([4])
iter:0, running_var.shape: torch.Size([4])
iter:0, weight.shape: torch.Size([4])
iter:0, bias.shape: torch.Size([4])

iter:1, running_mean.shape: torch.Size([4])
iter:1, running_var.shape: torch.Size([4])
iter:1, weight.shape: torch.Size([4])
iter:1, bias.shape: torch.Size([4])