Graphically Structured Diffusion Models

2025-05-06 0 0 1.07MB 23 页 10玖币
侵权投诉
Graphically Structured Diffusion Models
Christian Weilbach 1William Harvey 1Frank Wood 1
Abstract
We introduce a framework for automatically de-
fining and learning deep generative models with
problem-specific structure. We tackle problem
domains that are more traditionally solved by
algorithms such as sorting, constraint satisfac-
tion for Sudoku, and matrix factorization. Con-
cretely, we train diffusion models with an archi-
tecture tailored to the problem specification. This
problem specification should contain a graphical
model describing relationships between variables,
and often benefits from explicit representation of
subcomputations. Permutation invariances can
also be exploited. Across a diverse set of ex-
periments we improve the scaling relationship
between problem dimension and our model’s per-
formance, in terms of both training time and fi-
nal accuracy. Our code can be found at
https:
//github.com/plai-group/gsdm.
1. Introduction
A future in which algorithm development is fully trans-
formed from a challenging and labour intensive task (Cor-
men,2009;Marsland,2009;Russell & Norvig,2010;Willi-
amson & Shmoys,2011) into a fully automatable process
is seemingly close at hand. With prompt engineering large
language models like GPT (Brown et al.,2020) and now
ChatGPT have been shown to be capable of code completion
and even full algorithm development from natural language
task descriptions (Chen et al.,2021;Ouyang et al.,2022).
At the same time, significant advances in generative deep
learning (Ho et al.,2020;Song et al.,2021c;Ho et al.,2022),
AutoML (Hutter et al.,2018), and few-shot learning (Brown
et al.,2020) have made it possible to learn, from data, flex-
ible input-output mappings that generalize from ever smaller
amounts of data. This approach has spawned modern aphor-
1
Department of Computer Science, University of British
Columbia, Vancouver, Canada. Correspondence to: Christian
Weilbach <weilbach@cs.ubc.ca>.
Proceedings of the
40 th
International Conference on Machine
Learning, Honolulu, Hawaii, USA. PMLR 202, 2023. Copyright
2023 by the author(s).
isms from Karpathy (2017) like “ Gradient descent can write
better software than you. Sorry!”, appropriate attempts, in
our opinion, to re-brand deep learning as differentiable pro-
gramming (Baydin et al.,2017), and arguably even a new
industry called “Software 2.0” in which one “programs by
example” (Karpathy,2017).
There, however, remains a chasm between these two ap-
proaches, roughly delineated along the symbolic vs. connec-
tionist divide. Symbolically expressed algorithms can and
often do generalize perfectly across all inputs and exhibit
runtimes that are typically input “size” dependent. Software
2.0 algorithms struggle to generalize outside of their training
data thus are most often deployed in settings where copious
training data is available, the so-called “big-data” regime.
Most such “neural-network algorithms” have runtimes that
are not size dependent and resultingly cannot generalize in
the same fashion as symbolically expressed algorithms.
Efforts to bring these two approaches closer together
(Chaudhuri et al.,2021) often get lumped together under
the moniker “neuro-symbolic” methods. The general shape
of these methods, so to speak, is to impose some aspect of
symbolic reasoning on either the structure or computation
performed by a connectionist architecture. Our work can be
seen as a significantly novel methodological contribution to
this body of work.
We contribute a generic specification of methodology for
advantageously imposing task specific symbolic structure
into diffusion models and use it to demonstrate algorithm
learning from data in several small-scale but foundational
tasks across the algorithmic complexity spectrum. Specific-
ally, our approach consumes a graphical model “sketch”
that putatively could describe the joint data generative pro-
cess. This sketch consists only of nodes for variables, edges
between them, and optionally permutation invariances. We
combine this information with an otherwise generic dif-
fusion process (Ho et al.,2020), using the edges to ad-
vantageously constrain the transformer attention mechan-
isms (Vaswani et al.,2017) and permutation invariances to
determine when parameters within our architecture can be
shared. Compared to our neural baselines we improve the
scaling of computational cost with problem dimension in
most cases, and the scaling of problem performance with
dimension in all cases.
1
arXiv:2210.11633v3 [cs.LG] 16 Jun 2023
Graphically Structured Diffusion Models
Figure 1: An application of our framework to binary-continuous matrix factorization. In the first panel the computational
graph of the multiplication of the continuous matrix
AR3×2
and the binary matrix
RR2×3
is expanded as a
probabilistic graphical model in which intermediate products
C
are summed to give
E=AR
. This graph is used to
create a structured attention mask
M
, in which we highlight 1’s with the color of the corresponding graphical model (or
white for self-edges). In the third panel the projection into the sparsely-structured neural network guiding the diffusion
process is illustrated. The bottom shows the translation of permutation invariances of the probability distribution into shared
embeddings, as detailed in Section 2.4.
As a running example to keep in mind throughout the paper,
consider being given the task of developing a novel matrix
factorization algorithm, one which takes a real non-negative
valued matrix as input and outputs a distribution over two
matrix factors, one constrained to be binary valued, the
other constrained to be non-negative. The most traditional
approach is to painstakingly hand-develop through intel-
lectual willpower some new algorithm like Gram-Schmidt
which may not exist and might take an entire career to de-
velop. A more modern approach, to which we compare, is
to symbolically specify a model describing a joint data gen-
erating process and employ a generic inference algorithm
like MCMC (Wingate et al.,2011). Such a model is usu-
ally much easier to specify but the resulting “inversion al-
gorithm,” running a generic inference algorithm at test time,
trades sure generalization with worst-case infinite runtime.
Alternatively one could generate a large training dataset
from such a generative description, then hand-architect and
train a deep neural network to learn the desired inversion
algorithm, software 2.0 style (Le et al.,2017b). This is slow
to develop and train, usually requiring architectural innov-
ation, but constant-time fast at test time, albeit with likely
poor algorithm-style generalization. Our approach, most
like that of (Weilbach et al.,2020), strikes a middle ground.
We adopt the software 2.0 approach but provide a generic re-
cipe for specializing a generic and powerful diffusion-based
network architecture that trains quickly, generalizes reliably,
and whose runtime scales with problem size.
2. Background
2.1. Conditional Diffusion Models
Defining
x0
to be data sampled from a data distribution
q(x0)
, a diffusion process constructs a chain
x0:T
with noise
added at each stage by the transition distribution
q(xt|xt1) = N(xt;p1βtxt1, βtI)(1)
leading to the joint distribution
q(x0:T) = q(x0)
T
Y
t=1
q(xt|xt1).(2)
The sequence
β1:T
controls the amount of noise added at
each step and, along with
T
itself, is chosen to be large
enough that the marginal
q(xT|x0)
resulting from Equa-
tion (1) is approximately a unit Gaussian for any x0.
This diffusion process inspires a diffusion model (Sohl-
Dickstein et al.,2015;Ho et al.,2020;Song et al.,2021c),
or DM, which approximately “inverts” it using a neural
network that outputs
pθ(xt1|xt)q(xt1|xt)
. We can
sample from a diffusion model by first sampling
xT
p(xT) = N(0,I)
and then sampling
xt1pθ(·|xt)
for
each
t=T, . . . , 1
, eventually sampling
x0
. In the condi-
tional DM variant (Tashiro et al.,2021) the neural network
is additionally conditioned on information
y
so that the
modelled distribution is
pθ(x0:T|y) = p(xT)
T
Y
i=1
pθ(xt1|xt,y).(3)
The transitions
pθ(xt1|xt,y)
are typically approximated
by a Gaussian with non-learned diagonal covariance, and so
2
Graphically Structured Diffusion Models
the learning problem is simply to fit the Gaussian’s mean.
Ho et al. (2020) parameterize this mean as an affine function
of
E[x0|xt,y]
and, by doing so, reduce the problem to
fitting an estimator of x0from xtand ywith the loss
L(θ) =
T
X
t=1
Eq(x0,xt,y)λ(t)· ∥ˆ
xθ(xt,y, t)x02
2.
(4)
Ho et al. (2020); Song et al. (2021b) show that there ex-
ists a weighting function
λ(t)
such that this loss is (the
negative of) a lower-bound on the marginal log-likelihood
log pθ(x0|y)
. We instead use uniform weights
λ(t)=1
which has been shown to give better results in practice (Ho
et al.,2020).
2.2. Transformer Architecture
Figure 2 outlines our neural architecture for
ˆ
xθ
, which is
run once for every diffusion time step
t
. Its inputs are
xt
,
y
,
and the diffusion timestep
t
. The (noisy) value of each latent
graphical model node is represented in
xt
and, similarly,
y
contains the value of each observed graphical model node.
We use linear projections to embed all of these values, con-
catenating the embeddings of all latent and observed nodes
to obtain, for an
n
-node graphical model and
d
-dimensional
embedding space, an
n×d
array of embedding “tokens”. We
add learned “node embeddings” to these tokens to identify
which graphical model node each corresponds to, and also
add learned observation embedding vectors for tokens cor-
responding to observed nodes. The resulting
n×d
array
is passed through a stack of self-attention (Vaswani et al.,
2017) and ResNet (He et al.,2016) blocks, as summarized
in Figure 2, with the ResNet blocks taking an embedding
of the diffusion timestep
t
as an additional input. All of the
timestep embedder, ResNet blocks, and self-attention mod-
ules are identical to those of Song et al. (2021a), except that
we replace convolutions with per-token linear projections.
The tokens corresponding to non-observed graphical model
nodes are then fed through a learned linear projection to
produce an output for each.
The self-attention layers are solely responsible for con-
trolling interactions between embeddings, and therefore
correlations between variables in the modelled distribution
pθ(x0|y)
. Inside the self-attention, each embedding is pro-
jected into a query vector, a key vector, and a value vector,
all in
Rd
. Stacking these values for all embeddings yields
the matrices
Q,K,VRn×d
(given a
n
-node graphical
model). The output of the self-attention is calculated as
eout =ein +W V =ein + softmax QKTV(5)
where the addition of the self-attention layer’s input
ein
Rn×d
corresponds to a residual connection. We note that
QKT
yields a pairwise interaction matrix which lets us im-
pose an additional attention mask
M
before calculating the
output
eout =ein + softmax MQKTV
. This mask-
ing interface is central to structure the flow of information
between graphical model nodes in Section 3.1 .
2.3. Graphical Models
GSDM leverages problem structure described in the
form of a graphical model. There is considerable flex-
ibility in the specification of this graphical structure
and we allow for both directed and undirected graph-
ical models. A directed graphical model describes a
joint distribution over
x= [x1, . . . , xn]
with the density
p(x) = Qn
i=1 pi(xi|parents(xi))
. This may be natural to
use if the problem of interest can be described by a causal
model. This is the case in the BCMF example in Figure 1,
where the forward model is a matrix multiplication and
we can use the matrix multiplication’s compute graph as a
graphical model. If the data is simulated and source code is
available then we can automatically extract the simulator’s
compute graph as a graphical model (Appendix L).
Alternatively, an undirected graphical model uses the
density
p(x)Qm
j=1 fj(vertices(j))
where
vertices(j)
are the vertices connected to factor
j
and
fj
maps their
values to a scalar. This is a natural formulation if the
problem is defined by constraints on groups of nodes,
e.g. for Sudoku with row, column and block constraints
(Appendix C). Finally, the graphical model can combine
directed and undirected components, using a density
p(x)Qn
i=1 p(xi|parents(xi)) Qm
j=1 fj(vertices(j))
.
We use this in our graphical model for sorting (Appendix C),
which combines a causal forward model with constraints.
We emphasise that GSDM does not need the link functions
(i.e. the form of each
pi
and
fj
) to be specified as long
as data is available, which is desirable as they are often
intractable or Dirac in practice. Also, while the selection
of a graphical model for data can be subjective, we find in
Section 4.3 that GSDM is not sensitive to small changes in
the specification of the graphical model and that there can
be multiple modeling perspectives yielding similar GSDM
performance. In general, we use the most intuitive graphical
model that we can come up with for each problem whether
it is directed, undirected, or a combination.
2.4. Permutation Invariance
Large probabilistic models often contain permutation invari-
ance, in the sense that the joint probability density
q(x0)
is
invariant to permutations of certain indices (Bloem-Reddy
& Teh,2020). For example the matrix multiplication in Fig-
ure 1 is invariant with respect to permutations of any of the
plate indices.
1
If the joint probability density is invariant to a
particular permutation, this can be enforced in a distribution
1
In general, plate notation implies permutation invariance as
long as no link functions depend on the plate indices themselves.
3
Graphically Structured Diffusion Models
Self-attention
ResNet
node emb 1
node emb 2
observed
intermediate
variable
Attn
mask
ResNet ResNetResNet
node emb 3
node emb 4
Figure 2: An example GSDM architecture for a graphical
model with one observed and three latent variables. The
components within the dashed lines are repeated multiple
times. Arrows represent information flow. For clarity, we
leave out simple operations including linear transformations;
see the appendix for full detail.
modelled by a DM by making the neural network archi-
tecture equivariant to the same permutation (Hoogeboom
et al.,2022). We show how to encode such equivariances in
GSDM in Section 3.2.
3. Method
The first stage in using GSDM is to define a graphical model
as discussed previously. This section focuses on how to map
from a graphical model to the corresponding GSDM archi-
tecture, an example of which is shown in Figure 2. The
backbone of the architecture is a stack of transformer mod-
ules operating on a set of embeddings, with each embedding
corresponding to one graphical model node. The colors in
Figure 2 outline this section. Section 3.1 describes how we
derive a sparse attention mechanism from a graphical model.
Section 3.2 explains the node embeddings. Section 3.3 mo-
tivates our decision to model “intermediate” variables jointly
with the variables of interest. Section 3.4 describes how we
handle observed variables, and Section 3.5 describes how
we handle discrete variables.
3.1. Faithful Structured Attention
Our architecture in Figure 2 runs a self-attention mechanism
over a set of embeddings, each of which corresponds to a
graphical model node. To add structural information, we
use the graphical model’s edges to construct attention masks
M
for the self-attention layers. Precisely, we allow variable
i
to attend to variable
j
if there is an edge between node
i
and node
j
, and irrespective of the direction of the edge. If
the graphical model contains factors, we additionally allow
attention between any node pairs
(i, j)
which connect to
the same factor. These heuristics typically keep the graph
sparse while, importantly, ensuring that given enough atten-
tion layers all output nodes can depend on all input nodes.
We show in Appendix E that this is necessary to faithfully
capture the dependencies in q(x|y).
To reduce our memory usage and computation compared
to a dense matrix multiplication of the masked matrix we
provide an efficient sparse attention implementation as de-
scribed in Appendix C and the released code. Its computa-
tional cost is
O(nm)
, where
n
is the number of dimensions
and
m
is the number of entries in the densest row of
M
.
We show in Appendix H that, even after accounting for cost
of the modeling of additional intermediate variables as de-
scribed later, the sparse attention mechanism gives GSDM
a reduction in computational complexity of
O(n)
relative
to a naive baseline in three of our four experiments.
3.2. Node Embeddings
GSDM’s architecture contains positional embeddings to let
the neural network distinguish which inputs correspond to
which graphical model nodes. The simplest variation of
GSDM learns one embedding per graphical model node
independently, and we call this approach independent em-
beddings, or IE. An issue with IE is that it cannot generally
be adapted to changing problem dimension. A generic solu-
tion to this involves noting that graphical model nodes can
often be grouped together into “arrays”. For instance, the
BCMF example in Figure 1 contains 39 nodes but these
belong to just 4 multi-dimensional arrays:
A
,
R
,
C
, and
E
.
We suggest array embeddings, or AE, which can be auto-
matically constructed for such problems with (potentially
variable-size) ordered datatypes. With AE, we compute
the embedding for each node as the sum of a shared array
embedding, learned independently for every array, and a si-
nusoidal positional embedding (Vaswani et al.,2017) for its
position within an array. Scalars can be treated as arrays of
size 1. AEs work well in our experiments and are a sensible
default.
For graphical models exhibiting permutation invariances
we can optionally enforce these invariances exactly using
exchangeable embeddings, or EE. We do so according to
the following result.
Theorem 3.1 (Permutation invariance in GSDM).Let
A
represent the indices of a subset of the dimensions of data
x
and
ΠA
be the class of permutations that permute only
dimensions indexed by
A
. Assume we have a GSDM para-
meterised with neural network
ˆ
xθ(·;M)
, where
M
is the
structured attention mask. If the node embeddings used
by
ˆ
xθ
are shared across all nodes indexed by
A
, then the
distribution modelled by GSDM will be invariant to all per-
mutations πsatisfying
M=πMand πΠA(6)
4
Graphically Structured Diffusion Models
where
πM
is a permutation of both the rows and columns
of Mby π.
Proof. See Appendix F.
One implication of the
M=πM
condition is that it holds
trivially for a DM without sparse attention, in which
M
is a matrix of all ones. The modeled distribution would
therefore be invariant to any permutation of
A
(Hoogeboom
et al.,2022). This may be a useful permutation invariance to
encode for some problems but, for the structured problems
considered in this paper, it is too simple and not valid. In
none of our experiments are there two variables whose val-
ues can be swapped without changing the density under the
data distribution. For example in BCMF, the data density is
invariant to reordering any of the plate indices, but not to
swapping any single pair of nodes in them.
When
M
is a structured matrix as we propose for GSDM,
Theorem 3.1 suggests a way to incorporate invariances that
are closely tied to the problem structure. On the BCMF
problem shown in Figure 1 we use only four embeddings,
sharing a single embedding between all nodes in
A
; another
between all nodes in
R
; and so on for
C
and
E
. Without
imposing an attention mask, this would make the network
invariant to any permutation of the variables within each of
A
,
R
,
C
and
E
. With GSDM’s attention mask, it is only
invariant to permutations which lead to the same mask. This
means that the learned distribution is invariant only to the
ordering of the indices
i
,
j
, and
k
. As represented by the
plate notation in Figure 1, this is a desired invariance that
matches the data distribution.
In general, Theorem 3.1 suggests a simple heuristic for
checking when problem invariances can be enforced. If per-
muting the ordering of an index does not affect the sparsity
mask for a given problem then sharing embeddings across
instances of this index will enforce a permutation invariance
with respect to this index in GSDM. We utilise this property
in three of our four experiments. Along with our compiler
which generates an attention mask for a programatically-
defined graphical model, checking when Theorem 3.1 holds
for a given class of permutations can reasonably be auto-
mated. We do, however, still require human knowledge to
propose invariances suitable for the graphical model.
3.3. Intermediate Variables
When translating a generative model into a graphical model,
any observed variables or latent variables of particular in-
terest to the modeler should clearly be included as nodes.
There will be other latent variables, which we call “interme-
diate variables”, which are not directly of interest but may
be included to make the graphical model more interpretable,
more sparse, or otherwise preferable. Whether or not these
Figure 3: Time to fit the Boolean circuit with (solid lines)
and without (dashed) intermediate variables. Accuracy is
computed on 16 validation examples every 500 iterations,
to a maximum of 20 000.
are included has implications for GSDM as it will either be
trained to model them jointly with the variables of interest
if they are included, or trained without this signal if they are
not. There is no model-agnostic “right” answer to whether
or not intermediate variables will be helpful but we point out
that including them is often beneficial for GSDMs because
of (1) GSDM’s empirical success at utilising the learning
signal from these intermediate variables as described below
and (2) the reduced computational cost of GSDM’s sparse
attention that is related to the number of graphical model
edges more so than the number of nodes, and so is not
necessarily increased by adding intermediate variables.
As an illustrative example, consider a Boolean logic circuit
which takes an input of size
2n
. The input is split into pairs
and each pair is mapped through a logic gate to give an
output of size
2n1
. After
n
layers and a total of
2n1
logic gates, there is a single output. Suppose that you know
that each gate is randomly assigned to be either an OR gate
or an AND gate, and you wish to infer which from data. If
the data contains only the inputs and the single output, it
contains only 1 bit of information. Identifying the function
computed by each of the
O(2n)
gates will therefore require
at least
O(2n)
data points. On the other hand, if the data
contains intermediate variables in the form of the output
of every logic gate, each data point contains
O(2n)
bits of
information so the task may be solvable with only a few data
points. Figure 3 shows that this reasoning holds up when we
train a DM on this example. Without intermediate variables,
the number of training iterations needed scales exponentially
with
n
. With the combination of intermediate variables
and structured attention, however, the training behaviour is
fundamentally changed to scale more gracefully with n.
3.4. Flexible Conditioning
Optimizing the DM loss in Equation (4) requires a partition-
ing of data into latent variables (outputs)
x0
, and observed
variables (inputs)
y
. Our neural network distinguishes
between variables in
xt
and
y
via a learned observation
5
摘要:

GraphicallyStructuredDiffusionModelsChristianWeilbach1WilliamHarvey1FrankWood1AbstractWeintroduceaframeworkforautomaticallyde-finingandlearningdeepgenerativemodelswithproblem-specificstructure.Wetackleproblemdomainsthataremoretraditionallysolvedbyalgorithmssuchassorting,constraintsatisfac-tionforSud...

展开>> 收起<<
Graphically Structured Diffusion Models.pdf

共23页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:23 页 大小:1.07MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 23
客服
关注