Just Mix Once Worst-group Generalization by Group Interpolation Giorgio Giannone1Serhii Havrylov Jordan Massiah Emine Yilmaz Yunlong Jiao1 DTU2Amazon Amazon Amazon UCL Amazon

2025-05-06 0 0 2.87MB 14 页 10玖币
侵权投诉
Just Mix Once: Worst-group Generalization by Group Interpolation
Giorgio Giannone1Serhii Havrylov Jordan Massiah Emine Yilmaz Yunlong Jiao1
DTU2Amazon Amazon Amazon, UCL Amazon
Abstract
Advances in deep learning theory have revealed
how average generalization relies on superficial
patterns in data. The consequences are brittle
models with poor performance with shift in group
distribution at test time. When group annotation
is available, we can use robust optimization tools
to tackle the problem. However, identification
and annotation are time-consuming, especially
on large datasets. A recent line of work lever-
ages self-supervision and oversampling to im-
prove generalization on minority groups without
group annotation. We propose to unify and gener-
alize these approaches using a class-conditional
variant of mixup tailored for worst-group gener-
alization. Our approach, Just Mix Once (JM1),
interpolates samples during learning, augment-
ing the training distribution with a continuous
mixture of groups. JM1 is domain agnostic and
computationally efficient, can be used with any
level of group annotation, and performs on par
or better than the state-of-the-art on worst-group
generalization. Additionally, we provide a simple
explanation of why JM1 works.
1 Introduction
Supervised learning aims to fit a model on a train set to
maximize a global metric at test time (Vapnik,1999;James
et al.,2013;Hastie et al.,2009). However, the optimization
process can exploit spurious correlations between the tar-
get
y
and superficial patterns
c
in the data (Geirhos et al.,
2018a,b;Hermann et al.,2020;Recht et al.,2019). In this
work, we are especially interested in the setting in which
labels and patterns can be categorized into groups
g= (c, y)
1Correspondence to:
gigi@dtu.dk,jyunlong@amazon.co.uk.
2Work done when the first author was at Amazon Cambridge.
Preprint.
and we study generalization performance in the presence of
shifts in group distribution at test time.
For example, in the case of
coloredMNIST
(Figure 1),
a model can use information about
color
(spurious cor-
relation with the target class) instead of
shape
to clas-
sify a digit (almost all
6
s are blue at train time) during
training. Then, in the presence of a shift in group distribu-
tion at test time (Hendrycks et al.,2021), the model relies
on the color bias with degradation in generalization per-
formance (for example,
6
s are green and blue with equal
proportion at test time). Powerful deep learning models
exploit easy superficial correlations (Geirhos et al.,2018a,
2017), such as texture, color, background (Hermann et al.,
2020;Hendrycks and Dietterich,2019) to improve average
generalization. Solving this issue involves constraining the
model (Ilyas et al.,2019;Tsipras et al.,2018;Schmidt et al.,
2018), for example leveraging explicit constraints (Zemel
et al.,2013;Hardt et al.,2016;Zafar et al.,2017), invari-
ances (Arjovsky et al.,2019;Creager et al.,2021), and
causal structure (Oneto and Chiappa,2020).
Similarly, in the presence of minority groups in a dataset, the
model will tend to ignore such groups, relying on frequent
patterns in majority groups. Recent work has shown how
to tackle unbalanced subpopulation or groups in data using
explicit supervision (Sagawa et al.,2019;Yao et al.,2022),
self-supervision (Nam et al.,2020), and oversampling (Liu
et al.,2021), improving generalization on minority groups.
However, these methods have limitations in terms of applica-
bility: such methods can be applied only to specific settings,
like full group annotation or absence of group annotation,
but fail to handle hybrid scenarios, such as partial labeling,
group clustering, and changes in the train scheme.
In this work, we focus on improving worst-group generaliza-
tion in realistic scenarios with any level of group annotation,
devising a group annotation-agnostic method. Our approach,
Just Mix Once (JM1), builds on supervised (Sagawa et al.,
2019) and self-supervised approaches (Liu et al.,2021) to
improve worst-group performance, generalizing and unify-
ing such methods. JM1 leverages an augmentation scheme
based on a class-conditional variant of mixup (Zhang et al.,
2017;Verma et al.,2019;Carratino et al.,2020) to interpo-
late groups and improve worst-group generalization. JM1 is
scalable, does not require undersampling or oversampling,
and is flexible, easily adaptable to a variety of problem con-
arXiv:2210.12195v1 [cs.LG] 21 Oct 2022
Just Mix Once: Worst-group Generalization by Group Interpolation
figurations (full, partial, absence of group annotation) with
minimal modifications.
Contribution. Our contributions are the following:
We propose a simple, general mechanism, JM1, based
on a class-conditional variant of mixup, to improve
generalization on minority groups in the data. JM1
is group annotation agnostic, i.e. it can be employed
with different levels of annotations, generalizing and
unifying proposed methods to improve group general-
ization.
We propose a novel interpretation of why interpolating
majority and minority groups is effective in improving
worst-group generalization, and justify it theoretically
in an ideal case and empirically in realistic scenarios.
We perform extensive experiments with different lev-
els of group annotation and ablation for the mixup
and class-conditional strategy. JM1 is successfully
employed with I) fine-grained annotation, II) coarse
annotation, and III) only a small validation set anno-
tated, demonstrating that our method outperforms or is
on par with the SOTA on vision and language datasets
for worst-group generalization.
Figure 1: Groups Identification Phase.
Left: train set. Right:
train set partitioned. In the first phase, the training dynamics are
exploited to split the data into two partitions. The assumption
is that, among the samples identified as difficult (samples in the
orange circle), there are typically samples from minority groups.
The misclassification rate in the early stage of training is used as a
clustering signal to estimate the group distribution: patterns that are
superficially frequent in the data (e.g., color, texture, background)
are easy to classify and have a small loss, partitioning a majority
group; infrequent patterns (e.g., shape) are challenging to model
and are misclassified during the early stage of training, partitioning
aminority group.
2 Background
Our work tackles the problem of improving worst-group
generalization with different levels of group annotation avail-
able during training. Here we present two main classes of
methods to deal with groups in the data with and without
group annotation on the train set.
Assuming that the training data
D={xi, yi}n
i=1
are parti-
tioned into (non-overlapping) groups
G1,...,Gm
, through
sub-populations represented by tuples
(y, c)
of label
y
and
Figure 2: Class-conditional Mixing Phase.
After the identifi-
cation phase, we exploit label information to augment the train-
ing data, using a MixUp-inspired strategy to augment samples
from different partitions. We use the label information to select
two samples from different partitions, i.e.,
(hi
g, yi
g)
and
(hj
¯g, yj
¯g)
where
yi
g=yj
¯g=y, g 6= ¯g
. Then we create a mixed-up sample
(hmix, ymix)
s.t.
hmix =α hi
g+ (1 α)hj
¯g, ymix =y
. Note
that
h
can either be an input (Zhang et al.,2017) or a learned
representation (Verma et al.,2019). This simple mechanism gives
us a principled and domain-agnostic way to augment samples
marginalizing the group information in the data: by sampling
α
,
we augment the training data and build a continuous distribution
of groups in the data. Therefore, the model cannot rely on fre-
quent patterns because each sample has a "slightly different" group
pattern.
confounding factors
c
, we can define the following per-
group and group-average loss:
J(θ;G) : = 1
|G| X
(xg,yg)∈G
l(xg, yg;θ)
J(θ) : = 1
m
m
X
k=1
J(θ;Gk).
(1)
Notice that in case
m=n
, each group contains a single data
point, and the group-average loss collapses to the standard
ERM loss
JERM
. Note that if we define a uniform per-group
loss weight
pk= 1/m, k = 1, . . . , m
, we can rewrite the
group-average loss as:
J(θ) = X
k
pkJ(θ;Gk) = Ep(G)J(θ;G).(2)
Distributional Robust Optimization (DRO) (Sagawa et al.,
2019) aims to optimize:
JDRO = max
kJ(θ, Gk),(3)
which corresponds to a pointwise Dirac distribution
pk= 1
if
k= arg maxkJ(θ;Gk)
else 0 in Eq. 2. This approach
is different from the standard training routine, where the
goal is optimizing the average error among groups. If group
information on the train set is absent, methods such as Just
Train Twice (JTT) (Liu et al.,2021;Nam et al.,2020) can
be used. JTT is a powerful approach that solves worst-
group generalization in a simple two-stage approach: in
the first phase, the goal is to partition the data into two
clusters: one with majority groups (frequent superficial pat-
terns); and one with minority groups (uncommon patterns).
Giorgio Giannone1, Serhii Havrylov, Jordan Massiah, Emine Yilmaz, Yunlong Jiao1
The assumption is that samples from minority patterns are
difficult to model and frequently misclassified in the early-
stage of training. JTT uses the misclassified samples in
the early stage of training as a signal to partition the data.
Specifically, a supervised learner, parameterized by
θ
, is
trained on the data
D={(xi, yi)}n
i=1
using ERM loss
(up to constant)
JERM(θ) := P(x,y)∈D l(x, y;θ)
with early
stopping. To avoid overfitting, a small validation set with
group annotation is used to select the best identification
epoch
T
to create the appropriate partitions. Then mis-
classified samples (Nam et al.,2020) are saved in a buffer
B:= {(xb, yb)s.t. ˆ
fT(xb)6=yb}.
Once the partitions are identified, the same learner is trained
for the second time on a re-weighted version of the data,
where samples in
B
are oversampled
λ1
times. The
intuition is that, if the examples in the error set
B
come
from challenging groups, such as those where the spuri-
ous correlation does not hold, then up-weighting them will
lead to better worst-group performance (Liu et al.,2021).
Specifically, the loss for phase II) can be written as (up to
constant):
JJTT(θ) := λX
(xb,yb)∈B
l(xb, yb;θ) + X
(x¯
b,y¯
b)/∈B
l(x¯
b, y¯
b;θ),
where
λ
is the up-sampling rate for samples in
B
identified
from phase I). A similar approach was proposed in (Nam
et al.,2020) with the difference that the groups are re-
weighted after each epoch using the loss magnitude in phase
one. This approach is more general, but in practice more
difficult to train and scale.
Limitations.
GroupDRO and JTT are effective methods but
can be used only in specif scenarios: full-group annotation
on the train set for the former, and group annotation on the
validation set for the latter. These methods are challenging
to adapt for hybrid scenarios where group annotation can
be fine-grained for some classes (group annotation on each
sample) and coarse for others (partition-based, cluster-based
annotation). Such "sparse" and incomplete group annotation
is the most common in practice for large, diverse datasets.
To deal with such issues, we propose a simple method to
improve worst-group accuracy with different levels of anno-
tation granularity available.
3 Just Mix Once
We address the limitations of GroupDRO and JTT by propos-
ing Just Mix Once (JM1). Our goal is to improve the classifi-
cation accuracy of minority groups, with or without explicit
annotation on the groups at training time. In the follow-
ing we will use mixing and interpolation interchangeably.
JM1 consists of two phases: phase I) discovers and clus-
ters minority groups in the data, and phase II) uses class-
conditional mixing to improve worst-group performance.
JM1 can be used with and without group annotation on the
train set.
I) (Optional) Groups Identification.
In absence of
group annotation on the train set, similarly to phase I) of JTT,
we assume that we have annotated groups on a small vali-
dation set and resort to using a self-supervised signal (Nam
et al.,2020;Arazo et al.,2019) based on the misclassifi-
cation rate (and the loss magnitude) in the early stage of
training similarly to JTT. Fig. 1illustrates how relevant
samples from minority groups are identified. Note that self-
supervised training of group identification is used only in
absence of group annotation on the train set: if group an-
notations are available on the train set, we always use such
oracle groups to identify the majority and minority groups.
II) Class-conditional Interpolation.
To generalize on
minority groups at test time, a model should not rely on
spurious correlations during training (e.g., texture or color)
but on the signal of interest (e.g., shape). Unlike the over-
sampling approach taken by JTT, we resort to a better aug-
mentation (Fig. 2) to improve generalization for low-density
groups. Our mixing strategy is inspired by MixUp (Verma
et al.,2019;Zhang et al.,2017;Carratino et al.,2020), and
we propose a novel class-conditional variant that achieves
robust worst-group generalization. Specifically, for any two
samples
(hi
g, hj
¯g)
in the input (or representation) space from
the two partitions
g, ¯g
(i.e., difficult/misclassified/minority
group
g
vs the other majority group
¯g
) with the same class
label
yi
g=yj
¯g=y
, we mix them up using a convex interpo-
lation:
hmix =α hi
g+ (1 α)hj
¯g, ymix =y, (4)
where
α
is the mixing parameter (details on how to choose
α
are deferred to Sec. 4). Notably, we empirically show
that naively implementing the standard MixUp without the
proposed two-stage approach fails to generalize in terms of
worst-group performance (Fig. 8).
Finally, JM1 uses a simple ERM loss over mixed data:
JERM
JM1 =Xhi
g,g∈G Xhj
¯g,¯g¯
Gl(hmix, y).
3.1 How JM1 Works
In this subsection we explain why JM1, based on augment-
ing the group distribution by interpolating samples, works.
First, we discuss how our mixing strategy enables JM1 to
build a "continuous" spectrum of groups in the data by
sampling
α
. Using MixUp with
α
mixing rate, we generate
new samples drawn from a continuous mixture (Eq. 4),
which has as limiting case the training distribution. In this
view, we can interpret each mixed-up sample as drawn from
a "mixed-up group"
Gα
parametrized by
α
. This means,
assuming an oracle partitioning in phase I), JM1 mixes a
majority and a minority group in the data at each iteration
while generating a continuous group distribution. We denote
by
J(θ;Gα)
our per-group loss for an interpolated group
Gα
with the mixing rate
α
, and write the group-average loss
摘要:

JustMixOnce:Worst-groupGeneralizationbyGroupInterpolationGiorgioGiannone1SerhiiHavrylovJordanMassiahEmineYilmazYunlongJiao1DTU2AmazonAmazonAmazon,UCLAmazonAbstractAdvancesindeeplearningtheoryhaverevealedhowaveragegeneralizationreliesonsupercialpatternsindata.Theconsequencesarebrittlemodelswithpoorp...

展开>> 收起<<
Just Mix Once Worst-group Generalization by Group Interpolation Giorgio Giannone1Serhii Havrylov Jordan Massiah Emine Yilmaz Yunlong Jiao1 DTU2Amazon Amazon Amazon UCL Amazon.pdf

共14页,预览3页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!

相关推荐

分类:图书资源 价格:10玖币 属性:14 页 大小:2.87MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 14
客服
关注