Understanding the Failure of Batch Normalization for Transformers in NLP Jiaxi Wang1 Ji Wu12 Lei Huang3

2025-05-06 0 0 1.36MB 17 页 10玖币
侵权投诉
Understanding the Failure of Batch Normalization for
Transformers in NLP
Jiaxi Wang1, Ji Wu1,2, Lei Huang3
1Department of Electronic Engineering, Tsinghua University
2Institute for Precision Medicine, Tsinghua University
{wjx20@mails, wuji_ee@mail}.tsinghua.edu.cn
3SKLSDE, Institute of Artificial Intelligence, Beihang University
huangleiAI@buaa.edu.cn
Abstract
Batch Normalization (BN) is a core and prevalent technique in accelerating the
training of deep neural networks and improving the generalization on Computer
Vision (CV) tasks. However, it fails to defend its position in Natural Language
Processing (NLP), which is dominated by Layer Normalization (LN). In this paper,
we are trying to answer why BN usually performs worse than LN in NLP tasks
with Transformer models. We find that the inconsistency between training and
inference of BN is the leading cause that results in the failure of BN in NLP.
We define Training Inference Discrepancy (TID) to quantitatively measure this
inconsistency and reveal that TID can indicate BN’s performance, supported by
extensive experiments, including image classification, neural machine translation,
language modeling, sequence labeling, and text classification tasks. We find that
BN can obtain much better test performance than LN when TID keeps small through
training. To suppress the explosion of TID, we propose Regularized BN (RBN) that
adds a simple regularization term to narrow the gap between batch statistics and
population statistics of BN. RBN improves the performance of BN consistently and
outperforms or is on par with LN on 17 out of 20 settings, involving ten datasets
and two common variants of Transformer1.
1 Introduction
Deep learning [
21
] has revolutionized Computer Vision (CV) [
20
] and Natural Language Processing
(NLP) [
41
]. Normalization layers are key components to stabilize and accelerate the training in
Deep Neural Networks (DNNs). In CV, Batch Normalization (BN) [
17
] is the default normalization
technique and reveals superior performance over other normalization techniques in image recognition
tasks by enforcing the input of a neuron to have zero mean and unit variance within a mini-batch
data. Furthermore, a growing number of theoretical works analyze the excellent properties of BN
in benefiting optimization [
17
,
36
,
4
,
13
,
8
,
9
]. While BN almost dominates in CV with empirical
success and theoretical properties, Layer Normalization (LN) is the leading normalization technique
in NLP, especially for Transformer models that achieve the state-of-the-art performance on extensive
tasks, including machine translation [
41
], natural language understanding [
10
], text generation [
34
],
few shot learning [
5
], to name a few. As a direct substitute of LN, BN performs poorly in Transformer
for neural machine translation [
38
]. It remains elusive to explain the failure of BN in NLP community.
In this work, we are trying to take a step forward. Our contributions are summarized as follows:
1Our code is available at https://github.com/wjxts/RegularizedBN
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.05153v1 [cs.CL] 11 Oct 2022
We find that the inconsistency between training and inference leads to the failure of BN
in NLP, supported by our extensive experiments, including image classification, neural
machine translation, language modeling, sequence labeling, and text classification tasks.
We define Training Inference Discrepancy (TID) to quantitatively measure this inconsistency
and show that TID can serve as an indicator of BN’s performance. In particular, BN reaches
much better test performance than LN when TID keeps small through training, e.g., in image
recognition and language modeling tasks.
We propose Regularized BN (RBN) that adds a regularization term in BN to penalize and
reduce the TID when the TID of BN is large. We reveal the optimization advantages of RBN
over LN by exploring the layer-wise training dynamics of Transformer.
We empirically show that RBN can exceed or match the performance of LN, sometimes with
a large margin, on 17 out of 20 settings, involving ten datasets and two common variants of
Transformer. Besides, RBN introduces no extra computation at inference compared to LN.
2 Related Work
Analyses of BN’s Success
As BN becomes an indispensable component in deep neural networks
deployed in CV tasks, a bunch of works explore the theoretical reasons behind its success. From
the view of optimization, the original BN paper [
17
] argues that BN can reduce internal covariate
shift and thus stabilize the training, while Santurkar et al.
[36]
debate that BN could smooth the loss
landscape and thus enable training of neural network with larger learning rate [
4
]. Daneshmand
et al.
[8
,
9]
prove that a stack of randomized linear layers and BN layers will endow the intermediate
features of neural network with sufficient numerical rank as depth increases, which is beneficial for
optimization and learning discriminative hierarchical features. Huang et al.
[13]
show that BN could
improve the layer-wise conditioning of the neural network optimization by exploring the spectrum of
Hessian matrix with block diagonal approximation [
28
]. From the view of generalization, Ioffe and
Szegedy
[17]
, Luo et al.
[25]
, Li et al.
[22]
, Wu and Johnson
[43]
argue that BN serves as regularizer
which reduces over-fitting when its stochasticity is small and may have detrimental effect when it is
large [
43
]. Huang et al.
[12]
further propose Stochastic Normalization Disturbance (SND) to measure
such stochasticity and shows that large SND will hinder the training of neural networks.
Training Inference Inconsistency of BN
Normalizing along the batch dimension usually intro-
duces training inference inconsistency since mini-batch data is neither necessary nor desirable during
inference. BN uses population statistics, estimated by running average over mini-batch statistics,
for inference. The training inference inconsistency usually harms the performance of BN for small-
batch-size training since the estimation of population statistics could be inaccurate [
42
]. One way
to reduce the inconsistency between training and inference is to exploit the estimated population
statistics for normalization during training [
16
,
6
,
47
,
50
,
49
]. These works may outperform BN
when the batch size is small, where inaccurate estimation may be the main issue [
17
,
18
], but they
usually work inferior to BN under moderate batch-size training [
24
]. Another way to reduce the
inconsistency is estimating corrected normalization statistics during inference only, either for domain
adaptation [
23
], corruption robustness [
37
,
31
,
2
], or small-batch-size training [
39
,
40
]. We note that
a recent work [
14
] investigates the estimation shift problem of BN. Unlike this work that addresses
the accumulated estimation shift due to the stack of BNs for CNNs in CV tasks, our work pays more
attention to how the training inference inconsistency of BN correlates with its performances for
Transformers in NLP tasks. Besides, the estimation shift of BN defined in [
14
], which addresses the
differences between the estimated population statistics and the expected statistics, differs from our
TID of BN that addresses the differences between the mini-batch statistics and populations statistics.
Exploring the Failure of BN in Transformer
Similar to our work, Power Normalization (PN) [
38
]
also investigates the reason behind the failure of BN in Transformers. Our work significantly differs
from PN [
38
] in the following facets. PN attributes the failure of BN to the unstable training of BN
incurred by fluctuated forward and backward batch statistics with outlier values, while we observe
that the training of BN is as good as LN and the inconsistency between training and inference of BN
matters more. Based on our observation, we propose a regularization term to reduce the TID of BN.
Compared with PN, which incorporates a layer-scale layer (root mean square layer normalization [
51
]
without affine transformation [
45
]), our method introduces no extra computation at inference. Besides,
we use a more reasonable index to measure inconsistency which is invariant to the scale of data.
Furthermore, we show that our RBN can improve the layer-wise training dynamics of LN, which
reveals the optimization advantages of RBN.
2
Figure 1: Train loss, validation loss/BLEU of Transformer trained on IWSLT14 with BN and LN.
The training of
TransformerBN
is better than
TransformerLN
while the validation loss/BLEU of
TransformerBN
underperforms that of
TransformerLN
after 8 epoch. At the end of the training,
TransformerBN
falls behind
TransformerLN
with large BLEU scores. Lower loss and higher BLEU
scores indicate better performance. Based on the inconsistency of training and validation perfor-
mance of BN, we hypothesize that the training inference discrepancy of BN causes its performance
degradation.
3 Analyses of Training Inference Inconsistency in TransformerBN
3.1 Preliminary
Batch Normalization (BN) [
17
] is typically used to stabilize and accelerate DNN’s training. Let
xRddenote the d-dimensional input to a neural network layer. During training, BN standardizes
each neuron/channel within mmini-batch data by2
ˆ
xj=BNtrain(xj) = xjµB,j
qσ2
B,j
, j = 1,2, ..., d, (1)
where
µB,j =1
mPm
i=1 x(i)
j
and
σ2
B,j =1
mPm
i=1(x(i)
jµB,j )2
are the mini-batch mean and
variance for each neuron, respectively. Note that an extra small number
is usually added to the
variance in practice to prevent numerical instability. During inference, the population mean
µ
and
variance σ2of the layer input are required for BN to make a deterministic prediction [17] as:
ˆ
xj=BNinf (xj) = xjµj
qσ2
j
, j = 1,2, ..., d. (2)
These population statistics
{µ, σ2}
are usually calculated as the running average of mini-batch
statistics over different training iteration twith an update factor αas follows:
(µ(t)= (1 α)µ(t1) +αµ(t)
B,
(σ2)(t)= (1 α)(σ2)(t1) +α(σ2
B)(t).(3)
The discrepancy of BN for normalization during training (using Eqn. 1) and inference (using Eqn. 2)
can produce stochasticity, since the population statistics of BN are estimated from the mini-batch
statistics that depend on the sampled mini-batch inputs. This discrepancy is believed to benefit the
generalization [
17
,
12
] if the stochasticity is well controlled. However, this discrepancy usually harms
the performance of small-batch-size training [
42
] since the estimation of population statistics can
be inaccurate. To address this problem, a bunch of batch-free normalizations are proposed that use
consistent operations during training and inference, e.g., Layer Normalization (LN) [1].
Basic Observations
To analyze the failure of BN in NLP tasks, we first plot the training loss
and validation loss/BLEU [
33
] of BN and LN on IWSLT14 (De-En) dataset with the original
Transformer model (see Figure 1). We observe that the training of
TransformerBN
is faster than
TransformerLN
. The training nll_loss of BN is even smaller than that of LN, especially at the
beginning. However, validation loss/BLEU of BN is worse than that of LN after around the seventh
epoch. This phenomenon can not be attributed to over-fitting since BN introduces more stochasticity
than LN in the training phase. The inconsistency between training and inference of BN may play a
role.
2
BN usually uses extra learnable scale and shift parameters [
17
] to recover the potentially reduced represen-
tation capacity, and we omit them since they are not relevant to our discussion.
3
Figure 2: Top: The average deviation of batch mean
µB
(left figure) and batch variance
σ2
B
(right
figure) to population mean
µ
and population variance
σ2
of all BN layers through training in
ResNet18 and
TransformerBN
. There are 21 BN layers in ResNet18 and 12 BN layers in the encoder
of
TransformerBN
. At the end of training, ResNet18 has mean/variance deviation of around
4%
/
4%
and those in
TransformerBN
are around
11%
/
13%
. Large deviation of statistics hurts the performance
of
TransformerBN
. Bottom: Variance deviation of BN layers with different depths (left) at the end of
training and variance deviation over depth and training progress (right).
Since BN in ResNet18 also involves training inference inconsistency, we guess the degree of such
inconsistency has a difference between ResNet18 and
TransformerBN
. Therefore, we plot the
deviation of batch statistics to population statistics of BN in ResNet18 and
TransformerBN
in
Figure 2 (top) to make a comparison. ResNet18 is trained on CIFAR-10 [
19
] and accuracy will drop 2
percent if we replace BN with LN. We find that at the end of the training,
TransformerBN
has a much
bigger mean and variance deviation than ResNet18. Besides, the last several BN layers that are close
to the output in
TransformerBN
have large variance deviation (Figure 2 (bottom)), which negatively
impact the model output. Furthermore, the performance degradation of
TransformerBN
coincides
with the increase of variance deviation by comparing Figure 1 (right) and Figure 2 (bottom right).
Based on these observations, we hypothesize that the inconsistency between training and inference of
BN causes BN’s performance degradation in neural machine translation. We first mathematically
define the training inference discrepancy of BN in the next subsection.
3.2 Training Inference Discrepancy
By observing Eqns. 1 and 2, the normalized output during training can be calculated as:
xjµB,j
σB,j
=xjµj
σj
+µjµB,j
σjσj
σB,j
, j = 1,2, ..., d, (4)
where
σB,j >0
and
σj>0
are the standard deviation for the
j
-th dimension. We can see
µjµB,j
σj
and
σj
σB,j
can be viewed as random variables. Their magnitude can characterize the diversity of
mini-batch examples during training and indicate how hard the estimation of population statistics
is. We thus define the training inference discrepancy to quantitatively measure the inconsistency as
follows.
4
摘要:

UnderstandingtheFailureofBatchNormalizationforTransformersinNLPJiaxiWang1,JiWu1;2,LeiHuang31DepartmentofElectronicEngineering,TsinghuaUniversity2InstituteforPrecisionMedicine,TsinghuaUniversity{wjx20@mails,wuji_ee@mail}.tsinghua.edu.cn3SKLSDE,InstituteofArticialIntelligence,BeihangUniversityhuangle...

展开>> 收起<<
Understanding the Failure of Batch Normalization for Transformers in NLP Jiaxi Wang1 Ji Wu12 Lei Huang3.pdf

共17页,预览4页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:17 页 大小:1.36MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 17
客服
关注