欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch:Batch Normalization批标准化

程序员文章站 2022-07-04 22:59:07
...

原文地址

分类目录——Pytorch

  • 首先我觉得莫烦老师关于Batch Normalization解释很贴切,引用在这里

    Pytorch:Batch Normalization批标准化

    在神经网络中, 数据分布对训练会产生影响. 比如某个神经元 x 的值为1, 某个 Weights 的初始值为 0.1, 这样后一层神经元计算结果就是 Wx = 0.1; 又或者 x = 20, 这样 Wx 的结果就为 2. 现在还不能看出什么问题, 但是, 当我们加上一层激励函数, **这个 Wx 值的时候, 问题就来了. 如果使用 像 tanh 的激励函数, Wx 的**值就变成了 ~0.1 和 ~1, 接近于 1 的部分已经处在了 激励函数的饱和阶段, 也就是如果 x 无论再怎么扩大, tanh 激励函数输出值也还是 接近1. 换句话说, 神经网络在初始阶段已经不对那些比较大的 x 特征范围 敏感了. 这样很糟糕, 想象我轻轻拍自己的感觉和重重打自己的感觉居然没什么差别, 这就证明我的感官系统失效了. 当然我们是可以用之前提到的对数据做 normalization 预处理, 使得输入的 x 变化范围不会太大, 让输入值经过激励函数的敏感部分. 但刚刚这个不敏感问题不仅仅发生在神经网络的输入层, 而且在隐藏层中也经常会发生.

    引自 什么是批标准化 (Batch Normalization)

    我们知道在中间层也是可以用tahn**函数的,这时候就要用批标准化来处理了

  • 看一下对比效果

    Pytorch:Batch Normalization批标准化
  • 关键代码

    net_bn = torch.nn.Sequential(
        # torch.nn.BatchNorm1d(1),
        torch.nn.Linear(1, N_HIDDEN),
        torch.nn.BatchNorm1d(N_HIDDEN),		# 批标准化
        torch.nn.ReLU(),
        torch.nn.Linear(N_HIDDEN, N_HIDDEN),
        torch.nn.BatchNorm1d(N_HIDDEN),
        torch.nn.ReLU(),
        torch.nn.Linear(N_HIDDEN, 1),
    )
    
  • 一个批标准化与无批标准化的对比实例

  • 参考文献

    Batch Normalization 批标准化

相关标签: Python # Pytorch