MaskTune Mitigating Spurious Correlations by Forcing to Explore Saeid Asgari Taghanaki

2025-05-02 0 0 1.61MB 16 页 10玖币
侵权投诉
MaskTune: Mitigating Spurious Correlations by
Forcing to Explore
Saeid Asgari Taghanaki*
Autodesk AI Lab
Aliasghar Khani*
Autodesk AI Lab
Fereshte Khani*
Stanford University
Ali Gholami*
Autodesk AI Lab
Linh Tran
Autodesk AI Lab
Ali Mahdavi-Amiri
Simon Fraser University
Ghassan Hamarneh
Simon Fraser University
Abstract
A fundamental challenge of over-parameterized deep learning models is learning
meaningful data representations that yield good performance on a downstream
task without over-fitting spurious input features. This work proposes MaskTune,
a masking strategy that prevents over-reliance on spurious (or a limited number
of) features. MaskTune forces the trained model to explore new features during
a single epoch finetuning by masking previously discovered features. MaskTune,
unlike earlier approaches for mitigating shortcut learning, does not require any
supervision, such as annotating spurious features or labels for subgroup samples
in a dataset. Our empirical results on biased MNIST, CelebA, Waterbirds, and
ImagenNet-9L datasets show that MaskTune is effective on tasks that often suffer
from the existence of spurious correlations. Finally, we show that MaskTune
outperforms or achieves similar performance to the competing methods when
applied to the selective classification (classification with rejection option) task.
Code for MaskTune is available at
https://github.com/aliasgharkhani/
Masktune.
1 Introduction
Spurious correlations are coincidental feature associations formed between a subset of the input and
target variables, which may be caused by factors such as data selection bias [Torralba and Efros,2011,
Jabri et al.,2016]. The presence of spurious correlations in training data can cause over-parameterized
deep neural networks to fail, often drastically, when such correlations do not hold in test data [Sagawa
et al.,2019] or when encountering domain shift [Arjovsky et al.,2019]. Consider the classification
problem of cows and camels [Beery et al.,2018], where most of the images of cows vs. camels
are captured on green fields vs. desert backgrounds due to selection bias (and perhaps the nature
of the problem that camels are often in the desert). A model trained on such data may rely on the
background as the key discriminative feature between cows and camels, thus failing on images of
cows on non-green backgrounds or camels on non-desert backgrounds.
In over-parameterized regimes, there are often several solutions with almost identical loss values,
and the optimizer (e.g., SGD) typically selects a low-capacity one [Wilson et al.,2017,Valle-Perez
et al.,2018,Arpit et al.,2017,Kalimeris et al.,2019]. In the presence of spurious correlations, the
optimizer might choose to leverage them as they generally demand less capacity than the expected
semantic cues of interest, e.g., relying on the local color or texture of grass instead of the elaborate
visual features that give a cow its appearance [Bruna and Mallat,2013,Bruna et al.,2015,Brendel
and Bethge,2019,Khani and Liang,2021].
*These authors contributed equally to this work.
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.00055v2 [cs.LG] 8 Oct 2022
ERM
Raw input
Class activations
Masked input
MaskTune
Single epoch
finetune
without spurious features
ERM
with spurious features Mask ERM
The first set of discriminative features found by ERM Masked features The second set of discriminative features found by MaskTune
Figure 1: MaskTune generates a new set of masked samples by obstructing the features discovered by
a model fully trained via empirical risk minimization (ERM). The ERM model is then fine-tuned for
only one epoch using the masked version of the original training data to force new feature exploration.
The features highlighted in yellow, red, and green correspond to features discovered by ERM, the
masked features, and the newly discovered features by MaskTune, respectively.
In previous work, a supervised loss function has been employed to reduce the effect of spurious
correlations [Sagawa et al.,2019]. However, identifying and annotating the spurious correlations
in a large dataset as a training signal is impractical. Other works have attempted to force models
to discard context and background (as spurious features) through input morphing using perceptual
similarities [Taghanaki et al.,2021] or learning casual variables [Javed et al.,2020]. Discarding
context and background, however, is incompatible with the human visual system that relies on
contextual information when detecting and recognizing objects [Palmer,1975,Biederman et al.,1982,
Chun and Jiang,1998,Henderson and Hollingworth,1999,Torralba,2003]. In addition, spurious
features may appear on the object itself (e.g., facial attributes). Thus discarding the context and
background may be a futile strategy in these cases.
Instead of requiring contextual and background information to be discarded or relying on a limited
number of features, we propose a single-epoch finetuning technique called MaskTune that prevents a
model from learning only the “first” simplest mapping (potentially spurious correlations) from the
input to the corresponding target variable. MaskTune forces the model to explore other input variables
by concealing (masking) the ones that have already been deemed discriminatory. As we finetune
the model with a new set of masked samples, we force the training to escape its myopic and greedy
feature-seeking approach and encourage exploring and leveraging more input variables. In other
words, as the previous clues are hidden, the model is constrained to find alternative loss-minimizing
input-target mappings. MaskTune conceals the first clues discovered by a fully trained model, whether
they are spurious or not. This forces the model to investigate and leverage new complementary
discriminatory input features. A model relying upon a broader array of complementary features
(some may be spurious while others are not) is expected to be more robust to test data missing a
subset of these features.
Figure 1visualizes how MaskTune works via a schematic of the cow-on-grass scenario. Even in the
absence of spurious correlations, models tend to focus on the shortcut (e.g., ears or skin texture of a
cow), which can prevent models from generalizing to scenarios where those specific parts are missing.
However, the object is still recognizable from the remaining parts. As an alternative, MaskTune
generates a diverse set of partially masked training examples, forcing the model to investigate a wider
area of the input features landscape e.g., new pixels.
A further disadvantage of relying on a limited number of features is the model’s inability to know
when it does not know. Let’s go back to the cow-camel classification example; if cows only appear
on grass in the training set then it is unclear which of the grass or the cow refers to the "cow" label.
A model that only relies on the grass feature can confidently make a wrong prediction when some
other object appear in the grass in the test time. We need the model to predict only if both cow
and grass appear in the picture and abstain otherwise. One method used in the literature to address
this issue is selective classification [Geifman and El-Yaniv,2019,2017,Khani et al.,2016], which
allows a network to reject a sample if it is not confident in its prediction. Selective classification
is essential in mission-critical applications such as autonomous driving, medical diagnostics, and
2
robotics as they need to defer the prediction to human if they are uncertain about the prediction.
Learning different sets of discriminatory features, in addition to reducing the effect of spurious
features, enables MaskTune to be applied to the problem of selective classification.
We apply MaskTune to two main tasks: a) robustness to spurious correlations, and b) selective
classification. We cover four different datasets under (a) including MNIST with synthetic spurious
features, CelebA and Waterbirds with spurious features in different subgroups [Sagawa et al.,2019],
and the Background Challenge [Xiao et al.,2020] which is a dataset for measuring the reliance
of methods on background information for prediction. Under (b) we test MaskTune on CIFAR-
10 [Krizhevsky et al.,2009], SVHN [Netzer et al.,2011], and Cats vs. Dogs [Geifman and El-Yaniv,
2019] datasets. On both tasks, we outperform or perform similarly to the previous complex methods
using our simple technique.
To the best of our knowledge, this is the first work to present a finetuning technique using masked
data to overcome spurious correlations. Our contributions are summarized as follows:
1.
We propose MaskTune, a new technique to reduce the effect of spurious correlations or
over-reliance on a limited number of input features without any supervision such as object
location or data subgroup labels.
2.
We show that MaskTune leads to learning a model that does not rely solely on the initially
discovered features.
3. We show how our method can be applied to selective classification tasks.
4.
We empirically verify the robustness of the learned representations to spurious correlations
on a variety of datasets.
2 Method
Setup.
We consider the supervised learning setting with inputs
x∈ X Rd
and corresponding
labels
y∈ Y ={1, . . . , k}
. We assume having access to samples
D0={(xi, yi)}n
i=1
drawn from an
unknown underlying distribution pdata(x, y).
Our goal is to learn the parameters
θΘ
of a prediction model
mθ:X → Y
that obtains low
classification error w.r.t some loss function (e.g., cross entropy)
`: Θ ×(X × Y)R
. Specifically,
we minimize:
L(θ) = Ex,ypdata (x,y)[`(mθ(x), y)] 1
n
n
X
i=1
`(mθ(xi), yi)(1)
where nis the number of pairs in the training data.
Besides having good prediction accuracy, we aim to develop a model which does not solely rely on
spurious or a limited number of input features. We propose to mask the input training data to create a
new set. We then finetune (for only one epoch) a fully trained ERM model with the new masked data
to reduce over-reliance on spurious or a limited number of features. The single epoch fine-tuning is
done using a small learning rate e.g., the last decayed learning rate that the ERM has used in the first
step. We found that large learning rates or more than one epoch fine-tuning leads to forgetting the
discriminative features learned by the ERM.
Input Masking.
A key ingredient of our approach is a masking function
G
that is applied offline
(i.e., after full training). The goal here is to construct a new masked dataset by concealing the most
discriminative features in the input discovered by a model after full training. This should encourage
the model to investigate new features with the masked training set during finetuning. As for
G
, we
adopt the xGradCAM [Selvaraju et al.,2017], which was originally designed for a visual explanation
of deep models by creating rough localization maps based on the gradient of the model loss w.r.t.
the output of a desired model layer. Given an input image of size
H×W×C
, xGradCAM outputs
a localization map
A
of size
H×W×1
, which shows the contribution of each pixel of the input
image in predicting the most probable class, i.e., it calculates the loss by choosing the class with
highest logit value (not the true label) as the target class. After acquiring the localization map, for
3
each sample
(xi, yi)
, where
xiX
and
yiY
, we mask the locations with the most contribution
as:
ˆxi=T(Axi;τ)xi;Axi=G(mθ(xi), yi)(2)
where
T
refers to a thresholding function by the threshold factor
τ
(i.e.,
T=Axiτ
), and
denotes
element-wise multiplication. As the resolution of
A
is typically coarser than that of the input data,
T(Axi)is up-sampled to the size of the input.
Procedurally, we first learn model
minitial
θ
using original unmasked training data
Dinitial
. Then we use
minitial
θ,Gand Tto create the masked set Dmasked. Finally, the fully trained predictor minitial
θis tuned
using Dmasked to obtain mfinal
θ.
As for the masking step, any explainability approach can be applied (note that some may have more
computational complexity, such as ScoreCAM [Wang et al.,2020]). We use xGradCAM [Selvaraju
et al.,2017] as it is fast and produces relatively denser heat-maps than other methods [Srinivas and
Fleuret,2019,Selvaraju et al.,2017,Wang et al.,2020].
2.1 MaskTune in Over-parameterized Regimes
Consider the overparametrized regime, in which the model family has sufficient complexity to fully fit
the training data. It has been shown that deep neural nets are overparametrized and can fit completely
random data [Zhang et al.,2021]. The generalization ability of deep neural nets is still not clear, but
there are some speculation that connect the deep network generalization to their tendency of choosing
simple functions that fit the training data [Valle-Perez et al.,2018,Arpit et al.,2017]. However,
this simplicity bias can cause side effects such as their poor performance with respect to adversarial
examples [Raghunathan et al.,2019] or to distribution shifts [Khani and Liang,2021,Shah et al.,
2020].
Here we study the effect of masking input features on complexity of a model in a situation where
indeed the training procedure chooses the least complex model that fits training data. We show that in
this case masking will result in learning a more complex model that discovers new features as the
previous ones are blocked.
Formally, let
C
denote a function that measures model complexity and assume that the masking
function
T
(as described in 2) only returns binary values, i.e., an indicator function that only keeps
some features and zeros out the rest. We show that if training procedure returns the least complex
model then masking results in a more complex model.
Proposition 1.
Consider an optimizing procedure that finds
min C(mθ), s.t., `(mθ)=0
as defined
in 1. Let masking function
T
return binary values. If both models
mθ
and
minitial
θ
fit the training data
(i.e., zero loss) then we have C(mfinal
θ)≥ C(minitial
θ).
Proof.
Note that both models belong to the model family (
minitial
θ, mfinal
θΘ
), and they both fit the
training data. In the first step, training procedure chooses
minitial
θ
over
mfinal
θ
; therefore according to
our assumption C(mfinal
θ)≥ C(minitial
θ).
2.2 Adapting MaskTune for Selective Classification
Here we show how to use MaskTune for the selective classification problem. In order to make a more
reliable prediction, we ensemble the original model (
minitial
θ
) and MaskTune (
mfinal
θ
) and only predict
if both models agree. As a result, if there exist two sets of features that can predict the label, our
method only predicts if both agree on the label (e.g., grass and cow in Figure 1).
To get an intuition on the performance of MaskTune for selective classification, similar to [Khani
et al.,2016] we analyze the noiseless overparametrized linear regression. We show that MaskTune
adaptation, explained above, abstains in the presence of covariate shift, thus leading to a more reliable
prediction.
In particular, we show that MaskTune only predicts if the relationship between masked and unmasked
features in training data holds in the test time. For example, if features describing “cow” are
4
摘要:

MaskTune:MitigatingSpuriousCorrelationsbyForcingtoExploreSaeidAsgariTaghanaki*AutodeskAILabAliasgharKhani*AutodeskAILabFereshteKhani*StanfordUniversityAliGholami*AutodeskAILabLinhTranAutodeskAILabAliMahdavi-AmiriSimonFraserUniversityGhassanHamarnehSimonFraserUniversityAbstractAfundamentalchallengeof...

展开>> 收起<<
MaskTune Mitigating Spurious Correlations by Forcing to Explore Saeid Asgari Taghanaki.pdf

共16页,预览4页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:16 页 大小:1.61MB 格式:PDF 时间:2025-05-02

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 16
客服
关注