Counterfactual reasoning, therefore, requires making several assumptions to overcome this limitation.
The most popular one assumes knowledge of the underlying structural equation model that describes
the data generating process [
21
] or a specific functional form thereof [
1
,
5
,
21
]. Unfortunately, this
assumption is rarely met in practice, especially in high-dimensional data such as time series or images.
This led to the development of deep structural equation models that attempt to model the structural
equations with neural networks [
17
,
24
]. However, despite their ability to model high-dimensional
data, these approaches fail to provide theoretical guarantees for the reconstruction of counterfactuals.
Indeed, they focus on modeling the factual distribution which, without further assumption, can,
unfortunately, lead to erroneous counterfactual distributions.
In this work, we bridge the gap between the classical structural equation model assumptions and
deep-learning-based architectures. By assuming that the treatment and observables are continuous
and that the hidden variables that contribute most to the treatment response are categorical, we can
rely on recent results in identifiability of mixture distributions [
3
] to show that we can approximately
recover the counterfactuals using arbitrary parametric functions (i.e. deep neural networks) to model
the causal dependence between variables. This allows us to infer counterfactuals on high-dimensional
data such as time series and images. Generally, this work explores the assumptions that can lead to
approximate counterfactual reconstructions while controlling the discrepancy between the recovered
and the true counterfactual distributions.
Besides the general appeal of endowing machine learning architectures with counterfactual reasoning
abilities, an important motivation for our work is the counterfactual estimation of treatment effects
in clinical patient trajectories. In this motivational example, one wishes to predict the individual
treatment effect retrospectively. Based on the observed treatment outcome of a particular patient,
we want to predict what would have been the outcome under a different treatment assignment. The
ability to perform counterfactual inference on patient trajectories has indeed been identified as a
potential tool for improving long and costly randomized clinical trials [15].
Contributions
•
We provide a new set of assumptions under which the counterfactuals are identifiable using
arbitrary neural networks architecture, bridging the gap between structural equation models
and deep learning architectures.
•
We derive a new counterfactual identifiability result that motivates a novel counterfactual
reconstruction architecture.
•
We evaluate our construction on three different datasets with different high-dimensional
modalities (images and time series) and demonstrate accurate counterfactual estimation.
2 Background
2.1 Problem Setup : Counterfactual Estimation
We consider the general causal model
M=hU, V, F i
depicted in Figure 1a consisting of background
variables
U
, endogenous variables
X, T
and
Y
and the set of structural functions
F
. Background
variables
U={UX, UT, U, W }
are hidden exogenous random variables that determine the values
of the observed variables
V={X, T, Y }
. Covariates
X∈ X
represent the information available
before treatment assignment,
T∈ T
is the treatment assignment and
Y∈ Y
is the observed response
to the treatment. We refer to the space of probability measures on
Y
as
P(Y)
. Observed variables
V
are generated following the structural equations
F={fX, fT, fY, f}
, such that
X=fX(Ux, W )
,
T=fT(UT, X)
,
U=f(W)
,
Y=fY(X, T, U)
. We further assume strong ignorability (i.e. no
hidden confounders between Tand Y).
Using notations introduced in Pearl et al.
[21]
, we define the potential response of a variable
Y
to an action
do(T=t)
for a particular realization of
U=u
as
Yt(u)
. Our goal is to predict the
counterfactual response, for a new treatment assignment (
T=t0
), conditioned on an observed
initial treatment response. That is, the probability of observing a different treatment response under
treatment
t0
, after observing treatment response
y
for covariate
x
and treatment
t
. The probability
density function of counterfactual y0then writes:
2