
Giorgio Giannone1, Serhii Havrylov, Jordan Massiah, Emine Yilmaz, Yunlong Jiao1
The assumption is that samples from minority patterns are
difficult to model and frequently misclassified in the early-
stage of training. JTT uses the misclassified samples in
the early stage of training as a signal to partition the data.
Specifically, a supervised learner, parameterized by
θ
, is
trained on the data
D={(xi, yi)}n
i=1
using ERM loss
(up to constant)
JERM(θ) := P(x,y)∈D l(x, y;θ)
with early
stopping. To avoid overfitting, a small validation set with
group annotation is used to select the best identification
epoch
T
to create the appropriate partitions. Then mis-
classified samples (Nam et al.,2020) are saved in a buffer
B:= {(xb, yb)s.t. ˆ
fT(xb)6=yb}.
Once the partitions are identified, the same learner is trained
for the second time on a re-weighted version of the data,
where samples in
B
are oversampled
λ1
times. The
intuition is that, if the examples in the error set
B
come
from challenging groups, such as those where the spuri-
ous correlation does not hold, then up-weighting them will
lead to better worst-group performance (Liu et al.,2021).
Specifically, the loss for phase II) can be written as (up to
constant):
JJTT(θ) := λX
(xb,yb)∈B
l(xb, yb;θ) + X
(x¯
b,y¯
b)/∈B
l(x¯
b, y¯
b;θ),
where
λ
is the up-sampling rate for samples in
B
identified
from phase I). A similar approach was proposed in (Nam
et al.,2020) with the difference that the groups are re-
weighted after each epoch using the loss magnitude in phase
one. This approach is more general, but in practice more
difficult to train and scale.
Limitations.
GroupDRO and JTT are effective methods but
can be used only in specif scenarios: full-group annotation
on the train set for the former, and group annotation on the
validation set for the latter. These methods are challenging
to adapt for hybrid scenarios where group annotation can
be fine-grained for some classes (group annotation on each
sample) and coarse for others (partition-based, cluster-based
annotation). Such "sparse" and incomplete group annotation
is the most common in practice for large, diverse datasets.
To deal with such issues, we propose a simple method to
improve worst-group accuracy with different levels of anno-
tation granularity available.
3 Just Mix Once
We address the limitations of GroupDRO and JTT by propos-
ing Just Mix Once (JM1). Our goal is to improve the classifi-
cation accuracy of minority groups, with or without explicit
annotation on the groups at training time. In the follow-
ing we will use mixing and interpolation interchangeably.
JM1 consists of two phases: phase I) discovers and clus-
ters minority groups in the data, and phase II) uses class-
conditional mixing to improve worst-group performance.
JM1 can be used with and without group annotation on the
train set.
I) (Optional) Groups Identification.
In absence of
group annotation on the train set, similarly to phase I) of JTT,
we assume that we have annotated groups on a small vali-
dation set and resort to using a self-supervised signal (Nam
et al.,2020;Arazo et al.,2019) based on the misclassifi-
cation rate (and the loss magnitude) in the early stage of
training similarly to JTT. Fig. 1illustrates how relevant
samples from minority groups are identified. Note that self-
supervised training of group identification is used only in
absence of group annotation on the train set: if group an-
notations are available on the train set, we always use such
oracle groups to identify the majority and minority groups.
II) Class-conditional Interpolation.
To generalize on
minority groups at test time, a model should not rely on
spurious correlations during training (e.g., texture or color)
but on the signal of interest (e.g., shape). Unlike the over-
sampling approach taken by JTT, we resort to a better aug-
mentation (Fig. 2) to improve generalization for low-density
groups. Our mixing strategy is inspired by MixUp (Verma
et al.,2019;Zhang et al.,2017;Carratino et al.,2020), and
we propose a novel class-conditional variant that achieves
robust worst-group generalization. Specifically, for any two
samples
(hi
g, hj
¯g)
in the input (or representation) space from
the two partitions
g, ¯g
(i.e., difficult/misclassified/minority
group
g
vs the other majority group
¯g
) with the same class
label
yi
g=yj
¯g=y
, we mix them up using a convex interpo-
lation:
hmix =α hi
g+ (1 −α)hj
¯g, ymix =y, (4)
where
α
is the mixing parameter (details on how to choose
α
are deferred to Sec. 4). Notably, we empirically show
that naively implementing the standard MixUp without the
proposed two-stage approach fails to generalize in terms of
worst-group performance (Fig. 8).
Finally, JM1 uses a simple ERM loss over mixed data:
JERM
JM1 =Xhi
g,g∈G Xhj
¯g,¯g∈¯
Gl(hmix, y).
3.1 How JM1 Works
In this subsection we explain why JM1, based on augment-
ing the group distribution by interpolating samples, works.
First, we discuss how our mixing strategy enables JM1 to
build a "continuous" spectrum of groups in the data by
sampling
α
. Using MixUp with
α
mixing rate, we generate
new samples drawn from a continuous mixture (Eq. 4),
which has as limiting case the training distribution. In this
view, we can interpret each mixed-up sample as drawn from
a "mixed-up group"
Gα
parametrized by
α
. This means,
assuming an oracle partitioning in phase I), JM1 mixes a
majority and a minority group in the data at each iteration
while generating a continuous group distribution. We denote
by
J(θ;Gα)
our per-group loss for an interpolated group
Gα
with the mixing rate
α
, and write the group-average loss