
Outlier-Robust Group Inference via Gradient Space Clustering
Figure 1:
An illustration of learning group annotations in the presence of outliers.
(a) A toy
dataset in two dimensions. There are four groups
g= 1,2,3,4
and an outlier.
g= 1
and
g= 3
are
the majority groups distributed as mixtures of three components each;
g= 2
and
g= 4
are unimodal
minority groups.
y
-axis is the decision boundary of a logistic regression classifier. Figures (b, c, d)
compare different data views for learning group annotations and detecting outliers via clustering of
samples with y= 0. (b) loss values can confuse outliers and minority samples which both can have
high loss; (c) in the original feature space it is difficult to distinguish one of the majority group modes
and the minority group; (d) gradient space (bias gradient omitted for visualization) simplifies the data
structure making it easier to identify the minority group and to detect outliers.
identities with toxicity (Dixon et al.,2018;Garg et al.,2019;Yurochkin & Sun,2020). A related
phenomenon is subpopulation shift (Koh et al.,2021), i.e., when the test distribution differs from
the train distribution in terms of group proportions. Under subpopulation shift, poor performance
on the minority groups in the train data translates into poor overall test distribution performance,
where these groups are more prevalent or more heavily weighted. Subpopulation shift occurs in many
application domains (Tatman,2017;Beery et al.,2018;Oakden-Rayner et al.,2020;Santurkar et al.,
2020;Koh et al.,2021).
Prior work offers a variety of methods for training models robust to subpopulation shift and spurious
correlations, including group distributionally robust optimization (gDRO) (Hu et al.,2018;Sagawa
et al.,2019), importance weighting (Shimodaira,2000;Byrd & Lipton,2019), subsampling (Sagawa
et al.,2020;Idrissi et al.,2022;Maity et al.,2022), and variations of tilted ERM (Li et al.,2020,
2021). These methods are successful in achieving comparable performance across groups in the data,
but they require group annotations. The annotations can be expensive to obtain, e.g., labeling spurious
backgrounds in image recognition (Beery et al.,2018) or labeling identity mentions in the toxicity
example. It also could be challenging to anticipate all potential spurious correlations in advance, e.g.,
it could be background, time of day, camera angle, or unanticipated identities subject to harassment.
Recently, methods have emerged for learning group annotations (Sohoni et al.,2020;Liu et al.,2021;
Creager et al.,2021) and variations of DRO that do not require groups (Hashimoto et al.,2018;Zhai
et al.,2021). One common theme is to treat data where an ERM model makes mistakes (i.e., high-loss
points) as a minority group (Hashimoto et al.,2018;Liu et al.,2021) and increase the weighting
of these points. Unfortunately, such methods are at risk of overfitting to outliers (e.g., mislabeled
data, corrupted images), which are also high-loss points. Indeed, existing methods for outlier-robust
training propose to ignore the high-loss points (Shen & Sanghavi,2019), the direct opposite of the
approach in (Hashimoto et al.,2018;Liu et al.,2021).
In this paper, our goal is to learn group annotations in the presence of outliers. Rather than using loss
values (which above were seen to create opposing tradeoffs), we propose to instead first represent
data using gradients of a datum’s loss w.r.t. the model parameters. Such gradients tell us how a
specific data point wants the parameters of the model to change to fit it better. In this gradient space,
we anticipate groups (conditioned on label) to correspond to gradients forming clusters. Outliers, on
the other hand, majorly correspond to isolated gradients: they are likely to want model parameters
to change differently from any of the groups and other outliers. See Figure 1for an illustration.
The gradient space structure allows us to separate out the outliers and learn the group annotations
via traditional clustering techniques such as DBSCAN (Ester et al.,1996). We use learned group
annotations to train models with improved worst-group performance (measured w.r.t. the true group
annotations).
We summarize our contributions below:
2