Action Matching Learning Stochastic Dynamics from Samples Kirill Neklyudov1Rob Brekelmans1Daniel Severo1 2Alireza Makhzani1 2

2025-04-30 0 0 5.13MB 32 页 10玖币
侵权投诉
Action Matching:
Learning Stochastic Dynamics from Samples
Kirill Neklyudov 1Rob Brekelmans 1Daniel Severo 1 2 Alireza Makhzani 1 2
Abstract
Learning the continuous dynamics of a system
from snapshots of its temporal marginals is a prob-
lem which appears throughout natural sciences
and machine learning, including in quantum sys-
tems, single-cell biological data, and generative
modeling. In these settings, we assume access
to cross-sectional samples that are uncorrelated
over time, rather than full trajectories of samples.
In order to better understand the systems under
observation, we would like to learn a model of
the underlying process that allows us to propagate
samples in time and thereby simulate entire indi-
vidual trajectories. In this work, we propose Ac-
tion Matching, a method for learning a rich family
of dynamics using only independent samples from
its time evolution. We derive a tractable training
objective, which does not rely on explicit assump-
tions about the underlying dynamics and does
not require back-propagation through differential
equations or optimal transport solvers. Inspired
by connections with optimal transport, we derive
extensions of Action Matching to learn stochastic
differential equations and dynamics involving cre-
ation and destruction of probability mass. Finally,
we showcase applications of Action Matching by
achieving competitive performance in a diverse
set of experiments from biology, physics, and gen-
erative modeling.
1. Introduction
Understanding the time evolution of systems of particles
or individuals is a fundamental problem appearing across
machine learning and the natural sciences. In many scenar-
ios, it is expensive or even physically impossible to observe
entire individual trajectories. For example, in quantum me-
1
Vector Institute
2
University of Toronto. Correspondence to:
<k.necludov@gmail.com>, <makhzani@vectorinstitute.ai>.
Proceedings of the
40 th
International Conference on Machine
Learning, Honolulu, Hawaii, USA. PMLR 202, 2023. Copyright
2023 by the author(s).
chanics, the act of measurement at a given point collapses
the wave function (Griffiths & Schroeter, 2018), while in
biological applications, single-cell RNA- or ATAC- sequenc-
ing techniques destroy the cell in question (Macosko et al.,
2015; Klein et al., 2015; Buenrostro et al., 2015).
Instead, from ‘cross-sectional’ or independent samples at
various points in time, we would like to learn a model which
simulates particles such that their density matches that of
the observed samples. The problem of learning stochastic
dynamics from marginal samples is variously referred to as
learning population dynamics (Hashimoto et al., 2016) or
as trajectory inference (Lavenant et al., 2021), in contrast to
time series modeling where entire trajectories are assumed
to be available. Learning such models to predict entire
trajectories holds the promise of facilitating simulation of
complex chemical or physical systems (Vázquez, 2007; Noé
et al., 2020) and understanding developmental processes or
treatment effects in biology (Schiebinger et al., 2019; Tong
et al., 2020; Schiebinger, 2021; Bunne et al., 2021).
Furthermore, recent advances in generative modeling have
been built upon learning stochastic dynamics which interpo-
late between the data distribution and a prior distribution. In
particular, score-based diffusion models (Song et al., 2020b;
Ho et al., 2020) construct a stochastic differential equation
(SDE) to move samples from the data distribution to a prior
distribution, while score matching (Hyvärinen & Dayan,
2005) is used to learn a reverse SDE which models the gradi-
ents of intermediate distributions. However, these methods
rely on analytical forms of the SDEs and/or the tractability
of intermediate Gaussian distributions (Lipman et al., 2022).
Since our proposed method can learn dynamics which simu-
late an arbitrary path of marginal distributions, it can also
be applied in the context of generative modeling. Namely,
we can approach generative modeling by constructing an
interpolating path between the data and an arbitrary prior
distribution, and learning to model the resulting dynamics.
In this work, we propose Action Matching, a method for
learning population dynamics from samples of their tempo-
ral marginals qt. Our contributions are as follows:
In Theorem 2.1, we establish the existence of a
unique gradient field
s
t
which traces any given time-
1
arXiv:2210.06662v3 [cs.LG] 8 Jun 2023
Action Matching
Figure 1.
Action Matching, and its entropic (eAM) and unbalanced (uAM) variants, can learn to trace any arbitrary distributional path.
For a given path, AM learns deterministic trajectories, eAM learns stochastic trajectories, and uAM learns weighted trajectories.
continuous distributional path
qt
. Notably, our restric-
tion to gradient fields is without loss of expressivity
for this class of
qt
. To learn this gradient field, we pro-
pose the tractable Action Matching training objective
in Theorem 2.2.
In Sec. 3.1-3.3, we extend the above approach in sev-
eral ways: an ‘entropic’ version which can approxi-
mate ground-truth dynamics involving stochasticity,
an ‘unbalanced’ version which allows for creation and
destruction of probability mass, and a version which
can minimize an arbitrary convex cost function in the
Action Matching objective.
We discuss the close relationship between Action
Matching and dynamical optimal transport, along with
other related works in Sec. 5 and App. B.
Since Action Matching relies only on samples and does
not require tractable intermediate densities or knowl-
edge of the underlying stochastic dynamics, it is appli-
cable in a wide variety of problem settings. In particu-
lar, we demonstrate competitive performance of Action
Matching in a number of experiments, including trajec-
tory inference in biological data (Sec. 4.1), evolution
of quantum systems (Sec. 4.2), and a variety of tasks
in generative modeling (Sec. 4.3).
2. Action Matching
2.1. Continuity Equation
Suppose we have a set of particles in space
X Rd
, ini-
tially distributed as
qt=0
. Let each particle follow a time-
dependent ODE (continuous flow) with the velocity field
v: [0,1] × X Rdas follows
d
dt x(t) = vt(x(t)) , x(t= 0) = x . (1)
The continuity equation describes how the density of the
particles qtevolves in time t, i.e.,
t qt=−∇ · (qtvt),(2)
which holds in the distributional sense, where
∇·
denotes
the divergence operator. Under mild conditions, the fol-
lowing theorem shows that any continuous dynamics can
be modeled by the continuity equation, and moreover any
continuity equation results in a continuous dynamics.
Theorem 2.1 (Adapted from Theorem 8.3.1 of Ambrosio
et al. (2008)).Consider a continuous dynamic with the
density evolution of
qt
, which satisfies mild conditions (ab-
solute continuity in the 2-Wasserstein space of distributions
P2(X)
). Then, there exists a unique (up to a constant)
function
s
t(x)
, called the “action”,
1
such that vector field
v
t(x) = s
t(x)and qtsatisfies the continuity equation
t qt=−∇ · (qts
t(x)) .(3)
In other words, the ODE
d
dt x(t) = s
t(x)
can be used to
move samples in time such that the marginals are qt.
Using Theorem 2.1, the problem of learning the dynamics
can be boiled down to learning the unique vector field
s
t
,
only using samples from
qt
. Motivated by this, we restrict
our search space to the family of gradient vector fields
St={∇st|st:X R}.(4)
We use a neural network to parameterize the set of functions
st(x;θ)
, and will propose the Action Matching objective in
Sec. 2 to learn parameters
θ
such that
st(x;θ)
approximates
s
t(x)
. Once we have learned the vector field, we can move
samples forward or backward in time by simulating the ODE
in Eq. (1) with the velocity
st
. The continuity equation
ensures that for
s
t
, samples at any given time
t[0,1]
are distributed according to qt.
Note that, even though we arrived at the continuity equation
and ground truth vector field
s
t(x)
using ODEs, continu-
ity equation can describe a rich family of density evolutions,
1
The Hamilton-Jacobi formulation of classical mechanics de-
scribes the velocity of particles using a gradient field of the “ac-
tion”, which matches our usage throughout this work.
2
Action Matching
including diffusion equation (see Equation 37 of Song et al.
(2020b)), or even more general evolutions such as porous
medium equations (Otto, 2001) in fluid mechanics. Since
these processes also define an absolutely continuous curve
in the density space, Theorem 2.1 applies. Thus, for the
task of modeling the marginal evolution of
qt
, our restric-
tion to ODEs using gradient vector fields does not sacrifice
expressivity.
2.2. Action Matching Loss
The main development of this paper is the Action Matching
method, which allows us to recover the true action
s
t
while
having access only to samples from
qt
. With this action
in hand, we can simulate the continuous dynamics whose
evolution matches
qt
using the vector field
s
t
(see Fig. 1).
In order to do so, we define the variational action
st(x)
parameterized by a neural network, which approximates
s
t(x)by minimizing the “ACTION-GAP” objective
ACTION-GAP(s, s):=1
2Z1
0
Eqt(x)∥∇st(x)− ∇s
t(x)2dt . (5)
Note that this objective is intractable, as we do not have
access to
s
. However as the following proposition shows,
we can still derive a tractable objective for minimizing the
action gap.
Theorem 2.2. For an arbitrary variational action
s
, the
ACTION-GAP(s, s)
can be decomposed as the sum of an
intractable constant K, and a tractable term LAM(s)
ACTION-GAP(st, s
t) = K+LAM(st).(6)
where
LAM(s)
is the Action Matching objective, which we
minimize
LAM(s):=Eq0(x)s0(x)Eq1(x)s1(x)
+Z1
0
Eqt(x)1
2∥∇st(x)2+st
t (x)dt (7)
See App. A.1 for the proof. The term
LAM
is tractable, since
we can use the samples from marginals
qt
to obtain an unbi-
ased low variance estimate. We show in App. A.1 that the in-
tractable constant
K
is the kinetic energy of the distributional
path, defined as
K(s
t):=1
2R1
0Eqt(x)∥∇s
t(x)2dt
, and
thus minimizing
LAM(s)
can be viewed as maximizing a
variational lower bound on the kinetic energy.
Connection with Optimal Transport In App. B.1, we
show that the optimal dynamics of AM along the curve is
also optimal in the sense of optimal transport with the 2-
Wasserstein cost. More precisely, at any given time
t
, the
optimal vector field in the AM objective defines a mapping
between two infinitesimally close distributions
qt
and
qt+h
,
which is of the form
x7→ x+hs
t(x)
. This mapping
Algorithm 1 Action Matching 2
Require: data {xj
t}Nt
j=1, xj
tqt(x)
Require: parametric model st(x, θ)
for learning iterations do
get batch of samples from boundaries:
{xi
0}n
i=1 q0(x),{xi
1}n
i=1 q1(x)
sample times {ti}n
i=1 Uniform[0,1]
get batch of intermediate samples {xi
ti}n
i=1 qt(x)
LAM(θ) = 1
n
n
P
i=1 s0(xi
0, θ)s1(xi
1, θ)
+1
2
sti(xi
ti, θ)
2+sti(xi
ti)
t
update the model θOptimizer(θ, θLAM(θ))
end for
output trained model st(x, θ)
is indeed the same as the Brenier map (Brenier, 1987) in
optimal transport, which is of the form
x7→ x+φt(x)
,
where φtis the (c-convex) Kantorovich potential.
Finally, in App. A.1, we adapt reasoning from Albergo
& Vanden-Eijnden (2022) to show that the 2-Wasserstein
distance between the ground truth marginals and those sim-
ulated using our learned
st(x)
can be upper bounded in
terms of ACTION-GAP(st, s
t).
2.3. Learning, Sampling, and Likelihood Evaluation
Learning We provide pseudo-code for learning with the
Action Matching objective in Algorithm 1. With our learned
st(x, θ)
, we now describe how to simulate the dynam-
ics and evaluate likelihoods when the initial density
q0
is
known.
Sampling We sample from the target distribution via the
trained function
st(x(t), θ)
by solving the following ODE
forward in time:
d
dt x(t) = xst(x(t), θ), x(t= 0) q0(x).(8)
Recall that this sampling process is justified by Eq. (3),
where st(x(t), θ)approximates s
t(x(t)).
Evaluating the Log-Likelihood When the density for
q0
is available, we can evaluate the log-likelihood of a sample
xq1
using the continuous change of variables formula
2
Notebooks with pedagogical examples of AM are given at
github.com/necludov/jam#tutorials.
3
Action Matching
(Chen et al., 2018). Integrating the ODE backward in time,
log q1(x) = log q0(x(0)) Z1
0
dt s
t(x(t)),
d
dt x(t) = xs
t(x(t)), x(t= 1) = x,
(9)
where
d
dt log qt=s
t
can be confirmed using a simple
calculation and we approximate s
t(x(t)) by st(x(t), θ).
3. Extensions of Action Matching 3
In this section, we propose several extensions of Action
Matching, which can be used to learn dynamics which in-
clude stochasticity (Sec. 3.1), allow for teleportation of
probability mass (Sec. 3.2), and minimize alternative kinetic
energy costs (Sec. 3.3).
3.1. Entropic Action Matching
In this section, we propose entropic Action Matching (eAM),
which can recover the ground-truth dynamics arising from
diffusion processes with gradient field drift term and known
diffusion term. This setting takes place in biological appli-
cations studying the Brownian motion of cells in a medium
(Schiebinger et al., 2019; Tong et al., 2020). We will show
in Prop. 3.1 that, at optimality, entropic AM can also learn
to trace any absolutely continuous distributional path un-
der mild conditions, so that the choice between entropic
AM and deterministic AM should be made based on prior
knowledge of the true underlying dynamics.
Consider the stochastic differential equation
dx(t) = vt(x)dt +σtdWt, x(t= 0) = x . (10)
where
Wt
is the Wiener process. We know that the evolu-
tion of density of this diffusion process is described by the
Fokker–Planck equation:
t qt=−∇ · (qtvt) + σ2
t
2qt,(11)
In the following proposition, we extend Theorem 2.1 and
prove that any continuous distributional path, regardless of
ground truth generating dynamics, can be modeled with the
diffusion dynamics in the state-space.
Proposition 3.1. Consider a continuous dynamic with the
density evolution of
qt
, and suppose
σt
is given. Then, there
exists a unique (up to a constant) function
˜s
t(x)
, called the
“entropic action”, such that vector field
v
t(x) = ˜s
t(x)
and
qtsatisfies the Fokker-Planck equation
t qt=−∇ · (qt˜s
t) + σ2
t
2qt,(12)
3
This section can be skipped without loss of understanding of
our core contributions, although our experiments also evaluate the
entropic AM method from Sec. 3.1.
See App. A.2 for the proof. This proposition indicates that
the we can use the the SDE
dx(t) = ˜s
tdt +σtdWt
to
move samples in time such that the marginals are qt.
Entropic AM objective aims to recover the unique
˜s
t(x)
, as
described by the above proposition. In order to learn the
diffusion velocity vector, we define the variational action
st(x)
, parameterized by a neural network, that approximates
˜s
t(x), by minimizing the “E-ACTION-GAP” objective
E-ACTION-GAP(s, ˜s):=1
2Z1
0
Eqt(x)∥∇st(x)− ∇˜s
t(x)2dt .
Note that while the E-ACTION-GAP is similar to the original
ACTION-GAP objective, it minimizes the distance to
˜s
t
,
which is different than
s
t
. As in AM, this objective is
intractable since we do not have access to
˜s
t
. However,
we derive a tractable objective in the following proposition.
Proposition 3.2. For an arbitrary variational action
s
, the
E-ACTION-GAP(s, s)
can be decomposed as the sum of an
intractable constant
K
, and a tractable term
LeAM(s)
which
can be minimized:
E-ACTION-GAP(s, ˜s) = LeAM(s) + KeAM ,
where
LeAM(s)
is the entropic Action Matching objective,
which we minimize
LeAM(s):=Eq0(x)[s0(x)] Eq1(x)[s1(x)] (13)
+Z1
0
Eqt(x)1
2∥∇st(x)2+st
t (x) + σ2
t
2stdt
See App. A.2 for the proof. The constant
KeAM
is the en-
tropic kinetic energy, discussed in App. A.2.
Connection with Entropic Optimal Transport In
App. B.3, we describe connections between the eAM ob-
jective and dynamical formulations of entropy-regularized
optimal transport (Cuturi, 2013) or Schrödinger Bridge
(Léonard, 2014; Chen et al., 2016; 2021) problems.
3.2. Unbalanced Action Matching
In this section, we further extend the scope of underlying
dynamics which can be learned by Action Matching by
allowing for the creation and destruction of probability mass
via a growth rate
gt(x)
. This term is useful to account for
cell growth and death in trajectory inference for single-cell
biological dynamics (Schiebinger et al., 2019; Tong et al.,
2020; Baradat & Lavenant, 2021; Lübeck et al., 2022; Chizat
et al., 2022), and is well-studied in relation to unbalanced
optimal transport problems (Chizat et al., 2018a;b;c; Liero
et al., 2016; 2018; Kondratyev et al., 2016).
To introduce unbalanced Action Matching (uAM), consider
the following ODE, which attaches importance weights to
4
Action Matching
each sample and updates the weights according to a growth
rate gt(x)while transporting the samples in space,
d
dt x(t) = vt(x(t)) , x(t= 0) = x . (14)
d
dt log wt(x(t)) = gt(x(t)) , w(t= 0) = w . (15)
where
vt
is the vector field moving particles, similar to con-
tinuity equation,
wt(x)
is the importance weight of particles,
and
gt(x(t))
the growth rate of particles. These importance
weights can grow or shrink over time, allowing the particles
to be destroyed or create mass probability without needing
to transport the particles. The evolution of density govern-
ing the importance weighted ODE is given by the following
continuity equation:
t qt=−∇ · (qtvt) + qtgt.(16)
In the following proposition, we extend Theorem 2.1 to
show that any distributional path (under mild conditions),
regardless of how it was generated in the state-space, can be
modeled with the importance weighted ODE.
Proposition 3.3. Consider a continuous dynamic with den-
sity evolution
qt
satisfying mild conditions. Then, there
exists a unique function
ˆs
t(x)
, called the “unbalanced ac-
tion”, such that velocity field
v
t(x) = ˆs
t(x)
and growth
term
g
t(x) = ˆs
t(x)
satisfy the importance weighted conti-
nuity equation:
t qt=−∇ · (qtˆs
t) + qtˆs
t,(17)
See App. A.3 for the proof. This proposition indicates that
we can use the importance weighted ODE
d
dt x(t) = ˆs
t(x(t)) , x(t= 0) = x , (18)
d
dt log wt(x(t)) = ˆs
t(x(t)) , w(t= 0) = w , (19)
to move the particles and update their weights in time, such
that the marginals are qt.
Remarkably, the optimal velocity vector field
v
t=ˆs
t
and growth rate
g
t= ˆs
t
in Prop. 3.3 are linked to a sin-
gle action function
ˆs
t(x)
. Thus, for learning the varia-
tional action
st(x)
in unbalanced AM, we add a term to the
UNBALANCED-ACTION-GAP
” objective which encourages
stto match ˆs
t,
U-ACTION-GAP(s, ˆs):=1
2Z1
0
Eqt(x)∥∇st(x)− ∇ˆs
t(x)2dt
+1
2Z1
0
Eqt(x)st(x)ˆs
t(x)2dt .
As before,
U-ACTION-GAP(s, ˆs)
objective is intractable
since we do not have access to
ˆs
t
. However, as the following
proposition shows, we can still derive a tractable objective.
Proposition 3.4. For an arbitrary variational action
s
, the
U-ACTION-GAP(s, ˆs)
can be decomposed as the sum of in-
tractable constants
K
and
G
, and a tractable term
LuAM(s)
U-ACTION-GAP(s, ˆs) = KuAM +GuAM +LuAM(s)
where
LuAM(s)
is the unbalanced Action Matching objective,
which we minimize
LuAM(s):=Eq0(x)[s0(x)] Eq1(x)[s1(x)] (20)
+Z1
0
Eqt(x)1
2∥∇st(x)2+st
t (x) + 1
2s2
tdt .
See App. A.3 for the proof. The constants
KuAM
and
GuAM
are the unbalanced kinetic and growth energy, defined in
App. A.3 and B.4. We note that the entropic and unbalanced
extensions of Action Matching can also be combined, as is
common in biological applications (Schiebinger et al., 2019;
Chizat et al., 2022). To showcase how uAM can handle cre-
ation and destruction of mass, without transporting particles,
we provide a mixture of Gaussians example in App. E.3.
Connection with Unbalanced Optimal Transport In
App. B.4, we show that at any given time
t
, the optimal
dynamics of uAM along the curve is optimal in the sense
of the unbalanced optimal transport (Chizat et al., 2018a;
Liero et al., 2016; Kondratyev et al., 2016) between two
infinitesimally close distributions qtand qt+h.
3.3. Action Matching with Convex Costs
In App. B.2, we further extend AM to minimize kinetic
energies defined using an arbitrary strictly convex cost
c(vt)
(Villani (2009, Ch. 7)). For a given path qt, consider
KcAM := inf
vtZ1
0
Eqt(x)[c(vt)]dt s.t.
t qt=−∇ · (qtvt).
In this case, the unique vector field tracing the density evolu-
tion of
qt
becomes
v
t=c(¯s
t)
, where
c
is the convex
conjugate of the
c
. The corresponding action gap becomes
an integral of the Bregman divergence generated by c,
ACTION-GAPc(st,¯s
t):=
1
Z
0
Eqt(x)Dc[st:¯s
t]dt.
In practice, we can minimize ACTION-GAPcusing the fol-
lowing c-Action Matching loss:
LcAM(st):=Zs0(x0)q0(x0)dx0Zs1(x1)q1(x1)dx1
+
1
Z
0Zc(st(xt)) + st(xt)
t qt(xt)dxtdt .
For
c(·) = c(·) = 1
2∥ · ∥2
, we recover standard AM. Im-
portantly, the continuity equation for this formulation is
t qt=−∇ · (qtc(¯s
t)) .(21)
5
摘要:

ActionMatching:LearningStochasticDynamicsfromSamplesKirillNeklyudov1RobBrekelmans1DanielSevero12AlirezaMakhzani12AbstractLearningthecontinuousdynamicsofasystemfromsnapshotsofitstemporalmarginalsisaprob-lemwhichappearsthroughoutnaturalsciencesandmachinelearning,includinginquantumsys-tems,single-cellb...

展开>> 收起<<
Action Matching Learning Stochastic Dynamics from Samples Kirill Neklyudov1Rob Brekelmans1Daniel Severo1 2Alireza Makhzani1 2.pdf

共32页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:32 页 大小:5.13MB 格式:PDF 时间:2025-04-30

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 32
客服
关注