K-SAM Sharpness-Aware Minimization at the Speed of SGD Renkun Ni

2025-05-03 0 0 1.38MB 13 页 10玖币
侵权投诉
K-SAM: Sharpness-Aware Minimization at the Speed
of SGD
Renkun Ni
University of Maryland
rn9zm@umd.edu
Ping-yeh Chiang
University of Maryland
Jonas Geiping
University of Maryland
Micah Goldblum
New York University
Andrew Gordon Wilson
New York University
Tom Goldstein
University of Maryland
Abstract
Sharpness-Aware Minimization (SAM) has recently emerged as a robust technique
for improving the accuracy of deep neural networks. However, SAM incurs a high
computational cost in practice, requiring up to twice as much computation as vanilla
SGD. The computational challenge posed by SAM arises because each iteration
requires both ascent and descent steps and thus double the gradient computations.
To address this challenge, we propose to compute gradients in both stages of SAM
on only the top-k samples with highest loss. K-SAM is simple and extremely
easy-to-implement while providing significant generalization boosts over vanilla
SGD at little to no additional cost.
1 Introduction
Methods for promoting good generalization are of tremendous value to deep learning practitioners
and a point of fascination for deep learning theorists. While machine learning models can easily
achieve perfect training accuracy on any dataset given enough parameters, it is unclear when a model
will generalize well to test data. Improving the ability of models to extrapolate their knowledge,
learned during training, to held out test samples is the key to performing well in the wild.
Recently, there has been a line of work that argues for the geometry of the loss landscape as a major
contributor to generalization performance for deep learning models. A number of researchers have
argued that flatter minima lead to models that generalize better (Keskar et al., 2016; Xing et al., 2018;
Jiang et al., 2019; Smith et al., 2021). Underlying this work is the intuition that small changes in
parameters yield perturbations to decision boundaries so that flat minima yield wide-margin decision
boundaries (Huang et al., 2020). Motivated by these investigations, Foret et al. (2020) propose an
effective algorithm - Sharpness Aware Minimization (SAM) - to optimize models toward flatter
minima and better generalization performance. The proposed algorithm entails performing one-step
adversarial training in parameter space, finding a loss function minimum that is “flat” in the sense
that perturbations to the network parameters in worst-case directions still yield low training loss.
This simple concept achieves impressive performance on a wide variety of tasks. For example, Foret
et al. (2020) achieved notable improvements on various benchmark vision datasets (e.g., CIFAR-10,
ImageNet) by simply swapping out the optimizer. Later, Chen et al. (2021) found that SAM improves
sample complexity and performance of vision transformer models so that these transformers are
competitive with ResNets even without pre-training.
Moreover, further innovations to the SAM setup, such as modifying the radius of the adversarial
step to be invariant to parameter re-scaling (Kwon et al., 2021), yield additional improvements to
generalization on vision tasks. In addition, Bahri et al. (2021) recently found that SAM not only works
in the vision domain, but also improves the performance of language models on GLUE, SuperGLUE,
Preprint. Under review.
arXiv:2210.12864v1 [cs.LG] 23 Oct 2022
Web Questions, Natural Questions, Trivia QA, and TyDiQA (Wang et al., 2018, 2019; Joshi et al.,
2017; Clark et al., 2020).
Despite the simplicity of SAM, the improved performance comes with a steep cost of twice as much
compute, given that SAM requires two forward and backward passes for each optimization step: one
for the ascent step and another for the descent step. The additional cost may make SAM too expensive
for widespread adoption by practitioners, thus motivating studies to decrease the computational cost
of SAM. For example, Efficient SAM (Du et al., 2021) decreases the computational cost of SAM by
using examples with the largest increase in losses for the descent step. Bahri et al. (2021) randomly
select a subset of examples for the ascent step making the ascent step faster. However, both methods
still require a full forward and backward pass for either the ascent or descent step on all samples in
the batch, so they are more expensive than vanilla SGD. In comparison to these works, we develop a
version of SAM that is as fast as SGD so that practitioners can adopt our variant at no cost.
In this paper, we propose K-SAM, a simple modification to SAM that reduces the computational
costs to that of vanilla SGD, while achieving comparable performance to the original SAM optimizer.
K-SAM exploits the fact that a small subset of training examples is sufficient for both gradient
computation steps (Du et al., 2021; Fan et al., 2017; Bahri et al., 2021), and the examples with largest
losses dominate the average gradient over a large batch. To decrease computational cost, we only
use the
K
examples with the largest loss values in the training batch for both gradient computations.
When
K
is chosen properly, our proposed K-SAM can be as fast as vanilla SGD and meanwhile
improves generalization by seeking flat minima similarly to SAM.
We empirically verify the effectiveness of our proposed approach across datasets and models. We
demonstrate that a small number of samples with high loss produces gradients that are well-aligned
with the entire batch loss. Moreover, we show that our proposed method can achieve comparable
performance to the original (non-accelerated) SAM for vision tasks, such as image classification on
CIFAR-
{10,100}
, and language models on the GLUE benchmark while keeping the training cost
roughly as low as vanilla SGD. On the other hand, we observe that on large-scale many-class image
classification tasks, such as ImageNet, the average gradient within a batch is broadly distributed, so
that a very small number of samples with highest losses are not representative of the batch. This
phenomenon makes it hard to simultaneously achieve training efficiency and strong generalization by
subsampling. Nonetheless, we show that K-SAM can achieve comparable generalization to SAM
with around 65% training cost.
2 Related Work
2.1 Sharpness-Aware Minimization
In this section, we briefly introduce how SAM simultaneously minimizes the loss while also decreas-
ing loss sharpness, and we detail why it requires additional gradient steps that in turn double the
computational costs during training. Instead of finding parameters yielding low training loss only,
SAM attempts to find the parameter vector whose neighborhood possesses uniformly low loss, thus
leading to a flat minimum. Formally, SAM achieves this goal by solving the mini-max optimization
problem,
min
wLSAM
S(w),where LSAM
S(w) = max
kk2ρLS(w+),(1)
where
w
are the parameters of the neural network,
L
is the loss function,
S
is the training set, and
is
a small perturbation within an
l2
ball of norm
ρ
. In order to solve the outer minimization problem,
SAM applies a first-order approximation to solve the inner maximization problem,
= arg max
kk2ρ
LS(w+),
arg max
kk2ρ
LS(w) + TwLS(w).
=ρwLS(w)/k∇wLS(w)k2.(2)
After computing the approximate maximizer
ˆ
, SAM obtains “sharpness aware” gradient for descent:
LSAM
S(w)≈ ∇wLS(w)|w.(3)
2
In short, given a base optimizer, i.e., SGD, instead of computing the gradient of the model at the
current parameter vector
w
, SAM updates model parameters using the gradient with respect to the
perturbed model at
w+ ˆ
. Therefore a SAM update step requires two forward and backward passes
on each sample in a batch, namely a gradient ascent step to achieve the perturbation
ˆ
and a gradient
descent step to update the current model, which doubles the train time compared to the base optimizer.
2.2 Efficient Sharpness-Aware Minimization
Recently, several works improve the efficiency of SAM while retaining its performance benefits.
Bahri et al. (2021) improve the efficiency of SAM by reducing the computational cost of the gradient
ascent step. Instead of computing the ascent gradient on the whole mini-batch, they estimate the
gradient using a random subset. This tweak makes the computational cost only about 25% slower
than the vanilla SGD training routine when using
1/4
of the training batch for the ascent gradient,
while maintaining performance comparable to SAM. However, reducing the computational cost of
only the ascent step cannot make SAM as fast or faster than the base optimizer. In addition, we show
in Section 4 that using a random subset of the mini-batch for both gradient computations in SAM has
a negative impact on performance, which suggests that we should seek a better selection method than
random selection.
Du et al. (2021) propose an efficient SAM (ESAM) algorithm employing two strategies, namely
stochastic weight perturbation and sharpness-sensitive data selection. Stochastic weight perturbation
updates only part of the weights during the gradient ascent step which offers limited speed-ups.
During the descent step, the gradient is only computed using the examples whose loss values increase
the most after the parameter perturbation. With these two strategies, ESAM achieves up to 40.3%
acceleration compared to SAM. However, since ESAM selects the examples with the largest loss
differences, it has to compute the gradient over all samples in the batch for the ascent step and then
must similarly perform a forward pass over all samples in the batch for the descent step, which limits
the possible speed-ups from this approach.
2.3 Top-K Optimization
Top-k optimization has been applied to vanilla SGD (Fan et al., 2017; Kawaguchi & Lu, 2020).
Given a training mini-batch, Ordered SGD selects the
K
examples with the largest losses on which
to perform the minimization and computes a subgradient based only on these examples. This work
theoretically shows that on convex loss functions, Ordered-SGD is guaranteed to converge sublinearly
to a global optimum and to a critical point with weakly convex losses. In addition, they empirically
show that Ordered-SGD can achieve comparable results to vanilla SGD on multiple machine learning
models such as SVM and deep neural networks. In this paper, we show that top-k optimization can be
effectively applied to both gradient computation steps in SAM and is actually essential to achieving
good performance when Kis small.
3 K-SAM: Efficient Sharpness-Aware Minimization by Subsampling
We now introduce our proposed method, K-SAM, where we select
K1, K2
examples with the largest
loss values to estimate the gradients in the ascent and descent steps of a SAM update, respectively.
When
K1
and
K2
are small, the computational complexity of SAM will be vastly reduced since
a large proportion of the compute required for neural network training is concentrated in gradient
computations.
Recall that in the Section 2.1, given a mini-batch of examples,
B
, SAM first computes the ascent
gradient with respect to the current parameters
w
to find the perturbation by Eq
(2)
. Then, the
final gradient descent direction is formed by Eq
(3)
. In order to decrease the computational cost,
we approximate both gradients by using subsets
MK1,MK2
of the mini-batch
B
for gradients
calculation, where each subset contains training examples with the largest
K1, K2
loss values,
respectively. Formally, given the losses,
l=LB(w)
, of a batch
B
and its index set
I={1,2,· · · ,|B|}
of the samples, the subsets can be generated as following,
M{K1,K2}=n(xi, yi)∈ B :iQ{1,2},where Q{1,2}= arg max
QI,|Q|={K1,K2}X
iQ
lio.
3
摘要:

K-SAM:Sharpness-AwareMinimizationattheSpeedofSGDRenkunNiUniversityofMarylandrn9zm@umd.eduPing-yehChiangUniversityofMarylandJonasGeipingUniversityofMarylandMicahGoldblumNewYorkUniversityAndrewGordonWilsonNewYorkUniversityTomGoldsteinUniversityofMarylandAbstractSharpness-AwareMinimization(SAM)hasrecen...

展开>> 收起<<
K-SAM Sharpness-Aware Minimization at the Speed of SGD Renkun Ni.pdf

共13页,预览3页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:13 页 大小:1.38MB 格式:PDF 时间:2025-05-03

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 13
客服
关注