Decoupled Context Processing for Context Augmented Language Modeling Zonglin Li

2025-04-26 0 0 775.1KB 16 页 10玖币
侵权投诉
Decoupled Context Processing for
Context Augmented Language Modeling
Zonglin Li
Google Research, New York
lizonglin@google.com
Ruiqi Guo
Google Research, New York
guorq@google.com
Sanjiv Kumar
Google Research, New York
sanjivk@google.com
Abstract
Language models can be augmented with a context retriever to incorporate knowl-
edge from large external databases. By leveraging retrieved context, the neural net-
work does not have to memorize the massive amount of world knowledge within its
internal parameters, leading to better parameter efficiency, interpretability and mod-
ularity. In this paper we examined a simple yet effective architecture for incorporat-
ing external context into language models based on decoupled
Encoder-Decoder
architecture. We showed that such a simple architecture achieves competitive
results on auto-regressive language modeling and open domain question answer-
ing tasks. We also analyzed the behavior of the proposed model which performs
grounded context transfer. Finally we discussed the computational implications of
such retrieval augmented models.
1 Introduction
Transformers have proven to be powerful language models that capture an impressive amount of
world knowledge in its internal parameters and generalize to a variety of downstream tasks [
39
,
31
].
Recently, there has been a lot of success in improving language model quality by increasing the
number of parameters in transformers, often on the order of hundreds of billion [
10
,
30
,
38
,
6
].
However, the scaling of model size also contributes to the exponential rise of the computation costs,
both in terms of the number of accelerators needed and energy consumption [29].
To overcome the exponential increase in the number of parameters, one natural idea is to utilize
information retrieved from an external source such as a massive external database, therefore freeing
the neural network from having to memorize world knowledge. To this end, researchers proposed
multiple context augmented language model architectures [
21
,
15
,
4
,
44
,
26
]. Such architecture
typically has two components: a retriever that embeds the input sequence and retrieves relevant
context from external source through vector similarity search; a neural network that integrates both
the input and retrieved external context into the prediction of target sequence, formally:
P(y|x,C=Retrieve(x,D); θ)P(y|x, θ0)(1)
Here,
C={c}
is a set of context retrieved from the external database
D
.
θ0
is a self-contained
language model which predicts target sequence
y
based solely on the input
x
whereas
θ
corresponds
to the context augmented language model which incorporates both the input
x
and the retrieved
context C.
One of the challenges for such context augmented language model is the computational cost of
context retrieval and incorporation, especially when multiple pieces of context is present or the
context sequence is long. In this paper, we propose a computationally efficient architecture for
Equal contribution
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.05758v1 [cs.CL] 11 Oct 2022
incorporating context based on vanilla
Encoder-Decoder
, which decouples the encoding of context
and the prediction of target sequence. We show that the model with such a simple architecture is
competitive when compared with customized mechanisms such as Chunked-Cross-Attention [
4
]
on language modeling score (as measured by bits-per-byte, BPB), while being more efficient in
terms of parameter count and computation cost. Then, we define metrics to measure the utility of
the retrieved context and use it to guide the training of the retriever. We further show competitive
results on downstream tasks of question answering, and demonstrate that the model takes advantage
of the retrieved context without memorizing facts within its internal parameters. Finally, we study the
implication of context retrieval in terms of retrieval latency, accuracy and computation cost.
To summarize the main contributions of this article:
Proposed a novel Encoder-Decoder based architecture for incorporating retrieved external
context, which decouples context encoding from language model inference.
Demonstrated the competitive results of the proposed model on both the auto-regressive
language modeling task and the open domain question answering task.
Analyzed model behavior by understanding how context improves language modeling pre-
diction on tokens with different linguistic properties and how the model performs grounded
context transfer.
Discussed computational cost and retrieval efficiency in context augmentation.
2 Related Works
Large language models, typically in the form of big neural networks, are trained with a huge amount
of training data rich in unstructured knowledge. Researchers have found that after model training, the
neural networks often end up storing a surprisingly large amount of memorized information within
its weights [
2
,
7
] which are then leveraged as a knowledge base. Multiple hypotheses have been
developed on how components such as fully-connected layers [
13
] and attention layers [
5
] may be
responsible for such memorization behavior. While the capability of storing world knowledge is
desirable, memorization also contributes to huge model sizes and the lack of explicit control over
knowledge base, such as performing selection or updates.
An alternative strategy is to enable language models to incorporate world knowledge in the form of
retrieved context from external sources, instead of having to memorize them. Multiple works have
proposed architectures that support external retrieval, usually composed of a context retriever that
searches a large external key value store and a method of integrating retrieved information.
There are various ways to construct the key value store. The keys are primarily used for similarity
matching, and they can be sparse vectors such as BM25 [
34
], or dense embeddings extracted from
part of the model [
21
,
25
,
40
], or from pretrained embedders [
12
,
28
], or embedders trained for
specific downstream tasks [
15
,
20
,
41
,
35
,
24
,
11
]. The values also have various different forms. For
example, TOME [
11
] stores a dense embedding about the contextual information of an entity mention,
while Realm [
15
], RAG [
26
], FID [
18
], Retro [
4
], MARGE [
25
], DenSPI [
35
] and DensePhrases [
24
]
store the raw text as the value. Works such as
k
NN-LM [
21
] and Spalm [
44
] store one token as
a value. Finally the key value store is searched over using vector similarity techniques, typically
with some off-the-shelf nearest neighbor search implementations such as FAISS [
19
], ScaNN [
14
],
HNSW [27] or SPTAG [9].
There are also many different ways to integrate the retrieved context. A popular approach is to
concatenate the retrieval results with the original input and jointly process them. It has been adopted
by works such as Realm [
15
], RAG [
26
], and FiD [
18
]. Other works utilize some forms of cross
attention for the context integration, such as the Chunked-Cross-Attention with input conditioning
in Retro [
4
],
k
NN Attention in Memorizing Transformer [
40
], Memory Attention in TOME [
11
]
and Cross-Attention in MARGE [
25
]. For token level integration,
k
NN-LM [
21
] uses simple linear
interpolation while Spalm [
44
] uses a learned gate based on the last layer embedding. There are
also works that directly utilize the retrieval results without any integration, such as DenSPI [
35
] and
DensePhrases [
24
]. Most of the works use retrieval as a way to augment tasks such as language
modeling or question answering, with the exception of MARGE [
25
] where retrieval is treated as an
autoencoder bottleneck for multilingual pretraining, and is not strictly necessary for inference. We
compare representative previous works and contrast with our proposal in Table 1.
2
Method Retrieval
Granularity Retrieval
Encoding Context
Integration
Decoupled
Context Encoding Tasks
kNN-LM [21] Token Last layer Interpolation Yes LM
Spalm [44] Token Last layer Gating Yes LM
Realm [15], RAG [26],
FID [18] Input Trained Concat No OpenQA
Retro [4] Chunk Frozen Chunked-
Cross-Attention No LM,
OpenQA
Proposed Chunk Frozen /
trained
Encoder-
Decoder
Cross-Attention Yes LM,
OpenQA
Table 1: Architectural differences between previous retrieval augmented model and ours in (i) context
retrieval, (ii) context integration and (iii) targeted applications.
Input x
Decoder
Input
x1 x2 x3 y1 y2
x2 x3 y1 y2 y3
Server
(Retrieval+Context Encoder)
Client
(Context Augmented LM)
Decoder
Output
Request
Auto-regressive Decoding
Key Value
Store
Context c
Doc Retrieval
Embedder
EmbD
Key: EmbD(c)
Value: HEnc(c)
Query Retrieval
Embedder
EmbQ
Query:
EmbQ(x)
Retrieve
Precompute
Respond
Context
Encoder HEnc
Figure 1:
Architecture of context augmented language model.
We opt to use the standard
Encoder-Decoder
cross attention mechanism for context incorporation which allows us to de-
couple context encoding from LM inference.
c
,
x
,
y
serve as context, decoder input and decoder
target respectively. In other words, the client sends over input
x
, and the server conducts retrieval to
find relevant context and returns the encoded representation HEnc(c;θEnc ). The encoded representa-
tion
HEnc
is pre-computed offline and returned as “metadata” of the retrieval. Note that training of
the
Encoder
and
Decoder
is joint, while they are decoupled at inference time: the client does not
need to store parameters or the run inference on the Encoder component.
3 Architecture
We use
Encoder-Decoder
Transformer architecture [
39
] to integrate language model input and
retrieved context. We denote the context encoder and LM decoder as
Enc
and
Dec
respectively. Given
an input token sequence
x= (x1, x2, ..., xn)
, the task is to predict the next tokens
y= (y1, y2, ..., ys)
.
Without external context, we concatenate
x
before
y
and the task becomes a traditional auto-regressive
language modeling to predicts targets
y
following input (or “promopt”)
x
. In this setting, only the
decoder is involved (denoted as “No-retrieval”). To incorporate external context, we use
c
,
x
,
y
to
serve as encoder input, decoder input and decoder target respectively. We first use a retriever to
identify the context
c= (c1, c2, . . . , cm)
given input
x
, then fetch the the pre-computed encoder
output of the corresponding context tokens
HEnc(c;θEnc )
as output. In this setting, encoder output
HEnc is directly used by the decoder through Encoder-Decoder cross-attention to influence the final
prediction. The decoder does not have to know the exact tokens of cthat are retrieved.
No-retrieval:P(yi|y<i, x1, x2,...xn;θ0
Dec)
Retrieval:P(yi|y<i, x1, x2,...xn,{HEnc(c)};θDec)
3
Under this formulation, only decoder parameters
θDec
are required at inference time. The retriever
retrieves indices of the relevant context and looks up their encodings. The context encodings are
processed ahead of the time, and are completely decoupled from online operations. This is in contrast
to previous works of Realm [
15
], Rag [
26
] or Retro [
4
] where the interaction between input
x
and
context
c
is bi-directional, which necessitates context encoding at inference time. In our model,
information flows uni-directionally from
c
to
x
and
y
, and that the encoding of each context
c
is
processed independently. On one hand, this is more restrictive than bi-directional interaction; on the
other hand, such a design ensures complete decoupling of context processing and the online language
model inference. The exact mechanism is detailed in Figure. 1.
Conceptually the retriever can be an arbitrary blackbox. In practice, we use a dual encoder formula-
tion [
8
,
16
], which first represents
x
as a query embedding
EmbQ(x)
and performs vector similarity
search over a database of
D
to find the indices of documents whose document embedding has the
highest inner products with the query embedding. We then look up the context encoder outputs that
correspond to retrieved indices and return them as the retriever output.
l= arg max
v∈D
hEmbQ(x),v)i;D={EmbD(c); c∈ C}
HEnc[l] = Enc(cl;θEnc )
In the case of multiple supporting context,
k-arg max
is used instead of
arg max
. The encoder
outputs of each supporting context are then concatenated:
P(yi|y<i, x1, x2,...xn, Concat(HEnc [l1],HEnc[l2],· · · HEnc [lk]); θDec);
Where Concat is simply vector concatenation:
Concat((h1,h2,· · · ,hn],[g1,g2, ..., gm], ...)=[h1,h2, ..., hn,g1,g2, ..., gm, ...]
At training time, the encoder and decoder are jointly trained. We first perform offline retrieval to
form triplets of
(x,y,c)
, where
c
is retrieved by some predefined retriever. The loss is masked and
only defined on the targets
y
. Because the encoding of each context is independent and there is no
interaction between context, the attention matrix of encoder is block diagonal and we process them in
a linear loop over each diagonal block. Thus, the computation cost of both encoder and decoder at
each step is linear in the number of context.
For online language model inference, only the retriever and the decoder are involved. The retrieval
embedding
EmbD(c)
and the encoder output
HEnc(c)
of the context are both offline pre-processed
and prepared into the retrieval database of such associated key-value pairs. When a new input
sequence arrives, the retriever is only responsible for the approximate nearest neighbor search and
lookup of the associated value that is the pre-computed encoder output. The decoder then takes in the
input sequence and cross attends on the concatenation of pre-computed encoder output to generate the
targeted tokens. Thanks to the decoupling, neither the retriever nor the decoder needs to store encoder
parameters. Hence, such an approach is more parameter efficient compared to similar works such as
Retro [
4
] by saving the storage and computation budget on encoder, which is helpful in “client-server”
scenario where the capacity of the “client” can be limited. When accounting the parameter count in
comparison with other models, we only need to count the decoder and cross attention parameters. We
also followed Retro’s [
4
] approach of excluding the embedding matrices from the parameter count.
4 Auto-regression Language Modeling
We experimented with the same “encoder-decoder” context incorporation mechanism for both auto-
regressive language modeling and open domain question answering. The only difference is that
auto-regressive language modeling processes input sequences in a sliding window fashion, while
question answering task receives the full input sequence (the question) at once.
4.1 Experimental Setup
For auto-regressive language modeling, we use English C4 [
32
] version
2.2.1
, the same as Retro.
We train the language model and prepare the retrieval database using the
train
split and evaluate the
results using
validation
split. The language model target sequence is a sliding window (chunk)
4
摘要:

DecoupledContextProcessingforContextAugmentedLanguageModelingZonglinLiGoogleResearch,NewYorklizonglin@google.comRuiqiGuoGoogleResearch,NewYorkguorq@google.comSanjivKumarGoogleResearch,NewYorksanjivk@google.comAbstractLanguagemodelscanbeaugmentedwithacontextretrievertoincorporateknowl-edgefromlarge...

展开>> 收起<<
Decoupled Context Processing for Context Augmented Language Modeling Zonglin Li.pdf

共16页,预览4页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:16 页 大小:775.1KB 格式:PDF 时间:2025-04-26

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 16
客服
关注