Training Debiased Subnetworks with Contrastive Weight Pruning Geon Yeong Park1Sangmin Lee2Sang Wan Lee1Jong Chul Ye123 1Bio and Brain Engineering2Mathematical Sciences3Kim Jaechul Graduate School of AI

2025-05-06 0 0 1.85MB 21 页 10玖币
侵权投诉
Training Debiased Subnetworks with Contrastive Weight Pruning
Geon Yeong Park1Sangmin Lee2Sang Wan Lee1Jong Chul Ye1,2,3
1Bio and Brain Engineering, 2Mathematical Sciences, 3Kim Jaechul Graduate School of AI
Korea Advanced Institute of Science and Technology (KAIST), Daejeon, Korea
{pky3436, leeleesang, sangwan, jong.ye}@kaist.ac.kr
Abstract
Neural networks are often biased to spuriously corre-
lated features that provide misleading statistical evidence
that does not generalize. This raises an interesting ques-
tion: “Does an optimal unbiased functional subnetwork ex-
ist in a severely biased network? If so, how to extract such
subnetwork?” While empirical evidence has been accumu-
lated about the existence of such unbiased subnetworks,
these observations are mainly based on the guidance of
ground-truth unbiased samples. Thus, it is unexplored how
to discover the optimal subnetworks with biased training
datasets in practice. To address this, here we first present
our theoretical insight that alerts potential limitations of
existing algorithms in exploring unbiased subnetworks in
the presence of strong spurious correlations. We then fur-
ther elucidate the importance of bias-conflicting samples
on structure learning. Motivated by these observations, we
propose a Debiased Contrastive Weight Pruning (DCWP)
algorithm, which probes unbiased subnetworks without ex-
pensive group annotations. Experimental results demon-
strate that our approach significantly outperforms state-of-
the-art debiasing methods despite its considerable reduc-
tion in the number of parameters.
1. Introduction
While deep neural networks have made substantial
progress in solving challenging tasks, they often undesir-
ably rely on spuriously correlated features or dataset bias,
if present, which is considered one of the major hurdles in
deploying models in real-world applications. For example,
consider recognizing desert foxes and cats from natural im-
ages. If the background scene (e.g., a desert) is spuriously
correlated to the type of animal, the neural networks might
use the background information as a shortcut to classifica-
tion, resulting in performance degradation in different back-
grounds (e.g., a desert fox in the house).
To investigate the origin of the spurious correlations, this
paper considers shortcut learning as a fundamental architec-
tural design issue of neural networks. Specifically, if any
available information channels in deep networks’ structure
could transmit the information of spuriously correlated fea-
tures (spurious features from now on), networks would ex-
ploit those features as long as they are sufficiently predic-
tive. It naturally follows that pruning weights on spurious
features can purify the biased latent representations, thereby
improving performances on bias-conflicting samples1. We
conjecture that this neural pruning may improve the gener-
alization of the network in a way that reduces the effective
dimension of spurious features, considering that the failure
of Out-of-Distribution (OOD) generalization may arise due
to high-dimensional spurious features [26,34].
Recently, Zhang et al. [37] has empirically demonstrated
the existence of subnetworks that are less susceptible to
spurious features. Based on the modular property of neu-
ral networks [5], they prune out weights that are closely
related to the spurious attributes. While [37] affords us
valuable insights on the importance of neural architectures,
the study has limitation in that such neural pruning requires
sufficient number of ground-truth bias-conflicting samples.
Thus, how to discover the optimal subnetworks in practice
when the dataset is highly biased?
To address this, we first present a simple theoretical ob-
servation that reveals the limitations of existing substruc-
ture probing methods in searching unbiased subnetworks.
Specifically, we reveal that there exists an unavoidable gen-
eralization gap in the subnetworks obtained by standard
pruning algorithms in the presence of strong spurious corre-
lations. Our analysis also shows that trained models may in-
evitably rely on the spuriously correlated features in a prac-
tical training setting with finite training time and a number
of samples.
In addition, we show that sampling more bias-conflicting
data makes it possible to identify spurious weights. Specifi-
cally, bias-conflicting samples require that the weights as-
1The bias-aligned samples refer to data with a strong correlation be-
tween (potentially latent) spurious features and target labels (e.g., cat in
the house). The bias-conflicting samples refer to the opposite cases where
spurious correlations do not exist (e.g., cat in the desert).
arXiv:2210.05247v3 [cs.LG] 26 Jun 2023
Figure 1. Concept: We demonstrate an inevitable generalization gap of subnetworks obtained by standard pruning methods including [37].
Based on these observations, we design a novel subnetwork probing framework by fully exploiting unbiased samples.
sociated with spurious features should be pruned out as
the spurious features do not help predict bias-conflicting
samples. Our theoretical observations suggest that balanc-
ing the ratio between the number of bias-aligned and bias-
conflicting samples is crucial in finding the optimal unbi-
ased subnetworks.
In practice, the dataset may severely lack diversity for
bias-conflicting samples due to the potential pitfalls in data
collection protocols or human prejudice. Since it is of-
ten highly laborious to supplement enough bias-conflicting
samples, we propose a novel debiasing scheme called De-
biased Contrastive Weight Pruning (DCWP) that uses the
oversampled bias-conflicting data to search unbiased sub-
networks.
As shown in Fig. 1, DCWP is comprised of two stages:
(1) identifying the bias-conflicting samples without expen-
sive annotations on spuriously correlated attributes, and (2)
training the pruning parameters to obtain weight pruning
masks with the sparsity constraint and debiased loss func-
tion. Here, the debiased loss includes a weighted cross-
entropy loss for the identified bias-conflicting samples and
an alignment loss to further reduce the geometrical align-
ment gap between bias-aligned and bias-conflicting samples
within each class.
We demonstrate that DCWP consistently outperforms
state-of-the-art debiasing methods across various biased
datasets, including the Color-MNIST [23,27], Corrupted
CIFAR-10 [13], Biased FFHQ [21] and CelebA [25], even
without direct supervision on the bias type. Our approach
improves the accuracy on the unbiased evaluation dataset
by 86.74% 93.41%,27.86% 35.90% on Colored-
MNIST and Corrupted CIFAR-10 compared to the second
best model, respectively, even when 99.5% of samples are
bias-aligned.
2. Related works
Spurious correlations. A series of empirical works have
shown that the deep networks often find shortcut solutions
relying on spuriously correlated attributes, such as the tex-
ture of image [10], language biases [12], or sensitive vari-
ables such as ethnicity or gender [7,28]. Such behavior is
of practical concern because it deteriorates the reliability of
deep networks in sensitive applications like healthcare, fi-
nance, and legal services [4].
Debiasing frameworks. Recent studies to train a debi-
ased network robust to spurious correlations can be roughly
categorized into approaches (1) leveraging annotations of
spurious attributes, i.e., bias label [29,36], (2) presuming
specific type of bias, e.g., texture [1,9] or (3) without using
explicit kinds of supervisions on dataset bias [22,27]. The
authors in [15,29] optimize the worst-group error by using
training group information. For practical implementation,
reweighting or subsampling protocols are often used with
increased model regularization [30]. Liu et al.; Sohoni et
al. [24,32] extend these approaches to the settings without
expensive group annotations. Goel et al.; Kim et al. [11,21]
provide bias-tailored augmentations to balance the major-
ity and minority groups. In particular, these approaches
have mainly focused on better approximation and regular-
ization of worst-group error combined with advanced data
sampling, augmentation, or retraining strategies.
Studying impacts of neural architectures. Recently,
the effects of deep neural network architecture on gener-
alization performance have been explored. Diffenderfer et
al. [6] employ recently advanced lottery-ticket-style prun-
ing algorithms [8] to design the compact and robust network
architecture. Bai et al. [2] directly optimize the neural ar-
chitecture in terms of accuracy on OOD samples. Zhang et
al. [37] demonstrate the effectiveness of pruning weights on
spurious attributes, but the solution for discriminating such
spurious weights lacks robust theoretical justifications, re-
sulting in marginal performance gains. To fully resolve the
above issues, we carry out a theoretical case study, and build
a novel pruning algorithm that distills the representations to
be independent of the spurious attributes.
3. Theoretical insights
3.1. Problem setup
Consider a supervised setting of predicting labels Y∈ Y
from input samples X∈ X by a classifier fθ:X → Y
parameterized by θΘ. Following [37], let (Xe, Y e)
Pe, where Xe∈ X and Ye∈ Y refer to the input random
variable and the corresponding label, respectively, and e
E={1,2,...E}denotes the index of environment, Peis
the corresponding distribution, and the set Ecorresponds to
every possible environments. We further assume that Eis
divided into training environmments Etrain and unseen test
environments Etest, i.e. E=Etrain ∪ Etest.
For a given a loss function :X × Y × ΘR+, the
standard training protocol for the empirical risk minimiza-
tion (ERM) is to minimize the expected loss with a training
environment e∈ Etrain:
ˆ
θERM = arg min
θ
E(Xe,Y e)ˆ
Pe(Xe, Y e;θ),(1)
where ˆ
Peis the empirical distribution over the training data.
Our goal is to learn a model with good performance on
OOD samples of e∈ Etest.
3.2. Motivating example
We conjecture that neural networks trained by ERM in-
discriminately rely on predictive features, including those
spuriously correlated ones [34].
To verify this conjecture, we present a simple binary-
classification example (Xe, Y e)Pe, where Ye∈ Y =
{−1,1}represents the corresponding target label, and a
sample Xe∈ X ={−1,1}D+1 RD+1 is constituted
with both the invariant feature Ze
inv ∈ {−1,1}and spu-
rious features Ze
sp ∈ {−1,1}D, i.e. Xe= (Ze
inv,Ze
sp).
Suppose, furthermore, Ze
sp,i denote the i-th spurious feature
component of Ze
sp. Note that we assume D1to simulate
the model heavily relies on spurious features Ze
sp [26,37].
We consider the setting where the training environment
e∈ Etrain is highly biased. In other words, we suppose that
Ze
inv =Ye, and each of the i-th spurious feature compo-
nent Ze
sp,i is independent and identically distributed (i.i.d)
Bernoulli variable: i.e. Ze
sp,i independently takes a value
equal to Yewith a probability peand Yewith a prob-
ability 1pe, where pe(0.5,1],e∈ Etrain. Note
that pe1as the environment is severely biased. A
test environment e∈ Etest is assumed to have pe= 0.5,
which implies that the spurious feature is totally indepen-
dent with Ye. Then we introduce a linear classifier fpa-
rameterized by a weight vector w= (winv,wsp)RD+1,
where winv Rand wsp RD. In this example, we
consider a class of pretrained classifiers parameterized by
˜
w(t) = ˜winv(t),˜wsp,1(t),..., ˜wsp,D (t), where t<Tis
a finite pretraining time for some sufficiently large T. Time
twill be often omitted in notations for simplicity.
Our goal is to obtain the optimal sparse classifier with a
highly biased training dataset. To achieve this, we introduce
a binary weight pruning mask mas m= (minv,msp)
{0,1}D+1 for the pretrained weights, which is a signifi-
cant departure from the theoretical setting in [37]. Specif-
ically, let minv Bern(πinv), where πinv and 1πinv
represents the probability of preserving (i.e. minv = 1)
and pruning out (i.e. minv = 0), respectively. Simi-
larly, let msp,i Bern(πsp,i),i. Then, our optimiza-
tion goal is to estimate the pruning probability parameter
π= (π1, . . . , πD+1)=(πinv, πsp,1, . . . , πsp,D), where
mP(π)is a mask sampled with probability parameters
π. Accordingly, our main loss function for the pruning pa-
rameters given the environment ecan be defined as follows:
e(π) = 1
2EXe,Y e,m[1 Yeˆ
Ye]
=1
2EXe,Y e,mh1Ye·sgn ˜wT(Xem)i,
(2)
where ˆ
Yeis the prediction of binary classifier, ˜wis the pre-
trained weight vector, sgn(·)represents the sign function,
and represents element-wise product.
We first derive the upper-bound of the training loss e(π)
to illustrate the difficulty of learning optimal pruning pa-
rameters in a biased data setting. The proof can be found in
Supplementary Material.
Theorem 1. (Training and test bound) Assume that pe>
1/2in the biased training environment e∈ Etrain. Define
˜
w(t)as weights pretrained for a finite time t<T. Then
the upper bound of the error of training environment w.r.t.
pruning parameters πis given as:
e(π)2 exp 2πinv + (2pe1) PD
i=1 αi(t)πsp,i2
4PD
i=1 αi(t)2+ 1 ,
(3)
where the weight ratio αi(t) = ˜wsp,i(t)/˜winv (t)is
bounded below some positive constant. Given a test envi-
ronment e∈ Etest with pe=1
2, the upper bound of the
error of test environment w.r.t. πis given as:
e(π)2 exp 2π2
inv
4PD
i=1 αi(t)2+ 1,(4)
which implies that there is an unavoidable gap between
training bound and test bound.
The detailed proof of Theorem 1is provided in the sup-
plementary material. This mismatch of the bounds is at-
tributed to the contribution of πsp,i on the training bound
(3). Intuitively, the networks prefer to preserve both ˜winv
and ˜wsp,i in the presence of strong spurious correlations due
to the inherent sensitivity of ERM to all kinds of predictive
features [17,34]. This behavior is directly reflected in the
training bound, where increasing either πinv or πsp,i, i.e.,
the probability of preserving weights, decreases the train-
ing bound. This inertia of spurious weights may prevent
themselves from being primarily pruned against the spar-
sity constraint.
We note that the unintended reliance on spurious features
is fundamentally rooted to the positivity of the weight ratio
αi(t). In the proof of Theorem 1in Supplementary Ma-
terial, we show some intriguing properties of αi(t): (1) If
infinitely many data and sufficient training time is provided,
the gradient flow converges to the optimal solution which is
invariant to Ze
sp, i.e., αi(t)0. In this ideal situation,
the gap between training and test bound is closed, thereby
guaranteeing generalizations of obtained subnetworks. (2)
However, given a finite time t<Twith a strongly biased
dataset in practice, αi(t)is bounded below by some positive
constant, resulting in an inevitable generalization gap.
Theorem 1implies that the classifier may preserve spu-
rious weights due to the lack of bias-conflicting sam-
ples, which serve as counterexamples that spurious features
themselves fail to explain. It motivates us to analyze the
training bound in another environment ηwhere we can sys-
tematically augment bias-conflicting samples. Specifically,
consider Xη= (Zη
inv,Zη
sp), where Zη
inv =Yηand mix-
ture distribution of Zη
sp given Yη=yis defined in an ele-
ment wise as follows:
Pη
mix(Zη
sp,i |Yη=y) =ϕP η
debias(Zη
sp,i |Yη=y)+
(1 ϕ)Pη
bias(Zη
sp,i |Yη=y),
(5)
where ϕis a scalar mixture weight,
Pη
debias(Zη
sp,i |Yη=y) = (1,if Zη
sp,i =y
0,if Zη
sp,i =y(6)
is a debiasing distribution to weaken the correlation be-
tween Yηand Zη
sp,i by setting the value of Zη
sp,i as Yη,
and
Pη
bias(Zη
sp,i |Yη=y) = (pη,if Zη
sp,i =y
1pη,if Zη
sp,i =y(7)
is a biased distribution similarly defined in the previous en-
vironment e∈ Etrain. Given this new environment η, the
degree of spurious correlations can be controlled by ϕ. This
leads to a training bound as follow:
Theorem 2. (Training bound with the mixture distribution)
Assume that the defined mixture distribution Pη
mix is biased,
i.e., for all i∈ {1, . . . , D},
Pη
mix(Zη
sp,i =y|Ye=y)Pη
mix(Zη
sp,i =y|Yη=y).
(8)
Then, ϕsatisfies 0ϕ11
2pη. Then the upper bound
of the error of training environment ηw.r.t. the pruning
parameters is given by
η(π)
2 exp 2(πinv + (2pη(1 ϕ)1) PD
i=1 αi(t)πsp,i)2
4PD
i=1 αi(t)2+ 1 !.
(9)
Furthermore, when ϕ= 1 1
2pη, the mixture distribution is
perfectly debiased, and we have
η(π)2 exp 2π2
inv
4PD
i=1 αi(t)2+ 1,(10)
which is equivalent to the test bound in (4).
The detailed proof is provided in the supplementary ma-
terial. Our new training bound (31) suggests that the sig-
nificance of πsp,i on training bound decreases as ϕprogres-
sively increases, and at the extreme end with ϕ= 1 1
2pη,
it can be easily shown that Pη
mix(Zη
sp,i |Yη=y) = 1
2for
both y= 1 and y=1so that Zη
sp,i turns out to be ran-
dom. In other words, by plugging ϕ= 11
2pηinto (31), we
can minimize the gap between training and test error bound,
which guarantees the improved OOD generalization.
4. Debiased Contrastive Weight Pruning
Our theoretical observations elucidate the importance
of balancing between the bias-aligned and bias-conflicting
samples in discovering the optimal unbiased subnetworks
structure. While the true analytical form of the debiasing
distribution is unknown in practice, we aim to approximate
such unknown distribution with existing bias-conflicting
samples and simulate the mixture distribution Pη
mix with
modifying sampling strategy. To this end, we propose a De-
biased Contrastive Weight Pruning (DCWP) algorithms that
learn the unbiased subnetworks structure from the original
full-size network.
Consider a Llayer neural networks as a function
fW:X RCparameterized by weights W=
{W1,...,WL}, where C=|Y| is the number of classes.
Analogous to the earlier works on pruning, we introduce bi-
nary weight pruning masks m={m1,...,mL}to model
the subnetworks as f(·;m1W1,...,mLWL). We de-
note such subnetworks as fmWfor the notational simplic-
ity. We treat each entry of mlas an independent Bernoulli
variable, and model their logits as our new pruning param-
eters Θ={Θ1,...,ΘL}where ΘlRnland nlrepre-
sents the dimensionality of the l-th layer weights Wl. Then
πl,i =σl,i)denotes the probability of preserving the i-
th weight of l-th layer Wl,i where σrefers to a sigmoid
function. To enable the end-to-end training, the Gumbel-
softmax trick [18] for sampling masks together with 1reg-
ularization term of Θis adopted as a sparsity constraint.
With a slight abuse of notations, mG(Θ)denotes a
set of masks sampled with logits Θby applying Gumbel-
softmax trick.
Then our main optimization problem is defined as fol-
lows:
min
Θdebias{(xi, yi)}|S|
i=1;˜
W,Θ+λ1X
l,i
|Θl,i|,(11)
where Sdenotes the index set of whole training sam-
ples, λ1>0is a Lagrangian multiplier, ˜
Wrepresents
the pretrained weights and debias is our main objective
which will be illustrated later. Note that we freeze the
pretrained weights ˜
Wduring training pruning parameters
Θ. We interchangeably use debias{(xi, yi)}|S|
i=1; Θand
debiasS; Θin the rest of the paper. For comparison with
our formulation, we recast the optimization problem of [37]
with our notations as follows:
min
Θ{(xi, yi)}|S|
i=1;˜
W,Θ+λ1X
l,i
|Θl,i|,(12)
where [37] uses the cross entropy (CE) loss function for .
Bias-conflicting sample mining In the first stage, we
identify bias-conflicting training samples which empower
functional modular probing. Specifically, we train a bias-
capturing model and treat an error set Sbc of the index
of misclassified training samples as bias-conflicting sample
proxies. Our framework is broadly compatible with vari-
ous bias-capturing models, where we mainly leverage the
ERM model trained with generalized cross entropy (GCE)
loss [39]:
GCE (xi, yi;WB) = 1pyi(xi;WB)q
q,(13)
where q(0,1] is a hyperparameter controlling the de-
gree of bias amplification, WBis the parameters of the
bias-capturing model, and pyi(xi;WB)is a softmax out-
put value of the bias-capturing model assigned to the target
label yi. Compared to the CE loss, the gradient of the GCE
loss up-weights the samples with a high probability of pre-
dicting the correct target, amplifying the network bias by
putting more emphasis on easy-to-predict samples [27].
To preclude the possibility that the generalization per-
formance of DCWP is highly dependent on the behavior
of the bias-capturing model, we demonstrate in Section 5
that DCWP is reasonably robust to the degradation of ac-
curacy on capturing bias-conflicting samples. Details about
the bias-capturing model and simulation settings are pre-
sented in the supplementary material.
Upweighting Bias-conflicting samples After mining
the index set of bias-conflicting sample proxies Sbc, we treat
Sba =S\Sbc as the index set of majority bias-aligned sam-
ples. Then we calculate the weighted cross entropy (WCE)
loss W C E {xi, yi}|S|
i=1;˜
W,Θas follows:
W C E S;˜
W,Θ:= EmG(Θ)λupbc(Sbc;m,˜
W)+
ba(Sba;m,˜
W),
(14)
where λup 1is an upweighting hyperparameter, and
bc(Sbc;m,˜
W) = 1
|Sbc|X
iSbc
CE (xi, yi;m˜
W),(15)
where CE denotes the cross entropy loss. ba is defined as
similar to bc.
The expectation is approximated with Monte Carlo esti-
mates, where the number of mask msampled per iteration
is set to 1 in practice. To implement (14), we oversample
the samples in Sbc for λup times more than the samples in
Sba. This sampling strategy is aimed at increasing the mix-
ture weight ϕof the proposed mixture distribution Pη
mix in
(5), while we empirically approximate the unknown bias-
conflicting group distribution with the sample set Sbc.
Note that although simple oversampling of bias-
conflicting samples may not lead to the OOD generalization
due to the inductive bias towards memorizing a few coun-
terexamples in overparameterized neural networks [30],
such failure is unlikely reproduced in learning pruning pa-
rameters under the strong sparsity constraint. We sam-
ple new weight masks mfor each training iteration in a
stochastic manner, effectively precluding the overparame-
terized networks from potentially memorizing the minority
samples. As a result, DCWP exhibits reasonable perfor-
mance even with few bias-conflicting samples.
Bridging the alignment gap by pruning To fully uti-
lize the bias-conflicting samples, we consider the sample-
wise relation between bias-conflicting samples and major-
ity bias-aligned samples. Zhang et al. [38] demonstrates
that the deteriorated OOD generalization is potentially at-
tributed to the distance gap between same-class representa-
tions; bias-aligned representations are more closely aligned
than bias-conflicting representations, although they are gen-
erated from the same-class samples. We hypothesized that
well-designed pruning masks could alleviate such geomet-
rical misalignment. Specifically, ideal weight sparsifica-
tion may guide each latent dimension to be independent of
spurious attributes, thereby preventing representations from
being misaligned with spuriously correlated latent dimen-
sions. This motivates us to explore pruning masks by con-
trastive learning. (Related illustrative example in appendix)
Following the conventional notations of contrastive
learning, we denote fenc
W:X RnL1as an encoder
摘要:

TrainingDebiasedSubnetworkswithContrastiveWeightPruningGeonYeongPark1SangminLee2SangWanLee1∗JongChulYe1,2,3∗1BioandBrainEngineering,2MathematicalSciences,3KimJaechulGraduateSchoolofAIKoreaAdvancedInstituteofScienceandTechnology(KAIST),Daejeon,Korea{pky3436,leeleesang,sangwan,jong.ye}@kaist.ac.krAbst...

收起<<
Training Debiased Subnetworks with Contrastive Weight Pruning Geon Yeong Park1Sangmin Lee2Sang Wan Lee1Jong Chul Ye123 1Bio and Brain Engineering2Mathematical Sciences3Kim Jaechul Graduate School of AI.pdf

共21页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!

相关推荐

分类:图书资源 价格:10玖币 属性:21 页 大小:1.85MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 21
客服
关注