
Based on this finding, Foret et al. [
17
] propose a novel approach to improve model generalization
called sharpness-aware minimization (SAM), which simultaneously minimizes loss value and loss
sharpness. SAM quantifies the landscape sharpness as the maximized difference of loss when a
perturbation is added to the weight. When the model reaches a sharp area, the perturbed gradients in
SAM help the model jump out of the sharp minima. In practice, SAM requires two forward-backward
computations for each optimization step, where the first computation is to obtain the perturbation and
the second one is for parameter update. Despite the remarkable performance [
17
,
34
,
14
,
7
], This
property makes SAM double the computational cost of the conventional optimizer, e.g., SGD [3].
Since SAM calculates perturbations indiscriminately for all parameters, a question is arisen:
Do we need to calculate perturbations for all parameters?
Above all, we notice that in most deep neural networks, only about 5% of parameters are sharp and
rise steeply during optimization [
31
]. Then we explore the effect of SAM in different dimensions to
answer the above question and find out (i) little difference between SGD and SAM gradients in most
dimensions (see Fig. 1); (ii) more flatter without SAM in some dimensions (see Fig. 4and Fig. 5).
Inspired by the above discoveries, we propose a novel scheme to improve the efficiency of SAM via
sparse perturbation, termed Sparse SAM (SSAM). SSAM, which plays the role of regularization, has
better generalization, and its sparse operation also ensures the efficiency of optimization. Specifically,
the perturbation in SSAM is multiplied by a binary sparse mask to determine which parameters
should be perturbed. To obtain the sparse mask, we provide two implementations. The first solution
is to use Fisher information [
16
] of the parameters to formulate the binary mask, dubbed SSAM-F.
The other one is to employ dynamic sparse training to jointly optimize model parameters and the
sparse mask, dubbed SSAM-D. The first solution is relatively more stable but a bit time-consuming,
while the latter is more efficient.
In addition to these solutions, we provide the theoretical convergence analysis of SAM and SSAM in
non-convex stochastic setting, proving that our SSAM can converge at the same rate as SAM, i.e.,
O(log T/√T)
. At last, we evaluate the performance and effectiveness of SSAM on CIFAR10 [
33
],
CIFAR100 [
33
] and ImageNet [
8
] with various models. The experiments confirm that SSAM
contributes to a flatter landscape than SAM, and its performance is on par with or even better than
SAM with only about 50% perturbation. These results coincide with our motivations and findings.
To sum up, the contribution of this paper is three-fold:
•
We rethink the role of perturbation in SAM and find that the indiscriminate perturbations
are suboptimal and computationally inefficient.
•
We propose a sparsified perturbation approach called Sparse SAM (SSAM) with two
variants, i.e., Fisher SSAM (SSAM-F) and Dynamic SSAM (SSAM-D), both of which enjoy
better efficiency and effectiveness than SAM. We also theoretically prove that SSAM can
converge at the same rate as SAM, i.e.,O(log T/√T).
•
We evaluate SSAM with various models on CIFAR and ImageNet, showing WideResNet
with SSAM of a high sparsity outperforms SAM on CIFAR; SSAM can achieve competitive
performance with a high sparsity; SSAM has a comparable convergence rate to SAM.
2 Related Work
In this section, we briefly review the studies on sharpness-aware minimum optimization (SAM),
Fisher information in deep learning, and dynamic sparse training.
SAM and flat minima.
Hochreiter et al. [
27
] first reveal that there is a strong correlation between the
generalization of a model and the flat minima. After that, there is a growing amount of research based
on this finding. Keskar et al. [
31
] conduct experiments with a larger batch size, and in consequence
observe the degradation of model generalization capability. They [
31
] also confirm the essence of this
phenomenon, which is that the model tends to converge to the sharp minima. Keskar et al. [
31
] and
Dinh et al. [
12
] state that the sharpness can be evaluated by the eigenvalues of the Hessian. However,
they fail to find the flat minima due to the notorious computational cost of Hessian.
Inspired by this, Foret et al. [
17
] introduce a sharpness-aware optimization (SAM) to find a flat
minimum for improving generalization capability, which is achieved by solving a mini-max problem.
2