
ERM
Raw input
Class activations
Masked input
MaskTune
Single epoch
finetune
without spurious features
ERM
with spurious features Mask ERM
The first set of discriminative features found by ERM Masked features The second set of discriminative features found by MaskTune
Figure 1: MaskTune generates a new set of masked samples by obstructing the features discovered by
a model fully trained via empirical risk minimization (ERM). The ERM model is then fine-tuned for
only one epoch using the masked version of the original training data to force new feature exploration.
The features highlighted in yellow, red, and green correspond to features discovered by ERM, the
masked features, and the newly discovered features by MaskTune, respectively.
In previous work, a supervised loss function has been employed to reduce the effect of spurious
correlations [Sagawa et al.,2019]. However, identifying and annotating the spurious correlations
in a large dataset as a training signal is impractical. Other works have attempted to force models
to discard context and background (as spurious features) through input morphing using perceptual
similarities [Taghanaki et al.,2021] or learning casual variables [Javed et al.,2020]. Discarding
context and background, however, is incompatible with the human visual system that relies on
contextual information when detecting and recognizing objects [Palmer,1975,Biederman et al.,1982,
Chun and Jiang,1998,Henderson and Hollingworth,1999,Torralba,2003]. In addition, spurious
features may appear on the object itself (e.g., facial attributes). Thus discarding the context and
background may be a futile strategy in these cases.
Instead of requiring contextual and background information to be discarded or relying on a limited
number of features, we propose a single-epoch finetuning technique called MaskTune that prevents a
model from learning only the “first” simplest mapping (potentially spurious correlations) from the
input to the corresponding target variable. MaskTune forces the model to explore other input variables
by concealing (masking) the ones that have already been deemed discriminatory. As we finetune
the model with a new set of masked samples, we force the training to escape its myopic and greedy
feature-seeking approach and encourage exploring and leveraging more input variables. In other
words, as the previous clues are hidden, the model is constrained to find alternative loss-minimizing
input-target mappings. MaskTune conceals the first clues discovered by a fully trained model, whether
they are spurious or not. This forces the model to investigate and leverage new complementary
discriminatory input features. A model relying upon a broader array of complementary features
(some may be spurious while others are not) is expected to be more robust to test data missing a
subset of these features.
Figure 1visualizes how MaskTune works via a schematic of the cow-on-grass scenario. Even in the
absence of spurious correlations, models tend to focus on the shortcut (e.g., ears or skin texture of a
cow), which can prevent models from generalizing to scenarios where those specific parts are missing.
However, the object is still recognizable from the remaining parts. As an alternative, MaskTune
generates a diverse set of partially masked training examples, forcing the model to investigate a wider
area of the input features landscape e.g., new pixels.
A further disadvantage of relying on a limited number of features is the model’s inability to know
when it does not know. Let’s go back to the cow-camel classification example; if cows only appear
on grass in the training set then it is unclear which of the grass or the cow refers to the "cow" label.
A model that only relies on the grass feature can confidently make a wrong prediction when some
other object appear in the grass in the test time. We need the model to predict only if both cow
and grass appear in the picture and abstain otherwise. One method used in the literature to address
this issue is selective classification [Geifman and El-Yaniv,2019,2017,Khani et al.,2016], which
allows a network to reject a sample if it is not confident in its prediction. Selective classification
is essential in mission-critical applications such as autonomous driving, medical diagnostics, and
2