Improving Sharpness-Aware Minimization with Fisher Mask for Better Generalization on Language Models Qihuang Zhong1 Liang Ding2 Li Shen2 Peng Mi3 Juhua Liu4y Bo Du1y Dacheng Tao2

2025-04-24 0 0 783.52KB 22 页 10玖币
侵权投诉
Improving Sharpness-Aware Minimization with Fisher Mask
for Better Generalization on Language Models
Qihuang Zhong1
, Liang Ding2, Li Shen2, Peng Mi3, Juhua Liu4
, Bo Du1, Dacheng Tao2
1National Engineering Research Center for Multimedia Software, Institute of Artificial Intelligence, School of Computer Science
and Hubei Key Laboratory of Multimedia and Network Communication Engineering, Wuhan University, China
2JD Explore Academy, China 3School of Informatics, Xiamen University, China
4Research Center for Graphic Communication, Printing and Packaging, and Institute of Artificial Intelligence, Wuhan University, China
{zhongqihuang,liujuhua,dubo}@whu.edu.cn,{dingliang1,shenli100}@jd.com,mipeng@stu.xmu.edu.cn,dacheng.tao@gmail.com
Abstract
Fine-tuning large pretrained language models
on a limited training corpus usually suffers
from poor generalization. Prior works show
that the recently-proposed sharpness-aware
minimization (SAM) optimization method can
improve the model generalization. However,
SAM adds a perturbation to each model pa-
rameter equally (but not all parameters con-
tribute equally to the optimization of training),
which we argue is sub-optimal and will lead to
excessive computation. In this paper, we pro-
pose a novel optimization procedure, namely
FSAM1, which introduces a Fisher mask to im-
prove the efficiency and performance of SAM.
In short, instead of adding perturbation to all
parameters, FSAM uses the Fisher informa-
tion to identity the important parameters and
formulates a Fisher mask to obtain the sparse
perturbation, i.e., making the optimizer focus
on these important parameters. Experiments
on various tasks in GLUE and SuperGLUE
benchmarks show that FSAM consistently out-
performs the vanilla SAM by 0.671.98 aver-
age score among four different pretrained mod-
els. We also empirically show that FSAM
works well in other complex scenarios, e.g.,
fine-tuning on generation tasks or limited train-
ing data. Encouragingly, when training data is
limited, FSAM improves the SAM by a large
margin, i.e., up to 15.1.
1 Introduction
The “pretraining-finetuning” paradigm has become
the de facto standard for the community of natural
language processing (NLP) (Devlin et al.,2019;
Liu et al.,2019;Clark et al.,2019b;Raffel et al.,
2020;Brown et al.,2020;Lewis et al.,2020). Given
a pretrained language model (PLM), the dominant
Work was done when Qihuang was interning at JD
Explore Academy.
Corresponding Authors: Juhua Liu (e-mail: liu-
juhua@whu.edu.cn), Bo Du (e-mail: dubo@whu.edu.cn)
1https://github.com/WHU-ZQH/FSAM4PLM
fine-tuning manner is tuning the entire pretrained
parameters for each downstream task (Radford
et al.,2018;Devlin et al.,2019). While fine-tuning
the entire PLM can improve performance on a wide
range of NLP tasks, it usually suffers from over-
fitting and poorer generalization ability (Xu et al.,
2021;Bahri et al.,2022), especially in the large-
scale PLMs and limited training data scenarios.
Hence, some existing efforts attempt to provide
more regularization in the fine-tuning stage (Zhang
et al.,2018;Müller et al.,2019;Xu et al.,2021),
among which the optimization of the training loss
is an intuitive and effective method. Specifically,
motivated by the finding (Keskar et al.,2016;
Neyshabur et al.,2017) that the smoother loss
landscape refers to the better model generaliza-
tion, Foret et al. (2020) propose the “
sharpness-
aware minimization
” (SAM) to simultaneously
minimize loss value and loss sharpness, where the
sharpness can be quantified as the maximized dif-
ference of loss when a perturbation is added to the
current weights. In practice, SAM performs two
forward-backward computations for each optimiza-
tion step, where the first forward-backward is to
obtain the perturbation for each model parameter
and the second one is to update the parameters.
Many prior works (Wu et al.,2020;Zheng et al.,
2021) show the effectiveness of SAM in the vi-
sion domain, motivated by this, Bahri et al. (2022)
first apply the SAM to the language domain, more
recently.
Although Bahri et al. (2022) empirically show
the remarkable performance of SAM on several lan-
guage understanding tasks, SAM calculates pertur-
bations indiscriminately for all parameters, which
is time-consuming and hinders the application of
SAM. Furthermore, inspired by the finding (Keskar
et al.,2016) that only about 5% of parameters are
sharp and rise steeply during optimization, we no-
tice that
not all parameters contribute equally to
the optimization of training
. Hence, this raises
arXiv:2210.05497v1 [cs.CL] 11 Oct 2022
a question that whether we can calculate pertur-
bations for only some individual parameters, and
thus make the optimizer focus on these important
parameters.
To this end, we propose a novel optimization
approach, Fisher SAM (FSAM), which introduces
a Fisher mask to improve the efficiency and effec-
tiveness of SAM. In short, FSAM first uses the
Fisher information (Fisher,1922) as the metric to
identify the sharper parameters
2
and formulates
a binary Fisher mask correspondingly. Then, the
Fisher mask is multiplied with the perturbations
to obtain the sparse perturbations, which are lastly
used to perform regularization in the parameter
update. In this way, only parts of sharper parame-
ters will be added into the perturbations, and the
optimizer can thus focus more on these important
parameters. Also, the sparse perturbations could
ensure the training acceleration via sparse back-
propagation
3
. Moreover, one may concern that the
sparse Fisher mask would affect the convergence
rate of FSAM (Lin et al.,2019). Hence, we theoret-
ically provide the convergence analysis of FSAM,
ensuring that the convergence of FSAM is irrele-
vant to the Fisher mask.
We conduct a large-scale and systematic study
to evaluate the performance and effectiveness of
FSAM. Firstly, we apply SAM and FSAM to fine-
tune various PLMs on parts of GLUE and Super-
GLUE benchmarks, where the results show that
FSAM consistently outperforms the vanilla SAM
by 0.67
1.98 average score among these PLMs,
and surpasses the Adam (Kingma and Ba,2015)
optimizer by 1.41
1.91 points. Secondly, we
conduct experiments on two popular generation
tasks (i.e., XSUM and CoNLL2014) and prove that
FSAM can deliver promising results against SAM.
Lastly, quantitative analysis and in-depth discus-
sion demonstrate the universality and effectiveness
of FSAM in various complex scenarios, and prove
that FSAM indeed brings better model generaliza-
tion. Specifically, we show that our Fisher mask
strategy not only works well in the SAM, but also
can be applied to other SAM variants.
To summarize, our contributions are two-fold:
2
We refer to these parameters as the important ones, be-
cause they will rise steeply during optimization and affect the
model generalization significantly.
3
Since the fine-grained sparse training is limited to the
hardware, we do not achieve actual sparse speedup in this
work. Despite it, we still believe that FSAM has great potential
to achieve true training acceleration in the future, with the
development of hardware for fine-grained sparse operation.
(1) We propose a novel optimization approach
(namely FSAM) with theoretical convergence guar-
antee for PLMs. Specifically, FSAM improves the
performance and efficiency of recently-proposed
SAM via a Fisher mask strategy, which can also be
applied to more SAM variants. (2) Extensive exper-
iments show that FSAM consistently outperforms
the SAM by a large margin on both language un-
derstanding and generation tasks. The systematic
study demonstrates the effectiveness and universal-
ity of FSAM on improving model generalization.
2 Related Work
SAM and its variants.
Hochreiter and Schmid-
huber (1994) first show the strong correlation be-
tween the flat minima and the generalization of
a model, inspired by this, Foret et al. (2020) pro-
pose the SAM to find a flat minimum and thus
improve model generalization. While many ex-
isting works prove the effectiveness of SAM on
various computer vision tasks (Wu et al.,2020;
Chen et al.,2021;Zheng et al.,2021), the double
forward-propagation process of SAM brings more
computational cost. To this end, Du et al. (2021)
propose an Efficient SAM (ESAM) for reducing the
computational cost of SAM. Additionally, there are
also some efforts that focus on more efficient and
effective SAM optimization (Zhuang et al.,2021;
Kwon et al.,2021;Mi et al.,2022).
Improving Generalization.
Recently, we have
witnessed numerous PLMs that achieved tremen-
dous success in the community of NLP (Yang et al.,
2019;Devlin et al.,2019;Brown et al.,2020;Lewis
et al.,2020;Raffel et al.,2020;Joshi et al.,2020;
He et al.,2020;Qi et al.,2021;Zhong et al.,2022).
The current dominant fine-tuning approach needs to
tune all pretrained parameters for each downstream
task, which makes the PLM easily memorize the
training data and thus leads to overfitting. To tackle
this issue, some works attempt to provide implicit
and explicit regularization into the training of mod-
els, such as dropout (Srivastava et al.,2014), label
smoothing (Müller et al.,2019), mixup (Zhang
et al.,2018) and other data-augmentation meth-
ods (Sennrich et al.,2016;Wang et al.,2018b;
Zhong et al.,2021;Wang et al.,2022;Ding et al.,
2022). On the other hand, motivated by the suc-
cessful applications of SAM in the vision domain,
Bahri et al. (2022) involve applying SAM to opti-
mize the T5 (Raffel et al.,2020) model on multiple
language tasks and show that SAM can improve
the generalization of PLMs effectively.
We depart from the prior work (Bahri et al.,
2022) and ours as follows: 1) different motivations:
instead of verifying the effect of vanilla SAM on
several language understanding tasks, we aim to
improve the efficiency and effectiveness of SAM.
2) different contributions: our main contribution
is to propose a fisher mask strategy, which can be
applied to both SAM and its variants. 3) more anal-
ysis: we provide more experimental results and
analysis towards the effectiveness of our method in
more complex scenarios.
3 Methodology
In this section, we first review the Sharpness-Aware
Minimization, and then propose our Sharpness-
Aware Minimization with Fisher mask, coined as
FSAM. Finally, we theoretically analyze the con-
vergence of FSAM with adaptive learning rate.
3.1 Sharpness-Aware Minimization
Preliminary.
In this paper, we denote the weight
of a neural network as
wRd
. Suppose the
training dataset
S={(xi, yi)}n
i=1
i.i.d. drawn
from the distribution
D
. The object function of
the data
xi
from
S
is denote as
fS(xi)
. Since the
Adam (Kingma and Ba,2015) and its variants are
widely used in NLP tasks, the learning rate is esti-
mated via RMSProp/Adam style.
Sharpness-Aware Minimization.
Foret et al.
(2020) propose the Sharpness-Aware Minimiza-
tion (SAM) to improve the generalization, which
is achieved by the following min-max problem:
min
w
max
||||2ρf(w+),(1)
where
ρ
is a predefined value to control the neigh-
borhood size, and the
is the perturbation vector
on model weight. The optimization is expected
that the model loss will not significantly rise with
a certain amount of weight change controlled by
ρ
,
which is intuitively consistent with the generaliza-
tion capacity of model.
With the Taylor expansion, the perturbation vec-
tor could be achieved approximately:
= arg max
||||2ρ
fS(w+)(2)
arg max
||||2ρ
fS(w) + · ∇wf(w)(3)
=ρ· ∇wf(w)||∇wf(w)||2,(4)
and the object function could be simplified as
min
w
f(w+ρwf(w)
||∇wf(w)||2
),(5)
The solution of the above function could be ob-
tained by a two-step gradient descent. In the first
gradient descent step, the perturbation vector
is
calculated by Equation 2. The second gradient
descent step is the actual weight update.
However, despite the improvement of SAM on
many tasks, SAM requires a two-step gradient cal-
culation which leads to the double overhead com-
pared to the conventional optimizer, e.g., Stochastic
Gradient Descent (SGD) and Adam.
3.2 Sharpness-Aware Minimization with
Fisher Mask
In this subsection, we propose the Sharpness-
Aware Minimization with Fisher Mask (FSAM)
in detail, which reduces the computation of SAM
by sparse calculation.
To be specific, we compute only a fraction of the
elements in the perturbation vector
, which would
be multiplied by a sparse binary mask
m∈ {0,1}d
.
To control the amount of perturbation, the sparse
mask
m
satisfies
1Tm= (1 s)·d
, where the
s
is the predefined sparse ratio and empirically set to
0.9. The objective function of FSAM is denoted as
min
w
fS(w+ρwf(w)m
||∇wf(w)||2
),(6)
where
is the Hadamard product, i.e., the element-
wise multiplication. For the stability of optimiza-
tion, we update the mask
m
with a fixed interval
(denoted as Fi) during training. The algorithm of
FSAM is shown in Algorithm 1.
To find the optimal mask during training, we
apply the Fisher information to achieve sparse
perturbation. The Fisher information is proposed
by (Fisher,1922) to measures the information car-
ried by an observable random variable about the
unknown parameters of the distribution. The Fisher
information is defined by
F=ExEylog p(y|x)log p(y|x)T,(7)
where the
p(y|x)
is the output of model in machine
learning. However, due to the over-parameterized
model in deep learning, the computation of Fisher
information is unacceptable, i.e.,
FR|w|×|w|
.
To save the computational effort, we approximate
Algorithm 1 Fisher SAM (FSAM)
Input:
sparse ratio
s
, dense model
w
, binary mask
m
, update
interval
Tm
, base learning rate
γ
,
ˆv1=δ2
, training set
S.
1: Initialize wand mrandomly.
2: for epoch t= 1,2. . . T do
3: for each training iteration do
4: Sample a batch from S:B
5: Compute perturbation by Eq. 2
6: if tmod Tm= 0 then
7: Sample NF isher data from distribution S.
8: Compute Empirical Fisher by Equation 9.
9: m1ArgTopK( ˆ
F , (1 s)· |w|)
10: m0ArgTopK(ˆ
F , s · |w|)
11: Update mask mby merging: m=m0m1.
12: end if
13: m
14: end for
15: Compute SAM gradient gt=fB(w+)
16: vt=β2vt1+ (1 β2)[gt]2
17: ˆvt= max(ˆvt1, vt)
18: wwγgt1
ˆvt
19: end for
20: return Final weight of model w
Fisher information as the diagonal matrix, i.e.,
F
R|w|
. Consider the expectation in Equation 7, the
first one is the data distribution
xp(x)
, which is
not available in most tasks. We approximate it by
sampling NF isher data from p(x):
F=1
NF isher
Eylog p(y|xi)2.(8)
The second expectation is over
p(y|x)
, which can
be achieved by the label
yi
for data
xi
in supervised
learning. Finally, we calculate the Fisher informa-
tion as "Empirical Fisher":
ˆ
F=1
NF isher log p(yi|xi)2.(9)
Since the empirical Fisher is the same size as the
weight, i.e.,
ˆ
FR|w|
, the value of the element
in Fisher
ˆ
F
represents the importance of the cor-
responding element in weight
w
. Thus, we sort
the elements of
ˆ
F
in descending, and the weights
with top
k
Fisher values will be perturbed, i.e., the
corresponding element in mask will be set to 1:
m1ArgTopK( ˆ
F , (1 s)· |w|),(10)
where
m1
is the set whose elements in the mask
m
are 1, i.e.,
m={mi= 1|mim}
, and
ArgTopK
(x, k)
returns the top
k
largest values among
x
. On
the other hand, the other weights with small Fisher
values will not be perturbed, i.e., the corresponding
element in mask will be set to 0:
m0ArgTopK( ˆ
F , s · |w|).(11)
3.3 Theoretical Analysis
In this subsection, we theoretically analyze the con-
vergence and generalization of FSAM. Due to the
space limitation, we only show the convergence
analysis here, and the generalization analysis and
whole proof are presented in Appendix A.1.
Assumption 1.
(
L
-smooth.) Consider
f
is differ-
entiable with gradient Lipschitz property: It exists
L > 0s.t.
||∇f(w)− ∇f(v)|| ≤ L||wv||,w, v Rd.
Assumption 2.
(Bounded stochastic gradients.)
The variance of stochastic gradient is bounded:
E[||∇fi(x)− ∇f(x)||2]σ2
Assumption 3.
(Bounded gradient.) The stochas-
tic gradient is bounded: It exists G0s.t.
||∇fi(w)||G
Theorem 1.
Consider the function
f
under the
assumption 1,2,3, and a fixed base learning rate
γt
satisfies that γtδ
8L, we have
1
T
T1
X
t=0
E||∇f(xt)||22Gf(x0)f
γtT
+20GL2ρ2
δ+2G3
Td(1
δ1
G)
+4tL
δ
2
δ+4tL
δ
σ2
+4γtLG3
Td(G2δ2)
The Theorem 1shows that when
T
is large, FSAM
could achieve the linear speedup convergence rate
with respect to mini-batch size
b
under the setting
of γt=O(qb
T)and ρ=O(q1
bT ),i.e.,
1
T
T1
X
t=0
E||∇f(xt)||2=O(r1
bT )
4 Experimental Setup
4.1 Tasks and Datasets
To investigate the effectiveness and universality of
our FSAM method, we conduct extensive experi-
ments on various NLP tasks. Specifically, different
from Bahri et al. (2022) that only verify the method
on several language understanding tasks, we eval-
uate our method on both language understanding
and generation tasks.
Table 1: Experimental results (dev scores) on various language understanding benchmarks. Comparison between
vanilla SAM and our proposed FSAM applied to four widely used large-scale PLMs. The best results for each
setting are in bold. AVG.” denotes the average scores on all tasks, which are underlined. Results show that our
FSAM brings consistent improvements across all understanding tasks among different PLMs.
Method CoLA MRPC STS-B RTE CB BoolQ WSC WiC AVG.
Mcc. Acc. F1. Pear. Spea. Acc. Acc. Acc. Acc. Acc.
BERT-large
Adam 62.8 87.3 91.1 89.5 89.3 70.7 87.5 74.3 68.3 72.7 79.35
Adam+SAM 62.1 87.9 91.4 89.8 89.4 71.5 91.1 72.9 68.3 74.1 79.85
Adam+FSAM 63.4 89.0 92.0 90.4 89.9 74.4 94.6 75.3 68.5 74.4 81.19
ELECTRA-large
Adam 69.0 89.2 92.4 92.1 92.1 87.3 91.1 85.6 83.6 74.4 85.68
Adam+SAM 63.9 91.9 94.2 92.4 92.4 89.2 92.9 82.2 84.6 72.4 85.61
Adam+FSAM 69.6 92.4 94.5 92.3 92.5 88.8 96.4 85.9 89.4 74.1 87.59
ALBERT-xxlarge
Adam 71.1 90.7 93.3 92.9 92.7 87.0 89.3 86.8 85.6 75.5 86.49
Adam+SAM 69.9 90.7 93.2 92.6 92.4 88.1 91.1 87.7 82.7 76.6 86.50
Adam+FSAM 72.3 91.9 94.2 93.0 92.8 88.8 91.1 87.9 86.5 76.6 87.51
RoBERTa-large
Adam 66.7 90.4 93.1 92.1 92.0 87.0 92.8 86.0 78.1 73.3 85.15
Adam+SAM 68.5 90.7 93.3 91.5 91.3 87.7 96.4 84.2 81.3 74.0 85.89
Adam+FSAM 69.5 90.7 93.2 91.9 91.6 87.7 98.2 86.8 81.5 74.5 86.56
Language Understanding Tasks.
Following
many previous works (Vu et al.,2022;Bahri et al.,
2022;Zhong et al.,2022), we conduct experiments
on a combination of tasks from GLUE (Wang
et al.,2018a) and SuperGLUE (Wang et al.,
2019) benchmarks, including linguistic acceptabil-
ity (CoLA), natural language inference (RTE, CB),
paraphrase and similarity (MRPC and STS-B),
question answering (BoolQ), word sense disam-
biguation (WiC) and coreference resolution (WSC).
In practice, we evaluate the performance with Accu-
racy (“Acc.”) metric for most tasks, except the addi-
tional F1 score for MRPC, the Pearson-Spearman
correlations (“Pear./Spea.”) for STS-B and the
Matthew correlation (“Mcc.”) for CoLA.
Language Generation Tasks.
We also use two
popular generation tasks following Liu et al.
(2021); Zhang et al. (2022) as the benchmarks, i.e.,
abstractive summarization (XSUM) and grammati-
cal error correction (CoNLL2014). For the XSUM,
we report results in terms of standard ROUGE met-
rics (Lin,2004), i.e., Rouge-1, Rouge-2 and Rouge-
L, respectively. For the CoNLL2014, MaxMatch
scores (Dahlmeier and Ng,2012) are used for eval-
uation with Precision, Recall, and F0.5values 4.
4
Due to the space limitation, we present the details of all
used tasks and datasets in Appendix A.2
4.2 Implementations
In practice, we use the pretrained models and
code in HuggingFace
5
(Wolf et al.,2019). Specif-
ically, for the understanding tasks, we employ 4
widely used PLMs in our study, i.e., BERT (Devlin
et al.,2019), ELECTRA (Clark et al.,2019b), AL-
BERT (Lan et al.,2019) and RoBERTa (Liu et al.,
2019). Furthermore, an representative sequence-
to-sequence model, BART (Lewis et al.,2020), is
used for the generation tasks.
We compare our proposed FSAM method with
the base optimizer (without using any SAM ap-
proach) and vanilla SAM method. Specifically, the
Adam (Kingma and Ba,2015) is used as the base
optimizer to tune our models. The
β2
and weight
decay of Adam are set as 0.999 and 0.01. SAM
and FSAM use the same settings as above. More
specially, we grid search for the neighborhood size
of SAM and FSAM on {1e-2, 5e-3, 1e-3}. Ad-
ditionally, for each downstream task, we follow
the same hyper-parameter settings from the prior
works (Lewis et al.,2020;Xu et al.,2021). The
detailed hyper-parameters of fine-tuning on these
downstream tasks can be seen in Appendix A.3.
We report the averaged results over 5 random seeds
for NLU tasks, while for NLG tasks, we follow ex-
5https://github.com/huggingface/transformers
摘要:

ImprovingSharpness-AwareMinimizationwithFisherMaskforBetterGeneralizationonLanguageModelsQihuangZhong1,LiangDing2,LiShen2,PengMi3,JuhuaLiu4y,BoDu1y,DachengTao21NationalEngineeringResearchCenterforMultimediaSoftware,InstituteofArticialIntelligence,SchoolofComputerScienceandHubeiKeyLaboratoryofMulti...

展开>> 收起<<
Improving Sharpness-Aware Minimization with Fisher Mask for Better Generalization on Language Models Qihuang Zhong1 Liang Ding2 Li Shen2 Peng Mi3 Juhua Liu4y Bo Du1y Dacheng Tao2.pdf

共22页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:22 页 大小:783.52KB 格式:PDF 时间:2025-04-24

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 22
客服
关注