
Preprint. Under Review.
emerged, such as variational methods (Graves, 2011; Blundell et al., 2015; Kingma et al., 2015),
Hamiltonian methods (Springenberg et al., 2016) and Langevin diffusion methods (Welling & Teh,
2011). Other methods to achieve Bayesian marginalization also exist, such as deep ensembles (Laksh-
minarayanan et al., 2016) and efficient versions of them (Wen et al., 2020; Gal & Ghahramani, 2015),
which have been empirically shown to improve uncertainty quantification. The concept of uncertainty
and calibration are inherently related, where calibration is commonly interpreted as the frequentist
notion of uncertainty. In our work, we will adopt some of these techniques specifically for the context
of semi-supervised learning in order to improve model calibration during pseudo-labeling. While
other methods for improving model calibration exists (Platt, 1999; Zadrozny & Elkan, 2002; Guo
et al., 2017), these are most commonly achieved in a post-hoc manner using a held-out validation
set; instead, we seek to improve calibration during training and with a scarce set of labels. Finally,
in the intersection of SSL and calibration, Rizve et al. (2021) proposes to leverage uncertainty to
select a better calibrated subset of pseudo-labels. Our work builds on a similar motivation, however,
in addition to improving the selection metric with uncertainty estimates, we further seek to directly
improve calibration via Bayesian marginalization (i.e. averaging predictions).
3 NOTATION AND BACKGROUND
Given a small amount of labeled data
L={(xl, yl)}Nl
l=1
(here,
yl∈ {0,1}K
, are one-hot labels) and
a large amount of unlabeled data
U={xu}Nu
u=1
, i.e.
NuNl
, in SSL, we seek to perform a
K
-class
classification task. Let
f(·, θf)
be a backbone encoder (e.g. ResNet or WideResNet) with trainable
parameters
θf
,
h(·, θh)
be a linear classification head, and
H
denote the standard cross-entropy loss.
Threshold-mediated methods.
Threshold-mediated methods such as Pseudo-Labels (Lee, 2013),
UDA (Xie et al., 2019a) and FixMatch (Sohn et al., 2020) minimizes a cross-entropy loss on
augmented copies of unlabeled samples whose confidence exceeds a pre-defined threshold. Let
α1
and
α2
denote two augmentation transformations and their corresponding network predictions for
sample
x
to be
q1=h◦f(α1(x))
and
q2=h◦f(α2(x))
, the total loss on a batch of unlabeled data
has the following form:
Lu=1
µB
µB
X
u=1
1(max(q1,u)≥τ)H(ρt(q1,u), q2,u)(1)
where
B
denotes the batch-size of labeled examples,
µ
a scaling hyperparameter for the unlabeled
batch-size,
τ∈[0,1]
is a threshold parameter often set close to 1 and
ρt
is either a sharpening
operation on the pseudo-labels, i.e.
[ρt(q)]k:= [q]1/t
k/PK
c=1[q]1/t
c
or an
argmax
operation (i.e.
t→0
).
ρt
also implicitly includes a “stop-gradient” operation, i.e. gradients are not back-propagated
from predictions of pseudo-labels.
Lu
is combined with the expected cross-entropy loss on labeled
examples,
Ll=1
BPB
l=1 H(yl, q1,l)
to form the combined loss
Ll+λLu
, with hyperparameter
λ
.
Differences between Pseudo-Labels, UDA and FixMatch are detailed in Appendix C.1.
Representation learning based methods.
We use PAWS (Assran et al., 2021) as a canonical ex-
ample for this family. A key difference from threshold-mediated methods is the lack of the parametric
classifier
h
, which is replaced by a non-parametric soft-nearest neighbour classifier (
πd
) based on
a labeled support set
{zs}B
s=1
. Let
z1=f(α1(x))
and
z2=f(α2(x))
be the representations for the
two views from the backbone encoder, their pseudo-labels (
q1
,
q2
) and the unlabeled loss are given by:
qi=πd(zi,{zs}) =
B
X
s=1
d(zi, zs)·ys
PB
s0=1 d(zi, zs0);Lu=1
2µB
µB
X
u=1
H(ρt(q1,u), q2,u)+H(ρt(q2,u), q1,u)
(2)
where
d(a, b) = exp(a·b/(kakkbkτp))
is a similarity metric with temperature hyperparameter
τp
and all other symbols have the same meanings defined before. The combined loss is
Lu+Lme-max
where the latter is a regularization term
Lme-max =H(¯q)
that seeks to maximize the entropy of the
average of predictions ¯q:= (1/(2µB)) PµB
u=1(ρt(q1,u) + ρt(q2,u)).
Calibration metrics.
A popular empirical metric to measure a model’s calibration is via the
Expected Calibration Error (
ECE
). Following (Guo et al., 2017; Minderer et al., 2021), we focus on
3