Adaptively Weighted Data Augmentation Consistency Regularization for Robust Optimization under Concept Shift Yijun Dong1 Yuege Xie2 and Rachel Ward1

2025-04-27 0 0 9.42MB 29 页 10玖币
侵权投诉
Adaptively Weighted Data Augmentation Consistency
Regularization for Robust Optimization under Concept Shift
Yijun Dong*1, Yuege Xie2, and Rachel Ward1
1University of Texas at Austin
2Snap Inc.
February 1, 2023
Abstract
Concept shift is a prevailing problem in natural tasks like medical image seg-
mentation where samples usually come from different subpopulations with variant
correlations between features and labels. One common type of concept shift in med-
ical image segmentation is the “information imbalance” between label-sparse sam-
ples with few (if any) segmentation labels and label-dense samples with plentiful la-
beled pixels. Existing distributionally robust algorithms have focused on adaptively
truncating/down-weighting the “less informative” (i.e., label-sparse in our context)
samples. To exploit data features of label-sparse samples more efficiently, we pro-
pose an adaptively weighted online optimization algorithm — AdaWAC— to in-
corporate data augmentation consistency regularization in sample reweighting. Our
method introduces a set of trainable weights to balance the supervised loss and unsu-
pervised consistency regularization of each sample separately. At the saddle point of
the underlying objective, the weights assign label-dense samples to the supervised
loss and label-sparse samples to the unsupervised consistency regularization. We
provide a convergence guarantee by recasting the optimization as online mirror de-
scent on a saddle point problem. Our empirical results demonstrate that AdaWAC
not only enhances the segmentation performance and sample efficiency but also im-
proves the robustness to concept shift on various medical image segmentation tasks
with different UNet-style backbones.
1 Introduction
Modern machine learning is revolutionizing the field of medical imaging, especially
in computer-aided diagnosis with computed tomography (CT) and magnetic resonance
imaging (MRI) scans. However, classical learning objectives like empirical risk mini-
mization (ERM) generally assume that training samples are independently and identically
*Equal contribution. Correspondence to: ydong@utexas.edu
Work done at University of Texas at Austin.
1
arXiv:2210.01891v2 [cs.CV] 31 Jan 2023
(i.i.d.) distributed, whereas real-world medical image data rarely satisfy this assumption.
Figure 1.1 instantiates a common observation in medical image segmentation where the
segmentation labels corresponding to different cross-sections of the human body tend to
have distinct proportions of labeled (i.e., non-background) pixels, which is accurately re-
flected by the evaluation of supervised cross-entropy loss during training. We refer to this
as the “information imbalance” among samples, as opposed to the well-studied “class
imbalance” [Taghanaki et al.,2019b,Wong et al.,2018,Yeung et al.,2022] among the
numbers of segmentation labels in different classes. Such information imbalance induces
distinct difficulty/paces of learning with the cross-entropy loss for different samples [Ha-
cohen and Weinshall,2019,Tang et al.,2018,Tullis and Benjamin,2011,Wang et al.,
2021b]. Specifically, we say a sample is label-sparse when it contains very few (if any)
segmentation labels; in contrast, a sample is label-dense when its segmentation labels
are prolific. Motivated by the information imbalance among samples, we explore the
following questions:
What is the effect of separation between sparse and dense labels on segmentation?
Can we leverage such information imbalance to improve the segmentation accuracy?
We formulate the mixture of label-sparse and label-dense samples as a concept shift
— a type of distribution shift in the conditional distribution of labels given features
P(y|x). Coping with concept shifts, prior works have focused on adaptively truncat-
ing (hard-thresholding) the empirical loss associated with label-sparse samples. These
include the Trimmed Loss Estimator [Shen and Sanghavi,2019], MKL-SGD [Shah et al.,
2020], Ordered SGD [Kawaguchi and Lu,2020], and the quantile-based Kacmarz algo-
rithm [Haddock et al.,2022]. Alternatively, another line of works [Sagawa et al.,2020,
Wang et al.,2018] proposes to relax the hard-thresholding operation to soft-thresholding
by down-weighting instead of truncating the less informative samples. However, dimin-
ishing sample weights reduces the importance of both the features and the labels simulta-
neously, which is still not ideal as the potentially valuable information in the features of
the label-sparse samples may not be fully used.
For further exploitation of the feature of training samples, we propose the incor-
poration of data augmentation consistency regularization on label-sparse samples. As
a prevalent strategy for utilizing unlabeled data, consistency regularization [Bachman
et al.,2014,Laine and Aila,2016,Sohn et al.,2020] encourages data augmentations of
the same samples to lie in the vicinity of each other on a proper manifold. For medi-
cal imaging segmentation, consistency regularization has been extensively studied in the
semi-supervised learning setting [Basak et al.,2022,Bortsova et al.,2019,Li et al.,2020,
Wang et al.,2021a,Zhang et al.,2021,Zhao et al.,2019,Zhou et al.,2021] as a strategy
for overcoming label scarcity. Nevertheless, unlike general vision tasks, for medical im-
age segmentation, the scantiness of unlabeled image data can also be a problem due to
regulations and privacy considerations Karimi et al. [2020], which makes it worthwhile
to reminisce the more classical supervised learning setting. In contrast to the aforemen-
tioned semi-supervised strategies, we explore the potency of consistency regularization
in the supervised learning setting by leveraging the information in the features of label-
sparse samples via data augmentation consistency regularization.
To naturally distinguish the label-sparse and label-dense samples, we make a key ob-
servation that the unsupervised consistency regularization on encoder layer outputs (of
2
Figure 1.1: Evolution of cross-entropy losses versus consistency regularization terms for
slices at different cross-sections of the human body in the Synapse dataset (described in
Section 5) during training.
a UNet-style architecture) is much more uniform across different subpopulations than
the supervised cross-entropy loss (as exemplified in Figure 1.1). Since the consistency
regularization is characterized by the marginal distribution of features P(x)but not la-
bels and therefore is less affected by the concept shift in P(y|x), it serves as a natural
reference for separating the label-sparse and label-dense samples. In light of this observa-
tion, we present the weighted data augmentation consistency (WAC) regularization — a
minimax formulation that reweights the cross-entropy loss versus the consistency regular-
ization associated with each sample via a set of trainable weights. At the saddle point of
this minimax formulation, the WAC regularization automatically separates samples from
different subpopulations by assigning all weights to the consistency regularization for
label-sparse samples, and all weights to the cross-entropy terms for label-dense samples.
We further introduce an adaptively weighted online optimization algorithm, AdaWAC,
for solving the minimax problem posed by the WAC regularization, which is inspired by
a mirror-descent-based algorithm for distributionally robust optimization [Sagawa et al.,
2020]. By adaptively learning the weights between the cross-entropy loss and consistency
regularization of different samples, AdaWAC comes with both a convergence guarantee
and empirical success.
The main contributions are summarized as follows:
We introduce the WAC regularization that leverages the consistency regularization on
the encoder layer outputs (of a UNet-style architecture) as a natural reference to distin-
guish the label-sparse and label-dense samples (Section 3).
We propose an adaptively weighted online optimization algorithm — AdaWAC— for
solving the WAC regularization problem with a convergence guarantee (Section 4).
Through extensive experiments on different medical image segmentation tasks with dif-
ferent UNet-style backbone architectures, we demonstrate the effectiveness of AdaWAC
not only for enhancing the segmentation performance and sample efficiency but also for
improving the robustness to concept shift (Section 5).
3
1.1 Related Work
Sample reweighting. Sample reweighting is a popular strategy for dealing with dis-
tribution/subpopulation shifts in training data where different weights are assigned to
samples from different subpopulations. In particular, the distributionally-robust opti-
mization (DRO) framework [Ben-Tal et al.,2013,Duchi and Namkoong,2018,Duchi
et al.,2016,Sagawa et al.,2020] considers a collection of training sample groups from
different distributions. With the explicit grouping of samples, the goal is to minimize
the worst-case loss over the groups. Without prior knowledge of sample grouping, im-
portance sampling [Alain et al.,2015,Gopal,2016,Katharopoulos and Fleuret,2018,
Loshchilov and Hutter,2015,Needell et al.,2014,Zhao and Zhang,2015], iterative trim-
ming [Kawaguchi and Lu,2020,Shen and Sanghavi,2019], and empirical-loss-based
reweighting [Wu et al.,2022] are commonly incorporated in the stochastic optimization
process for adaptive reweighting and separation of samples from different subpopulations.
Data augmentation consistency regularization. As a popular way of exploiting data
augmentations, consistency regularization encourages models to learn the vicinity among
augmentations of the same sample based on the assumption that data augmentations gen-
erally preserve the semantic information in data and therefore lie closely on proper mani-
folds. Beyond being a powerful building block in semi-supervised [Bachman et al.,2014,
Berthelot et al.,2019,Laine and Aila,2016,Sajjadi et al.,2016,Sohn et al.,2020] and
self-supervised [Chen et al.,2020,Grill et al.,2020,He et al.,2020,Wu et al.,2018] learn-
ing, the incorporation of data augmentation and consistency regularization also provably
improves generalization and feature learning even in the supervised learning setting [Shen
et al.,2022,Yang et al.,2022].
For medical imaging, data augmentation consistency regularization is generally lever-
aged as a semi-supervised learning tool [Basak et al.,2022,Bortsova et al.,2019,Li et al.,
2020,Wang et al.,2021a,Zhang et al.,2021,Zhao et al.,2019,Zhou et al.,2021]. In ef-
forts to incorporate consistency regularization in segmentation tasks with augmentation-
sensitive labels, Li et al. [2020] encourages transformation consistency between predic-
tions with augmentations applied to the image inputs and the segmentation outputs. Basak
et al. [2022] penalizes inconsistent segmentation outputs between teacher-student mod-
els, with MixUp [Zhang et al.,2017] applied to image inputs of the teacher model and
segmentation outputs of the student model. Instead of enforcing consistency in the seg-
mentation output space as above, we leverage the insensitivity of sparse labels to aug-
mentations and encourage consistent encodings (in the latent space of encoder outputs)
on label-sparse samples.
2 Problem Setup
Notation. For any KN, we denote [K] = {1, . . . , K}. We represent the elements
and subtensors of an arbitrary tensor by adapting the syntax for Python slicing on the
subscript (except counting from 1). For example, x[i,j]denotes the (i, j)-entry of the two-
dimensional tensor x, and x[i,:] denotes the i-th row. Let Ibe a function onto {0,1}such
that, for any event e,I{e}= 1 if eis true and 0otherwise. For any distribution Pand
nN, we let Pndenote the joint distribution of nsamples drawn i.i.d. from P. Finally,
4
we say that an event happens with high probability (w.h.p.) if the event takes place with
probability 1Ω (poly (n))1.
2.1 Pixel-wise Classification with Sparse and Dense Labels
We consider medical image segmentation as a pixel-wise multi-class classification prob-
lem where we aim to learn a pixel-wise classifier h:X [K]dthat serves as a good
approximation to the ground truth h:X [K]d.
Recall the separation of cross-entropy losses between samples with different propor-
tions of background pixels from Figure 1.1. We refer to a sample (x,y) X × [K]das
label-sparse if most pixels in yare labeled as background; for these samples, the cross-
entropy loss on (x,y)converges rapidly in the early stage of training. Otherwise, we say
that (x,y)is label-dense. Formally, we describe such variation as a concept shift in the
data distribution.
Definition 1 (Mixture of label-sparse and label-dense subpopulations).We assume that
label-sparse and label-dense samples are drawn from P0and P1with distinct conditional
distributions P0(y|x)and P1(y|x)but common marginal distribution P(x)such that
Pi(x,y) = Pi(y|x)P(x)(i= 0,1). For ξ[0,1], we define Pξas a data distribution
where (x,y)Pξis drawn either from P1with probability ξor from P0with probability
1ξ.
We aim to learn a pixel-wise classifier from a function class Hwhere every hθ∈ H
satisfies hθ(x)[j]= argmaxk[K]fθ(x)[j,:] for all j[d], and the underlying function
fθ∈ F, parameterized by some θ∈ Fθ, admits an encoder-decoder structure:
F fθ=φθψθφθ:X → Z, ψθ:Z [0,1]d×K.(2.1)
Here φθ, ψθcorrespond to the encoder and decoder functions, respectively. The parameter
space Fθis equipped with the norm k·kFand its dual norm k·kF,
1.(Z, %)is a latent
metric space.
To learn from segmentation labels, we consider the averaged cross-entropy loss:
`CE (θ; (x,y)) = 1
d
d
X
j=1
K
X
k=1
Iy[j]=k·log fθ(x)[j,k]
=1
d
d
X
j=1
log fθ(x)[j,y[j]].
(2.2)
We assume proper learning with θTξ[0,1] argminθ∈FθE(x,y)Pξ[`CE (θ; (x,y))] being
invariant with respect to ξ.2
1For AdaWAC (Proposition 2in Section 4), Fθis simply a subspace in the Euclidean space with di-
mension equal to the total number of parameters for each θ∈ Fθ, with k·kFand k·kF,both being the
`2-norm.
2We assume proper learning only to (i) highlight the invariance of the desired ground truth to ξthat can
be challenging to learn with finite samples in practice and (ii) provide a natural pivot for the convex and
compact neighborhood Fθ(γ)of ground truth θin Assumption 1granted by the pretrained initialization,
where θcan also be replaced with the pretrained initialization weights θ0∈ Fθ(γ). In particular, neither
our theory nor the AdaWAC algorithm requires the function class Fto be expressive enough to truly contain
such θ.
5
摘要:

AdaptivelyWeightedDataAugmentationConsistencyRegularizationforRobustOptimizationunderConceptShiftYijunDong*1,YuegeXie†2,andRachelWard11UniversityofTexasatAustin2SnapInc.February1,2023AbstractConceptshiftisaprevailingprobleminnaturaltaskslikemedicalimageseg-mentationwheresamplesusuallycomefromdiffer...

展开>> 收起<<
Adaptively Weighted Data Augmentation Consistency Regularization for Robust Optimization under Concept Shift Yijun Dong1 Yuege Xie2 and Rachel Ward1.pdf

共29页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:29 页 大小:9.42MB 格式:PDF 时间:2025-04-27

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 29
客服
关注