Arithmetic Sampling Parallel Diverse Decoding for Large Language Models

2025-04-30 0 0 461.93KB 17 页 10玖币
侵权投诉
Arithmetic Sampling: Parallel Diverse Decoding for Large Language Models
Luke Vilnis * 1 Yury Zemlyanskiy * 1 Patrick Murray 1Alexandre Passos 1Sumit Sanghai 1
Abstract
Decoding methods for large language models
often trade-off between diversity of outputs
and parallelism of computation. Methods such
as beam search and Gumbel top-k sampling
can guarantee a different output for each
element of the beam, but are not easy to
parallelize. Alternatively, methods such as
temperature sampling and its modifications (top-k
sampling, nucleus sampling, typical decoding,
and others), are embarrassingly parallel, but have
no guarantees about duplicate samples. We
present a framework for sampling according
to an arithmetic code book implicitly defined
by a large language model, compatible with
common sampling variations, with provable beam
diversity under certain conditions, as well as being
embarrassingly parallel and providing unbiased
and consistent expectations from the original
model. We demonstrate the effectiveness of
our approach on WMT machine translation,
more than halving the standard deviation when
estimating expected BLEU score reward, and
closing the BLEU score gap between independent
sampling and beam search by up to 63%.
1. Introduction
Large language models (LLMs) based on transformers
are crucial to modern natural language processing. The
ability of LLMs to capture knowledge from massive
pretraining datasets is useful for applications such as
machine translation and predictive text (Raffel et al.,2020;
Brown et al.,2020;Radford et al.,2019) as well as for
automated speech recognition (Martinez et al.,2021) and
image captioning (Devlin et al.,2015). However, because
of the powerful nonlinear dependencies in the architecture,
*
Equal contribution
1
Work done while all authors
were at Google Research.. Correspondence to: Luke
Vilnis
<
lvilnis@google.com
>
, Yury Zemlyanskiy
<urikz@google.com>.
Proceedings of the
40 th
International Conference on Machine
Learning, Honolulu, Hawaii, USA. PMLR 202, 2023. Copyright
2023 by the author(s).
Figure 1: Sequence model over sequences of length two
and a vocabulary of three symbols mapping points in the
unit interval to each sequence. An even lattice of code
points parallelizes decoding into diverse high-probability
sequences.
options for inference are limited.
While LLM inference can be performed exactly for the case
of drawing independent samples, practical systems often
use inexact search—often modifications to beam search—to
guarantee high-quality and diverse (either in n-gram overlap
or semantic difference) samples. Search-based approaches
including beam search, stochastic beam search (Gumbel
top-k) (Kool et al.,2019), determinantal beam search
(Meister et al.,2021b), and others, can produce diverse
samples by construction, at the cost of being difficult to
efficiently parallelize, as they must examine the entire set of
partial predictions (known as the beam) at each time step.
Intuitively, there seems to be a trade-off between
parallelizability of a sampling algorithm and ability
to guarantee non-duplicate samples. Methods based
on ancestral sampling parallelize very well as they
turn a random number generation seed into a sample
1
arXiv:2210.15458v2 [cs.CL] 1 Jun 2023
Arithmetic Sampling: Parallel Diverse Decoding for Large Language Models
independently. However, high probability sequences are
often generated multiple times. Conversely, search-style
methods inherently avoid generating duplicate samples but
when parallelized require synchronizing across replicas to
sort the candidates at each step.
In response to this, we introduce arithmetic sampling,
a technique for sampling from large language models
that produces a set of non-independent samples from the
model, based on a coding scheme implicitly defined by the
model. With this coding scheme, distant codes in code
space represent different sequences. Further, decoding
each code can be done independently from all other codes.
Arithmetic sampling boasts provable beam diversity under
certain conditions, produces unbiased expectations from the
original model, and is embarrassingly parallel and as simple
to implement as normal sampling.
In addition to analyzing bias and consistency of estimators
based on arithmetic sampling, we present results on the
metric properties of the codebook space and the conditions
under which it improves sample diversity, as well as an
analysis of the estimator variance.
Comparing against equivalent hyperparameters for standard
sampling, we show improvements of nearly 1 point of BLEU
in oracle experiments for En/Fr WMT translation, closing
the gap between independent sampling and beam search
by up to 63%, reducing estimation variance by more than
half and improving beam diversity. We see comparable
improvements in En/Ro translation and variance reduction
for ROUGE score on CNN/DailyMail summarization. We
release an open-source implementation of our algorithm
1
in
the popular T5X transformer library (Roberts et al.,2022).
2. Related Work
This paper draws on three main threads of related work.
Coding theory, diverse sampling, and quasi-Monte Carlo.
The use of latent “codes” to represent data has a long
history in the neural network and representation learning
literature, from autoencoders (LeCun,1987) to sparse
coding algorithms like K-SVD (Aharon et al.,2006). Rather
than using a high dimensional code learned from data using
an iterative algorithm or backpropagation, we design a
simple one dimensional arithmetic code (Cover & Thomas,
2006) post-hoc from a trained language model.
Diverse sampling inference techniques for large language
models fall into two categories: techniques for producing
a diverse beam (sample of sequences), and techniques
for discouraging overlap (n-gram repetition) within a
1
Code is available at
https://github.com/
google-research/google-research/tree/
master/arithmetic_sampling
single sequence. The former encompasses methods like
determinantal beam search (Meister et al.,2021b), parallel
approximate decoding (Cho,2016), stochastic beam search
(Kool et al.,2019), and conditional poisson stochastic
beam search (Meister et al.,2021a). Our method differs
from (determinantal) beam search or parallel approximate
decoding in that it is designed to faithfully sample from
the underlying probability model in that sample means can
be used to compute unbiased expectations. Unlike beam
search or sampling-without-replacement based variants, our
algorithm is embarassingly parallel.
Methods such as temperature sampling, top-k sampling (Fan
et al.,2018), nucleus sampling (Holtzman et al.,2019),
Mirostat (Basu et al.,2020), and typical decoding (Meister
et al.,2022) are useful both for increasing diversity over
standard beam search both across the beam and within a
single long generation. These methods, and any others
that modify conditional logits, are fully compatible with
and complementary to our algorithm, and we provide
experiments with temperature sampling and top-k sampling
demonstrating improvements.
There is also work on changing the training objective using
unlikelihood (Welleck et al.,2019) or reinforcement learning
(Lagutin et al.,2021) so that standard generation schemes
produce more diverse outputs, which is also orthogonal to
our methods.
The final thread of related work is quasi-Monte Carlo
integration. (Randomized) Quasi-Monte Carlo techniques
(l’Ecuyer,2016) combine the adaptive anytime properties of
Monte Carlo estimation with the reduced variance of lattice
integration methods (L’Ecuyer & Lemieux,2000), and have
been used in machine learning applications such as lowering
the variance of randomized kernel approximations (Yang
et al.,2014) and neural latent variable models (Buchholz
et al.,2018). Quasi-Monte Carlo has not been applied to
the standard neural autoregressive discrete distributions we
describe here, to our knowledge.
3. Background
3.1. Arithmetic Coding
An arithmetic code is an optimal lossless coding
scheme—that is, a coding scheme with minimal expected
code length—for when the exact joint distribution of the
data is known. Given a total ordering of the items being
encoded and defining
wi=Pj<i P(X=xj)
the
cumulative probability of item
xi
, an arithmetic code for
xi
is a number in the interval
(wi, wi+1)
. To represent
this code as a sequence of bits it’s usual to pick a rational
number in the interval
(wi, wi+1)
whose binary fraction
representation requires a small number of digits. Larger
intervals then tend to contain more numbers with short
2
Arithmetic Sampling: Parallel Diverse Decoding for Large Language Models
representations. Decoding an arithmetic code
c
requires
finding the unique value of
i
such that
wi<c<wi+1
. In
this way, codewords are assigned with a length which is
inversely (logarithmically) proportional to the probability
of an outcome, providing an optimal compression rate for
the average message, roughly equivalent to the entropy of
the distribution.
Definition 1 (Arithmetic codebook).By a slight abuse of
terminology, we will use the term arithmetic codebook to
refer not only to the map from symbols to binary fractional
representations, but also to the map from symbols to
subintervals of the unit interval, f:V2[0,1].
An example of an arithmetic codebook is he most common
method for sampling from categorical distributions in
practice. First one constructs a codebook assigning each
symbol to a subinterval of the unit interval, and then samples
uniformly at random from the unit interval and inverts the
map to find the symbol sampled.
3.2. Randomized Quasi-Monte Carlo
A common problem when working with probability
distributions is to compute the expectation of functions
under the distribution. The family of Monte Carlo
algorithms is commonly used for this purpose. A simple
Monte Carlo algorithm for estimating the expectation of a
function
s
under a probability distribution is to first obtain
n
i.i.d. samples
xi
from the distribution and then approximate
E[s(X)] 1
nPis(xi)
. Without loss of generality, these
methods are often formulated in terms of evaluating an
expectation of a function defined on the unit hypercube.
E[s(X)] = Zx[0,1]d
s(x)dx (1)
1
N
N
X
i
s(ui), uiUniform([0,1]d)
Because many integrals can be interpreted as expectations
of functions of a uniform distribution on the unit hypercube,
Monte Carlo algorithms have been fairly useful for
numerical integration as alternatives to quadrature methods
which approximate functions using grids and other regular
structures, often providing higher expected accuracy when
controlling for computational cost.
When dependent random variables are used in Monte Carlo
estimation of probabilistic quantities, it is commonly called
Randomized Quasi-Monte Carlo (RQMC). These methods
include the use of low-discrepancy sequences, lattice
rules, antithetic sampling, and stratification, among others
(l’Ecuyer,2016). In this work, we are most concerned with
lattice-based RQMC (L’Ecuyer & Lemieux,2000) methods,
which use perturbed lattices. A simple lattice-based RQMC
rule replaces the uniform sampling in Equation 1with a
regular lattice of points that has been randomly shifted by a
single uniform random vector.
E[s(X)] 1
N
N
X
i
s(li+u), u Uniform([0,1]d).(2)
3.3. Ancestral Sampling
While Section 3.1 makes it clear that given a codebook for
a discrete distribution, sampling from that distribution can
be done by simply generating a uniform random number, in
practice explicitly constructing such a codebook is often not
feasible as for example in probabilistic sequence models the
codomain includes all sequences of symbols up to some
large length. Instead, one models the joint probability
of a sequence of tokens as the product of the conditional
probabilities of each token given all of the preceding tokens,
P(XT, ..., X1)=ΠT1
t=0 P(XT|X1, ..., Xt).(3)
Each of these conditional probability functions can then be
modeled using a neural network. Analogously, sampling in
large language models is done through ancestral sampling,
wherein each token is sampled successively from the
conditional probability after conditioning on all previous
tokens,
xTP(XT|X1, ..., XT1).(4)
Definition 2 (Prefix of a sequence).For a sequence of
symbols
x1, ..., xT
, we call a contiguous subsequence
x1, ..., xtfor t<Taprefix of the sequence.
When working with these probabilistic sequence models, it
is natural to think in terms of prefixes. In fact, implicit
in our definition of the ancestral sampling scheme and
product-of-conditionals architecture is that it allows us to
compute probabilities not only over complete sequences,
but also over partial prefixes, i.e. the probability
P(X1=
x1, ..., XT=xT)
is the sum of probabilities of every
sequence longer than
T
sharing that prefix. There are
two ways that practical neural sequence models distinguish
between a prefix and a complete sequence. The first,
common in decoder-only models, is to decode every
sequence to some maximum length, and define the
distribution as applying only to sequences of that length.
The second is to include a special
EOS
(end-of-sentence)
token, and define a sequence as complete if it ends with
EOS
. So the prefix
(x1, ..., xT)
is distinguished from the
sequence
(x1, ..., xT,EOS)
. Padding tokens are added to
the end of the sequence after
EOS
to make every sequence
of a uniform length.
In our work we exploit this prefix structure in order to
construct an alternative to ancestral sampling for these
neural sequence models.
3
Arithmetic Sampling: Parallel Diverse Decoding for Large Language Models
4. Method
The core idea of our method is to improve the diversity of
samples by (1) reinterpreting an ancestral sampling scheme
as defining an arithmetic codebook, where distance in code
space correlates (in a sense to be made precise later) with
prefix distance in sentence space, and (2) using non-IID
random numbers to sample from the codebook. This allows
us to guarantee that the codewords are “far apart” in code
space, while preserving unbiasedness of our estimation.
4.1. Constructing the codebook
The algorithm has a geometric flavor and will be easiest to
follow while making reference to the toy example in Figure
1. As noted in Section 3.3, in real world sequence models
over a vocabulary
V
, it is impractical to explicitly construct
an arithmetic codebook (a mapping from sequences to
disjoint subintervals of
[0,1]
). What we will demonstrate
here is that it is possible to implicitly define an arithmetic
codebook for a given sequence model such that we can
(1) given a sequence or prefix, compute the corresponding
interval in the codebook, and (2) given a point (a “code”) in
the unit interval, to compute the the corresponding sequence.
Further, these computations can be done with complexity
no greater than that of normal likelihood computation or
ancestral sampling.
Without loss of generality, we can assume the sequences all
have uniform length
L
as described in Section 3.3. Given
an ordered vocabulary
V
, we use the standard dictionary
ordering on
VL
. Given two sequences
(a1, a2, ..., aL)
and
(b1, b2, ..., bL)
, their ordering depends on the order between
the two symbols in the first place
i
on which the two
sequences differ.
Because our dictionary ordering puts all sequences sharing
a given prefix into contiguous blocks, we can define the
codebook in terms of prefixes and only lazily materialize
the codes for longer prefixes as we need them.
Concretely, we compute the CDF of the first token in the
sequence
wi1=Pj<i1P(X1=vj)
, and assign to each
choice of prefix
X1=vi1
the subinterval
(wi1, wi1+1)
. All
codes for sequences starting with
vi1
will lie in this interval.
We recursively define the codebook for prefixes of length
two, computing
wi1i2=wi1+
P(X1=vi1)X
j<i2
P(X2=vj|X1=vi1)
This gives the subinterval corresponding to sequences that
start with vi1vi2. We extrapolate the following formula:
wi1...iL=wi1..iL1+
X
j<iL
P(X1=vi1, ..., XL1=viL1, XL=vj)(5)
For a given sequence or prefix
i1...iL
, we assign it to
the subinterval
(wi1...iL, wi1...iL+1)
. By inspecting this
equation we can see several things:
The intervals defined by
wi
s are a valid codebook
for the space of sequences. The subintervals
corresponding to any given prefix are disjoint from
those that do not share that prefix, so every sequence
ends up in a disjoint interval, and the length of each
interval is exactly the probability of the sequence.
The computation of a code for a given sequence has
the same FLOPS as evaluating the probability of that
sequence under the model using Equation 3. The only
probabilities involved in computing the
wi
s all involve
the conditional probabilities used to calculate a single
prefix and can be computed step-by-step.
Given a code point in the unit interval, discovering
its subinterval requires the same FLOPS as standard
ancestral sampling using Equation 4. This is described
in Algorithm 1and follows the same recursive
construction as used to define the codebook in Equation
5.
An immediate corollary of the third point is
Proposition 1. If the code point
c
is chosen randomly from
the unit interval, Algorithm 1samples from the distribution
P(X) = P(X1, ..., XT).
Remark. Conditioned on a set of codes
c1, ..., cN
, sampling
using Algorithm 1is embarrassingly parallel across
sequences. Because it uses the same FLOPS and follows
the same structure, we find that this has zero appreciable
computational overhead in practice compared to standard
sampling. We further discuss practicalities of parallel LLM
inference in Section 5.3 and Appendix A.
4.2. Sampling consistency and bias
So far we have only reproduced the standard algorithm
for ancestral sampling using an additional latent
uniform variable (which is similar to most practical
implementations). This latent variable however lets us
introduce some structure in how we pick our codes.
A naive codebook of maximal diversity can be obtained by
dividing the unit interval in a regular lattice. That is, for N
codes, we pick so
ci
is the
i
’th quantile
i/N
. Since this is
deterministic, Proposition 1 does not apply, but this gives a
consistent estimator.
4
摘要:

ArithmeticSampling:ParallelDiverseDecodingforLargeLanguageModelsLukeVilnis*1YuryZemlyanskiy*1PatrickMurray1AlexandrePassos1SumitSanghai1AbstractDecodingmethodsforlargelanguagemodelsoftentrade-offbetweendiversityofoutputsandparallelismofcomputation.MethodssuchasbeamsearchandGumbeltop-ksamplingcanguar...

展开>> 收起<<
Arithmetic Sampling Parallel Diverse Decoding for Large Language Models.pdf

共17页,预览4页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:17 页 大小:461.93KB 格式:PDF 时间:2025-04-30

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 17
客服
关注