
2 Generalizable Embedding Table Placement Problem
The embedding table placement problem seeks a device placement
1
of all the tables such that the over-
all cost (in terms of execution time) is minimized (we provide a background for the distributed training
of recommendation models in Appendix A.1). Consider
M
embedding tables
{e1,e2, ..., eM}
and
D
devices, where
ei∈RN
denotes the table features that characterize the embedding lookup patterns.
In our work, we use 21 table features, including hash size, dimension, table size, pooling factor,
and distribution (their definitions are provided in Appendix A.2). A placement
a= [a1, a2, ..., aM]
,
where
ai∈ {1,2, ..., D}
, assigns each table to a device. Let
c(a)
denote the cost measured on
GPUs. The goal of embedding table placement is to find the
a
such that
c(a)
is minimized. Due
to the NP-hardness of the partition problem [
12
], identifying the exact solution demands extensive
computational overhead. Thus, the state-of-the-art algorithms often approximate the optimal parti-
tion via sampling with RL [
24
,
13
]. However, sampling remains expensive because obtaining
c(a)
requires running operations on GPUs. Given that the embedding tables and the available devices can
frequently change, we wish to approximate the best awithout GPU execution.
Motivated by this, we study the generalizable embedding table placement (GETP) problem. Let
E
be the space of all the embedding tables. A placement task can be denoted as
Ti= (Ei, Di)
,
where
Ei⊆ E
is a set of tables, and
Di
is the number of devices. Given
Ntrain
training tasks
Ttrain ={T1, T2, ..., TNtrain }
, and
Ntest
testing tasks
Ttest ={T1, T2, ..., TNtest }
, the goal is to train a
placement policy based on
Ttrain
(GPU execution is allowed during training) such that the learned
policy can minimize the costs for the tasks in Ttest without GPU execution.
3 DreamShard Framework
...
Final State
Estimated MDP
RL Agent
Policy
Network
Device 1
Device 2
St ate t=0 St ate t=1
Real Hardwar e
Sampled
Action
Estimated
St ate Estim ated
Rewar d
Placem ent
Cost
Data
Figure 2: DreamShard framework. The agent
interacts with the estimated MDP, which is
trained with the cost data collected from GPUs.
We present DreamShard, an RL framework based
on estimated MDP, to tackle the GETP problem.
An overview of the framework is shown in Figure 2.
The key idea is to formulate the table placement
process as an MDP (Section 3.1) and train a cost net-
work to estimate its states and rewards (Section 3.2).
A policy network with a tailored generalizable net-
work architecture is trained by efficiently interact-
ing with the estimated MDP (Section 3.3). The
two networks are updated iteratively to improve the
state/reward estimation and the placement policy.
3.1 MDP Formulation
Given embedding tables
{e1,e2, ..., eM}
and
D
devices, we aim to generate a placement
a=
[a1, a2, ..., aM]
. The key idea is to place the tables one by one at each step, where the state character-
izes the tables that have been placed so far, the action is the device ID, and the reward represents the
execution time on GPUs. Specifically, at a step
t
, the state
st={st,d}D
d=1
is all the table features
of the tables placed on all the devices, where
st,d ={ei|i∈ Pd}
denotes all the table features
corresponding to device
d
(
Pd
is the set of table IDs that have been placed on device
d
). We further
augment the raw features with cost features which are obtained by collecting the operation computa-
tion and communication times from GPUs (Appendix A.3 provides a comprehensive analysis of the
cost features). Formally, the augmented state is defined as
est={st,{qt,d}D
d=1}
, where
qt,d ∈R3
has
three elements representing forward computation time, backward computation time, and backward
communication time for the current operation in device
d
(we provide detailed explanations of why
forward communication time is excluded in Appendix A.4). We find that the augmented cost features
can significantly boost the performance, evidenced by the ablations in Table 3. The action
at∈ At
is an integer specifying the device ID, where
At
is the set of legal actions at step
t
. A device ID is
considered legal if placing the current table on the corresponding device does not cause a memory
explosion. The reward
rt
is
0
for all the intermediate steps, and the reward at the final step
M
is the
negative of the cost, i.e., rM=−c(a), which encourages the agent to achieve lower cost.
1
In this work, we focus on GPU devices, where all the GPU devices are identical, which is the most common
configuration in our production. We defer the mixed scenarios of both GPUs and CPUs to future work.
3