Fine-Tuning Pre-trained Transformers into Decaying Fast Weights

2025-04-22 1 0 346.07KB 8 页 10玖币
侵权投诉
Fine-Tuning Pre-trained Transformers into Decaying Fast Weights
Huanru Henry Mao
Jenni
henry@jenni.ai
Abstract
Autoregressive Transformers are strong lan-
guage models but incur O(T)complexity
during per-token generation due to the self-
attention mechanism. Recent work proposes
kernel-based methods to approximate causal
self-attention by replacing it with recurrent for-
mulations with various update rules and fea-
ture maps to achieve O(1) time and memory
complexity. We explore these approaches and
find that they are unnecessarily complex, and
propose a simple alternative - decaying fast
weights - that runs fast on GPU, outperforms
prior methods, and retains 99% of attention’s
performance for GPT-2. We also show com-
petitive performance on WikiText-103 against
more complex attention substitutes.
1 Introduction
Autoregressive Transformers (Vaswani et al.,2017)
have demonstrated strong performance on text gen-
eration (Brown et al.,2020). The success of self-
attention in Transformers over recurrent models
(Hochreiter and Schmidhuber,1997) can be at-
tributed to its parallelizability (Hooker,2021) and
its effective gradient propagation over many time
steps (Ke et al.,2018). However, self-attention
has a high computation and memory cost. During
inference sampling, it consumes
O(T)
time and
memory and grows linearly per token generated.
These drawbacks motivated recent work to con-
vert or fine-tune attention into recurrent formula-
tions with
O(1)
memory and time complexity for
auto-regressive generation.
Kernel-based meth-
ods
for self-attention (Tay et al.,2021) learn ap-
proximations of the exponential similarity function
using
m
-dimensional feature maps to reformulate
attention as a recurrent computation. They replace
attention with “unlimited capacity” with fixed-
capacity fast weights (Schmidhuber,1992;Peng
et al.,2022), where the memory-accuracy trade-
off (Kerg et al.,2020) is controlled by
m
. Several
Figure 1: Plot of memory usage and execution time of
our decay rule,delta rule and attention when gen-
erating the next token at various sequence lengths on
Quadro RTX 4000. Decay and delta rule use approxi-
mately the same peak memory (overlapped in plot).
works explored different feature maps and recur-
rent formulations (i.e., update rules). Katharopou-
los et al. (2020) propose feature maps to maintain
positive outputs, while Choromanski et al. (2021);
Peng et al. (2021) carefully ensure their random
feature maps are unbiased estimates of the softmax
attention kernel. Schlag et al. (2021a); Peng et al.
(2021) propose more sophisticated update rules to
forget information in the recurrent state to improve
performance. Recently, Kasai et al. (2021) showed
that pre-trained Transformers can be fine-tuned into
a recurrent formulation using learned ReLU feature
maps with minor degradations. While promising,
it is unclear which update rules or feature maps are
critical for successful fine-tuning.
In this work, we investigate various update rule
configurations to fine-tune pre-trained Transform-
ers into RNNs for fast inference. We find that
prior
proposals contain unnecessary operations
, lead-
ing us to propose a simple element-wise
decay
update rule
with
no feature map
. We fine-tune
GPT-2 (Radford et al.,2019) into our recurrent
arXiv:2210.04243v1 [cs.LG] 9 Oct 2022
formulation to demonstrate that our rule outper-
forms prior methods and
recovers 99% of self-
attention’s performance
. We also show compet-
itive performance on WikiText-103 (Merity et al.,
2017) compared to more complex attention alterna-
tives. Our results support the idea (Merity,2019;
Zhai et al.,2021) that it is unnecessary for atten-
tion alternatives to maintain a close analogy to
self-attention, and it is more important to focus
on designing an expressive update rule.
2 Background and Related Work
2.1 Kernel-based Self-Attention
Kernel-based approximations (Katharopoulos et al.,
2020;Choromanski et al.,2021;Kasai et al.,2021)
to self-attention reorders computation such that a
typical
O(T d)
(per token) memory and time com-
plexity attention becomes
O(dm)
for
T
time steps,
dimension
d
and feature size
m
. Given input to
the attention layer
xtRd×1
and learned weight
matrices
W
, the causal self-attention (Vaswani
et al.,2017) for query
qt=WqxtRd×1
, key
kt=WkxtRd×1
and value
vt=WvxtRd×1
is defined as:
yt=
t
X
j
sim(kj, qt)
Pt
isim(ki, qt)vj
sim(x, y) = exp(x|y/d)
(1)
Kernel-based methods propose an approximation
to the exponential similarity function
g
sim(x, y) =
φ(x)|φ(y)
via a
m
-dimensional kernel feature map
φ:RdRm
. This approximation enables us to
rewrite Eq. 1as
yt=Pt
jvjφ(kj)|φ(qt)
Pt
iφ(ki)|φ(qt)(2)
due to the associative property of matrix multipli-
cation. This lends itself to a recurrent formulation
with state
StRd×m
and normalizer
ztR1×m
that can be computed at every time step:
St=St1+vtφ(kt)|, zt=zt1+φ(kt)|(3)
The state recurrence resembles fast weight additive
outer products (Schmidhuber,1992). Finally, the
output is computed by normalizing against
ztφ(qt)
,
which we refer to as attention normalization:
yt=Stφ(qt)
ztφ(qt)(4)
2.2 Update Rules
Because Eq. 3is an
additive update rule
, it is
unable to forget past memories. This can over-
whelm the fixed state capacity and lead to poor
performance. Peng et al. (2021) proposes a
gated
rule
similar to gated RNNs (Chung et al.,2014) to
decay old information and induce a recency bias:
St=gtSt1+ (1 gt)vtφ(kt)|
zt=gtzt1+ (1 gt)φ(kt)|(5)
where
gt=σ(Wgxt)R
is a learned scalar gate
that determines how much new information over-
writes existing information. They also analogously
modify the attention normalizer to incorporate
gt
.
This rule is problematic as it overwrites all state
elements equally without fine-grained control.
Schlag et al. (2021a) proposes improving Eq. 5
using a Fast Weight Programmer (Schmidhuber,
1992)
delta rule
to forget values associated with
the current write key by removing the associated
value before adding the new value:
St=St1gtSt1φ0(kt)φ0(kt)|+gtvtφ0(kt)|
(6)
where
gt
is a scalar that defines the extent to
which the new value replaces the old value. To
stabilize training, Schlag et al. (2021a) applies
sum normalization
to feature maps to enforce
the outputs to have components that sum to 1
(i.e.,
φ0(kt) = φ(kt)/Pd
jφ(kt)j
). This normaliza-
tion is applied to both key and query, and the out-
put is computed as
yt=Stφ0(qt)/ztφ0(qt)
.Schlag
et al. (2021a) showed that dropping
attention nor-
malization
(i.e.,
yt=Stφ0(qt)
) works just as well
and is redundant when combined with sum normal-
ization.
2.3 Kernel Feature Map
One motivation for kernel-based methods is to
closely approximate self-attention. Peng et al.
(2021) proposes Random Feature Attention (RFA),
which uses random feature maps to produce unbi-
ased estimates of the exponential
sim(x, y)
func-
tion. Choromanski et al. (2021) proposes FAVOR+,
a random feature map with lower variance. In-
stead of rigorously approximating self-attention,
several proposals aim simply to maintain positive
outputs motivated by the positivity of attention
weights. Katharopoulos et al. (2020) proposes
the
φ(x) = ELU(x)+1
(Clevert et al.,2016)
feature map. Schlag et al. (2021a) proposes De-
terministic Parameter-Free Projection (DPFP), a
摘要:

Fine-TuningPre-trainedTransformersintoDecayingFastWeightsHuanruHenryMaoJennihenry@jenni.aiAbstractAutoregressiveTransformersarestronglan-guagemodelsbutincurO(T)complexityduringper-tokengenerationduetotheself-attentionmechanism.Recentworkproposeskernel-basedmethodstoapproximatecausalself-attentionbyr...

展开>> 收起<<
Fine-Tuning Pre-trained Transformers into Decaying Fast Weights.pdf

共8页,预览2页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:8 页 大小:346.07KB 格式:PDF 时间:2025-04-22

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 8
客服
关注