Deep Counterfactual Estimation with Categorical Background Variables Edward De Brouwer

2025-04-26 0 0 6.47MB 27 页 10玖币
侵权投诉
Deep Counterfactual Estimation with Categorical
Background Variables
Edward De Brouwer
ESAT-STADIUS
KU Leuven
edward.debrouwer@esat.kuleuven.be
Abstract
Referred to as the third rung of the causal inference ladder, counterfactual queries
typically ask the "What if ?" question retrospectively. The standard approach to
estimate counterfactuals resides in using a structural equation model that accurately
reflects the underlying data generating process. However, such models are sel-
dom available in practice and one usually wishes to infer them from observational
data alone. Unfortunately, the correct structural equation model is in general not
identifiable from the observed factual distribution. Nevertheless, in this work, we
show that under the assumption that the main latent contributors to the treatment
responses are categorical, the counterfactuals can be still reliably predicted. Build-
ing upon this assumption, we introduce CounterFactual Query Prediction (CFQP),
a novel method to infer counterfactuals from continuous observations when the
background variables are categorical. We show that our method significantly out-
performs previously available deep-learning-based counterfactual methods, both
theoretically and empirically on time series and image data. Our code is available
at https://github.com/edebrouwer/cfqp.
1 Introduction
Counterfactual queries aim at inferring the impact of a treatment conditioned on another observed
treatment outcome. Typically, given an individual, a treatment assignment, and a treatment outcome,
the counterfactual question asks what would have happened to that individual, had it been given
another treatment, everything else being equal. An illustrative and motivating example is the case of
clinical time series. Based on the observation of the outcome of treatment
A
on a particular patient,
counterfactual queries ask what would have been the outcome for this patient, had it been given
treatment
B
instead. Notably, counterfactual prediction differs from interventional prediction, which
is also referred to as counterfactual potential outcomes [
23
] and constitutes the second rung of the
causation ladder [
19
]. Counterfactual predictions are retrospective, as they condition on an observed
treatment outcome. In contrast, interventional predictions are prospective as they only condition on
observations obtained before treatment assignment.
Much more than a statistical curiosity, counterfactual reasoning reflects complex cognitive abilities
that are deeply ingrained in the human brain [
22
] and emerges in the early stages of cognitive
development [
7
]. The ability to reason counterfactually can indeed help to identify causes of
outcomes retrospectively, has been suggested to be central in the formation of rational intention,
and supports key theories of human cognition [
28
]. The importance of counterfactual reasoning in
the human cognitive process has thus motivated researchers to endow artificially intelligent systems
with the same ability [
21
]. However, counterfactual inference is not possible from observational and
interventional data alone [20].
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.05811v4 [cs.LG] 16 Jan 2023
Counterfactual reasoning, therefore, requires making several assumptions to overcome this limitation.
The most popular one assumes knowledge of the underlying structural equation model that describes
the data generating process [
21
] or a specific functional form thereof [
1
,
5
,
21
]. Unfortunately, this
assumption is rarely met in practice, especially in high-dimensional data such as time series or images.
This led to the development of deep structural equation models that attempt to model the structural
equations with neural networks [
17
,
24
]. However, despite their ability to model high-dimensional
data, these approaches fail to provide theoretical guarantees for the reconstruction of counterfactuals.
Indeed, they focus on modeling the factual distribution which, without further assumption, can,
unfortunately, lead to erroneous counterfactual distributions.
In this work, we bridge the gap between the classical structural equation model assumptions and
deep-learning-based architectures. By assuming that the treatment and observables are continuous
and that the hidden variables that contribute most to the treatment response are categorical, we can
rely on recent results in identifiability of mixture distributions [
3
] to show that we can approximately
recover the counterfactuals using arbitrary parametric functions (i.e. deep neural networks) to model
the causal dependence between variables. This allows us to infer counterfactuals on high-dimensional
data such as time series and images. Generally, this work explores the assumptions that can lead to
approximate counterfactual reconstructions while controlling the discrepancy between the recovered
and the true counterfactual distributions.
Besides the general appeal of endowing machine learning architectures with counterfactual reasoning
abilities, an important motivation for our work is the counterfactual estimation of treatment effects
in clinical patient trajectories. In this motivational example, one wishes to predict the individual
treatment effect retrospectively. Based on the observed treatment outcome of a particular patient,
we want to predict what would have been the outcome under a different treatment assignment. The
ability to perform counterfactual inference on patient trajectories has indeed been identified as a
potential tool for improving long and costly randomized clinical trials [15].
Contributions
We provide a new set of assumptions under which the counterfactuals are identifiable using
arbitrary neural networks architecture, bridging the gap between structural equation models
and deep learning architectures.
We derive a new counterfactual identifiability result that motivates a novel counterfactual
reconstruction architecture.
We evaluate our construction on three different datasets with different high-dimensional
modalities (images and time series) and demonstrate accurate counterfactual estimation.
2 Background
2.1 Problem Setup : Counterfactual Estimation
We consider the general causal model
M=hU, V, F i
depicted in Figure 1a consisting of background
variables
U
, endogenous variables
X, T
and
Y
and the set of structural functions
F
. Background
variables
U={UX, UT, U, W }
are hidden exogenous random variables that determine the values
of the observed variables
V={X, T, Y }
. Covariates
X∈ X
represent the information available
before treatment assignment,
T∈ T
is the treatment assignment and
Y∈ Y
is the observed response
to the treatment. We refer to the space of probability measures on
Y
as
P(Y)
. Observed variables
V
are generated following the structural equations
F={fX, fT, fY, f}
, such that
X=fX(Ux, W )
,
T=fT(UT, X)
,
U=f(W)
,
Y=fY(X, T, U)
. We further assume strong ignorability (i.e. no
hidden confounders between Tand Y).
Using notations introduced in Pearl et al.
[21]
, we define the potential response of a variable
Y
to an action
do(T=t)
for a particular realization of
U=u
as
Yt(u)
. Our goal is to predict the
counterfactual response, for a new treatment assignment (
T=t0
), conditioned on an observed
initial treatment response. That is, the probability of observing a different treatment response under
treatment
t0
, after observing treatment response
y
for covariate
x
and treatment
t
. The probability
density function of counterfactual y0then writes:
2
p(Yt0=y0|X=x, Y =y, T =t) = p(Yt0=y0, X =x, T =t, Y =y)
p(X=x, T =t, Y =y)
=ˆu
p(Yt0(u) = y0)p(U=u|X=x, T =t, Y =y),(1)
and we refer to the counterfactual probability measure as
νt0(x, y, t)
. Equation 1 suggests a natural
three step procedure for computing the probability of counterfactual. First, the abduction step
infers the density of
U
conditioned on the observed treatment outcomes, covariates and treatments:
p(U=u|X=x, T =t, Y =y)
. Second, in the action step, one sets the new treatment in the
causal model (
do(T=t0)
). Lastly, in the prediction step, one can propagate the values of
U=u
and
T=t0in the causal graph, using F, to compute p(Yt0(u) = y0).
In practice, we only have access to a set of
N
observations of variables
X
,
Y
and
T
. We refer
to this dataset as
D= (X,T,Y)
where
X={xi:i= 1, ..., N}
,
Y={yi:i= 1, ..., n}
and
T={ti:i= 1, ..., N}
. Importantly, we don’t have access to counterfactual examples (i.e. a tuple
(x, y, t, t0, y0)), such that direclty learning a map (x, y, t, t0)yt0is excluded.
2.2 General Non-identifiability of Counterfactuals
Because the background variables
U
are hidden, the above three-steps procedure requires knowledge
of the structural functions
F={fX, fT, fY}
. Indeed, one can show that there exist multiple
structural functions
F
that would lead to the same observed joint density
p(X, Y, T )
but would
lead to incorrect counterfactual probabilities [
19
,
20
]. The correct causal model is thus in general
non-identifiable, leading to non-identifiability of the counterfactual probability. We specify what is
meant by the identifiability of counterfactuals in the following definition.
Definition 1
(Identifiability of Counterfactuals)
.
Let
ρ
be a metric on
P(Y)
,
νt0(X, Y, T )
the true
counterfactual probability measure and
ˆνt0(X, Y, T )
the estimator of the counterfactual probability
measure with
N
data points. Counterfactuals are
ρ
-identifiable at threshold
δ
if, for all
t, t0∈ T , x
X, y ∈ Y,
lim
N→∞ ρ(νt0(x, y, y),ˆνt0(x, y, t)) δ
Y
X
T
U
W
UX
UT
(a) General Bayesian network for the treatment
counterfactual problem.
X
,
Y
and
T
are observed
while
UX, UT, W
and
U
are hidden background
variables.
Y
X
T
UZ
Uη
W
UX
UT
(b) Bayesian Network embodying the hidden cat-
egorical background variable assumption.
U
is
split in a background categorical variable
UZ
and
a continuous background variable Uη.
Figure 1: Graphical model representations of the causal model
M
. We assume strong ignorability
and continuous treatments T, observables Xand responses Y.
3
2.3 Causal Model Assumptions for Counterfactual Idenfiability
Despite the general non-identifiability of structural equation models laid out above, we propose
plausible assumptions that one can build upon to identify counterfactuals reliably. We first assume
X
and
Y
are continuous (potentially high dimensional) variables (such as images or time series).
The treatment assignment
T
is also assumed continuous, and
X × T
is a connected space. Our first
central assumption posits that the hidden variable
U
factorizes into a categorical and a continuous
variable.
Assumption 1
(Categorical Background Variable)
.
The background variable
U
decomposes into a
categorical latent variable
UZ[K] = {1, .., K}
and an independent exogenous continuous variable
Uη.
This assumption is depicted in the graphical model of Figure 1b and embodies the intuition of
different hidden groups that drive the treatment response. For instance, a treatment could have
different responses depending on the stage of the disease a patient finds themself in. The disease
stage is unobserved yet correlated with the observed covariates X(through W).
Due to the categorical nature of variable
UZ
, one can write the conditional density of
Y
as a mixture
model:
p(Y=y|X=x, T =t) = X
uZ∈{1,..,K}
P(UZ=uZ)·ˆp(Uη=uη)I[fY(x, t, uZ, uη) = y]duη
We define
γ
as the probability density function of the conditional treatment response generated by
fY
,
UZ
and
Uη
.
γ
is thus a mixture probability density function with mixture components
1γk∈ P(Y)
and weights ωk.
Y|X, T γ(X, T ) =
K
X
k=1
ωkγk(X, T )(2)
Without loss of generality, we assume that
Uη N (0,Σ2)
. In the case of additive noise, the con-
ditional distribution of Y becomes a mixture of Gaussians:
γ(X, T ) = PK
k=1 ωkN(µk(X, T ),Σ2
k)
,
where
µk
are functions mapping
X
and
T
to the mean of the mixture components and we consider
different variances for each k. We now proceed with the next assumptions.
Assumption 2
(Continuity)
.
The moments of the probability density functions
γk(x, t)
exist and are
continuous functions of Xand T:
µr
k(x, t) = EYγk(x,t)[Yr]C(x, t)rN, k [K](3)
Assumption 3
(Clusterability)
.
For each
(x, t)(X,T)
, the density
γ(x, t)
is clusterable and the
expected deviation of each
γk(x, t)
is bounded by a constant
δR
. That is,
k[K],x, t
(X × T ), with µk(x, t) = EYγk(x,t)[Y]:
EYγk(x,t)kYµk(x, t)k2δ(4)
In our motivating clinical example, Assumption 2 reflects that the probability of specific treatment
response changes continuously over the set of observed covariates and treatments. In particular, the
expected treatment outcome for a particular patient varies continuously with the treatment assignment,
which is a common assumption, e.g. in clinical practice [14].
1
The mixture components are defined such that for any subset
A⊂Y
, we have
´Aγk(X, T )(y)dy =
´
I[fY(X, T, UZ=k, Uη)∈ A]dP (Uη). The mixutre weights are defined as αk=P(UZ=k).
4
Assumption 3 posits clusterability of the mixture components
γk
for which a rigorous mathematical
definition is given in Appendix C. It is motivated by recent results on the identifiability of mixtures
models[
3
]. Intuitively, it supposes that patients with the same observed covariates and treatment
assignment but different hidden group will show different treatment outcomes. We also bound the
expected deviation of the mixture components
γk
that characterize the inter-group variability in the
treatment response for a particular patient and treatment outcome.
3 Methods
3.1 Identifiability and Counterfactuals Reconstruction
For a fixed point
(X=x, T =t)
, Equation 2 is a finite mixture model, for which identifiability results
are available [
3
]. Notably, these results guarantee identifiability up to a permutation of the latent class
assignment
σ(·) : [K][K]
. That is, there exists some permutation
σ(·) : [K][K]
such that
ˆγσk(k)(x,t)γk(x, t)
,
ˆωσk(k)ωk
, where
ˆγ
and
ˆω
are the estimated density functions and weights.
However, it does not entail identifiability of the counterfactuals in the sense of definition 1. Indeed,
the action step of the counterfactual strategy from Section 2.1 requires a consistent permutation
σ
across the whole domain
(X × T )
in order to reuse the inferred class assignments
ˆ
UZ
at a specific
point
(X=x, T =t)
to predict the counterfactual at another point
(X=x, T =t0)
— with a
different treatment assignment. Nevertheless, using the assumptions from the previous section, we
can still ensure the identifiability of the counterfactuals as the following result confirms:
Result 3.1
(Identifiabilty of Counterfactuals with Categorical Background Variables)
.
Let
X
,
T
and
Y
be continuous random variables generated according to the graphical model of Figure1b with
the domain of
X
and
T
being connected. Let
W1(·,·)
be the first Wasserstein distance on
P(Y)
,
νt0(X, Y, T )
the probability distribution of
Yt0|X, Y, T
and
ˆνN
t0(X, Y, T )
its estimator from
N
observed data points. If Assumptions 1, 2 and 3 hold, for each
(x, t)
, the counterfactual distribution
is W1-identifiable in expectation at threshold δ:
lim
N→∞
EYγ(x,t)W1(νt0(x, Y, t),ˆνN
t0(x, Y, t))δ
In the special case when the noise response is additive, we have
lim
N→∞ W1(νt0(X, Y, T ),ˆνN
t0(X, Y, T )) = 0
The proof is given in Appendix C. This result gives us a bound on the distance between the inferred
and true counterfactual distributions in the asymptotic regime. Importantly, it does not restrict the
dimension of
X
and
Y
, and is thus valid on challenging data modalities such as time series or images.
Continuity of distribution and complexity
The result above holds asymptotically in the number
of available samples. In the additive Gaussian case, the sample complexity for learning a
K
-mixture
model with
YRd
within
total variation distance is
˜
O(Kd2/)
[
4
]. Fortunately, the continuity
assumption (Assumption 2) saves us from having to learn an individual mixture at each point
(X=x, T =t)
, by jointly learning the continuous moments functions
µ0
r(X, T )
. A better sample
complexity bound can then be derived with further assumptions on µ0
r(X, T ).
3.2 CFQP : CounterFactual Query Prediction
Equipped with those theoretical results, we introduce CFQP, a counterfactual prediction model based
on a neural Expectation-Maximization mechanism. The basic building block of CFQP is a base-model
m(x, t)
, that predicts the treatment response
y
based on covariates and treatment assignment. For
each latent category
k
, we learn a base-model that approximates the individual treatment response
in that category :
mk(x, t)µk(x, t)
. Our theoretical results require the true number of classes of
K0
to be known in advance. Yet, this is rarely the case in practice, and we describe our architecture
for an arbitrary number of classes
K
. The learning of the base-models follows three steps: a joint
initialization, an expectation phase, and a maximization phase. The overall process is depicted in
Figure 2. We also present a pseudo-code description of the procedure in Algorithm 1.
5
摘要:

DeepCounterfactualEstimationwithCategoricalBackgroundVariablesEdwardDeBrouwerESAT-STADIUSKULeuvenedward.debrouwer@esat.kuleuven.beAbstractReferredtoasthethirdrungofthecausalinferenceladder,counterfactualqueriestypicallyaskthe"Whatif?"questionretrospectively.Thestandardapproachtoestimatecounterfactua...

展开>> 收起<<
Deep Counterfactual Estimation with Categorical Background Variables Edward De Brouwer.pdf

共27页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:27 页 大小:6.47MB 格式:PDF 时间:2025-04-26

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 27
客服
关注