arXiv:2210.02374v1 [cs.PL] 5 Oct 2022
Axon: A Language for Dynamic Shapes
in Deep Learning Graphs
Alexander Collins
NVIDIA
acollins@nvidia.com
Vinod Grover
NVIDIA
vgrover@nvidia.com
Abstract
Axon is a language that enables shape and rank inference for
tensors in a Deep Learning graphs. It aims to make shapes
implicit and inferred, in a similar manner to how types are
implicit and inferred in many functional programming lan-
guages. Tensor dimensions are represented by expressions
consisting of symbolic variables, constants, and arithmetic
operators. Tensor shapes can be expressed as either a se-
quence of these dimension expressions, as a symbolic vari-
able, or as an appending of other shapes. This allows com-
plex constraints on shapes to be expressed.
Axon is functional in style, with a type system similar in
to Standard ML, extended to include shape information. It
provides a suite of built in operators over tensors, including
pointwise arithmetic operators, maps, reduction, loops and
user defined functions.
We describe a shape inference algorithm based on con-
straint solving which infers information about shapes, from
both shape information provided by the programmer and
the structure of the program. This allows fully automatic in-
ference of the shapes of tensors for complex Deep Learning
graphs.
This approach reduces programmer effort when specify-
ing graphs, as tensor shapes are not explicit, allows compo-
sition of Deep Learning graphs while maintaining input and
output tensor shape compatibility, and aids in automated er-
ror detection by identifying shape mismatches at runtime.
1 Introduction
Deep Learning models can be viewed as constrained func-
tional programs on tensor domains, which only permit side
effects or updates for certain types of models, and usually
only during training. Tensors have a type and a shape: they
are rectangular domains of simple element types. Requiring
the programmer to deal with these shapes can add signifi-
cant complexity to a language.
This paper describes shape and rank inference for tensors
in a language we call Axon, which aims to make shapes im-
plicit and inferred, in a similar manner to how types are
implicit and inferred in many functional programming lan-
guages. Axon allows the individual dimensions of a tensor
to be expressions consisting of symbolic variables, constants,
and arithmetic operators, allowing complex constraints on
shapes to be expressed. Furthermore, shapes can be expressed
as either a sequence of these dimension expressions, as a
symbolic variable, or as an appending the dimensions of sev-
eral other shapes. This allows rank inference to be expressed
using this shape appending.
An inference algorithm is also presented which infers in-
formation about shapes, from both shape information pro-
vided by the programmer and the structure of the program,
via a set of rules for built-in operators. Our system allows
the shapes involved in complex Deep Learning problems to
be automatically inferred without need for the programmer
to express the shape constraints or give concrete shapes
which are potentially unknown until the shape of the in-
puts to the program are known. Shape information can be
inferred from known shapes, or partial shapes, of program
inputs and from the structure of the program itself. Auto-
mated error detection also aids programmers by identifying
shape mismatches between the allowed shapes to an opera-
tion and inferred shapes.
The rest of the paper is structured as follows. Section 2
outlines the Axon language. Section 3 presents the syntax of
shape expressions, section 4 describes how standard Hindley-
Milner type inference [4, 7] is used to generate a set of shape
constraints for the program and section 5 describes the algo-
rithm used to solve sets of these shape constraints. Section 6
provides examples of shape inference in action and section 7
provides larger examples for commonly seen Deep Learn-
ing graphs. Section 8 discusses related work and section 9
presents our conclusions.
This paper makes the following contributions:
•A taxonomy of the kinds of shapes encountered in
Deep Learning graphs, which we use to guide the de-
sign of the language.
•A functional programming language, which we call
Axon. It allows a programmer to specify symbolic shapes
for input and output tensors for a graph, which permit
arithmetic expressions for the dimensions, and allows
for rank inference by expressing shapes as the compo-
sition of sub-shapes.
•A constraint based solver for inferring shape informa-
tion throughout a program, based on a set of shape
inference rules. Initial constraints are generated based
on the structure of a graph, and the rules are used to
reduce these constraints to a fixed point, where either
all shapes are known rank with constant dimensions,
or are partially unknown allowing for runtime vari-
able shape.
1