
Graphically Structured Diffusion Models
the learning problem is simply to fit the Gaussian’s mean.
Ho et al. (2020) parameterize this mean as an affine function
of
E[x0|xt,y]
and, by doing so, reduce the problem to
fitting an estimator of x0from xtand ywith the loss
L(θ) =
T
X
t=1
Eq(x0,xt,y)λ(t)· ∥ˆ
xθ(xt,y, t)−x0∥2
2.
(4)
Ho et al. (2020); Song et al. (2021b) show that there ex-
ists a weighting function
λ(t)
such that this loss is (the
negative of) a lower-bound on the marginal log-likelihood
log pθ(x0|y)
. We instead use uniform weights
λ(t)=1
which has been shown to give better results in practice (Ho
et al.,2020).
2.2. Transformer Architecture
Figure 2 outlines our neural architecture for
ˆ
xθ
, which is
run once for every diffusion time step
t
. Its inputs are
xt
,
y
,
and the diffusion timestep
t
. The (noisy) value of each latent
graphical model node is represented in
xt
and, similarly,
y
contains the value of each observed graphical model node.
We use linear projections to embed all of these values, con-
catenating the embeddings of all latent and observed nodes
to obtain, for an
n
-node graphical model and
d
-dimensional
embedding space, an
n×d
array of embedding “tokens”. We
add learned “node embeddings” to these tokens to identify
which graphical model node each corresponds to, and also
add learned observation embedding vectors for tokens cor-
responding to observed nodes. The resulting
n×d
array
is passed through a stack of self-attention (Vaswani et al.,
2017) and ResNet (He et al.,2016) blocks, as summarized
in Figure 2, with the ResNet blocks taking an embedding
of the diffusion timestep
t
as an additional input. All of the
timestep embedder, ResNet blocks, and self-attention mod-
ules are identical to those of Song et al. (2021a), except that
we replace convolutions with per-token linear projections.
The tokens corresponding to non-observed graphical model
nodes are then fed through a learned linear projection to
produce an output for each.
The self-attention layers are solely responsible for con-
trolling interactions between embeddings, and therefore
correlations between variables in the modelled distribution
pθ(x0|y)
. Inside the self-attention, each embedding is pro-
jected into a query vector, a key vector, and a value vector,
all in
Rd
. Stacking these values for all embeddings yields
the matrices
Q,K,V∈Rn×d
(given a
n
-node graphical
model). The output of the self-attention is calculated as
eout =ein +W V =ein + softmax QKTV(5)
where the addition of the self-attention layer’s input
ein ∈
Rn×d
corresponds to a residual connection. We note that
QKT
yields a pairwise interaction matrix which lets us im-
pose an additional attention mask
M
before calculating the
output
eout =ein + softmax M⊙QKTV
. This mask-
ing interface is central to structure the flow of information
between graphical model nodes in Section 3.1 .
2.3. Graphical Models
GSDM leverages problem structure described in the
form of a graphical model. There is considerable flex-
ibility in the specification of this graphical structure
and we allow for both directed and undirected graph-
ical models. A directed graphical model describes a
joint distribution over
x= [x1, . . . , xn]
with the density
p(x) = Qn
i=1 pi(xi|parents(xi))
. This may be natural to
use if the problem of interest can be described by a causal
model. This is the case in the BCMF example in Figure 1,
where the forward model is a matrix multiplication and
we can use the matrix multiplication’s compute graph as a
graphical model. If the data is simulated and source code is
available then we can automatically extract the simulator’s
compute graph as a graphical model (Appendix L).
Alternatively, an undirected graphical model uses the
density
p(x)∝Qm
j=1 fj(vertices(j))
where
vertices(j)
are the vertices connected to factor
j
and
fj
maps their
values to a scalar. This is a natural formulation if the
problem is defined by constraints on groups of nodes,
e.g. for Sudoku with row, column and block constraints
(Appendix C). Finally, the graphical model can combine
directed and undirected components, using a density
p(x)∝Qn
i=1 p(xi|parents(xi)) Qm
j=1 fj(vertices(j))
.
We use this in our graphical model for sorting (Appendix C),
which combines a causal forward model with constraints.
We emphasise that GSDM does not need the link functions
(i.e. the form of each
pi
and
fj
) to be specified as long
as data is available, which is desirable as they are often
intractable or Dirac in practice. Also, while the selection
of a graphical model for data can be subjective, we find in
Section 4.3 that GSDM is not sensitive to small changes in
the specification of the graphical model and that there can
be multiple modeling perspectives yielding similar GSDM
performance. In general, we use the most intuitive graphical
model that we can come up with for each problem whether
it is directed, undirected, or a combination.
2.4. Permutation Invariance
Large probabilistic models often contain permutation invari-
ance, in the sense that the joint probability density
q(x0)
is
invariant to permutations of certain indices (Bloem-Reddy
& Teh,2020). For example the matrix multiplication in Fig-
ure 1 is invariant with respect to permutations of any of the
plate indices.
1
If the joint probability density is invariant to a
particular permutation, this can be enforced in a distribution
1
In general, plate notation implies permutation invariance as
long as no link functions depend on the plate indices themselves.
3