Generalization Properties of Retrieval-based Models

2025-04-22 0 0 854.15KB 38 页 10玖币
侵权投诉
Generalization Properties of Retrieval-based Models
Soumya Basu*, Ankit Singh Rawat*, and Manzil Zaheer*
Google LLC, USA
{basusoumya,ankitsrawat,manzilzaheer}@google.com
Abstract
Many modern high-performing machine learning models such as GPT-3 primarily rely
on scaling up models, e.g., transformer networks. Simultaneously, a parallel line of work
aims to improve the model performance by augmenting an input instance with other
(labeled) instances during inference. Examples of such augmentations include task-specific
prompts and similar examples retrieved from the training data by a nonparametric
component. Remarkably, retrieval-based methods have enjoyed success on a wide range of
problems, ranging from standard natural language processing and vision tasks to protein
folding, as demonstrated by many recent efforts, including WebGPT and AlphaFold.
Despite growing literature showcasing the promise of these models, the theoretical
underpinning for such models remains underexplored. In this paper, we present a
formal treatment of retrieval-based models to characterize their generalization ability.
In particular, we focus on two classes of retrieval-based classification approaches: First,
we analyze a local learning framework that employs an explicit local empirical risk
minimization based on retrieved examples for each input instance. Interestingly, we
show that breaking down the underlying learning task into local sub-tasks enables the
model to employ a low complexity parametric component to ensure good overall accuracy.
The second class of retrieval-based approaches we explore learns a global model using
kernel methods to directly map an input instance and retrieved examples to a prediction,
without explicitly solving a local learning task.
1 Introduction
As our world is complex, we need expressive machine learning models to make high accuracy
predictions on real world problems. There are multiple ways to increase expressiveness of a
machine learning model. A popular way is to homogeneously scale the size of a parametric
model, such as neural networks, which has been behind many recent high-performance models
such as GPT-3 [Brown et al., 2020] and ViT [Dosovitskiy et al., 2021]. Their performance
(accuracy) exhibits a monotonic behavior with increasing model size, as demonstrated by
“scaling laws” [Kaplan et al., 2020]. Such large models, however, have their own limitations,
including high computation cost, catastrophic forgeting (hard to adapt to changing data),
lack of provenance, and explanability. Classical instance-based models Fix and Hodges [1989],
on the other hand, offer many desirable properties by design — efficient data structures,
incremental learning (easy addition and deletion of knowledge), and some provenance for
* Equal contribution in alphabetical order
1
arXiv:2210.02617v1 [cs.LG] 6 Oct 2022
Parametric
Training
data
Nonparametric
(a) Classical learning setups
,
,
,
Training
data
(b) Modern retrieval-based setup
Figure 1: An illustration of a retrieval-based classification model. Given an input instance
x
,
similar to an instance-based model, it retrieves similar (labeled) examples
Rx
=
{
(
x0
j, y0
j
)
}j
from training data. Subsequently, it processes (potentially via a nonparametric method)
input instance along with the retrieved examples to make the final prediction
ˆy
=
f
(
x, Rx
).
its prediction based on the nearest neighbors w.r.t. the input. However, these models often
suffer from weaker empirical performance as compared to deep parametric models.
Increasingly, a middle ground combining the two paradigms and retaining the best of both
worlds is becoming popular across various domains, ranging from natural language [Das
et al., 2021, Wang et al., 2022, Liu et al., 2022, Izacard et al., 2022], to vision [Liu et al.,
2015, 2019, Iscen et al., 2022, Long et al., 2022], to reinforcement learning [Blundell et al.,
2016, Pritzel et al., 2017, Ritter et al., 2020] , to even protein structure predictions [Cramer,
2021] . In such approaches, given a test input, one first retrieves relevant entries from a
data index and then processes the retrieved entries along with the test input to make the
final predictions using a machine learning model. This process is visualized in Figure 1b.
For example, in semantic parsing, models that augment a parametric seq2seq model with
similar examples have not only outperformed much larger models but also are more robust
to changes in data [Das et al., 2021].
While classical learning setups (cf. Figure 1a) have been studied extensively over decades,
even basic properties and trade-offs pertaining to retrieval-based models (cf. Figure 1b),
despite their aforementioned remarkable successes, remain highly under-explored. Most of
the existing efforts on retrieval-based machine learning models solely focus on developing end-
to-end domain-specific models, without identifying the key dataset properties or structures
that are critical in realizing performance gains by such models. Furthermore, at first glance,
due to the highly dependent nature of an input and the associated retrieved set, direct
application of existing statistical learning techniques does not appear as straightforward.
This prompts the natural question: What should be the right theoretical framework that can
help rigorously showcase the value of the retrieved set in ensuring superior performance of
modern retrieval-based models?
In this paper, we take the first step towards answering this question, while focusing on
the classification setting (Sec. 2.1). We begin with the hypothesis that the model might
be using the retrieved set to do local learning implicitly and then adapt its predictions to
the neighborhood of the test point. This idea is inspired from Bottou and Vapnik [1992].
Such local learning is potentially beneficial in cases where the underlying task has a local
structure, where a much simpler function class suffices to explain the data in a given local
neighborhood but overall the data can be complex (formally defined in Sec. 2.2). For instance
2
looking at a few answers at Stackoverflow even if not for same problem may help us solve our
issue much faster than understanding the whole system. We try to formally show this effect.
We begin by analyzing an explicit local learning algorithm: For each test input, (1) we
retrieve a few training examples located in the vicinity of the test input, (2) train a local
model by performing empirical risk minimization (ERM) with only these retrieved examples
local ERM ; and (3) apply the resulting local model to make prediction on the test input.
For the aforementioned retrieval-based local ERM, we derive finite sample generalization
bounds that highlight a trade-off between the complexity of the underlying function class
and size of neighborhood where local structure of the data distribution holds in Sec. 3.
Under this assumption of local regularity, we show that by using a much simpler function
class for the local model, we can achieve a similar loss/error to that of a complex global
model (Thm. 3.4). Thus, we show that breaking down the underlying learning task into local
sub-tasks enables the model to employ a low complexity parametric component to ensure
good overall accuracy. Note that the local ERM setup is reminiscent of semiparametric
polynomial regression [Fan and Gijbels, 2018] in statistics, which is a special case of our setup.
However, the semiparametric polynomial regression have been only analyzed asymptotically
under mean squared error loss [Ruppert and Wand, 1994] and its treatment under a more
general loss is unexplored.
We acknowledge that such local learning cannot be the complete picture behind the effec-
tiveness of retrieval-based models. As noted in Zakai and Ritov [2008], there always exists a
model with global component that is more “preferable” to a local-only model. In Sec. 3.2,
we extend local ERM to a two-stage setup: First learn a global representation using entire
dateset, and then utilize the representation at the test time while solving the local ERM
as previously defined. This enables the local learning to benefit from good quality global
representations, especially in sparse data regions.
Finally, we move beyond explicit local learning to a setting that resembles more closely
the empirically successful systems such as REINA, WebGPT, and AlphaFold: A model
that directly learns to predict from the input instance and associated retrieved similar
examples end-to-end. Towards this, we take a preliminary step in Sec. 4 by studying a novel
formulation of classification over an extended feature space (to account for the retrieved
examples) by using kernel methods [Deshmukh et al., 2019].
To summarize, our main contributions include: 1) Setting up a formal framework for
classification under local regularity; 2) Finite sample analysis of explicit local learning
framework; 3) Extending the analysis to incorporate a globally learnt model; and 4) Providing
the first rigorous treatment of an end-to-end retrieval-based models to understand its
generalization by using kernel-based learning.
2 Problem setup
We first provide a brief background on (multiclass) classification along with necessary
notations. Subsequently, we discuss the problem setup considered in this paper, which deals
with designing retrieval-based classification models for the data distributions with local
regularity.
3
2.1 Multiclass classification
In this work, we restrict ourselves to (multi-class) classification setting, with access to
n
train-
ing examples
S
=
{
(
xi, yi
)
}i[n]X×Y,
sampled i.i.d. from the data distribution D:= D
X,Y
.
Given
S
, one is interested in learning a classifier
h
:
XY
that minimizes miss-classification
error. It is common to define a classifier via a scorer
f
:
x7→ f1
(
x
)
, . . . , f|Y|
(
x
)
R|Y|
that assigns a score to each class in
Y
for an instance
x
. For a scorer
f
, the corresponding
classifier takes the form:
hf
(
x
) =
arg maxyYfy
(
x
)
.
Furthermore, we define the margin of
f
at a given label yYas
γf(x, y) = fy(x)maxy06=yfy0(x).(1)
Let
PD
(
A
) :=
E(X,Y )D1{A}
for any random variable
A
. Given
S
and a set of scorers
F⊆ {f
:
XR|Y|}
, learning a model implies finding a scorer in
F
that minimizes miss-
classification error:
f= arg minfFPD(hf(X)6=Y).(2)
One typically employs a surrogate loss [Bartlett et al., 2006]
`
for the miss-classification loss
{hf(X)6=Y}and aims minimize the associated risk:
R`(f) = E(X,Y )D`f(X), Y .(3)
Since the underlying data distribution Dis only accessible via examples in
S
, one learns a
good scorer by minimizing the (global) empirical risk over a large function class
Fglobal
as
follows:
ˆ
f= arg minfFglobal b
R`(f) := 1
nXi[n]`f(xi), yi.(4)
2.2 Data distributions with local regularity
In this work, we assume that the underlying data distribution Dfollows a local-regularity
structure, where a much simpler (parametric) function class suffices to explain the data in each
local neighborhood. Formally, for
xX
and
r >
0, we define
Bx,r
:=
{x0X
: d(
x, x0
)
r},
an
r
-radius ball around
x
, w.r.t. a metric d :
X×XR
. Let D
x,r
be the data distribution
restricted to Bx,r, i.e.,
Dx,r(A) = D(A)/D(Bx,r ×Y)ABx,r ×Y.(5)
Now, the local regularity condition of the data distribution ensures that, for each
xX
,
there exists a low-complexity function class
Fx
, with
|Fx||Fglobal|
, that approximates the
Bayes optimal (w.r.t.
Fglobal
) for the local classification problem defined by D
x,r
. That is, for
a given εX>0, we have1
minfFxEDx,r [`(f(X), Y )] minfFglobal EDx,r [`(f(X), Y )] + εX,xX.(6)
As an example, if
Fglobal
is linear in
Rd
(possibly dense) with bounded norm
τ
, then
Fx
can
be a simpler function class such as linear in
Rd
with sparsity
kd
and with bounded norm
τxτ.
1
As stated, we require the local-regularity condition to hold for each
x
. This can be relaxed to hold with
high probability with increased complexity of exposition.
4
2.3 Retrieval-based classification model
This work focuses on retrieval-based methods that can leverage the aforementioned local
regularity structure of the data distribution. In particular, we focus on two such approaches:
Local empirical risk minimization.
Given a (test) instance
x
, the local empirical risk
minimization (ERM) approach first retrieves a neighboring set
Rx
=
{
(
x0
j, y0
j
)
} ⊆ S
. Subse-
quently, it identifies a (local) scorer
ˆ
fx
from a ‘simple’ function class
Floc ⊂ {f
:
XR|Y|}
as follows:
ˆ
fx= arg minfFloc ˆ
Rx
`(f); ˆ
Rx
`(f) := 1
|Rx|X(x0,y0)Rx`f(x0), y0.(7)
Here,
Rx
corresponds to the samples in
S
that belong to
Bx,r
; hence, it follows the distribution
Dx,r. We assume there exists N(r, δ) such that for any r0, and δ > 0,
P(X,Y )D|RX|< N(r, δ)δ, and P(X,Y )D|RX|= 0= 0.(8)
Note that the local ERM approach requires solving a local learning task for each test instance.
Such a local learning algorithms was introduced in Bottou and Vapnik [1992]. Another point
worth mentioning here is that
(7)
employs the same function class
Floc
for each
x
, whereas
the local regularity assumption (cf.
(6)
) allows for an instance dependent function class
Fx
.
We consider
Floc
that approximates
xXFx
closely. In particular, we assume that, for some
εloc >0, we have
minfFloc EDx,r [`(f(X), Y )] minfFxEDx,r [`(f(X), Y )] + εloc,xX.(9)
Continuing with the example following
(6)
, where
Fx
is linear with sparsity
kd
and
bounded norm
τx
, one can take
Floc
to be linear with the same sparsity
k
and bounded
norm τ0<supxXτx.
Classification with extended feature space.
Another approach to leverage the retrieved
neighboring labeled instances during classification is to directly learn a scorer that maps
x×RxX×
(
X×Y
)
?
to per-class scores. One can learn such a scorer over extended feature
space X×(X×Y)?as follows:
ˆ
fex = arg minfFex ˆ
Rex
`(f); ˆ
Rex
`(f) := 1
nXi[n]`fxi,Rxi, yi),(10)
where
Fex f
:
X×
(
X×Y
)
?R|Y|
denotes a function class over the extended space.
Unlike local ERM approach,
(10)
learns a common function over extended space and does
not require solving an optimization problem for each test instance. That said, since
Fex
operates on the extended feature space, it can be significantly complex and computationally
expensive to employ as compared to Floc.
Our goal is to develop a theoretical understanding of the generalization behavior of these
two retrieval-based methods for classification with locally regular data distributions. We
present our theoretical treatment of local ERM and classification with extended feature
space in Sec. 3 and 4, respectively.
5
摘要:

GeneralizationPropertiesofRetrieval-basedModelsSoumyaBasu*,AnkitSinghRawat*,andManzilZaheer*GoogleLLC,USAfbasusoumya,ankitsrawat,manzilzaheerg@google.comAbstractManymodernhigh-performingmachinelearningmodelssuchasGPT-3primarilyrelyonscalingupmodels,e.g.,transformernetworks.Simultaneously,aparallelli...

展开>> 收起<<
Generalization Properties of Retrieval-based Models.pdf

共38页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:38 页 大小:854.15KB 格式:PDF 时间:2025-04-22

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 38
客服
关注