Under this formulation, only decoder parameters
θDec
are required at inference time. The retriever
retrieves indices of the relevant context and looks up their encodings. The context encodings are
processed ahead of the time, and are completely decoupled from online operations. This is in contrast
to previous works of Realm [
15
], Rag [
26
] or Retro [
4
] where the interaction between input
x
and
context
c
is bi-directional, which necessitates context encoding at inference time. In our model,
information flows uni-directionally from
c
to
x
and
y
, and that the encoding of each context
c
is
processed independently. On one hand, this is more restrictive than bi-directional interaction; on the
other hand, such a design ensures complete decoupling of context processing and the online language
model inference. The exact mechanism is detailed in Figure. 1.
Conceptually the retriever can be an arbitrary blackbox. In practice, we use a dual encoder formula-
tion [
8
,
16
], which first represents
x
as a query embedding
EmbQ(x)
and performs vector similarity
search over a database of
D
to find the indices of documents whose document embedding has the
highest inner products with the query embedding. We then look up the context encoder outputs that
correspond to retrieved indices and return them as the retriever output.
l∗= arg max
v∈D
hEmbQ(x),v)i;D={EmbD(c); c∈ C}
HEnc[l∗] = Enc(cl∗;θEnc )
In the case of multiple supporting context,
k-arg max
is used instead of
arg max
. The encoder
outputs of each supporting context are then concatenated:
P(yi|y<i, x1, x2,...xn, Concat(HEnc [l1],HEnc[l2],· · · HEnc [lk]); θDec);
Where Concat is simply vector concatenation:
Concat((h1,h2,· · · ,hn],[g1,g2, ..., gm], ...)=[h1,h2, ..., hn,g1,g2, ..., gm, ...]
At training time, the encoder and decoder are jointly trained. We first perform offline retrieval to
form triplets of
(x,y,c)
, where
c
is retrieved by some predefined retriever. The loss is masked and
only defined on the targets
y
. Because the encoding of each context is independent and there is no
interaction between context, the attention matrix of encoder is block diagonal and we process them in
a linear loop over each diagonal block. Thus, the computation cost of both encoder and decoder at
each step is linear in the number of context.
For online language model inference, only the retriever and the decoder are involved. The retrieval
embedding
EmbD(c)
and the encoder output
HEnc(c)
of the context are both offline pre-processed
and prepared into the retrieval database of such associated key-value pairs. When a new input
sequence arrives, the retriever is only responsible for the approximate nearest neighbor search and
lookup of the associated value that is the pre-computed encoder output. The decoder then takes in the
input sequence and cross attends on the concatenation of pre-computed encoder output to generate the
targeted tokens. Thanks to the decoupling, neither the retriever nor the decoder needs to store encoder
parameters. Hence, such an approach is more parameter efficient compared to similar works such as
Retro [
4
] by saving the storage and computation budget on encoder, which is helpful in “client-server”
scenario where the capacity of the “client” can be limited. When accounting the parameter count in
comparison with other models, we only need to count the decoder and cross attention parameters. We
also followed Retro’s [
4
] approach of excluding the embedding matrices from the parameter count.
4 Auto-regression Language Modeling
We experimented with the same “encoder-decoder” context incorporation mechanism for both auto-
regressive language modeling and open domain question answering. The only difference is that
auto-regressive language modeling processes input sequences in a sliding window fashion, while
question answering task receives the full input sequence (the question) at once.
4.1 Experimental Setup
For auto-regressive language modeling, we use English C4 [
32
] version
2.2.1
, the same as Retro.
We train the language model and prepare the retrieval database using the
train
split and evaluate the
results using
validation
split. The language model target sequence is a sliding window (chunk)
4