Preprint. Under Review. ON THE IMPORTANCE OF CALIBRATION IN SEMI-SUPERVISED LEARNING

2025-05-06 0 0 1.63MB 24 页 10玖币
侵权投诉
Preprint. Under Review.
ON THE IMPORTANCE OF CALIBRATION IN
SEMI-SUPERVISED LEARNING
Charlotte Loh
MIT EECS
MIT-IBM Watson AI Lab
Rumen Dangovski
MIT EECS
Shivchander Sudalairaj
MIT-IBM Watson AI Lab
Seungwook Han
MIT EECS
MIT-IBM Watson AI Lab
Ligong Han
Rutgers University
MIT-IBM Watson AI Lab
Leonid Karlinsky
MIT-IBM Watson AI Lab
Marin Soljaˇ
ci´
c
MIT Physics
Akash Srivastava
MIT-IBM Watson AI Lab
ABSTRACT
State-of-the-art (SOTA) semi-supervised learning (SSL) methods have been highly
successful in leveraging a mix of labeled and unlabeled data by combining tech-
niques of consistency regularization and pseudo-labeling. During pseudo-labeling,
the model’s predictions on unlabeled data are used for training and thus, model
calibration is important in mitigating confirmation bias. Yet, many SOTA methods
are optimized for model performance, with little focus directed to improve model
calibration. In this work, we empirically demonstrate that model calibration is
strongly correlated with model performance and propose to improve calibration
via approximate Bayesian techniques. We introduce a family of new SSL models
that optimizes for calibration and demonstrate their effectiveness across standard
vision benchmarks of CIFAR-10, CIFAR-100 and ImageNet, giving up to 15.9%
improvement in test accuracy. Furthermore, we also demonstrate their effectiveness
in additional realistic and challenging problems, such as class-imbalanced datasets
and in photonics science.
1 INTRODUCTION
While deep learning has achieved unprecedented success in recent years, its reliance on vast amounts
of labeled data remains a long standing challenge. Semi-supervised learning (SSL) aims to mitigate
this by leveraging unlabeled samples in combination with a limited set of annotated data. In
computer vision, two powerful techniques that have emerged are pseudo-labeling (also known as
self-training) (Rosenberg et al., 2005; Xie et al., 2019b) and consistency regularization (Bachman
et al., 2014; Sajjadi et al., 2016). Broadly, pseudo-labeling is the technique where artificial labels
are assigned to unlabeled samples, which are then used to train the model. Consistency regularization
enforces that random perturbations of the unlabeled inputs produce similar predictions. These two
techniques are typically combined by minimizing the cross-entropy between pseudo-labels and
predictions that are derived from differently augmented inputs, and have led to strong performances
on vision benchmarks (Sohn et al., 2020; Assran et al., 2021).
Intuitively, given that pseudo-labels (i.e. the model’s predictions for unlabeled data) are used to
drive training objectives, the calibration of the model should be of paramount importance. Model
calibration (Guo et al., 2017) is a measure of how a model’s output truthfully quantifies its predictive
uncertainty, i.e. it can be understood as the alignment between its prediction confidence and its ground-
truth accuracy. In some SSL methods, the model’s confidence is used as a selection metric (Lee, 2013;
Sohn et al., 2020) to determine pseudo-label acceptance, further highlighting the need for proper
confidence estimates. Even outside this family of methods, the use of cross-entropy minimization
objectives common in SSL implies that models will naturally be driven to output high-confidence
predictions (Grandvalet & Bengio, 2004). Having high-confidence predictions is highly desirable in
1
arXiv:2210.04783v1 [cs.LG] 10 Oct 2022
Preprint. Under Review.
SSL since we want the decision boundary to lie in low-density regions of the data manifold, i.e. away
from labeled data points (Murphy, 2022). However, without proper calibration, a model would easily
become over-confident. This is highly detrimental as the model would be encouraged to reinforce its
mistakes, resulting in the phenomenon commonly known as confirmation bias (Arazo et al., 2019).
Despite the fundamental importance of calibration in SSL, many state-of-the-art (SOTA) methods
have thus far been empirically driven and optimized for performance, with little focus on techniques
that specifically target improving calibration to mitigate confirmation bias. In this work, we explore
the generality of the importance of calibration in SSL by focusing on two broad families of SOTA
SSL methods that both use pseudo-labeling and consistency regularization: 1) threshold-mediated
methods (Sohn et al., 2020; Xie et al., 2019a; Lee, 2013) where the model selectively accepts pseudo-
labels whose confidence exceed a threshold and 2) “representation learning” methods adopted from
self-supervised learning (Assran et al., 2021) where pseudo-labels are non-selective and training
consists of two sequential stages of representation learning and fine-tuning.
To motivate our work, we first empirically show that strong baselines like FixMatch (Sohn et al., 2020)
and PAWS (Assran et al., 2021), from each of the two families, employ a set of indirect techniques to
implicitly maintain calibration and that achieving good calibration is strongly correlated to improved
performance. Furthermore, we demonstrate that it is not straightforward to control calibration via
such indirect techniques. To remedy this issue, we propose techniques that are directed to explicitly
improve calibration, by leveraging approximate Bayesian techniques, such as approximate Bayesian
neural networks (Blundell et al., 2015) and weight-ensembling approaches (Izmailov et al., 2018).
Our modification forms a new family of SSL methods that improve upon the SOTA on both standard
benchmarks and real-world applications. Our contributions are summarized as follows:
1.
Using SOTA SSL methods as case studies, we empirically show that maintaining good
calibration is strongly correlated to better model performance in SSL.
2.
We propose to use approximate Bayesian techniques to directly improve calibration and
provide theoretical results on generalization bounds for SSL to motivate our approach.
3.
We introduce a new family of methods that improves calibration via Bayesian model
averaging and weight-averaging techniques and demonstrate their improvements upon a
variety of SOTA SSL methods on standard benchmarks including CIFAR-10, CIFAR-100
and ImageNet, notably giving up to 15.9% gains in test accuracy.
4.
We further demonstrate the efficacy of our proposed model calibration methods in more
challenging and realistic scenarios, such as class-imbalanced datasets and a real-world
application in photonic science.
2 RELATED WORK
Semi-supervised learning (SSL) and confirmation bias.
A fundamental problem in SSL methods
based on pseudo-labeling (Rosenberg et al., 2005) is that of confirmation bias (Tarvainen & Valpola,
2017; Murphy, 2022), i.e. the phenomenon where a model overfits to incorrect pseudo-labels. Several
strategies have emerged to tackle this problem; Guo et al. (2020) and Ren et al. (2020) looked
into weighting unlabeled samples, Thulasidasan et al. (2019) and Arazo et al. (2019) proposes to
use augmentation strategies like MixUp (Zhang et al., 2017), while Cascante-Bonilla et al. (2020)
proposes to re-initialize the model before every iteration to overcome confirmation bias. Another
popular technique is to impose a selection metric (Yarowsky, 1995) to retain only the highest quality
pseudo-labels, commonly realized via a fixed threshold on the maximum class probability (Xie et al.,
2019a; Sohn et al., 2020). Recent works have further extended such selection metrics to be based
on dynamic thresholds, either in time (Xu et al., 2021) or class-wise (Zou et al., 2018; Zhang et al.,
2021). Different from the above approaches, our work proposes to overcome confirmation bias in
SSL by directly improving the calibration of the model through approximate Bayesian techniques.
Model calibration and uncertainty quantification.
Proper estimation of a network’s prediction
uncertainty is of practical importance (Amodei et al., 2016) and has been widely studied. A common
approach to improve uncertainty estimates is via Bayesian marginalization (Wilson & Izmailov,
2020), i.e. by weighting solutions by their posterior probabilities. Since exact Bayesian inference
is computationally intractable for neural networks, a series of approximate Bayesian methods have
2
Preprint. Under Review.
emerged, such as variational methods (Graves, 2011; Blundell et al., 2015; Kingma et al., 2015),
Hamiltonian methods (Springenberg et al., 2016) and Langevin diffusion methods (Welling & Teh,
2011). Other methods to achieve Bayesian marginalization also exist, such as deep ensembles (Laksh-
minarayanan et al., 2016) and efficient versions of them (Wen et al., 2020; Gal & Ghahramani, 2015),
which have been empirically shown to improve uncertainty quantification. The concept of uncertainty
and calibration are inherently related, where calibration is commonly interpreted as the frequentist
notion of uncertainty. In our work, we will adopt some of these techniques specifically for the context
of semi-supervised learning in order to improve model calibration during pseudo-labeling. While
other methods for improving model calibration exists (Platt, 1999; Zadrozny & Elkan, 2002; Guo
et al., 2017), these are most commonly achieved in a post-hoc manner using a held-out validation
set; instead, we seek to improve calibration during training and with a scarce set of labels. Finally,
in the intersection of SSL and calibration, Rizve et al. (2021) proposes to leverage uncertainty to
select a better calibrated subset of pseudo-labels. Our work builds on a similar motivation, however,
in addition to improving the selection metric with uncertainty estimates, we further seek to directly
improve calibration via Bayesian marginalization (i.e. averaging predictions).
3 NOTATION AND BACKGROUND
Given a small amount of labeled data
L={(xl, yl)}Nl
l=1
(here,
yl∈ {0,1}K
, are one-hot labels) and
a large amount of unlabeled data
U={xu}Nu
u=1
, i.e.
NuNl
, in SSL, we seek to perform a
K
-class
classification task. Let
f(·, θf)
be a backbone encoder (e.g. ResNet or WideResNet) with trainable
parameters
θf
,
h(·, θh)
be a linear classification head, and
H
denote the standard cross-entropy loss.
Threshold-mediated methods.
Threshold-mediated methods such as Pseudo-Labels (Lee, 2013),
UDA (Xie et al., 2019a) and FixMatch (Sohn et al., 2020) minimizes a cross-entropy loss on
augmented copies of unlabeled samples whose confidence exceeds a pre-defined threshold. Let
α1
and
α2
denote two augmentation transformations and their corresponding network predictions for
sample
x
to be
q1=hf(α1(x))
and
q2=hf(α2(x))
, the total loss on a batch of unlabeled data
has the following form:
Lu=1
µB
µB
X
u=1
1(max(q1,u)τ)H(ρt(q1,u), q2,u)(1)
where
B
denotes the batch-size of labeled examples,
µ
a scaling hyperparameter for the unlabeled
batch-size,
τ[0,1]
is a threshold parameter often set close to 1 and
ρt
is either a sharpening
operation on the pseudo-labels, i.e.
[ρt(q)]k:= [q]1/t
k/PK
c=1[q]1/t
c
or an
argmax
operation (i.e.
t0
).
ρt
also implicitly includes a “stop-gradient” operation, i.e. gradients are not back-propagated
from predictions of pseudo-labels.
Lu
is combined with the expected cross-entropy loss on labeled
examples,
Ll=1
BPB
l=1 H(yl, q1,l)
to form the combined loss
Ll+λLu
, with hyperparameter
λ
.
Differences between Pseudo-Labels, UDA and FixMatch are detailed in Appendix C.1.
Representation learning based methods.
We use PAWS (Assran et al., 2021) as a canonical ex-
ample for this family. A key difference from threshold-mediated methods is the lack of the parametric
classifier
h
, which is replaced by a non-parametric soft-nearest neighbour classifier (
πd
) based on
a labeled support set
{zs}B
s=1
. Let
z1=f(α1(x))
and
z2=f(α2(x))
be the representations for the
two views from the backbone encoder, their pseudo-labels (
q1
,
q2
) and the unlabeled loss are given by:
qi=πd(zi,{zs}) =
B
X
s=1
d(zi, zs)·ys
PB
s0=1 d(zi, zs0);Lu=1
2µB
µB
X
u=1
H(ρt(q1,u), q2,u)+H(ρt(q2,u), q1,u)
(2)
where
d(a, b) = exp(a·b/(kakkbkτp))
is a similarity metric with temperature hyperparameter
τp
and all other symbols have the same meanings defined before. The combined loss is
Lu+Lme-max
where the latter is a regularization term
Lme-max =H(¯q)
that seeks to maximize the entropy of the
average of predictions ¯q:= (1/(2µB)) PµB
u=1(ρt(q1,u) + ρt(q2,u)).
Calibration metrics.
A popular empirical metric to measure a model’s calibration is via the
Expected Calibration Error (
ECE
). Following (Guo et al., 2017; Minderer et al., 2021), we focus on
3
Preprint. Under Review.
    





















 



 
Figure 1:
Empirical study on CIFAR-100, 4000 labels.
(a) FixMatch: Test accuracy (
%
) against
ECE (lower ECE is better calibration) when varying the threshold
τ
(value shown beside each scatter
point). (b) PAWS: effect of various parameters on test accuracy and ECE; scatter plot explores a
wider range of parameters, Temp.[0.03,0.1] and Sharpen [0.1,0.5].
a slightly weaker condition and consider only the model’s most likely class-prediction, which can be
computed as follows. Let
ρ0(q)
denote the model’s prediction (where
ρ0
is the
argmax
operation) as
defined before, the model predictions on a batch of
N
samples are grouped into
M
equal-interval
bins, i.e.
Bm
contains the set of samples with
ρ0(q)(m1
M,m
M]
. ECE is then computed as the
expected difference between the accuracy and confidence of each bin over all Nsamples:
ECE =
M
X
m=1
|Bm|
N|acc(Bm)conf(Bm)|(3)
where
acc(Bm) = (1/|Bm|)Pi∈Bm
1(ρ0(qi) = yi)
and
conf(Bm) = (1/|Bm|)Pi∈Bmmax qi
with
yi
the true label of sample
i
. In this work, we estimate
ECE
using
M= 10
bins. We also
caveat here that while
ECE
is not free from biases (Minderer et al., 2021), we chose
ECE
over
alternatives (Brier, 1950; DeGroot & Fienberg, 1983) due to its simplicity and widespread adoption.
4 CALIBRATION IN SEMI-SUPERVISED LEARNING
Better calibration is correlated to better performances.
To motivate our work, we first perform
detailed ablation studies on SOTA SSL methods. In particular, we used FixMatch (Sohn et al.,
2020) and PAWS (Assran et al., 2021), each from the two families, due to their strong performance.
FixMatch relies on using a high value of the selection threshold
τ
to mitigate confirmation bias by
accepting only the most credible pseudo-labels; on the other hand, PAWS uses multiple techniques
that seem to have an effect on shaping the output prediction distribution and implicitly controlling
calibration — these include label smoothing (Müller et al., 2019), mean-entropy maximization (Joulin
& Bach, 2012), sharpening (Berthelot et al., 2019b;a; Xie et al., 2019a) and temperature scaling (Guo
et al., 2017). In Fig. 1, we ablated on all of these parameters and observe an overall trend, for
both methods, that model performance is strongly correlated to better calibration (i.e. lower ECE).
However, these trends are inevitably noisy since none of these parameters tunes for calibration in
isolation but are instead optimized towards performance. Therefore, in this work, we aim to explore
techniques that predominantly adjusts for model calibration in these methods to clearly demonstrate
the direct effect of improving calibration towards model accuracy.
4.1 IMPROVING CALIBRATION WITH BAYESIAN MODEL AVERAGING
Bayesian techniques have been widely known to produce well-calibrated uncertainty estimates (Wil-
son & Izmailov, 2020), thus in our work, we explored the use of approximate Bayesian Neural
Networks (BNN). To minimize the additional computational overhead, we propose to only replace
the
final layer
of the network with a BNN layer. For threshold-mediated methods this is the linear
classification head
h
and for representation learning methods this is the final linear layer of the
projection network (i.e. immediately before the representations). For brevity we will simply denote
this BNN layer to be
h
and an input embedding to this layer to be
v
in this section. Following a
Bayesian approach, we first assume a prior distribution on weights
p(θh)
. Given some training data
DX:= (X, Y )
, we seek to calculate the posterior distribution of weights,
p(θh|DX)
, which can then
be used to derive the posterior predictive
p(y|v, DX) = Rp(y|v, θh)p(θh|DX)h
. This process
4
Preprint. Under Review.
is also known as “Bayesian model averaging” or “Bayesian marginalization” (Wilson & Izmailov,
2020). Since exact Bayesian inference is computationally intractable for neural networks, we adopt
a variational approach following (Blundell et al., 2015), where we learn a Gaussian variational
approximation to the posterior
qφ(θh|φ)
, parameterized by
φ
, by maximizing the evidence lower-
bound (ELBO) (see Appendix B.1 for details). The
ELBO = Eqlog p(Y|X;θ)KL(q(θ|φ)kp(θ))
consists of a log-likelihood (data-dependent) term and a KL (prior-dependent) term.
Theoretical results.
We motivate our approach by using Corollary 1 below, derived from the
PAC-Bayes framework (see Appendix A for the proof). The statement shows the generalization error
bounds on the variational posterior in the SSL setting and suggests that this generalization error is
upper bounded by the negative
ELBO
. This motivates our approach, i.e. by maximizing the ELBO we
improve generalization by minimizing the upper bound to the generalization error. The second term
on the right side of the inequality characterises the SSL setting, and vanishes in the supervised setting.
Corollary 1
Let
D
be a data distribution where i.i.d. training samples are sampled, of which
Nl
are labeled,
(x, y)∼ DNl
and
Nu
are unlabeled,
(x, ˆy)∼ DNu
where
ˆyi
(
yi
) denotes the model-
assigned pseudo-labels (true labels) for input
xi
. For the negative log likelihood loss function
`
,
assuming that
`
is sub-Gaussian (see Germain et al. (2016)) with variance factor
s2
, then with
probability at least 1δ, the generalization error of the variational posterior q(θ|φ)is given by,
Eq(θ|φ)L`
D(q)1
N[ELBO] Eq(θ|φ)"1
Nu
Nu
X
i=1
log p(ˆyi|xi;θ)
p(yi|xi;θ)#+1
Nlog 1
δ+s2
2(4)
Incorporating Bayesian model averaging techniques involves two main modifications during pseudo-
labeling: 1)
M
weights from the BNN layer are sampled and predictions are derived from the Monte
Carlo estimated posterior predictive, i.e.
ˆq= (1/M)PM
mh(v, θ(m)
h)
, and 2) the selection criteria,
if present, is based upon their variance,
σ2
c= (1/M)PM
m(h(v, θ(m)
h)ˆq)2
, at the predicted class
c= argmaxc0[ˆq]c0
. This constitutes a better uncertainty measure as compared to the maximum logit
value used in threshold-mediated methods and is highly intuitive — if the model’s prediction has a
large variance, it is highly uncertain and the pseudo-label should not be accepted. In practice, as
σ2
c
decreases across training, we use a simple quantile
Q
over the batch to define the threshold where
pseudo-labels of samples with
σ2
c< Q
are accepted, with
Q
as a hyperparameter (see Appendix B.1
for pseudocode). In representation learning methods where predictions are computed using a
non-parametric nearest neighbour classifier, we use the “Bayesian marginalized” versions of the
representations, i.e. qi=πd(ˆzi,{ˆzs})in Eq. (2) where ˆz= (1/M)PM
mh(v, θ(m)
h).
Rested on the two modifications detailed above, in this work we introduce a new family of SSL
methods that builds upon SOTA SSL methods and name them “BAM-X”, which incorporates
approximate BAyesian Model averaging (BAM) during pseudo-labeling for SSL method X.
4.2 IMPROVING CALIBRATION WITH WEIGHT AVERAGING TECHNIQUES
A BNN classifier has two desirable characteristics: 1) rather than a single prediction, multiple
predictions are averaged or “Bayesian marginalized” and 2) we obtain a better measure of uncertainty
(i.e. confidence estimates) from the variance across predictions which serves as a better selection
metric. SSL methods that do not use a selection metric, e.g. representation learning methods like
PAWS, cannot explicitly benefit from (2) and in these cases, instead of aggregating over just one
layer, one could seek to aggregate over the entire network. This can be achieved by ensembling
approaches (Lakshminarayanan et al., 2016; Wen et al., 2020), which have been highly successful
at improving uncertainty estimation; however, training multiple networks would add immense
computational overhead. In this work, we propose and explore weight averaging approaches such
as Stochastic Weight Averaging (Izmailov et al., 2018) (SWA) and Exponential Moving Averaging
(EMA) (Tarvainen & Valpola, 2017; He et al., 2020; Grill et al., 2020a) to improve calibration during
pseudo-labeling. Weight averaging differs from ensembling in that model weights are averaged
instead of predictions; however, the approximation of SWA to Fast Geometric Ensembles (Garipov
et al., 2018) has been justified by previous work (Izmailov et al., 2018). To the best of our knowledge,
the connection of EMA to ensembling has not been formally shown.
5
摘要:

Preprint.UnderReview.ONTHEIMPORTANCEOFCALIBRATIONINSEMI-SUPERVISEDLEARNINGCharlotteLohMITEECSMIT-IBMWatsonAILabRumenDangovskiMITEECSShivchanderSudalairajMIT-IBMWatsonAILabSeungwookHanMITEECSMIT-IBMWatsonAILabLigongHanRutgersUniversityMIT-IBMWatsonAILabLeonidKarlinskyMIT-IBMWatsonAILabMarinSoljaci´c...

展开>> 收起<<
Preprint. Under Review. ON THE IMPORTANCE OF CALIBRATION IN SEMI-SUPERVISED LEARNING.pdf

共24页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!

相关推荐

分类:图书资源 价格:10玖币 属性:24 页 大小:1.63MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 24
客服
关注