
Dynamic Survival Transformers for Causal Inference
with Electronic Health Records
Prayag Chatha
Department of Statistics
University of Michigan
pchatha@umich.edu
Yixin Wang
Department of Statistics
University of Michigan
Zhenke Wu
Department of Biostatistics
University of Michigan
Jeffrey Regier
Department of Statistics
University of Michigan
Abstract
In medicine, researchers often seek to infer the effects of a given treatment on
patients’ outcomes. However, the standard methods for causal survival analysis
make simplistic assumptions about the data-generating process and cannot capture
complex interactions among patient covariates. We introduce the Dynamic Survival
Transformer (DynST), a deep survival model that trains on electronic health records
(EHRs). Unlike previous transformers used in survival analysis, DynST can make
use of time-varying information to predict evolving survival probabilities. We
derive a semi-synthetic EHR dataset from MIMIC-III to show that DynST can
accurately estimate the causal effect of a treatment intervention on restricted mean
survival time (RMST). We demonstrate that DynST achieves better predictive and
causal estimation than two alternative models.
1 Introduction
Medical practitioners are often interested in the effect of a treatment on a patient’s survival time
until an event of interest. For instance, if a patient is prescribed a certain antibiotic, how will that
affect their risk of experiencing sepsis in the next 24 hours? The field of causal survival analysis is
concerned with estimating treatment effects on time-to-event outcomes given incomplete (censored)
data; classical techniques such as the Kaplan-Meier curves [1] and the Cox regression model [2]
are extensively used despite their limitations. Kaplan-Meier curves are a descriptive tool that do
not model individual survival trajectories, while the Cox model assumes proportionality of hazard
functions, which may be unrealistic. Meanwhile, the rise of electronic health records (EHRs) has led
to an abundance of multi-concept longitudinal data: a setting for observational causal inference, if
randomized controlled trials prove impractical or unethical.
With this observational setting in mind, we propose the Dynamic Survival Transformer (DynST), a
deep-learning survival model that estimates individual survival probabilities over time from EHR
data. DynST is built on the Transformer [3], a recent neural network architecture that has achieved
state-of-the-art results in sequence-to-sequence learning, particularly in NLP [4]. Transformers can
flexibly model individual survival trajectories without making simplifying parametric assumptions
about the data-generating process. Unlike previous survival transformers [5, 6, 7] DynST exploits
both static and time-varying features to capture how a patient’s event risk evolves over time. Several
works have applied transformers to prediction problems in EHR data [8, 9, 10, 11], motivated by
similarities between EHRs and text, but DynST is the first transformer used to estimate the average
effect of a treatment intervention on survival outcomes. Using a semi-synthetic dataset derived from
Accepted to the NeurIPS 2022 Workshop on Learning from Time Series for Health.
arXiv:2210.15417v1 [cs.LG] 25 Oct 2022