Vision Transformers provably learn spatial structure Samy Jelassi Princeton University

2025-05-06 0 0 3.36MB 58 页 10玖币
侵权投诉
Vision Transformers provably learn spatial structure
Samy Jelassi
Princeton University
sjelassi@princeton.edu
Michael E. Sander
Ecole Normale Supérieure
michael.sander@ens.fr
Yuanzhi Li
Carnegie Mellon University
yuanzhil@andrew.cmu.edu
Abstract
Vision Transformers (ViTs) have achieved comparable or superior performance
than Convolutional Neural Networks (CNNs) in computer vision. This empirical
breakthrough is even more remarkable since, in contrast to CNNs, ViTs do not
embed any visual inductive bias of spatial locality. Yet, recent works have shown
that while minimizing their training loss, ViTs specifically learn spatially localized
patterns. This raises a central question: how do ViTs learn these patterns by
solely minimizing their training loss using gradient-based methods from random
initialization? In this paper, we provide some theoretical justification of this
phenomenon. We propose a spatially structured dataset and a simplified ViT model.
In this model, the attention matrix solely depends on the positional encodings.
We call this mechanism the positional attention mechanism. On the theoretical
side, we consider a binary classification task and show that while the learning
problem admits multiple solutions that generalize, our model implicitly learns the
spatial structure of the dataset while generalizing: we call this phenomenon patch
association. We prove that patch association helps to sample-efciently transfer to
downstream datasets that share the same structure as the pre-training one but differ
in the features. Lastly, we empirically verify that a ViT with positional attention
performs similarly to the original one on CIFAR-10/100, SVHN and ImageNet.
1 Introduction
Transformers are deep learning models built on self-attention [Vaswani et al.,2017], and in the past
several years they have increasingly formed the backbone for state-of-the-art models in domains
ranging from Natural Language Processing (NLP) [Vaswani et al.,2017,Devlin et al.,2018] to
computer vision [Dosovitskiy et al.,2020], reinforcement learning [Chen et al.,2021a,Janner
et al.,2021], program synthesis [Austin et al.,2021] and symbolic tasks [Lample and Charton,
2019]. Beyond their remarkable performance, several works reported the ability of transformers to
simultaneously minimize their training loss and learn inductive biases tailored to specific datasets
e.g. in computer vision [Raghu et al.,2021], in NLP [Brown et al.,2020,Warstadt and Bowman,
2020] or in mathematical reasoning [Wu et al.,2021]. In this paper, we focus on computer vision
where convolutions are considered to be an adequate and biologically plausible inductive bias since
they capture local spatial information [Fukushima,2003] by imposing a sparse local connectivity
pattern. This seems intuitively reasonable: nearby pixels encode the presence of small scale features,
whose patterns in turn determine more abstract features at longer and longer length scales. Several
seminal works [Cordonnier et al.,2019,Dosovitskiy et al.,2020,Raghu et al.,2021]empirically
show that although randomly initialized, the positional encodings in Vision transformers (ViTs)
Dosovitskiy et al. [2020] actually learn this local connectivity: closer patches have more similar
positional encodings, as shown in Figure 1a. A priori, learning such spatial structure is surprising.
Indeed, in contrast to convolutional neural networks (CNNs), ViTs are not built with the inductive
bias of local connectivity and weight sharing. They start by replacing an image by a collection of
D
patches
pX1,...,XDq P RdˆD
, each of dimension
d
. While each
Xi
represents (an embedding
of) a spatially localized portion of the original image, the relative positions of the patches Xiin the
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.09221v1 [cs.CV] 13 Oct 2022
(1)
(2)
0.0
0.2
0.4
0.6
0.8
1.0
(2)
(a)
(1) (2)
X1X2X3X4
X5X6X7X8
X9X10 X11 X12
X13 X14 X15 X16
X1X2X3X4
X5X6X7X8
X9X10 X11 X12
X13 X14 X15 X16
(b)
Figure 1: (a) Visualization of the positional encodings similarities
PJP“ pxpi,pjyqpi,jqPrDs2
at initialization
(1) and after training on Imagenet (2) using a "ViT-small-patch32-224" [Dosovitskiy et al.,2020]. We normalise
the values
PJP
between
´1
and
1
and apply a threshold of
0.55
. In contrast with the initial arrays that are
random, the final ones show local connectivity patterns: nearby patches have similar positional encodings. (b)
Partition of the patches into sets
S`
as in Definition 2.1. Squares in the same color belong to the same set
S`
. We
refer to (1) as a "spatially localized set" since all the elements in a
S`
are spatially contiguous. This is the type of
sets appearing in Figure 1a at the end of training. Definition 2.1 also covers sets with non-contiguous elements
as (2).
image are disregarded. Instead, relative spatial information is supplied through image-independent
positional encodings
P“ pp1,...,pDq P RdˆD
. Unlike CNNs, each layer of a ViT then learns, via
trainable self-attention, a non-local set of filters that non-linearly depend on both the values of all
patches Xjand their positional encodings pj.
Contributions.
The empirical observation of Figure 1a sets a central question: from a theoretical
perspective, how do ViTs manage to learn these local connectivity patterns by simply minimizing
their training loss using gradient descent from random initialization? While it is known that attention
can express local operations as convolution [Cordonnier et al.,2019], it remains unclear how ViTs
learn it. In this paper, we present a simple spatially-structured classification dataset for which it is
sufficient (but not necessary) to learn the structure in order to generalize. We also present a simplified
ViT model which we prove implicitly learns sparse spatial connectivity patterns when it minimizes
its training loss via gradient descent (GD). We name this implicit bias patch association (defined in
Definition 2.2). We prove that our ViT model leverages this bias to generalize. More precisely, we
make the following contributions:
In Section 2, we formally define the concept of performing patch association, which refer to the
ability of learning spatial connectivity patterns on a dataset.
In Section 3, we introduce a structured classification dataset and a simplified ViT model. This
model is simplified in the sense that its attention matrix only depends on the positional encodings.
We then present the learning problems we are interested in: empirical risk (realistic setting) and
population risk (idealized setting) minimization for binary classification.
In Section 4, we prove that a one-layer single-head ViT model trained with gradient descent on our
synthetic dataset performs patch association and generalizes, in the idealized (Theorem 4.1) and
realistic (Theorem 4.2) settings. We present a detailed proof, based on invariance and symmetries
of coefficients in the attention matrix throughout the learning process.
In Section 5, we show (Theorem 5.1) that after pre-training in our synthetic dataset, our model can
be sample-efficiently fine-tuned to transfer to a downstream dataset that shares the same structure
as the source dataset (and may have different features).
On the experimental side, we validate in Section 6 that ViTs learn spatial structure in images from
the CIFAR-100 dataset, even when the pixels of the images are permuted. This result validates that,
in contrast to CNNs, ViTs learn a more general form of spatial structure that is not limited to local
patterns (Figure 5). We finally show that our ViT model –where the attention matrix only depends
on the positional encodings– is competitive with the vanilla ViT on the ImageNet, CIFAR-10/100
and SVHNs datasets (Figure 6 and Figure 7).
Notation.
We use lower case letters for scalars, lower case bold for vectors and upper case bold
for matrices. Given an integer
D
, we define
rDs“t1, . . . , Du.
Any statement made "with high
probability" holds with probability at least
1´1{polypdq.
Given a vector
aPRd
and
kďd
, we
define
Topktajud
j1“ tai1, . . . , aiku
where
ai1, . . . , aik
are the
k
-largest elements. For a function
F
that implicitly depend on parameters
A
and
v
, we often write
FA,v
to highlight its parameters.
We use the asymptotic complexity notations when defining the different constants.
2
Related work
CNNs and ViTs.
Many computer vision architectures can be considered as a form of hybridization
between Transformers and CNNs. For example, DeTR [Carion et al.,2020] use a CNN to generate
features that are fed to a Transformer. [d’Ascoli et al.,2021] show that self-attention can be initialized
or regularized to behave like a convolution and [Dai et al.,2021,Guo et al.,2021] add convolution
operations to Transformers. Conversely, [Bello et al.,2019,Ramachandran et al.,2019,Bello,2021]
introduce self-attention or attention-like operations to supplement or replace convolution in ResNet-
like models. In contrast, our paper does not consider any form of hybridization with CNN, but rather
a simplification of the original ViT to explain how ViTs learn spatially structured patterns using GD.
Empirical understanding of ViTs.
A long line of work consists in analyzing the properties of
ViTs, such as robustness [Bhojanapalli et al.,2021,Paul and Chen,2021,Naseer et al.,2021] or the
effect of self-supervision [Caron et al.,2021,Chen et al.,2021b]. Closer to our work, some papers
investigate why ViTs perform so well. Raghu et al. [2021] compare the representations of ViTs
and CNNs and Melas-Kyriazi [2021], Trockman and Kolter [2022] argue that the patch embeddings
could explain the performance of ViTs. We empirically show in Section 6 that applying the attention
matrices to the positional encodings – which contains the structure of the dataset – approximately
recovers the baselines. Hence, our work rather suggests that the structural learning performed by the
attention matrices may explain the success of ViTs.
Theory for attention models.
Early theoretical works have focused on the expressivity of attention.
[Vuckovic et al.,2020,Edelman et al.,2021] addressed this question in the context of self-attention
blocks and [Dehghani et al.,2018,Wei et al.,2021,Hron et al.,2020] for Transformers. On the
optimization side, [Zhang et al.,2020] investigate the role of adaptive methods in attention models and
[Snell et al.,2021] analyze the dynamics of a single-head attention head to approximate the learning
of a Seq2Seq architecture. In our work, we also consider a single-head ViT trained with gradient
descent and exhibit a setting where it provably learns convolution-like patterns and generalizes.
Algorithmic regularization.
The question we address concerns algorithmic regularization which
characterizes the generalization of an optimization algorithm when multiple global solutions exist
in over-parametrized models. This regularization arises in deep learning mainly due to the non-
convexity of the objective function. Indeed, this latter potentially creates multiple global minima
scattered in the space that vastly differ in terms of generalization. Algorithmic regularization appears
in binary classification [Soudry et al.,2018,Lyu and Li,2019,Chizat and Bach,2020], matrix
factorization [Gunasekar et al.,2018,Arora et al.,2019], convolutional neural networks [Gunasekar
et al.,2018,Jagadeesan et al.,2022], generative adversarial networks [Allen-Zhu and Li,2021],
contrastive learning [Wen and Li,2021] and mixture of experts [Chen et al.,2022]. Algorithmic
regularization is induced by and depends on many factors such as learning rate and batch size [Goyal
et al.,2017,Hoffer et al.,2017,Keskar et al.,2016,Smith et al.,2018,Li et al.,2019], initialization
Allen-Zhu and Li [2020], momentum [Jelassi and Li,2022], adaptive step-size [Kingma and Ba,
2014,Neyshabur et al.,2015,Daniely,2017,Wilson et al.,2017,Zou et al.,2021,Jelassi et al.,2022],
batch normalization [Arora et al.,2018,Hoffer et al.,2019,Ioffe and Szegedy,2015] and dropout
[Srivastava et al.,2014,Wei et al.,2020]. However, all these works consider the case of feed-forward
neural networks which does not apply to ViTs.
2 Defining patch association
The goal of this section is to formalize the way ViTs learn sparse spatial connectivity patterns. We
thus introduce the concept of performing patch association for a spatially structured dataset.
Definition 2.1
(Data distribution with spatial structure)
.
Let
D
be a distribution over
RdˆDˆ1,1u
where each patch
X“ pX1,...,XDq P RdˆD
has label
yP t´1,1u
. We say that
D
is spatially
structured if
there exists a partition of
rDs
into
L
disjoint subsets i.e.
rDs “ ŤL
`1S`
with
S`ĹD
and
|S`| “ C.
3
there exists a labeling function f˚satisfying Pryf˚pXq ą 0s “ 1´d´ωp1qand,
f˚pXq:ÿ
`PrLs
φp`Xi˘iPS`q,where φ:RdˆCÑRis an arbitrary function.(1)
0 500 1000
Number of gradient descent steps
5%
10%
Test error
Test error
Figure 2: Left: Test error of the ViT on the
convolution structured dataset. Upper Right:
Grid displaying the input patches. Yellow
squares represent spatially localized sets
S`
.
Those sets are taken into account when com-
puting the convolutional function
f˚
. Lower Right:
Learnt PJPlooks random compared to upper one.
Examples.
A particular case for the sets
S`
s is
the one of spatially localized sets as in Figure 1b-
(1). In this case, we have
D16
,
C4
and
S1“ t1,2,5,6u,S2“ t3,4,7,8u,S3
t9,10,13,14u,S4“ t11,12,15,16u.
We empha-
size that Definition 2.1 is not limited to spatially lo-
calized sets and also covers non-contiguous sets as
Figure 1b-(2).
Labelling function
Definition 2.1 states that there
exists a labelling function that preserves the under-
lying structure by applying the same function
φ
to
each
S`
as in
(1)
. For instance, when the sets
S`
s
are spatially localized, f˚can be a one-hidden layer
convolutional network. In this paper, we are interested in patch association which refers to the ability
of an algorithm to identify the sets S`s, and is formally defined as follow.
Definition 2.2
(Patch association for ViTs)
.
Let
D
be as in Definition 2.1. Let
M:RdˆDÑ
1,1u
be a transformer and
PpMq
its positional encodings matrix. We say that
M
performs patch
association on Dif for all `P rLsand iPS`, we have TopCtxppMq
i,ppMq
jyuD
j1S`.
Definition 2.2 states that patch association is learned when for a given
iPS`,
its positional encoding
mainly attends those of
j
such that
i, j PS`
. In this way, the transformer groups the
Xi
according
to
S`
just like the true labeling function. Definition 2.2 formally describes the empirical findings in
Figure 1a-(2), where nearby patches have similar positional encodings. A natural question is then:
would ViTs really learn those
S`
after training to match the labeling function
f˚
? Without further
assumptions on the data distribution, we next show that the answer is no.
ViTs do not always learn patch association under Assumption 1.
We give a negative answer
through the following synthetic experiment. Consider the case where all the patches
Xj
are i.i.d.
standard Gaussian and
f˚
is a one-hidden layer CNN with cubic activation. The label
y
of any
X
is
then given by
ysignpf˚pXqq
. As shown in Figure 2, one-layer ViT reaches small test error on the
binary classification task. However,
PJP
does not match the convolution pattern encoded in
f˚
.
This is not surprising, since the data distribution
D
is Gaussian, and thus lacks spatial structure. Thus,
in order to prove that ViTs learn patch association, we need additional assumptions on
D
, which we
discuss in the next section.
3 Setting to learn patch association
In this section, we introduce our theoretical setting to analyze how ViTs learn patch association. We
first define our binary classification dataset and finally present the ViT model we use to classify it.
Assumption 1
(Data distribution with specific spatial structure)
.
Let
D
be a distribution as in
Definition 2.1 and
w˚PRd
be an underlying feature. We suppose that each data-point
X
is defined
as follow
Uniformly sample an index
`pXq
from
rLs
and for
jPS`pXq
,
Xjyw˚`ξj
, where
yw˚
is the
informative feature and ξj
i.i.d.
Np0, σ2pID´w˚w˚ Jqq (signal set).
For
`P rLszt`pXqu
and
jPS`
,
Xjδjw˚`ξj
, where
δj1
with probability
q{2
,
´1
with
same probability and 0otherwise, and ξj
i.i.d.
Np0, σ2pID´w˚w˚ Jqq (random sets).
4
0 0 -1
0
0
0
+1 +10-1
0+1 +1 +1
0
0
-1
-1+1
0 0
0
-1
-1 00-100
0
0 0 00
+1
-1
Figure 3: Visualization of a data-point
X
in
D
when the
S`
s are spatially lo-
calized. Each square depicts a patch
Xj
and squares of the same color be-
long to the same set
S`.
"0" indicates
that the patch does not have a feature,
"1" stands for feature
1¨w˚
and "-1" for
feature
´1¨w˚
. The large red square
depicts the signal set
`pXq.
Although
there are more "-1"’s than "+1"’s, the
label of
X
is
`1
since there are only
"+1"’s inside the signal set.
To keep the analysis simple, the noisy patches
are sampled from the orthogonal complement of
w˚.
Note that
D
admits the labeling function
f˚pXq “ ř`PrLsThreshold0.9CpřiPS`xw˚,Xiyq
, where
ThresholdCpzq “ zif |z| ą Cand 0otherwise.
We sketch a data-point of
D
in Figure 3. Our dataset can
be viewed as an extreme simplification of real-world image
datasets where there is a set of adjacent patches that contain a
useful feature (e.g. the nose of a dog) and many patches that
have uninformative or spurious features e.g. the background of
the image. We make the following assumption on the param-
eters of the data distribution.
Assumption 2.
We suppose that
dpolypDq
,
C
polylogpdq
,
qpolypCq{D
,
}w˚}21
and
σ21{d
. This
implies C!Dand q!1.
Assumption 2 may be justified by considering a "ViT-base-
patch16-224" model Dosovitskiy et al. [2020] on ImageNet.
In this case,
d384
,
D196
.
σ
is set to have
}ξj}2«
}w˚}2
.
q
is chosen so that there are more spurious features than
informative ones (low signal-to-noise regime) which makes the
data non-linearly separable. Our dataset is non-trivial to learn since generalized linear networks fail
to generalize, as shown in the next theorem (see Appendix Jfor a proof).
Theorem 3.1.
Let
D
be as in Assumption 1. Let
gpXq “ φ´řD
j1xwj,Xjy¯
be a generalized
linear model. Then, gdoes not fit the labeling function i.e. Prf˚pXqgpXq ď 0s ě 1{8.
Intuitively,
g
fails to generalize because it does not have any knowledge on the underlying partition
and the number of random sets is much higher than those with signal. Thus, a model must have a
minimal knowledge about the
S`
s in order to generalize. In addition, the following Theorem 3.2
states the existence of a transformer that generalizes without learning spatial structure (see Appendix
Jfor a proof), thus showing that the learning process has a priori no straightforward reason to lead to
patch association.
Theorem 3.2.
Let
D
be defined as in Assumption 1. There exists a (one-layer) transformer
M
so that
Prf˚pXqMpXq ď 0s “ d´ωp1qbut for all `P rLs,iPS`,TopCtxppMq
i,ppMq
jyuD
j1XS`“ H.
Simplified ViT model.
We now define our simplified ViT model for which we show in Section 4
that it implicitly learns patch association via minimizing its training objective. We first remind the
self-attention mechanism that is ubiquitously used in transformers.
Definition 3.1
(Self-attention [Bahdanau et al.,2014,Vaswani et al.,2017])
.
The attention mechanism
[Bahdanau et al.,2014,Vaswani et al.,2017] in the single-head case is defined as follow. Let
XPRdˆD
a data point and
PPRdˆD
its positional encoding. The self-attention mechanism
computes
1. the sum of patches and positional encodings i.e. X
X
XX`P.
2. the attention matrix AQKJwhere QX
X
XJWQ,KX
X
XJWK,WQ,WKPRdˆd.
3. the score matrix SPRDˆDwith coefficients Si,j exppAi,j {?dq{řD
r1exppAi,r{?dq.
4. the matrix VX
X
XJWV, where WVPRdˆd.
It finally outputs SAppX;Pqq “ SV PRdˆD.
In this paper, our ViT model relies on a different attention mechanism –the "positional attention"–
that we define as follows.
Definition 3.2
(Positional attention)
.
Let
XPRdˆD
and
PPRdˆD
the positional encoding. The
positional attention mechanism takes as input the pair pX;Pqand computes:
1. the attention matrix AQKJwhere QPJWQ,KPJWKand WQ,WKPRdˆd.
2. the score matrix SPRDˆDwith coefficients Si,j exppAi,j {?dq{řD
r1exppAi,r{?dq.
5
摘要:

VisionTransformersprovablylearnspatialstructureSamyJelassiPrincetonUniversitysjelassi@princeton.eduMichaelE.SanderEcoleNormaleSupérieuremichael.sander@ens.frYuanzhiLiCarnegieMellonUniversityyuanzhil@andrew.cmu.eduAbstractVisionTransformers(ViTs)haveachievedcomparableorsuperiorperformancethanConvolut...

展开>> 收起<<
Vision Transformers provably learn spatial structure Samy Jelassi Princeton University.pdf

共58页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:58 页 大小:3.36MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 58
客服
关注