Augmentation by Counterfactual Explanation - Fixing an Overconfident Classifier Sumedha Singla

2025-05-02 0 0 5.91MB 17 页 10玖币
侵权投诉
Augmentation by Counterfactual Explanation -
Fixing an Overconfident Classifier
Sumedha Singla
University of Pittsburgh
sumedha.singla@pitt.edu
Nihal Murali
University of Pittsburgh
nihal.murali@pitt.edu
Forough Arabshahi
Meta AI
forough@meta.com
Sofia Triantafyllou
University of Crete
sof.triantafillou@gmail.com
Kayhan Batmanghelich
University of Pittsburgh
kayhan@pitt.edu
Abstract
A highly accurate but overconfident model is ill-suited
for deployment in critical applications such as healthcare
and autonomous driving. The classification outcome should
reflect a high uncertainty on ambiguous in-distribution
samples that lie close to the decision boundary. The model
should also refrain from making overconfident decisions on
samples that lie far outside its training distribution, far-out-
of-distribution (far-OOD), or on unseen samples from novel
classes that lie near its training distribution (near-OOD).
This paper proposes an application of counterfactual expla-
nations in fixing an over-confident classifier. Specifically,
we propose to fine-tune a given pre-trained classifier using
augmentations from a counterfactual explainer (ACE) to fix
its uncertainty characteristics while retaining its predictive
performance. We perform extensive experiments with de-
tecting far-OOD, near-OOD, and ambiguous samples. Our
empirical results show that the revised model have improved
uncertainty measures, and its performance is competitive to
the state-of-the-art methods.
1. Introduction
Deep neural networks (DNN) are increasingly being
used in decision-making pipelines for real-world high-
stake applications such as medical diagnostics [6] and au-
tonomous driving [7]. For optimal decision making, the
DNN should produce accurate predictions as well as quan-
tify uncertainty over its predictions [8, 38]. While substan-
tial efforts are made to engineer highly accurate architec-
tures [23], many existing state-of-the-art DNNs do not cap-
ture the uncertainty correctly [9].
We consider two types of uncertainty: epistemic uncer-
*Equal contribution
tainty, caused due to limited data and knowledge of the
model, and aleatoric uncertainty, caused by inherent noise
or ambiguity in the data [30]. We evaluate these uncertain-
ties with respect to three test distributions (see Fig 1):
Ambiguous in-Distribution (AiD): These are the
samples within the training distribution that have an
inherent ambiguity in their class labels. Such ambigu-
ity represents high aleatoric uncertainty arising from
class overlap or noise [60], e.g. an image of a ‘5’ that
is similar to a ‘6’.
Near-OOD: Near-OOD represents a label shift where
label space is different between ID and OOD data. It
has high epistemic uncertainty arising from the classi-
fier’s limited information on unseen data. We use sam-
ples from unseen classes of the training distribution as
near-OOD.
Far-OOD: Far-OOD represents data distribution that
is significantly different from the training distribution.
It has high epistemic uncertainty arising from mis-
match between different data distributions.
Much of the earlier work focuses on threshold-based de-
tectors that use information from a pre-trained DNN to iden-
tify OOD samples [15, 19, 24, 68, 21]. Such methods focus
on far-OOD detection and often do not address the over-
confidence problem in DNN. In another line of research,
variants of Bayesian models [52, 9] and ensemble learn-
ing [22, 33] were explored to provide reliable uncertainty
estimates. Recently, there is a shift towards designing gen-
eralizable DNN that provide robust uncertainty estimates
in a single forward pass [65, 4, 48]. Such methods usu-
ally propose changes to the DNN architecture [62], training
procedure [71] or loss functions [50] to encourage separa-
tion between ID and OOD data. Popular methods include,
arXiv:2210.12196v1 [cs.LG] 21 Oct 2022
Near-OOD(wild) Far-OOD
(CIFAR10)
Ambiguous iD
(AiD)
Near-OOD
(other lesions)
Far-OOD
(CelebA)
Ambiguous iD
(AiD)
Skin lesion
(nv / bkl+mel)
AB C D
AFHQ
(cat / dog)
Figure 1. Comparison of the uncertainty estimates from the baseline, before (dotted line) and after (solid line) fine-tuning with augmentation by counterfac-
tual explanation (ACE). The plots visualize the distribution of predicted entropy (columns A-C) from the classifier and density score from the discriminator
(column D). The y-axis of this density plot is the probability density function whose value is meaningful only for relative comparisons between groups,
summarized in the legend. A) visualizes the impact of fine-tuning on the in-distribution (iD) samples. A large overlap suggests minimum changes to
classification outcome for iD samples. Next columns visualize change in the distribution for ambiguous iD (AiD) (B) and near-OOD samples (C). The
peak of the distribution for AiD and near-OOD samples shifted right, thus assigning higher uncertainty and reducing overlap with iD samples. D) compares
the density score from discriminator for iD (blue solid) and far-OOD (orange solid) samples. The overlap between the distributions is minimum, resulting
in a high AUC-ROC for binary classification over uncertain samples and iD samples. Our method improved the uncertainty estimates across the spectrum.
training deterministic DNN with a distance-aware feature
space [66, 42] and regularizing DNN training with a gener-
ative model that simulates OOD data [36]. However, these
methods require a DNN model to be trained from scratch
and are not compatible with an existing pre-trained DNN.
Also, they may use auxiliary data to learn to distinguish
OOD inputs [43].
Most of the DNN-based classification models are trained
to improve accuracy on a test set. Accuracy only captures
the proportion of samples that are on the correct side of the
decision boundary. However, it ignores the relative distance
of a sample from the decision boundary [31]. Ideally, sam-
ples closer to the boundary should have high uncertainty.
The actual predicted value from the classifier should reflect
this uncertainty via a low confidence score [25]. Conven-
tionally, DNNs are trained on hard-label datasets to min-
imize a negative log-likelihood (NLL) loss. Such mod-
els tend to over-saturate on NLL and end-up learning very
sharp decision boundaries [16, 49]. The resulting classifiers
extrapolate over-confidently on ambiguous, near boundary
samples, and the problem amplifies as we move to OOD
regions [8].
In this paper, we propose to mitigate the overconfidence
problem of a pre-trained DNN by fine-tuning it with aug-
mentations derived from a counterfactual explainer (ACE).
We derived counterfactuals using a progressive counterfac-
tual explainer (PCE) that create a series of perturbations
of an input image, such that the classification decision is
changed to a different class [58, 34]. PCE is trained to
generate on-manifold samples in the regions between the
classes. These samples along with soft labels that mimics
their distance from the decision boundary, are used to fine-
tuned the classifier. We hypothesis that fine-tuning on such
data would broaden the classifier’s decision boundary. Our
empirical results show the fine-tuned classifier exhibits bet-
ter uncertainty quantification over ambiguous-iD and OOD
samples. Our contributions are as follows: (1) We present
a novel strategy to fine-tune an existing pre-trained DNN
using ACE, to improve its uncertainty estimates. (2) We
proposed a refined architecture to generate counterfactual
explanations that takes into account continuous condition
and multiple target classes. (3) We used the discrimina-
tor of our GAN-based counterfactual explainer as a selec-
tion function to reject far-OOD samples. (4) The fine-tuned
classifier with rejection head, successfully captures uncer-
tainty over ambiguous-iD and OOD samples, and also ex-
hibits better robustness to popular adversarial attacks.
2. Method
In this paper, we consider a pre-trained DNN classifier,
fθ, with good prediction accuracy but sub-optimal uncer-
tainty estimates. We assume fθis a differentiable function
and we have access to its gradient with respect to the in-
put, xfθ(x), and to its final prediction outcome fθ(x).
We also assume access to either the training data for fθ, or
an equivalent dataset with competitive prediction accuracy.
We further assume that the training dataset for fθhas hard
labels {0,1}for all the classes.
<latexit sha1_base64="dUJXlumNIWz6CWS8Qr12fiQcmDI=">AAAB8XicbVBNS8NAEN3Ur1q/qh69LBbBU0mkqMeiF48V7Ae2oWy2k3bpZhN2J0IJ/RdePCji1X/jzX/jts1BWx8MPN6bYWZekEhh0HW/ncLa+sbmVnG7tLO7t39QPjxqmTjVHJo8lrHuBMyAFAqaKFBCJ9HAokBCOxjfzvz2E2gjYvWAkwT8iA2VCAVnaKXHsJ/1cATIpv1yxa26c9BV4uWkQnI0+uWv3iDmaQQKuWTGdD03QT9jGgWXMC31UgMJ42M2hK6likVg/Gx+8ZSeWWVAw1jbUkjn6u+JjEXGTKLAdkYMR2bZm4n/ed0Uw2s/EypJERRfLApTSTGms/fpQGjgKCeWMK6FvZXyEdOMow2pZEPwll9eJa2LqndZrd3XKvWbPI4iOSGn5Jx45IrUyR1pkCbhRJFn8kreHOO8OO/Ox6K14OQzx+QPnM8f6aeRFg==</latexit>
f
<latexit sha1_base64="qFCE1Fw72Vo5Fquk8yBJdvxOzw0=">AAAB9HicbVDLTgJBEJzFF64v1KOXiYTEE9k1RD0SPegRE0ES2JDZoRcmzD6c6SUhG77DiweN8erHePNvHGAPClbSSaWqO91dfiKFRsf5tgpr6xubW8Vte2d3b/+gdHjU0nGqODR5LGPV9pkGKSJookAJ7UQBC30Jj/7oZuY/jkFpEUcPOEnAC9kgEoHgDI3kVYJe1sUhIJvat71S2ak6c9BV4uakTHI0eqWvbj/maQgRcsm07rhOgl7GFAouYWp3Uw0J4yM2gI6hEQtBe9n86CmtGKVPg1iZipDO1d8TGQu1noS+6QwZDvWyNxP/8zopBldeJqIkRYj4YlGQSooxnSVA+0IBRzkxhHElzK2UD5liHE1OtgnBXX55lbTOq+5FtXZfK9ev8ziK5ISckjPikktSJ3ekQZqEkyfyTF7JmzW2Xqx362PRWrDymWPyB9bnDxWpkao=</latexit>
G
<latexit sha1_base64="CctnbZwlcylFZr8ejmxXw4Z1+Y8=">AAAB+XicbVBNS8NAEN3Urxq/oh69LJaCp5JIUY9FBT1WsLbQhrDZbtqlm03YnRRK6D/x4kERr/4Tb/4bt20O2vpg4PHeDDPzwlRwDa77bZXW1jc2t8rb9s7u3v6Bc3j0pJNMUdaiiUhUJySaCS5ZCzgI1kkVI3EoWDsc3cz89pgpzRP5CJOU+TEZSB5xSsBIgeNUoyDvwZABmdrVO/s2cCpuzZ0DrxKvIBVUoBk4X71+QrOYSaCCaN313BT8nCjgVLCp3cs0SwkdkQHrGipJzLSfzy+f4qpR+jhKlCkJeK7+nshJrPUkDk1nTGCol72Z+J/XzSC68nMu0wyYpItFUSYwJHgWA+5zxSiIiSGEKm5uxXRIFKFgwrJNCN7yy6vk6bzmXdTqD/VK47qIo4xO0Ck6Qx66RA10j5qohSgao2f0it6s3Hqx3q2PRWvJKmaO0R9Ynz+485Js</latexit>
D
(a) Progressive Counterfactual
Explainer (PCE)
Condition
(b) Augmented by Counterfactual
Explanation (ACE)
(c) Fine-tuning with
ACE
<latexit sha1_base64="CctnbZwlcylFZr8ejmxXw4Z1+Y8=">AAAB+XicbVBNS8NAEN3Urxq/oh69LJaCp5JIUY9FBT1WsLbQhrDZbtqlm03YnRRK6D/x4kERr/4Tb/4bt20O2vpg4PHeDDPzwlRwDa77bZXW1jc2t8rb9s7u3v6Bc3j0pJNMUdaiiUhUJySaCS5ZCzgI1kkVI3EoWDsc3cz89pgpzRP5CJOU+TEZSB5xSsBIgeNUoyDvwZABmdrVO/s2cCpuzZ0DrxKvIBVUoBk4X71+QrOYSaCCaN313BT8nCjgVLCp3cs0SwkdkQHrGipJzLSfzy+f4qpR+jhKlCkJeK7+nshJrPUkDk1nTGCol72Z+J/XzSC68nMu0wyYpItFUSYwJHgWA+5zxSiIiSGEKm5uxXRIFKFgwrJNCN7yy6vk6bzmXdTqD/VK47qIo4xO0Ck6Qx66RA10j5qohSgao2f0it6s3Hqx3q2PRWvJKmaO0R9Ynz+485Js</latexit>
D
(Copy)
Abstain
<latexit sha1_base64="L1kc6coMlIhaFDyvnvtp6tHgGeg=">AAACSnicbVBNSysxFM1Unx/jV9Wlm2AZqDwpMyLqUlTQpYJVodMOmfTWBpOZIbkjLUN/nxtX7vwRb/MWirgxrUX8uhByOOeem5sTZ1IY9P1HpzQx+WdqembWnZtfWFwqL69cmDTXHOo8lam+ipkBKRKoo0AJV5kGpmIJl/HN4VC/vAVtRJqcYz+DpmLXiegIztBSUZl5nagIsQvIBq537HpHrtdrCdfjLUFDIxQNEXpY1K0p1arqbwYbtrEXFXx4tQoxGMGPKfQvDY9A2nHVkLdT3IjKFb/mj4r+BMEYVMi4TqPyQ9hOea4gQS6ZMY3Az7BZMI2CSxi4YW4gY/yGXUPDwoQpMM1iFMWAepZpU7uqPQnSEfvZUTBlTF/FtlMx7Jrv2pD8TWvk2NlrFiLJcoSEvz/UySXFlA5zpW2hgaPsW8C4FnZXyrtMM442fdeGEHz/8k9wsVULdmrbZ9uV/YNxHDNkjayTKgnILtknJ+SU1Aknd+QfeSLPzr3z33lxXt9bS87Ys0q+VGnyDeL4sAE=</latexit>
f+(·)
Pre-trained Classifier
(d) Improved classifier
with
reject option
<latexit sha1_base64="qFCE1Fw72Vo5Fquk8yBJdvxOzw0=">AAAB9HicbVDLTgJBEJzFF64v1KOXiYTEE9k1RD0SPegRE0ES2JDZoRcmzD6c6SUhG77DiweN8erHePNvHGAPClbSSaWqO91dfiKFRsf5tgpr6xubW8Vte2d3b/+gdHjU0nGqODR5LGPV9pkGKSJookAJ7UQBC30Jj/7oZuY/jkFpEUcPOEnAC9kgEoHgDI3kVYJe1sUhIJvat71S2ak6c9BV4uakTHI0eqWvbj/maQgRcsm07rhOgl7GFAouYWp3Uw0J4yM2gI6hEQtBe9n86CmtGKVPg1iZipDO1d8TGQu1noS+6QwZDvWyNxP/8zopBldeJqIkRYj4YlGQSooxnSVA+0IBRzkxhHElzK2UD5liHE1OtgnBXX55lbTOq+5FtXZfK9ev8ziK5ISckjPikktSJ3ekQZqEkyfyTF7JmzW2Xqx362PRWrDymWPyB9bnDxWpkao=</latexit>
G
<latexit sha1_base64="aaG3jQtGm+tnIfNq68KwIp3YMm0=">AAACLHicbVBNS8NAEN34bfyqevSyWAIKUhIR9Sgq6FHBqtDUsNlO7dLdJOxOpCX0B3nxrwjiQRGv/g43tQe/Hiz7ePOGmXlxJoVB3391xsYnJqemZ2bdufmFxaXK8sqlSXPNoc5TmerrmBmQIoE6CpRwnWlgKpZwFXePyvrVHWgj0uQC+xk0FbtNRFtwhlaKKkdeOypC7ACygeuduN6x6/VcfiNoaISiIUIPi7rtSLXa8LeCTevqRQUvv5tCDEoaVap+zR+C/iXBiFTJCGdR5SlspTxXkCCXzJhG4GfYLJhGwSUM3DA3kDHeZbfQsDRhCkyzGB47oJ5VWtTuY1+CdKh+7yiYMqavYutUDDvmd60U/6s1cmzvNwuRZDlCwr8GtXNJMaVlcrQlNHCUfUsY18LuSnmHacbR5uvaEILfJ/8ll9u1YLe2c75TPTgcxTFD1sg62SAB2SMH5JSckTrh5J48khfy6jw4z86b8/5lHXNGPavkB5yPT+S8pac=</latexit>
ciUniform(0,1)
Figure 2. (a) Given a pre-trained classifier fθ, we learn a c-GAN based progressive counterfactual explainer (PCE) G(x,c), while keeping fθfixed. (b)
The trained PCE creates counterfactually augmented data. (c) A combination of original training data and augmented data is used to fine-tune the classifier,
fθ+∆. (d) The discriminator from PCE serves as a selection function to detect and reject OOD data.
Our goal is to improve the pre-trained classifier fθsuch
that the revised model provides better uncertainty estimates,
while retaining its original predictive accuracy. To enable
this, we follow a two step approach. First, we fine-tune
fθon counterfactually augmented data. The fine-tuning
helps in widening the classification boundary of fθ, result-
ing in improved uncertainty estimates on ambiguous and
near-OOD samples. Second, we use a density estimator to
identify and reject far-OOD samples.
We adapted previously proposed PCE [58] to generate
counterfactually augmented data. We improved the exist-
ing implementations of PCE, by adopting a StyleGANv2-
based backbone for the conditional-GAN in PCE. This al-
lows using continuous vector fθ(x)as condition for con-
ditional generation. Further, we used the discriminator of
cGAN as a selection function to abstain revised fθ+∆ from
making prediction on far-OOD samples (see Fig. 2).
Notation: The classification function is defined as fθ:
RdRK, where θrepresents model parameters. The
training dataset for fθis defined as D={X,Y}, where x
Xrepresents an input space and y∈ Y ={1,2,··· , K}is
a label set over Kclasses. The classifier produces point es-
timates to approximate the posterior probability P(y|x,D).
2.1. Progressive Counterfactual Explainer (PCE)
We designed the PCE network to take a query image
(xRd) and a desired classification outcome (cRK) as
input, and create a perturbation of a query image (ˆ
x) such
that fθ(ˆ
x)c. Our formulation, ˆ
x=G(x,c)allows us to
use cto traverse through the decision boundary of fθfrom
the original class to a counterfactual class. Following pre-
vious work [34, 58, 59], we design the PCE to satisfy the
following three properties:
1. Data consistency: The perturbed image, ˆ
xshould be
realistic and should resemble samples in X.
2. Classifier consistency: The perturbed image, ˆ
x
should produce the desired output from the classifier
fθi.e.fθ(G(x,c)) c.
3. Self consistency: Using the original classification de-
cision fθ(x)as condition, the PCE should produce a
perturbation that is very similar to the query image,
i.e.G(G(x,c), fθ(x)) = xand G(x, fθ(x)) = x.
Data Consistency: We formulate the PCE as a cGAN that
learns the underlying data distribution of the input space X
without an explicit likelihood assumption. The GAN model
comprised of two networks – the generator G(·)and the
discriminator D(·). The G(·)learns to generate fake data,
while the D(·)is trained to distinguish between the real and
fake samples. We jointly train G, D to optimize the follow-
ing logistic adversarial loss [12],
Ladv(D, G) = Ex[log D(x) + log(1 D(G(x,c)))] (1)
The earlier implementations of PCE [58], have a hard
constraint of representing the condition cas discrete vari-
ables. fθ(x)is a continuous variable in range [0,1]. We
adapted StyleGANv2 [1] as the backbone of the cGAN.
This formulation allow us to use cRKas condition.
We formulate the generator as G(x,c) = g(e(x),c),
a composite of two functions, an image encoder e(·)and
a conditional decoder g(·)[1]. The encoder function e:
X → W+, learns a mapping from the input space Xto an
extended latent space W+. The detailed architecture is pro-
vided in Fig. 3. Further, we also extended the discriminator
network D(·)to have auxiliary information from the classi-
fier fθ. Specifically, we concatenate the penultimate activa-
tions from the fθ(x)with the penultimate activations from
the D(x), to obtain a revised representation before the final
fully-connected layer of the discriminator. The detailed ar-
chitecture is summarized in supplementary material (SM).
x
<latexit sha1_base64="ktT5kyZq96p4OK5D6RqAkIfCD8E=">AAAC+nicbVLbbtMwGHbCYSOcOrjcjUUVKRWjSsak7XLA0EBMaEh0m9RUkeM4mzXnIPsPtErzKNxwAULc8iTc8TY4bbam237J0ufv+892mAuuwHX/Geat23furqzes+4/ePjocWftyZHKCknZgGYikychUUzwlA2Ag2AnuWQkCQU7Ds/f1PrxFyYVz9LPMMnZKCGnKY85JaCpYM3o2D6wMZQfDiondvyEwFkYl+MqKP2ICSBVD0+nuK308HM813qWvaAt+23bSV8te9/xaZTBBg1iZ7zRBPUW/A1pgjm3nC24dLyIfHHJ7C1I9yITJaJ8XwXtrpfFg3q+2dx0/9XHqrJatTpdt+/ODF8HXgO6qLHDoPPXjzJaJCwFKohSQ8/NYVQSCZwKplMXiuWEnpNTNtQwJQlTo3L2dBW2NRPhOJP6pIBnbDuiJIlSkyTUnnWH6qpWkzdpwwLinVHJ07wAltJ5obgQGDJc/wMccckoiIkGhEque8X0jEhCQf8WSy/BuzrydXC02fde9jc/bXV3XzfrWEXr6BlykIe20S56hw7RAFHjq/HN+GH8NKfmd/OX+XvuahpNzFO0ZOaf/6YR6wo=</latexit>
Affine Transformation
Style space Const
D
<latexit sha1_base64="dooQveNbswhvNv0VM0cL6qY1pJA=">AAACqnicbVHbbtNAEF27XIq5BXjkZUVk4YgQ2W0leKy4CCRAKogkRXFkrdfjdtX1RbvjKpHrj+MXeONvWCcuTS8jrXT2nDmzszNxKYVG3/9r2Vu3bt+5u33Puf/g4aPHvSdPJ7qoFIcxL2ShDmOmQYocxihQwmGpgGWxhGl88r7Vp6egtCjyn7gsYZ6xo1ykgjM0VNT77YYIC6y/fG281AszhsdxWi+aqA4TkMiaAT07o5vKgL6ia23guBe0437cTDJXx/3khTwpcMij1FsMO9Pggr+hTLTmLleL/ieeO1+fMx+cqNf3R/4q6HUQdKBPujiIen/CpOBVBjlyybSeBX6J85opFFxC44SVhpLxE3YEMwNzloGe16tRN9Q1TELTQpmTI12xm46aZVovs9hktu3rq1pL3qTNKkzfzmuRlxVCztcPpZWkWNB2bzQRCjjKpQGMK2F6pfyYKcbRbLcdQnD1y9fBZGcU7I52vu/1999149gmz8kL4pGAvCH75DM5IGPCrZfWN2tiTe2h/cP+Zc/WqbbVeZ6RS2En/wAai8y1</latexit>
Figure 3. PCE: The encoder-decoder architecture to create counterfactual
augmentation for a given query image. ACE: Given a query image, the
trained PCE generates a series of perturbations that gradually traverse the
decision boundary of fθfrom the original class to a counter-factual class,
while still remaining plausible and realistic-looking.
We also borrow the concept of path-length reg-
ularization Lreg(G)from StyleGANv2 to enforce
smoother latent space interpolations for the generator.
Lreg(G) = Ewe(x),x∼X (||JT
wx||2a)2, where xdenotes
random images from the training data, Jwis the Jacobian
matrix, and ais a constant that is set dynamically during
optimization.
Classifier consistency: By default, GAN training is
independent of the classifier fθ. We add a classifier-
consistency loss to regularize the generator and ensure that
the actual classification outcome for the generated image
ˆ
x, is similar to the condition cused for generation. We
enforce classification-consistency by a KullbackLeibler
(KL) divergence loss as follow[58, 59],
Lf(G) = DKL(fθ(ˆ
x)||c)(2)
Self consistency: We define the following reconstruction
loss to regularize and constraint the Generator to preserve
maximum information between the original image xand its
reconstruction ¯
x,
L(x,¯
x) = ||x¯
x||1+||e(x)e(¯
x)||1(3)
Here, first term is an L1 distance loss between the in-
put and the reconstructed image, and the second term is a
style reconstruction L1 loss adapted from StyleGANv2 [1].
We minimize this loss to satisfy the identify constraint on
self reconstruction using ¯
xself =G(x, fθ(x)). We further
insure that the PCE learns a reversible perturbation by re-
covering the original image from a given perturbed image
ˆ
xas ¯
xcyclic =G(ˆ
x, fθ(x)), where ˆ
x=G(x,c)with some
condition c. Our final reconstruction loss is defined as,
Lrec(G) = L(x,¯
xself) + L(x,¯
xcyclic)(4)
Objective function: Finally, we trained our model in an end-
to-end fashion to learn parameters for the two networks,
while keeping the classifier fθfixed. Our overall objective
function is
min
Gmax
Dλadv (Ladv(D, G) + Lreg(G))
+λfLf(G) + λrecLrec(G),
(5)
where, λs are the hyper-parameters to balance each of
the loss terms.
2.2. Augmentation by Counterfactual Explanation
Given a query image x, the trained PCE generates a se-
ries of perturbations of xthat gradually traverse the deci-
sion boundary of fθfrom the original class to a counter-
factual class, while still remaining plausible and realistic-
looking. We modify cto represent different steps in this
traversal. We start from a high data-likelihood region for
original class k(c[k][0.8,1.0]), walk towards the de-
cision hyper-plane (c[k][0.5,0.8)), and eventually cross
the decision boundary (c[k][0.2,0.5)) to end the traversal
in a high data-likelihood region for the counterfactual class
kc(c[k][0.0,0.2)). Accordingly, we set c[kc]=1c[k].
Ideally, the predicted confidence from NN should be in-
dicative of the distance from the decision boundary. Sam-
ples that lies close to the decision boundary should have
low confidence, and confidence should increase as we move
away from the decision boundary. We used cas a pseudo
indicator of confidence to generate synthetic augmentation.
Our augmentations are essentially showing how the query
image xshould be modified to have low/high confidence.
To generate counterfactual augmentations, we randomly
sample a subset of real training data as Xr⊂X. Next, for
each x∈ Xrwe generate multiple augmentations (ˆ
x=
G(x,c)) by randomly sampling c[k][0,1]. We used c
as soft label for the generate sample while fine-tuning the
fθ. The Xcrepresents our pool of generated augmentation
images. Finally, we create a new dataset by randomly sam-
pling images from Xand Xc. We fine-tune the fθon this
new dataset, for only a few epochs, to obtain a revised clas-
sifier given as fθ+∆. In our experiments, we show that the
revised decision function fˆ
θprovides improved confidence
estimates for AiD and near OOD samples and demonstrate
robustness to adversarial attacks, as compared to given clas-
sifier fθ.
2.3. Discriminator as a Selection Function
A selection function g:X → {0,1}is an addition
head on top of a classifier that decides when the classifier
should abstain from making a prediction. We propose to
use the discriminator network D(x)as a selection function
for fθ. Upon the convergence of the PCE training, the gen-
erated samples resemble the in-distribution training data.
Far-OOD samples are previously unseen samples which are
摘要:

cigureIF+omparisonoftheuncertaintyestimatesfromthebaselinedbefore(dottedlinewandafter(solidlinewne)tuningwithaugmentationbycounterfac)tualexplanation(v+QwFTheplotsvisualizethedistributionofpredictedentropy(columnsv)+wfromtheclassieranddensityscorefromthediscriminator(columnHwFThey)axisofthisdensit...

展开>> 收起<<
Augmentation by Counterfactual Explanation - Fixing an Overconfident Classifier Sumedha Singla.pdf

共17页,预览4页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:17 页 大小:5.91MB 格式:PDF 时间:2025-05-02

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 17
客服
关注