BN

概述

Batch Normalization 是来自谷歌的 Ioffe 和 Szegedy 与 2015 年提出的一种提高神经网络训练速度和稳定性的方法。 在当时 VGG、Inception 等结构的出现似乎表明加深网络层数是提高网络精度的最可行的手段。 然而更深层的网络也更难训练。在实践中,我们需要非常小心的设置学习率,稍大的学习率容易导致网络训练发散,而较小的学习率会让网络收敛过慢。

这种困难一部分归因于深度网络带来的更大的梯度消失和梯度爆炸问题。 另一方面,作者猜想这种困难是一种称为 ICS(Internal Covariate Shift) 的现象导致的,即在训练过程中,每层的输入的分布会随前层参数的变化而变化,导致本层参数不断适应新的数据分布,进而导致了训练的难度。

而 BN 正是为了解决这个问题而提出的一种方法。在神经元激励送入下一层之前,先对其进行某种“白化”(归一化,解相关)操作,将输入的分布“固定”下来,降低其受前层参数 W 的影响。从而削弱 ICS 现象,提高训练的稳定性和网络收敛的速度。

时至今日,BN 也已经成为了深度网络的标配,卷积、ReLU、BN 级联组成的基本单元也称为网络的基本构成模块。

而关于 BN 的工作机理在目前也是一个积极的研究的方向。

BN 算法

高斯分布的标准化

在进入 BN 之前,我们先复习一点数学知识。

XX 是一个服从高斯分布 N(μ,σ2)\mathcal{N}(\mu, \sigma^2) 的随机变量, 对随机变量进行变换,则有 Xμσ\frac{X-\mu}{\sigma} 服从标准正态分布 N(0,1)\mathcal{N}(0, 1), 我们称这个操作为“标准化”。 在 BN 相关的文献中,这个操作也成为归一化。

如果我们知道 XX 服从某个参数的正态分布,但是不知道参数的具体值,我们可以采用统计推断的方法,对 XX 进行 mm 次独立的采样 {xi}i=1m\{x_i\}_{i=1}^m,并估计出参数

μ^=1mi=1m(xi)σ2^=1m1i=1m(xiμ^)\hat{\mu} = \frac{1}{m}\sum_{i=1}^m(x_i) \\ \hat{\sigma^2} = \frac{1}{m-1}\sum_{i=1}^m(x_i-\hat{\mu})

并基于估计的参数值对随机变量进行标准化操作,即 Xμ^σ^\frac{X-\hat{\mu}}{\hat{\sigma}} 近似服从标准正态分布。

对卷积层的输出进行标准化

在实践中,BN 通常应用于卷积层。 考虑一个 batch 的图像样本 B={x1,x2,...,xm}\mathcal{B} = \{x_1, x_2, ..., x_m\} ,经过卷积网络,在某一卷积层的输出为一系列特征图 {z1,...zm}\{z_1, ... z_m\}

这里每一个特征图都是一个三维的数组。 我们用 zi(h,w,c)z_{i}^{(h, w, c)} 表示第 ii 个特征图的第 cc 个通道,在空间位置 (h,w)(h, w) 的激活值,是个标量。

对于卷积层,我们希望 BN 也能与卷积相同,呈现位移不变性,即不同位置的神经元用相同的方式进行标准化。 因此在计算均值和方差时,我们将同一个 channel 中,来自不同样本、不同位置的神经元组织在一起计算统计量,再对应进行标准化操作,即

μB(c)=1m×H×Wi=1mh=1Hw=1Wzi(h,w,c)σB2(c)=1m×H×Wi=1mh=1Hw=1W(zi(h,w,c)μB(c))2z^i(h,w,c)=zi(h,w,c)μB(c)σB(c)\begin{align} \mu_\mathcal{B}^{(c)} &= \frac{1}{m\times H\times W}\sum_{i=1}^m\sum_{h=1}^H\sum_{w=1}^W { z_{i}^{(h, w, c)} } \\ \sigma_\mathcal{B}^{2(c)} &= \frac{1}{m\times H\times W}\sum_{i=1}^m\sum_{h=1}^H\sum_{w=1}^W { \left(z_{i}^{(h, w, c)} -\mu_\mathcal{B}^{(c)}\right )^2 } \\ \hat{z}_i^{(h, w, c)} &= \frac{z_i^{(h, w, c)} - \mu_\mathcal{B}^{(c)}}{\sigma_\mathcal{B}^{(c)}} \end{align}

