Improving Robustness with Adaptive Weight Decay Amin Ghiasi Ali Shafahi Reza Ardekani Apple

2025-05-08 0 0 2.28MB 26 页 10玖币
侵权投诉
Improving Robustness with Adaptive Weight Decay
Amin Ghiasi, Ali Shafahi, Reza Ardekani
Apple
Cupertino, CA, 95014
{mghiasi2, ashafahi, rardekani} @apple.com
Abstract
We propose adaptive weight decay, which automatically tunes the hyper-parameter
for weight decay during each training iteration. For classification problems, we
propose changing the value of the weight decay hyper-parameter on the fly based on
the strength of updates from the classification loss (i.e., gradient of cross-entropy),
and the regularization loss (i.e.,
2
-norm of the weights). We show that this simple
modification can result in large improvements in adversarial robustness — an
area which suffers from robust overfitting — without requiring extra data across
various datasets and architecture choices. For example, our reformulation results in
20% relative robustness improvement for CIFAR-100, and 10% relative robustness
improvement on CIFAR-10 comparing to the best tuned hyper-parameters of
traditional weight decay resulting in models that have comparable performance to
SOTA robustness methods. In addition, this method has other desirable properties,
such as less sensitivity to learning rate, and smaller weight norms, which the latter
contributes to robustness to overfitting to label noise, and pruning.
1 Introduction
Deep Neural Networks (DNNs) have exceeded human capability on many computer vision tasks. Due
to their high capacity for memorizing training examples (Zhang et al., 2021), DNN generalization
heavily relies on the training algorithm. To reduce memorization and improve generaliazation, several
approaches have been taken including regularization and augmentation. Some of these augmentation
techniques alter the network input (DeVries & Taylor, 2017; Chen et al., 2020; Cubuk et al., 2019,
2020; Müller & Hutter, 2021), some alter hidden states of the network (Srivastava et al., 2014; Ioffe
& Szegedy, 2015; Gastaldi, 2017; Yamada et al., 2019), some alter the expected output (Warde-Farley
& Goodfellow, 2016; Kannan et al., 2018), and some affect multiple levels (Zhang et al., 2017; Yun
et al., 2019; Hendrycks et al., 2019b). Typically, augmentation methods aim to enhance generalization
by increasing the diversity of the dataset. The utilization of regularizers, such as weight decay (Plaut
et al., 1986; Krogh & Hertz, 1991), serves to prevent overfitting by eliminating solutions that solely
memorize training examples and by constraining the complexity of the DNN. Regularization methods
are most beneficial in areas such as adversarial robustness, and noisy-data settings – settings which
suffer from catastrophic overfitting. In this paper, we revisit weight decay; a regularizer mainly used
to avoid overfitting.
The rest of the paper is organized as follows: In Section 2, we revisit tuning the weight decay
hyper-parameter to improve adversarial robustness and introduce Adaptive Weight Decay. Also in
Section 2, through extensive experiments on various image classification datasets, we show that
adversarial training with Adaptive Weight Decay improves both robustness and natural generalization
compared to traditional non-adaptive weight decay. Next, in Section 3, we briefly mention other
potential applications of Adaptive Weight Decay to network pruning, robustness to sub-optimal
learning-rates, and training on noisy labels.
37th Conference on Neural Information Processing Systems (NeurIPS 2023).
arXiv:2210.00094v2 [cs.LG] 2 Dec 2023
2 Adversarial Robustness
DNNs are susceptible to adversarial perturbations (Szegedy et al., 2013; Biggio et al., 2013). In the
adversarial setting, the adversary adds a small imperceptible noise to the image, which fools the
network into making an incorrect prediction. To ensure that the adversarial noise is imperceptible to
the human eye, usually noise with bounded
p
-norms have been studied (Sharif et al., 2018). In such
settings, the objective for the adversary is to maximize the following loss:
max
|δ|pϵXent(f(x+δ, w), y),(1)
where
Xent
is the Cross-entropy loss,
δ
is the adversarial perturbation,
x
is the clean example,
y
is
the ground truth label, ϵis the adversarial perturbation budget, and wis the DNN paramater.
A multitude of papers concentrate on the adversarial task and propose methods to generate robust
adversarial examples through various approaches, including the modification of the loss function
and the provision of optimization techniques to effectively optimize the adversarial generation loss
functions (Goodfellow et al., 2014; Madry et al., 2017; Carlini & Wagner, 2017; Izmailov et al.,
2018; Croce & Hein, 2020a; Andriushchenko et al., 2020). An additional area of research centers on
mitigating the impact of potent adversarial examples. While certain studies on adversarial defense
prioritize approaches with theoretical guarantees (Wong & Kolter, 2018; Cohen et al., 2019), in
practical applications, variations of adversarial training have emerged as the prevailing defense
strategy against adversarial attacks (Madry et al., 2017; Shafahi et al., 2019; Wong et al., 2020;
Rebuffi et al., 2021; Gowal et al., 2020). Adversarial training involves on the fly generation of
adversarial examples during the training process and subsequently training the model using these
examples. The adversarial training loss can be formulated as a min-max optimization problem:
min
wmax
|δ|pϵXent(f(x+δ, w), y),(2)
2.1 Robust overfitting and relationship to weight decay
Adversarial training is a strong baseline for defending against adversarial attacks; however, it often
suffers from a phenomenon referred to as Robust Overfitting (Rice et al., 2020). Weight decay
regularization, as discussed in 2.1.1, is a common technique used for preventing overfitting.
2.1.1 Weight Decay
Weight decay encourages weights of networks to have smaller magnitudes (Zhang et al., 2018)
and is widely used to improve generalization. Weight decay regularization can have many forms
(Loshchilov & Hutter, 2017), and we focus on the popular
2
-norm variant. More precisely, we focus
on classification problems with cross-entropy as the main loss – such as adversarial training – and
weight decay as the regularizer, which was popularized by Krizhevsky et al. (2017):
Lossw(x, y) = Xent(f(x, w), y) + λwd
2w2
2,(3)
where
w
is the network parameters, (
x
,
y
) is the training data, and
λwd
is the weight-decay hyper-
parameter.
λwd
is a crucial hyper-parameter in weight decay, determining the weight penalty
compared to the main loss (e.g., cross-entropy). A small
λwd
may cause overfitting, while a large
value can yield a low weight-norm solution that poorly fits the training data. Thus, selecting an
appropriate λwd value is essential for achieving an optimal balance.
2.1.2 Robust overfitting phenomenon revisited
To study robust overfitting, we focus on evaluating the
adversarial robustness on the CIFAR-
10 dataset while limiting the adversarial budget of the attacker to
ϵ= 8
– a common setting for
evaluating robustness. For these experiments, we use a WideResNet 28-10 architecture (Zagoruyko
& Komodakis, 2016) and widely adopted PGD adversarial training (Madry et al., 2017) to solve the
adversarial training loss with weight decay regularization:
min
wmax
|δ|8Xent(f(x+δ, w), y) + λwd
2w2
2,(4)
2
(a) (b) (c)
Figure 1: Robust validation accuracy (a) and validation loss (b) and training loss (c) on CIFAR-10
subsets.
λwd = 0.00089
is the best performing hyper-parameter we found by doing a grid-search.
The other two hyper-parameters are two points from our grid-search, one with larger and the other
with smaller hyper-parameter for weight decay. The thickness of the plot-lines correspond to the
magnitude of the weight-norm penalties. As it can be seen by (a) and (b), networks trained by small
values of
λwd
suffer from robust-overfitting, while networks trained with larger values of
λwd
do not
suffer from robust overfitting but the larger
λwd
further prevents the network from fitting the data (c)
resulting in reduced overall robustness.
We reserve 10% of the training examples as a held-out validation set for early stopping and checkpoint
selection. In practice, to solve eq. 4, the network parameters
w
are updated after generating adversarial
examples in real-time using a 7-step PGD adversarial attack. We train for 200 epochs, using an initial
learning-rate of 0.1 combined with a cosine learning-rate schedule. Throughout training, at the end of
each epoch, the robust accuracy and robustness loss (i.e., cross-entropy loss of adversarial examples)
are evaluated on the validation set by subjecting the held-out validation examples to a 3-step PGD
attack. For further details, please refer to A.1.
To further understand the robust overfitting phenomenon in the presence of weight decay, we train
different models by varying the weight-norm hyperparameter λwd in eq. 4.
Figure 1 illustrates the accuracy and cross-entropy loss on the adversarial examples built for the
held-out validation set for three choices
1
of
λwd
throughout training. As seen in Figure 1(a), for
small
λwd
choices, the robust validation accuracy does not monotonically increase towards the end
of training. The Non-monotonicity behavior, which is related to robust overfitting, is even more
pronounced if we look at the robustness loss computed on the held-out validation (Figure 1(b)). Note
that this behavior is still evident even if we look at the best hyper-parameter value according to the
validation set (λ
wd = 0.00089).
Various methods have been proposed to rectify robust overfitting, including early stopping (Rice
et al., 2020), use of extra unlabeled data (Carmon et al., 2019), synthesized images (Gowal et al.,
2020), pre-training (Hendrycks et al., 2019a), use of data augmentations (Rebuffi et al., 2021), and
stochastic weight averaging (Izmailov et al., 2018).
In Fig. 1, we observe that simply having smaller weight-norms (by increasing
λwd
) could reduce this
non-monotonic behavior on the validation set adversarial examples. Although, this comes at the cost
of larger cross-entropy loss on the training set adversarial examples, as shown in Figure 1(c). Even
though the overall loss function from eq. 4 is a minimization problem, the terms in the loss function
implicitly have conflicting objectives: During the training process, when the cross-entropy term holds
dominance, effectively reducing the weight norm becomes challenging, resulting in non-monotonic
behavior of robust validation metrics towards the later stages of training. Conversely, when the
weight-norm term takes precedence, the cross-entropy objective encounters difficulties in achieving
significant reductions. In the next section, we introduce Adaptive Weight Decay, which explicitly
strikes a balance between these two terms during training.
2.2 Adaptive Weight Decay
Inspired by the findings in 2.1.2, we propose Adaptive Weight Decay (AWD). The goal of AWD is to
maintain a balance between weight decay and cross-entropy updates during training in order to guide
1Figure 2 captures the complete set of λwd values we tested.
3
the optimization to a solution which satisfies both objectives more effectively. To derive AWD, we
study one gradient descent step for updating the parameter wat step t+ 1 from its value at step t:
wt+1 =wt− ∇wt·lr wt·λwd ·lr, (5)
where
wt
is the gradient computed from the cross-entropy objective, and
wt·λwd
is the gradient
computed from the weight decay term from eq. 3. We define
λawd
as a metric that keeps track of the
ratio of the magnituedes coming from each objective:
λawd(t)=λwdwt
∥∇wt,(6)
To keep a balance between the two objectives, we aim to keep this ratio constant during training.
AWD is a simple yet effective way of maintaining this balance. Adaptive weight decay shares
similarities with non-adaptive (traditional) weight decay, with the only distinction being that the
hyper-parameter
λwd
is not fixed throughout training. Instead,
λwd
dynamically changes in each
iteration to ensure
λawd(t)λawd(t1) λawd
. To keep this ratio constant at every step
t
, we can
rewrite the λawd equation (eq. 6) as:
λwd(t)=λawd · ∥∇wt
wt,(7)
Eq. 7 allows us to have a different weight decay hyperparameter value (λwd) for every optimization
iteration
t
, which keeps the gradients received from the cross entropy and weight decay balanced
throughout the optimization. Note that weight decay penalty
λt
can be computed on the fly with
almost no computational overhead during the training. Using the exponential weighted average
¯
λt= 0.1ׯ
λt1+ 0.9×λt, we could make λtmore stable (Algorithm 1).
Algorithm 1 Adaptive Weight Decay
1: Input: λawd >0
2: ¯
λ0
3: for (x, y)loader do
4: pmodel(x)Get models prediction.
5: main CrossEntropy(p, y)Compute CrossEntropy.
6: wbackward(main)Compute the gradients of main loss w.r.t weights.
7: λ∥∇wλawd
wCompute iteration’s weight decay hyperparameter.
8: ¯
λ0.1ׯ
λ+ 0.9×stop_gradient(λ)Compute the weighted average as a scalar.
9: wwlr(w+¯
λ×w)Update Network’s parameters.
10: end for
2.2.1 Differences between Adaptive and Non-Adaptive Weight Decay
To study the differences between adaptive and non-adaptive weight decay and to build intuition, we
can plug in
λt
of the adaptive method (eq. 7) directly into the equation for traditional weight decay
(eq. 3) and derive the total loss based on Adaptive Weight Decay:
Losswt(x, y) = Xent(f(x, wt), y) + λawd · ∥∇wt∥∥wt
2,(8)
Please note that directly solving eq. 8 will invoke the computation of second-order derivatives since
λt
is computed using the first-order derivatives. However, as stated in Alg. 1, we convert the
λt
into a
non-tensor scalar to save computation and avoid second-order derivatives. We treat
∥∇wt
in eq. 8
as a constant and do not allow gradients to back-propagate through it. As a result, adaptive weight
decay has negligible computation overhead compared to traditional non-adaptive weight decay.
By comparing the weight decay term in the adaptive weight decay loss (eq. 8):
λawd
2w∥∥∇w
with that of the traditional weight decay loss (eq. 3):
λwd
2w2
, we can build intuition on some of
the differences between the two. For example, the non-adaptive weight decay regularization term
approaches zero only when the weight norms are close to zero, whereas, in AWD, it also happens
4
(a) (b)
Figure 2: Robust accuracy (a) and loss (b) on CIFAR-10 validation subset. Both figures highlight the
best performing hyper-parameter for non-adaptive weight decay
λwd = 0.00089
with sharp strokes.
As it can be seen, lower values of
λwd
cause robust overfitting, while high values of it prevent network
from fitting entirely. However, training with adaptive weight decay prevents overfitting and achieves
highest performance in robustness.
when the cross-entropy gradients are close to zero. Consequently, AWD prevents over-optimization
of weight norms in flat minima, allowing for more (relative) weight to be given to the cross-entropy
objective. Additionally, AWD penalizes weight norms more when the gradient of cross-entropy is
large, preventing it from falling into steep local minima and hence overfitting early in training.
We verify our intuition of AWD being capable of reducing robust overfitting in practice by replacing
the non-adaptive weight decay with AWD and monitoring the same two metrics from 2.1.2. The
results for a good choice of the AWD hyper-parameter (
λawd
) and various choices of non-adaptive
weight decay (λwd) hyper-parameter are summarized in Figure 2 2.
2.2.2 Related works to Adaptive Weight Decay
The most related studies to AWD are AdaDecay (Nakamura & Hong, 2019) and LARS (You et al.,
2017). AdaDecay changes the weight decay hyper-parameter adaptively for each individual parameter,
as opposed to ours which we tune the hyper-parameter for the entire network. LARS is a common
optimizer when using large batch sizes which adaptively changes the learning rate for each layer. We
evaluate these relevant methods in the context of improving adversarial robustness and experimentally
compare with AWD in Table 2 and Appendix D 3.
2.3 Experimental Robustness results for Adaptive Weight Decay
AWD can help improve the robustness on various datasets which suffer from robust overfitting. To
illustrate this, we focus on six datasets: SVHN, FashionMNIST, Flowers, CIFAR-10, CIFAR-100,
and Tiny ImageNet. Tiny ImageNet is a subset of ImageNet, consisting of 200 classes and images of
size
64 ×64 ×3
. For all experiments, we use the widely accepted 7-step PGD adversarial training to
solve eq. 4 (Madry et al., 2017) while keeping 10% of the examples from the training set as held-out
validation set for the purpose of early stopping. For early stopping, we select the checkpoint with the
highest
= 8
robustness accuracy measured by a 3-step PGD attack on the held-out validation set.
For CIFAR10, CIFAR100, and Tiny ImageNet experiments, we use a WideResNet 28-10 architecture,
and for SVHN, FashionMNIST, and Flowers, we use a ResNet18 architecture. Other details about
the experimental setup can be found in Appendix A.1. For all experiments, we tune the conventional
non-adaptive weight decay parameter (
λwd
) for improving robustness generalization and compare
that to tuning the
λawd
hyper-parameter for adaptive weight decay. To ensure that we search for
enough values for λwd, we use up to twice as many values for λwd compared to λawd.
Figure 3 plots the robustness accuracy measured by applying AutoAttack (Croce & Hein, 2020b)
on the test examples for the CIFAR-10, CIFAR-100, and Tiny ImageNet datasets, respectively. We
2See Appendix C.4 for similar analysis on other datasets.
3Due to space limitations we defer detailed discussions and comparisons to Appendix D.
5
摘要:

ImprovingRobustnesswithAdaptiveWeightDecayAminGhiasi,AliShafahi,RezaArdekaniAppleCupertino,CA,95014{mghiasi2,ashafahi,rardekani}@apple.comAbstractWeproposeadaptiveweightdecay,whichautomaticallytunesthehyper-parameterforweightdecayduringeachtrainingiteration.Forclassificationproblems,weproposechangin...

展开>> 收起<<
Improving Robustness with Adaptive Weight Decay Amin Ghiasi Ali Shafahi Reza Ardekani Apple.pdf

共26页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:26 页 大小:2.28MB 格式:PDF 时间:2025-05-08

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 26
客服
关注