The Dynamics of Sharpness-Aware Minimization Bouncing Across Ravines and Drifting Towards Wide Minima

2025-05-06 0 0 864.34KB 35 页 10玖币
侵权投诉
The Dynamics of Sharpness-Aware Minimization:
Bouncing Across Ravines
and Drifting Towards Wide Minima
Peter L. Bartlett
, Philip M. Long and Olivier Bousquet
Google
1600 Amphitheatre Parkway
Mountain View, CA 94040
{peterbartlett,plong,obousquet}google.com
Abstract
We consider Sharpness-Aware Minimization (SAM), a gradient-based optimization method
for deep networks that has exhibited performance improvements on image and language predic-
tion problems. We show that when SAM is applied with a convex quadratic objective, for most
random initializations it converges to a cycle that oscillates between either side of the minimum
in the direction with the largest curvature, and we provide bounds on the rate of convergence.
In the non-quadratic case, we show that such oscillations effectively perform gradient descent,
with a smaller step-size, on the spectral norm of the Hessian. In such cases, SAM’s update
may be regarded as a third derivative—the derivative of the Hessian in the leading eigenvector
direction—that encourages drift toward wider minima.
1 Introduction
The broad practical impact of deep learning has heightened interest in many of its surprising char-
acteristics: simple gradient methods applied to deep neural networks seem to efficiently optimize
nonconvex criteria, reliably giving a near-perfect fit to training data, but exhibiting good predictive
accuracy nonetheless [see Bartlett et al.,2021]. Optimization methodology is widely believed to
affect statistical performance by imposing some kind of implicit regularization, and there has been
considerable effort devoted to understanding the behavior of optimization methods and the nature
of solutions that they find. For instance, Barrett and Dherin [2020] and Smith et al. [2021] show
that discrete-time gradient descent and stochastic gradient descent can be viewed as gradient flow
methods applied to penalized losses that encourage smoothness, and Soudry et al. [2018] amd Azu-
lay et al. [2021] identify the implicit regularization imposed by gradient flow in specific examples,
including linear networks.
We consider Sharpness-Aware Minimization (SAM), a recently introduced [Foret et al.,2021]
gradient optimization method that has exhibited substantial improvements in prediction perfor-
mance for deep networks applied to image classification [Foret et al.,2021] and NLP [Bahri et al.,
2022] problems.
Also affiliated with University of California, Berkeley.
1
arXiv:2210.01513v2 [cs.LG] 11 Apr 2023
In introducing SAM, Foret et al motivate it using a minimax optimization problem
min
wmax
kk≤ρ`(w+),(1)
where `:RdRis an empirical loss defined on the parameter space Rd,k·kis the Euclidean
norm on the parameter space, and ρis a scale parameter. By viewing the difference
max
kk≤ρ`(w+)`(w)
as a measure of the sharpness of the empirical loss `at the parameter value w, the criterion in (1)
allows a trade-off between the empirical loss and the sharpness,
max
kk≤ρ`(w+) = `(w) + max
kk≤ρ`(w+)`(w)
| {z }
sharpness
.
In practice, SAM works with a simplification based on gradient measurements, starting with an
initial parameter vector w0Rdand updating the parameters at iteration tvia
wt+1 =wtη`wt+ρ`(wt)
k∇`(wt)k,(2)
where ηis a step-size parameter. Our goal in this paper is to understand the nature of the solutions
that the SAM updates (2) lead to.
In Sections 3and 4, we consider SAM with a convex quadratic criterion. The key insight is that
it is equivalent to a gradient descent method for a certain non-convex criterion whose stationary
points correspond to oscillations around the minimum in the directions of the eigenvectors of
the Hessian of the loss. The only stable stationary point corresponds to the leading eigenvector
direction: ‘bouncing across the ravine’. (Notice that this is not the solution to the motivating
minimax optimization problem (1), which is the minimum of the quadratic criterion.)
In Section 5, we consider SAM near a smooth minimum of the loss function `with a positive
semidefinite Hessian. For parameters corresponding to the solutions for the quadratic case, we see
that the SAM updates can be decomposed into two components. There is a large component in the
direction of the oscillation (bouncing across the ravine), and there is a smaller component in the
orthogonal subspace that corresponds to descending the gradient of the spectral norm of the Hes-
sian. Thus, SAM is able to drift towards wide minima by exploiting a specific third derivative (the
gradient of the second derivative in the leading eigenvalue direction) with only two gradient com-
putations per iteration. In Section 7, we present some open problems, the most important of which
is elucidating the relationship between wide minima of empirical loss and statistical performance.
2 Additional Related Work
Du et al. [2022] proposed a more computationally efficient variant of SAM. Beugnot et al. [2022]
studied the effect of a large learning rate with early stopping on spectrum of the Hessian in the
case of quadratic loss.
Cohen et al. [2020] provided a variety of natural settings where, empirically, when neural net-
works are trained with batch gradient descent and a fixed learning rate η, the spectral norm of the
2
Hessian tends toward 2, the “edge of stability”. Here, if the gradient is aligned with the principal
direction of the Hessian, the solution “bounces across the ravine”, as in the analysis of this paper.
A number of theoretical treatments of this phenomenon have since been proposed [Ahn et al.,2022,
Arora et al.,2022,Damian et al.,2022]. The most closely related of those to this paper is the work
of Damian et al. [2022], who also described conditions under which “bouncing across the ravine”
tends to decrease the spectral norm of the Hessian.
In independent work posted to arXiv after the initial version of this paper, Wen et al. [2022] per-
formed a variety of analyses of SAM and some related algorithms. Their results included showing
that SAM almost surely converges in the limit in the convex quadratic case, along with asymp-
totic analysis showing that, once SAM gets close enough to the manifold of loss minimizers, it
approximately tracks the path on a loss-minimizing manifold of gradient flow with respect to the
spectral norm of the Hessian, under smoothness assumptions on the loss. They also showed that
the stochastic version of SAM, in which both gradients at each step are estimated from a single
training example, approximately tracks the path of gradient flow with respect to the trace of the
Hessian.
3 SAM with Quadratic Loss: Bouncing Across Ravines
We first consider the application of SAM to minimize a convex quadratic objective `. Without loss
of generality, we assume that the minimum of `is at zero, the eigenvectors of `’s Hessian are the
coordinate axes, and the eigenvalues are sorted by the indices of the eigenvectors. Accordingly, for
Λ = diag(λ1, . . . , λd) with λ1≥ ··· ≥ λd>0, we consider loss `(w) = 1
2w>Λw. Then `(w) = Λw
and SAM sets
wt+1 =wtη`wt+ρ`(wt)
k∇`(wt)k
=IηΛηρ
kΛwtkΛ2wt.(3)
The following is our main result.
Theorem 1. There are polynomials pand p0and an absolute constant csuch that the following
holds. For any eigenvalues λ1> λ2... λd>0, loss `(w) = 1
2w>Λwwith Λ = diag(λ1, . . . , λd),
any neighborhood size ρ > 0, any step size 0< η < 1
2λ1, and any δ > 0, if w0is sampled from a
continuous probability distribution over Rd
whose density is bounded above by AR, and
for R > ηρλ1and q > 0, with probability at least 1δ,kw0k ≤ Rand w2
0,1q,
and w1, w2, ... are obtained through the SAM update (2), then, if κ=λ1d, for all
<p0(11, λd, η, ρ, δ, 1/ρ, A, R),
with probability 12δ, for all
t κ5
ηλdmin ηλd, λ2
12
21+d!plog 1