这里,z^i(h,w,c)\hat{z}_i^{(h, w, c)} 即为标准化后的特征图,为 BN 层的输出。 为了扩展激活值的动态范围,BN 算法还对标准化后的特征图进行一个仿射变换,每个通道分别进行。

yi(h,w,c)=γ(c)z^i(h,w,c)+β(c)y_i^{(h, w, c)} = \gamma^{(c)} \cdot \hat{z}_i^{(h, w, c)} + \beta ^{(c)}

这里 γ(c)\gamma^{(c)}β(c)\beta ^{(c)} 为可学习的参数,也是 BN 唯一引入的可学习参数,共 2c2c 个标量值。

我们可以把 γ(c)\gamma^{(c)}β(c)\beta ^{(c)} 的引入看作是对权重的“长度”和“方向”的分离。

BN 的后传

在后传计算中,我们需要根据 L/yi(h,w,c){\partial L} / {\partial y_i^{(h, w, c)}} 计算出 L/γ(c){\partial L} / {\partial \gamma^{(c)}}L/β(c){\partial L} / {\partial \beta ^{(c)}}L/zi(h,w,c){\partial L} / {\partial z_i^{(h, w, c)}} 三组梯度值。 三组梯度的计算均可通过常规的复合函数求导规则导出。 但需要注意的是,均值和方差并不是 BN 引入的参数,而是计算的中间结果。因此,在计算梯度 L/zi(h,w,c){\partial L} / {\partial z_i^{(h, w, c)}} 时,必须考虑 μB(c)/zi(h,w,c){\partial \mu_\mathcal{B}^{(c)}} / {\partial z_i^{(h, w, c)}}σB2(c)/zi(h,w,c){\partial \sigma_\mathcal{B}^{2(c)}} / {\partial z_i^{(h, w, c)}}

在不使用 BN 时,一个 batch 中不同样本的前传计算时独立的。相应的,后传计算也是独立的。

L=1mi=1mLiLW=1mi=1mLiW\begin{align} L &= \frac{1}{m}\sum_{i=1}^m L_i \\ \frac{\partial L}{\partial W} &= \frac{1}{m}\sum_{i=1}^m \frac{\partial L_i}{\partial W} \end{align}

引入 BN 后,由于梯度需要经过统计量后传,因而不同样本之前的前后传会互相影响,后传计算复杂了很多。

BN 在推理阶段的计算

与 DropOut 类似,BN 在训练和推理阶段的行为是不同的。 在训练阶段,我们使用 batch 中的样本计算统计量,并对卷积层的输出进行标准化。 在推理阶段,我们使用全局统计量对样本进行标准化操作。在实践中,这个统计量是通过滑动平均的方式在训练过程中逐渐积累出来的。

由于 BN 在测试阶段仅仅是一个线性操作,为了提高计算效率,训练好的统计量、beta 和 gamma 通常会融合到相邻的线性计算中(如卷积层),减少计算的次数。

sync BN 与 frozen BN

我们知道,均值和方差的估计的准确性依赖样本的数量,即 batch 的大小。而小的 batch 对于 BN 来说是灾难性的,网络将很容易训练发散。

在训练图像分类网络时,使用较大 batch(如 64 以上)是相对容易的。 然而在使用迁移学习技术,将预训练模型应用于其他视觉任务并进行 finetune 训练时,可能会受到显存的限制,不能将 batchsize 设置过大。

针对这个问题,在工程上通常有两种方法。一种是成为 frozen BN 的策略,即在训练过程中,使用全局同计量作为 BN 的统计量,改统计量也不更新。这时,BN 与一个线性变换层相当。 另一种方法成为 sync BN,是将不同 GPU 设备上计算的统计量进行跨设备通信,再进行求平均运算。

BN 的作用

在前传阶段,BN 可以控制每一层输出的方差,在后传阶段,BN 也可以将梯度的方差控制在一个范围内。因而有助于降低梯度爆炸和消失现象。在使用 BN 后,我们可以使用更大的学习率,更快并更稳定地训练深层神经网络。

Last updated