In short, given a base optimizer, i.e., SGD, instead of computing the gradient of the model at the
current parameter vector
w
, SAM updates model parameters using the gradient with respect to the
perturbed model at
w+ ˆ
. Therefore a SAM update step requires two forward and backward passes
on each sample in a batch, namely a gradient ascent step to achieve the perturbation
ˆ
and a gradient
descent step to update the current model, which doubles the train time compared to the base optimizer.
2.2 Efficient Sharpness-Aware Minimization
Recently, several works improve the efficiency of SAM while retaining its performance benefits.
Bahri et al. (2021) improve the efficiency of SAM by reducing the computational cost of the gradient
ascent step. Instead of computing the ascent gradient on the whole mini-batch, they estimate the
gradient using a random subset. This tweak makes the computational cost only about 25% slower
than the vanilla SGD training routine when using
1/4
of the training batch for the ascent gradient,
while maintaining performance comparable to SAM. However, reducing the computational cost of
only the ascent step cannot make SAM as fast or faster than the base optimizer. In addition, we show
in Section 4 that using a random subset of the mini-batch for both gradient computations in SAM has
a negative impact on performance, which suggests that we should seek a better selection method than
random selection.
Du et al. (2021) propose an efficient SAM (ESAM) algorithm employing two strategies, namely
stochastic weight perturbation and sharpness-sensitive data selection. Stochastic weight perturbation
updates only part of the weights during the gradient ascent step which offers limited speed-ups.
During the descent step, the gradient is only computed using the examples whose loss values increase
the most after the parameter perturbation. With these two strategies, ESAM achieves up to 40.3%
acceleration compared to SAM. However, since ESAM selects the examples with the largest loss
differences, it has to compute the gradient over all samples in the batch for the ascent step and then
must similarly perform a forward pass over all samples in the batch for the descent step, which limits
the possible speed-ups from this approach.
2.3 Top-K Optimization
Top-k optimization has been applied to vanilla SGD (Fan et al., 2017; Kawaguchi & Lu, 2020).
Given a training mini-batch, Ordered SGD selects the
K
examples with the largest losses on which
to perform the minimization and computes a subgradient based only on these examples. This work
theoretically shows that on convex loss functions, Ordered-SGD is guaranteed to converge sublinearly
to a global optimum and to a critical point with weakly convex losses. In addition, they empirically
show that Ordered-SGD can achieve comparable results to vanilla SGD on multiple machine learning
models such as SVM and deep neural networks. In this paper, we show that top-k optimization can be
effectively applied to both gradient computation steps in SAM and is actually essential to achieving
good performance when Kis small.
3 K-SAM: Efficient Sharpness-Aware Minimization by Subsampling
We now introduce our proposed method, K-SAM, where we select
K1, K2
examples with the largest
loss values to estimate the gradients in the ascent and descent steps of a SAM update, respectively.
When
K1
and
K2
are small, the computational complexity of SAM will be vastly reduced since
a large proportion of the compute required for neural network training is concentrated in gradient
computations.
Recall that in the Section 2.1, given a mini-batch of examples,
B
, SAM first computes the ascent
gradient with respect to the current parameters
w
to find the perturbation by Eq
(2)
. Then, the
final gradient descent direction is formed by Eq
(3)
. In order to decrease the computational cost,
we approximate both gradients by using subsets
MK1,MK2
of the mini-batch
B
for gradients
calculation, where each subset contains training examples with the largest
K1, K2
loss values,
respectively. Formally, given the losses,
l=LB(w)
, of a batch
B
and its index set
I={1,2,· · · ,|B|}
of the samples, the subsets can be generated as following,
M{K1,K2}=n(xi, yi)∈ B :i∈Q{1,2},where Q{1,2}= arg max
Q⊆I,|Q|={K1,K2}X
i∈Q
lio.
3