torchode A Parallel ODE Solver for PyTorch Marten Lienen Stephan Günnemann Department of Informatics Munich Data Science Institute

2025-05-06 0 0 796.9KB 10 页 10玖币
侵权投诉
torchode: A Parallel ODE Solver for PyTorch
Marten Lienen & Stephan Günnemann
Department of Informatics & Munich Data Science Institute
Technical University of Munich, Germany
{m.lienen,s.guennemann}@tum.de
Abstract
We introduce an ODE solver for the PyTorch ecosystem that can solve multiple
ODEs in parallel independently from each other while achieving significant per-
formance gains. Our implementation tracks each ODE’s progress separately and
is carefully optimized for GPUs and compatibility with PyTorch’s JIT compiler.
Its design lets researchers easily augment any aspect of the solver and collect and
analyze internal solver statistics. In our experiments, our implementation is up to
4.3 times faster per step than other ODE solvers and it is robust against within-batch
interactions that lead other solvers to take up to 4 times as many steps.
1 Introduction
Ordinary differential equations (ODEs) are the natural framework to represent continuously evolving
systems. They have been applied to the continuous transformation of probability distributions (Chen
et al.,2018;Grathwohl et al.,2019), modeling irregularly-sampled time series (De Brouwer et al.,
2019;Rubanova et al.,2019), and graph data (Poli et al.,2019) and connected to numerical methods
for PDEs (Lienen & Günnemann,2022). Various extensions (Dupont et al.,2019;Xia et al.,2021;
Norcliffe et al.,2021) and regularization techniques (Pal et al.,2021;Ghosh et al.,2020;Finlay
et al.,2020) have been proposed and (Gholami et al.,2019;Massaroli et al.,2020;Ott et al.,2021)
have analyzed the choice of hyperparameters and model structure. Despite the large interest in these
methods, the performance of PyTorch (Paszke et al.,2019) ODE solvers has not been a focus point
and benchmarks indicate that solvers for PyTorch lag behind those in other ecosystems.1
torchode
aims to demonstrate that faster model training and inference with ODEs is possible
with PyTorch. Furthermore, parallel, independent solving of batched ODEs eliminates unintended
interactions between batched instances that can dramatically increase the number of solver steps and
introduce noise into model outputs and gradients.
2 Related Work
The most well-known ODE solver for PyTorch is
torchdiffeq
that popularized training with the
adjoint equation (Chen et al.,2018). Their implementation comes with many low- to medium-order
explicit solvers and has been the basis for a differentiable solver for controlled differential equations
(Kidger et al.,2020). Another option in the PyTorch ecosystem is
TorchDyn
, a collection of tools
for implicit models that includes an ODE solver but also utilities to plot and inspect the learned
dynamics (Poli et al.,2021).
torchode
goes beyond their ODE solving capabilities by solving
multiple independent problems in parallel with separate initial conditions, integration ranges and
solver states such as accept/reject decisions and step sizes, and a particular concern for performance
such as compatibility with PyTorch’s just-in-time (JIT) compiler.
Code is available at github.com/martenlienen/torchode.
1benchmarks.sciml.ai,github.com/patrick-kidger/diffrax/tree/main/benchmarks
DLDE Workshop in the 36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.12375v2 [cs.LG] 17 Jan 2023
Recently, Kidger has released with
diffrax
(2022) a collection of solvers for ODEs, but also con-
trolled, stochastic, and rough differential equations for the up-and-coming deep learning framework
JAX (Bradbury et al.,2018). They exploit the features of JAX to offer many of the same benefits
that
torchode
makes available to the PyTorch community and
diffrax
s internal design was an
important inspiration for the structure of our own implementation.
Outside of Python, the Julia community has an impressive suite of solvers for all kinds of differential
equations with
DifferentialEquations.jl
(Rackauckas & Nie,2017). After a first evaluation
of different types of sensitivity analysis in 2018 (Ma et al.), they released
DiffEqFlux.jl
which
combines their ODE solvers with a popular deep learning framework (Rackauckas et al.,2019).
3 Design & Features of torchode
We designed
torchode
to be correct, performant, extensible and introspectable. The former two
aspects are, of course, always desirable, while the latter two are especially important to researchers
who may want to extend the solver with, for example, learned stepping methods or record solution
statistics that the authors did not anticipate.
Table 1: Feature comparison with existing PyTorch ODE solvers.
torchode torchdiffeq TorchDyn
Parallel solving 3 7 7
JIT compilation 3 7 7
Extensible 3 7 3
Solver statistics 3 7 7
Step size controller PID I I
The major architectural differ-
ence between
torchode
and
existing ODE solvers for Py-
Torch is that we treat the batch
dimension in batch training
explicitly (Table 1). This
means that the solver holds
a separate state for each in-
stance in a batch, such as ini-
tial condition, integration bounds and step size, and is able to accept or reject their steps independently.
Batching instances together that need to be solved over different intervals, even of different lengths,
requires no special handling in
torchode
and even parameters such as tolerances could be specified
separately for each problem. Most importantly, our parallel integration avoids unintended interactions
between problems in a batch that we explore in Section 4.1.
Two other aspects of
torchode
s design that are of particular importance in research are extensibility
and introspectability. Every component can be re-configured or easily replaced with a custom
implementation. By default,
torchode
collects solver statistics such as the number of total and
accepted steps. This mechanism is extensible as well and lets a custom step size controller, for
example, return internal state to the user for further analysis without relying on global state.
The speed of model training and evaluation constrains computational resources as well as researcher
productivity. Therefore, performance is a critical concern for ODE solvers and
torchode
takes
various implementation measures to optimize throughput as detailed below and evaluated in Sec-
tion 4.2. Another way to save time is the choice of time step. It needs to be small enough to control
error accumulation but as large as possible to progress quickly.
torchode
includes a PID controller
that is based on analyzing the step size problem in terms of control theory (Söderlind,2002,2003).
These controllers generalize the integral (I) controllers used in
torchdiffeq
and
TorchDyn
and
are included in
DifferentialEquations.jl
and
diffrax
. In our evaluation in Appendix Cthese
controllers can save up to 5% of steps if the step size changes quickly.
What makes torchode fast?
ODE solving is inherently sequential except for efforts on parallel-
in-time solving (Gander,2015). Taking the evaluation time of the dynamics as fixed, performance
of an ODE-based model can therefore only be improved through a more efficient implementation
of the solver’s looping code, so as to minimize the time between consecutive dynamics evaluations.
In addition to the common FSAL and SSAL optimizations for Runge-Kutta methods to reuse
intermediate results,
torchode
avoids expensive operations such as conditionals evaluated on the
host that require a CPU-GPU synchronization as much as possible and seeks to minimize the number
of PyTorch kernels launched. We rely extensively on operations that combine multiple instructions
in one kernel such as
einsum
and
addcmul
, in-place operations, pre-allocated buffers, and fast
polynomial evaluation via Horner’s rule that saves half of the multiplications over the naive evaluation
2
摘要:

torchode:AParallelODESolverforPyTorchMartenLienen&StephanGünnemannDepartmentofInformatics&MunichDataScienceInstituteTechnicalUniversityofMunich,Germany{m.lienen,s.guennemann}@tum.deAbstractWeintroduceanODEsolverforthePyTorchecosystemthatcansolvemultipleODEsinparallelindependentlyfromeachotherwhileac...

展开>> 收起<<
torchode A Parallel ODE Solver for PyTorch Marten Lienen Stephan Günnemann Department of Informatics Munich Data Science Institute.pdf

共10页,预览2页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:10 页 大小:796.9KB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 10
客服
关注