
We approach automated parallelism with a different philosophy. Our automated SPMD partitioner is
closely integrated with our PartIR compiler stack. PartIR [
30
] provides our partitioner with the ability
to validate partitioning decisions, expose a host of static analyses about the module, automatically
propagate sharding decisions across a module, and produce analytical cost model estimates. The latter
is obtained using a simulator of a lowered and communication-optimized partitioned module at the
compiler level. By contrast, ILP-based encodings [
37
] attempt to address scalability issues in large
graphs by e.g. heuristically grouping operators to prune the search space effectively – furthermore,
there may be communication optimizations post-ILP that cannot be captured when deciding the
parallel plan. FlexFlow [
13
] addresses scalability by making the simulator in their MCMC search
incremental and relies on pre-defined layers/block-operators (e.g. Multi-Head Attention). On the
other hand, our approach can partition any array operation of a low-level tensor IR (XLA) that
high-level NN libraries compile. As we show, our tight compiler integration allows our search to be
data and size efficient and run routinely as part of our user (typically researchers) iterative workflows.
In this work, we extend our automated partitioner, Automap [
27
], to support partitioning a model
across multi-dimensional device topology and discover expert-level composite SPMD partitioning
strategies. Our contribution is a design of a goal-oriented Monte Carlo tree search (MCTS) [
3
] that
decomposes the problem into smaller sub-problems. Furthermore, we incorporate a partitioning-
specific compiler analysis into the MCTS to reduce both the nodes and edges of the tree, improving the
search’s robustness. We show that our partitioner discovers composite expert-level SPMD strategies
for common models, such as Transformers and GraphNets. Moreover, it produces significantly better
than human-written strategies for models as UNet, for which no known strategy is available.
2 Background
2.1 Logical Meshes and Composite Strategies
Figure 1: The composite strategy of batch and model
parallelism over mesh
{batch:N, model:M}
. We also
show communication patterns that may emerge; on the
left, possible communication along the
model
axis (e.g.,
Megatron activation reductions), and on the right, com-
munication along the
batch
axis (e.g., gradient reduc-
tions). The color coding denotes the unique parameter
shards that each device holds; e.g., all devices along the
batch axis holds the same shard of parameters.
Leveraging composite partitioning techniques has enabled training of recent large models [
2
,
22
,
10
].
The main idea is to structure the available accelerator devices into an n-dimensional logical mesh
(which will typically, but not necessarily, correspond to the physical communication topology). For
instance, we may view 32 TPU devices as a 4x8 mesh or a system of 2 GPU servers with 16 GPUs
each as a 2x16 mesh. Once we are given such a logical mesh of available devices, e.g., with a 2D
mesh, a conventional strategy would be to do batch parallelism over one axis and parameter sharding
(model parallelism) over the other. Figure 1 graphically depicts this strategy. ZeRO-style sharding of
the optimizer [
23
] (on top of batch parallelism, and possibly also parameter sharding) is simply a
different stage that shards the optimizer parameters along the axis used for batch parallelism.
Conceptually, each stage of a composite strategy, like the above, optimizes for specific objectives. For
example, batch parallelism and parameter sharding typically improve runtime (while also reducing
memory requirements); whereas ZeRO-style optimizer sharding aims mainly towards improving
memory (but may improve runtime too, as the typically memory-bound vector operations of the
optimizer update are sharded, as was already observed in precursor work [
33
]). ZeRO “stage-2”
optimizer sharding may not increase communication cost since it replaces AllReduce operations with
pairs of parameter AllGather and gradient ReduceScatter. ZeRO “stage-3” aims to further improve
memory by introducing separate parameter AllGather for the forward and the backward pass of a
model; hence may slightly increase the runtime in favor of keeping smaller tensors live across forward
and backward computation. Note that, in our setting, the logical mesh will be given ahead of time by
the user; the partitioner’s task is to discover composite strategies like those described above based on
user-provided objectives.
2