one of the following holds:
3
(a) All the iterates (b) The iterates close to the origin
Figure 1: The first 30 iterates of SAM, initialized at (2,2) with λ1= 1 and λ2= 1/2, η= 1/5 and
ρ= 1.
kwtηρλ1e1
2ηλ1k ≤ and kwt+1 +ηρλ1e1
2ηλ1k ≤ , or
kwt+ηρλ1e1
2ηλ1k ≤ and kwt+1 ηρλ1e1
2ηλ1k ≤ .
Theorem 1has the following corollary.
Theorem 2. For any eigenvalues λ1> λ2... λd>0, any neighborhood size ρ > 0, and any
step size 0< η < 1
2λ1, if w0is sampled from a continuous probability distribution over Rdwith
E[kw0k2]<, then, almost surely, for all  > 0, for all large enough t, the iterates of SAM applied
to the quadratic loss `(w) = 1
2w>diag(λ1, . . . , λd)wsatisfy:
kwtηρλ1e1
2ηλ1k ≤ and kwt+1 +ηρλ1e1
2ηλ1k ≤ , or
kwt+ηρλ1e1
2ηλ1k ≤ and kwt+1 ηρλ1e1
2ηλ1k ≤ .
Our analysis shows that, when SAM is initialized far from the optimum, training proceeds in
two stages. Early, the objective function is reduced exponentially fast, with the most rapid progress
made in the directions with highest variance. This can be seen, for example, in Figure 1a, which
plots the first 30 iterates of SAM initialized at (2,2) in the case that λ1= 1 and λ2= 1/2, η= 1/5
and ρ= 1. After a certain point, however, SAM’s iterates “overshoot” in the direction of highest
variance, as can be seen in Figure 1b, which is the same as Figure 1a, except zoomed in to the
region near the origin, where the details of the later iterates can be seen. During this second phase,
the share of the length of the parameter vector in the first component increases, and the process
converges to the oscillation described in Theorem 2. Note that, as illustrated in Figure 1a, due to
the normalization by kwtk, the parameter vector can jump away from a position very close to the
origin, with a correspondingly very small loss. However, as we will see, the training process makes
steady progress with respect to a potential function that we will define in Section 4.3.
4 Proof of Theorem 1
In this section, we prove the following theorem, which implies Theorem 1. We denote max{z, 0}
by [z]+.
4
Theorem 3. There is an absolute constant csuch that, for any eigenvalues λ1> λ2... λd>0,
loss `(w) = 1
2w>Λwwith Λ = diag(λ1, . . . , λd), any neighborhood size ρ > 0, any initialization pa-
rameters R, A, q > 0, and any step size 0< η < 1
2λ1, for all 0<<min npηλ1/2,1/(2ρλ1), ηρλ2
1/2o,
for all δ > 0, if w0is sampled from a continuous probability distribution over Rd
whose density is bounded above by AR, and
with probability at least 1δ,kw0k ≤ Rand w2
0,1q,
and w1, w2, ... are obtained through the SAM update (2), then, with probability 12δ, for all
t6λ5
1
ηλ6
dmin nηλd,λ2
1
λ2
21olog 4
ηλ1
+1
min nηλd,λ2
1
λ2
21olog 4(1 + ηρλ2
1)2
λ2
d2+ log R2
q
+
2hlog R
ηρλ1i+
ηλdmin nηλd,λ2
1
λ2
21o log (2λ1R) + hlog R
ηρλ1i+log 9·6d+3R3
(ηλd)d+3(ηρλ1)3
ηλd
+ log
4πd/2(4ηρλ2
1)d1hlog R
ηρλ1i+A
Γ(d/2)δηλd
!
+6
ηλ1
ln 2(1 + ηρλ2
1)
λd
one of the following holds:
kwtηρλ1e1
2ηλ1k ≤ and kwt+1 +ηρλ1e1
2ηλ1k ≤ , or
kwt+ηρλ1e1
2ηλ1k ≤ and kwt+1 ηρλ1e1
2ηλ1k ≤ .
The proof of Theorem 3requires some lemmas, which we prove first. Throughout this section,
we assume that ηλ1<1/2 and we highlight where the assumption λ1> λ2is used.
The evolution of the gradient `(wt)=Λwtplays a key role in the dynamics of SAM. To
simplify expressions, we refer to it using the shorthand vt. Substituting into the SAM update (3)
for the quadratic loss gives
vt+1 =IηΛηρ
kvtkΛ2vt,
so, for all i[d] and all t, we have
vt+1,i =1ηλiηρλ2
i
kvtkvt,i
= (1 ηλi)kvtk − ηρλ2
i
1ηλivt,i
kvtk
= (1 ηλi) (kvtk − γi)vt,i
kvtk,
5
摘要:

TheDynamicsofSharpness-AwareMinimization:BouncingAcrossRavinesandDriftingTowardsWideMinimaPeterL.Bartlett*,PhilipM.LongandOlivierBousquetGoogle1600AmphitheatreParkwayMountainView,CA94040fpeterbartlett,plong,obousquetggoogle.comAbstractWeconsiderSharpness-AwareMinimization(SAM),agradient-basedoptimiz...

展开>> 收起<<
The Dynamics of Sharpness-Aware Minimization Bouncing Across Ravines and Drifting Towards Wide Minima.pdf

共35页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!

相关推荐

分类:图书资源 价格:10玖币 属性:35 页 大小:864.34KB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 35
客服
关注