
Recently, Kidger has released with
diffrax
(2022) a collection of solvers for ODEs, but also con-
trolled, stochastic, and rough differential equations for the up-and-coming deep learning framework
JAX (Bradbury et al.,2018). They exploit the features of JAX to offer many of the same benefits
that
torchode
makes available to the PyTorch community and
diffrax
’s internal design was an
important inspiration for the structure of our own implementation.
Outside of Python, the Julia community has an impressive suite of solvers for all kinds of differential
equations with
DifferentialEquations.jl
(Rackauckas & Nie,2017). After a first evaluation
of different types of sensitivity analysis in 2018 (Ma et al.), they released
DiffEqFlux.jl
which
combines their ODE solvers with a popular deep learning framework (Rackauckas et al.,2019).
3 Design & Features of torchode
We designed
torchode
to be correct, performant, extensible and introspectable. The former two
aspects are, of course, always desirable, while the latter two are especially important to researchers
who may want to extend the solver with, for example, learned stepping methods or record solution
statistics that the authors did not anticipate.
Table 1: Feature comparison with existing PyTorch ODE solvers.
torchode torchdiffeq TorchDyn
Parallel solving 3 7 7
JIT compilation 3 7 7
Extensible 3 7 3
Solver statistics 3 7 7
Step size controller PID I I
The major architectural differ-
ence between
torchode
and
existing ODE solvers for Py-
Torch is that we treat the batch
dimension in batch training
explicitly (Table 1). This
means that the solver holds
a separate state for each in-
stance in a batch, such as ini-
tial condition, integration bounds and step size, and is able to accept or reject their steps independently.
Batching instances together that need to be solved over different intervals, even of different lengths,
requires no special handling in
torchode
and even parameters such as tolerances could be specified
separately for each problem. Most importantly, our parallel integration avoids unintended interactions
between problems in a batch that we explore in Section 4.1.
Two other aspects of
torchode
’s design that are of particular importance in research are extensibility
and introspectability. Every component can be re-configured or easily replaced with a custom
implementation. By default,
torchode
collects solver statistics such as the number of total and
accepted steps. This mechanism is extensible as well and lets a custom step size controller, for
example, return internal state to the user for further analysis without relying on global state.
The speed of model training and evaluation constrains computational resources as well as researcher
productivity. Therefore, performance is a critical concern for ODE solvers and
torchode
takes
various implementation measures to optimize throughput as detailed below and evaluated in Sec-
tion 4.2. Another way to save time is the choice of time step. It needs to be small enough to control
error accumulation but as large as possible to progress quickly.
torchode
includes a PID controller
that is based on analyzing the step size problem in terms of control theory (Söderlind,2002,2003).
These controllers generalize the integral (I) controllers used in
torchdiffeq
and
TorchDyn
and
are included in
DifferentialEquations.jl
and
diffrax
. In our evaluation in Appendix Cthese
controllers can save up to 5% of steps if the step size changes quickly.
What makes torchode fast?
ODE solving is inherently sequential except for efforts on parallel-
in-time solving (Gander,2015). Taking the evaluation time of the dynamics as fixed, performance
of an ODE-based model can therefore only be improved through a more efficient implementation
of the solver’s looping code, so as to minimize the time between consecutive dynamics evaluations.
In addition to the common FSAL and SSAL optimizations for Runge-Kutta methods to reuse
intermediate results,
torchode
avoids expensive operations such as conditionals evaluated on the
host that require a CPU-GPU synchronization as much as possible and seeks to minimize the number
of PyTorch kernels launched. We rely extensively on operations that combine multiple instructions
in one kernel such as
einsum
and
addcmul
, in-place operations, pre-allocated buffers, and fast
polynomial evaluation via Horner’s rule that saves half of the multiplications over the naive evaluation
2