Efficiently Enhancing Zero-Shot Performance of Instruction Following Model via Retrieval of Soft Prompt Seonghyeon Ye Joel Jang Doyoung Kim Yongrae Jo Minjoon Seo

2025-05-03 0 0 2.28MB 22 页 10玖币
侵权投诉
Efficiently Enhancing Zero-Shot Performance of Instruction Following
Model via Retrieval of Soft Prompt
Seonghyeon Ye Joel Jang Doyoung Kim Yongrae Jo Minjoon Seo
KAIST
{seonghyeon.ye,joeljang,ikevin98,yongrae,minjoon}@kaist.ac.kr
Abstract
Enhancing the zero-shot performance of
instruction-following models requires heavy
computation, either by scaling the total number
of training datasets or the model size. In this
work, we explore how retrieval of soft prompts
obtained through prompt tuning can efficiently
assist hard prompts in zero-shot task general-
ization. Specifically, we train soft prompt em-
beddings for each prompt through prompt tun-
ing, store the samples of the training instances
mapped with the prompt embeddings, and re-
trieve the corresponding prompt embedding of
the training instance closest to the query in-
stance during inference. While only adding
0.007% additional parameters, retrieval of soft
prompt enhances the performance of T0 on un-
seen tasks by outperforming it on 10 out of 11
datasets as well as improving the mean accu-
racy of T0 on BIG-bench benchmark by 2.39%
points. Also, we report an interesting finding
that retrieving source embeddings trained on
similar answer choice formats is more impor-
tant than those on similar task types.1
1 Introduction
Training Large Language Models (LLMs) on huge
amounts of data has enabled LMs to perform down-
stream tasks without any fine-tuning with the aid of
natural prompts or concatenation of a few demon-
stration instances (Brown et al.,2020;Rae et al.,
2021;Kojima et al.,2022;Chowdhery et al.,2022).
Additionally, recent works have shown that adding
ainstruction tuning stage, an additional training
step that helps pretrained LMs understand prompts
and demonstrations results in a significant perfor-
mance boost on zero-shot task generalization even
for moderate-sized LMs (Min et al.,2021;Sanh
et al.,2021;Wei et al.,2021;Wang et al.,2022b;
Ye et al.,2022;Chung et al.,2022). This extra
1
Model checkpoints and code implementation are available
at github.com/seonghyeonye/RoSPr.
Dense
retriever
MIPS
Keys
Values
Source Prompt Library
Instance 1
Instance 2
Instance 3
...
Instance
Output
T0
Embedding
Etm
Evaluation Input
Evaluation Input
Figure 1: During zero-shot inference, ROSPRselects similar
training instances with the given input from the Source Prompt
Library and retrieves the prompt embeddings corresponding
to the selected training instances.
instruction-tuning stage involves explicit, multi-
task prompted learning on various tasks, enabling
LMs to quickly adapt to unseen tasks at inference.
To maximize the effect of instruction-tuning, two
approaches have been widely explored: (1) scal-
ing the number of training datasets, and (2) scaling
the model size (Wang et al.,2022b;Chung et al.,
2022). However, both approaches require heavy
computation, not applicable with an academic bud-
get. Specifically, the first approach requires updat-
ing the whole parameters of the model every time
a training dataset is added, showing limitations in
terms of scalability. On the other hand, the second
approach requires heavy memory requirements to
load and train a massive LLM.
To enhance the zero-shot performance of
instruction-following model efficiently, we intro-
duce Retrieval ofSoft Prompt (ROSPR), which
is easily scalable and requires minimal computa-
tion by only adding 0.007% parameters to the main
model during inference. As shown in Figure 1,
by training prompt embeddings (soft prompt) for
each given hard prompt through prompt tuning, we
construct a Source Prompt Library consisting of
samples of training instances mapped with their
corresponding prompt embeddings. Then, during
inference, by using a simple, off-the-shelf dense
retriever model, we search for training instances
similar to the given query instances and retrieve
their corresponding prompt embeddings. Because
the backbone LM is frozen, it allows the retrieved
arXiv:2210.03029v4 [cs.CL] 16 Oct 2023
embeddings to serve as adapters assisting hard
prompts. While ROSPRcan be applied to any LM,
in this work, we use T0 (Sanh et al.,2021) as our
initial backbone LM and perform prompt tuning on
the tasks used during the instruction-tuning stage.
While adding only 0.007% additional param-
eters, ROSPRoutperforms T0 on 10 out of 11
evaluation datasets and outperforms efficient fine-
tuning baselines without any target task fine-tuning.
ROSPRis also effective for challenging tasks such
as tasks from BIG-bench (Srivastava et al.,2022),
outperforming T0 by 2.39% mean accuracy. Fur-
thermore, we provide several interesting findings:
(1) Variants of ROSPRthat include interpolation of
multiple prompt embeddings and scoring method
that considers the answer choice distribution during
retrieval further increases the effect of ROSPR(2)
Also, we provide analysis of which factors attribute
to the performance of ROSPRand show that, sim-
ilarly to the role of demonstrations in in-context
learning (Min et al.,2022), heuristic features such
as answer choice format are more important than
the similarity of the source task.
2 Related Work
2.1 Task Generalization with
Instruction-Tuning
Prompts and demonstrations are essential for task
generalization since proper explanations are re-
quired for LMs to understand an unseen task (Ko-
jima et al.,2022;Wei et al.,2022;Lampinen et al.,
2022). Instruction-tuning, which is explicit multi-
task prompted training on various downstream
tasks, is a simple but effective way to achieve
this, resulting in improved zero-shot capabilities.
Zhong et al. (2021) first introduced the method
of instruction-tuning by converting various tasks
into a question-answering format and finetuning
the model on the aggregated dataset. Following
works (Mishra et al.,2022;Min et al.,2021;Sanh
et al.,2021;Wei et al.,2021;Wang et al.,2022b;
Xu et al.,2022;Ouyang et al.,2022;Ye et al.,2022;
Chung et al.,2022) extended this approach on a
larger scale and show that zero-shot task generaliza-
tion could be enhanced with more diverse prompts,
a larger number of training downstream tasks, and
a larger LM.
2.2 Source Task Retrieval
Retrieving a source task that is relevant to the target
task has shown to result in faster and better task
adaptation. For parameter-efficient fine-tuning, Vu
et al. (2022); Su et al. (2022) retrieve source prompt
embedding that is similar to the target prompt em-
bedding and obtain a better initialization point for
prompt tuning. Instead of utilizing a single prompt
embedding, recent works show a mixture of multi-
ple prompt embeddings to be effective (Asai et al.,
2022;Qin and Eisner,2021).
For instruction-tuning, Lin et al. (2022) retrieve
training instances that are similar to the query
through a dense retriever and fine-tune the model
using the retrieved examples. For in-context learn-
ing, Rubin et al. (2021); Liu et al. (2022b); Wang
et al. (2023) retrieve training data that could be
used for demonstrations. Wang et al. (2022c) show
the effect of retrieving prompt embeddings in a
continual learning setting. Although our proposed
method is related to these works, the novelty of
our work lies in applying source task retrieval in
the zero-shot setting and retrieving soft prompts
instead of training instances.
3 Method
In this section, we introduce Retrieval of Prompt
Tuning (ROSPR) for zero-shot task generalization.
A detailed overview is shown in Figure 2. We first
train source prompt embeddings of LM for each
hard prompt given a source task using prompt tun-
ing (Section 3.1). Then, we save training instance
samples along with their prompt embeddings in the
Source Prompt Library and use it to retrieve embed-
dings at inference to perform tasks in a zero-shot
manner (Section 3.2). We additionally introduce in-
terpolation of multiple source prompt embeddings
(Section 3.3) and variance-based ranking (Section
3.4) to increase robustness and accuracy.
3.1 Training Source Prompt Embeddings
Even though ROSPRmay be used to augment any
type of LM, we use T0 (Sanh et al.,2021) as the
backbone LM for this paper. For training of soft
prompts, we utilize the source tasks and prompts
used for the instruction-tuning phase of T0. While
T0 was trained in a multi-task learning manner, we
freeze the initial T0 parameters and train only soft
prompts (source prompt embeddings) for each hard
prompt of the source task.
Prompt Tuning Among various parameter-
efficient fine-tuning methods, we follow prompt
tuning proposed by Lester et al. (2021) because the
number of trainable parameters is extremely small
No
No
No
Dense
retriever
Eij
T0
MIPS
Source Prompt Training
Inference
hj(xik)
Can an answer to
"What are lanyards used for?"
also be used to answer
“What is a lanyard?"?
hj(yik)
Keys
Values
Source Prompt Library
...
hj(xik)
h2(x11)
h1(x12)
h1(x11)
No
Freq
Var
8
2
10
0.2
1
1
Ranking
Suppose I don't think they should abolish it.
Can we infer that "they should abolish it”?
T0
E11
T0
Interpolate
T0
Eij
RoSPr
RoSPr
+ Inter
RoSPr
+ Var
Q
🧊
🔥
Figure 2: An overview of ROSPR. For each hard prompt of the source datasets, soft prompts are trained via prompt tuning. After
storing training instances as keys and corresponding prompt embedding as values, ROSPRsearches training instances similar to
query set
Q
, retrieves the corresponding prompt embeddings, and selects the most frequently retrieved candidate for inference.
Variants of selection strategy are also shown: ROSPR+INTER interpolates between multiple related source embeddings and
ROSPR+VAR ranks candidate embeddings considering both frequency and variance.
(
204K parameters per prompt), which implies
that the memory overhead of parameter retrieval at
inference is negligible.
For each source training dataset
Di(i= 1, .., T )
where
T
is the total number of source datasets, we
train source embeddings
Eij (j= 1, .., Mi)
where
Mi
is the number of hard prompts in
Di
, making
soft prompt embeddings for each individual hard
prompts. Specifically, given a training instance
{xik, yik}(k= 1, .., K)
from
Di
where
K
is the
number of sampled training instances per dataset,
we first convert it into its hard prompted version
{hj(xik), hj(yik)}
where
hj
(·) denotes adding the
j-th hard prompt
2
. Next, we train the LM with the
following objective:
max
Eij
P(hj(yik)|Eij ;hj(xik)) (1)
where all the parameters of the underlying back-
bone LM are frozen and only
Eij
is trainable.
In short, given
Di
, we perform
Mi
number of
prompt tunings for each hard prompts, resulting
in
PT
i=1 Mi
total number of source prompt embed-
dings. For training efficiency, we only train on
K= 5000
training instances for a single epoch for
each source prompt embedding.
3.2 Zero-Shot Embedding Retrieval
After source prompt embedding training, we re-
trieve the most related source embeddings and se-
lect one from the retrieved candidates to be used at
inference (right part of Figure 2).
2
For each instances, the input and output are converted into
its prompted version using the promptsource toolkit (Bach
et al.,2022). An example is given in Appendix G.
We first construct a Source Prompt Library, con-
sisting of sentence-level representations of train-
ing instance inputs as keys and the corresponding
source prompt embedding as the values. For each
available source prompt embedding, nnumber of
samples are stored in the library. The sentence-
level representations are obtained by getting the
mean representation of hidden states of the last
layer of the dense retriever. We use a T0-small
encoder as a dense retriever, replicated based on
Sanh et al. (2021) with smaller model size.
At inference, we first randomly sample
Q
query
instances from the target task, following Lin et al.
(2022). After obtaining sentence-level representa-
tions for each query through our T0-small encoder,
we retrieve top-
N
examples for each query instance
using MIPS (maximum inner product search) op-
eration
3
on our Source Prompt Library, retrieving
a total of
QN
prompt embeddings. As the de-
fault methodology, among the retrieved embedding
candidates, we select the most frequently retrieved
prompt embedding as our designated soft prompt
for the given target task and concatenate the em-
bedding with each of the target task instances be-
fore feeding it to our backbone LM. In the next
two subsections, we explain different strategies for
calculating the target embedding from the
QN
prompt embedding candidates.
3.3 Interpolation of Prompt Embeddings
When retrieving only a single prompt embedding
for a given task (Section 3.2), it may result in high
variance across evaluation prompts when the se-
lected prompt embedding does not fit well with the
3
For all indexing and searching, we use FAISS (Johnson
et al.,2019) for fast source prompt embedding retrieval.
given task. Recent works on prompt embedding re-
trieval have shown that the interpolation of prompt
embeddings effectively transfers to the target task
(Asai et al.,2022;Vu et al.,2022). We also explore
calculating the target embedding through interpola-
tion of multiple source embeddings instead of just
using a single embedding. Among
QN
prompt
candidates searched in Section 3.2, we select top-
N
candidate embeddings based on the frequency
of the search. Then, we calculate the weighted sum
of the candidate embeddings, where the interpo-
lation weight for each source embedding is based
on the proportion of frequency. While Asai et al.
(2022); Vu et al. (2022) require fine-tuning the tar-
get embeddings on the target task to calculate the
weights for interpolation, our approach does not re-
quire any target task fine-tuning, enabling zero-shot
task transfer.
3.4 Variance-based Ranking
Similar to the scoring and calibration method of
Lu et al. (2022); Zhao et al. (2021), we introduce
a scoring method applicable to zero-shot classi-
fication tasks that allows ranking the
QN
re-
trieved prompt embedding candidates by consid-
ering the answer choice distribution of the given
target task as extra cues together with the original
frequency cues. To accomplish this, we perform a
forward pass with the concatenation of each can-
didate prompt embeddings together with the given
hard prompt (only including the instruction, ex-
cluding the input instance) of the target task and
give a higher score to the embedding candidate that
results in lower variance. Ideally, the combination
of soft and hard prompts should result in equal
probability among the answer choices because the
actual context of the task is not included.
Specifically, when given a target task with k-th
hard prompt
hk
, for each candidate embedding
Eij
,
we calculate the scoring as follows:
Score(hk, Eij) = freq(hk, Eij)
qVar[P(y|Eij , hk)]
(2)
where
y
refers to the available output options of the
target task.
4 Experimental Settings
In this section, we explain the experimental settings
of training of source prompt embedding and con-
struction of our Source Prompt Library. We also
explain our evaluation setting during zero-shot in-
ference and baseline models. We provide detailed
experiment configurations in Appendix F.
4.1 Source Tasks
For training soft prompts through prompt tuning,
we use the subset of source tasks used for the initial
T0 instruction-tuning (Sanh et al.,2021)
4
. For each
source task, we use the prompts for each dataset
in T0, resulting in a total of 230 prompts. For
Source Prompt Library construction, we sample
only
n= 100
training instances per source embed-
ding to minimize the inference latency. We show
a variation of
n
and different methods to sample
n
training instances in Appendix D.
4.2 Evaluation Tasks
Following Sanh et al. (2021), we evaluate on the
validation set of 4 held-out tasks (natural language
inference, sentence completion, coreference res-
olution, word sense disambiguation) resulting in
a total of 11 evaluation datasets. We also follow
Sanh et al. (2021) and evaluate on 14 different
datasets from the BIG-bench benchmark (Srivas-
tava et al.,2022)
5
. We use rank classification
evaluation method by selecting the output option
with higher log-likelihood following Brown et al.
(2020); Sanh et al. (2021). For all evaluation tasks,
we use accuracy as an evaluation metric and report
the mean accuracy and standard deviation of all of
the evaluation prompts per given dataset (average
of
10 prompts per evaluation dataset)
6
. For BIG-
bench tasks, we do not report standard deviation
because only one prompt is provided per task.
4.3 Baseline Models
Zero-shot Baseline For zero-shot baseline mod-
els, we show the results of T0 (3B) together with
a 4 times larger T0 (11B) instruction-tuned model.
We also compare with GPT-3 (175B) model which
is 60 times larger than T0 (3B).
Fine-tuning Baseline We also compare with
efficient fine-tuning baseline models that utilize
prompt tuning. These models require target task
prompt tuning, indicating that zero-shot transfer
4
We use 29 out of 38 datasets that are used to train T0. We
explain the training task selection rationale in Appendix H.1
5
We provide the full list of evaluation datasets in Appendix
H.6
For methods based on ROSPR, we report the performance
average of 3 runs with different random seeds for the sampling
of evaluation queries used for the prompt retrieval.
Method # of Param NLI Sentence Completion Coreference Resolut. WSD Total Avg.
(Base/Trainable) RTE CB AN. R1 AN. R2 AN. R3 COPA Hellasw. StoryC. Winogr. WSC WiC Mean STD
T0 3B / 0 64.55 45.36 33.81 33.11 33.33 75.88 26.60 84.03 50.97 65.10 50.69 51.22 3.62
PT (FT) 3B / 204K 63.14 44.52 33.07 31.44 32.94 73.00 26.39 - 50.67 46.73 50.02 - -
ATTEMPT (FT) 3B / 614K 68.95 44.88 36.19 34.73 34.81 74.38 26.76 - 51.33 64.90 51.05 - -
T0+ROSPR3B / 204K 71.54 49.48 37.05 34.64 33.92 78.75 26.97 85.52 51.50 64.52 51.76 53.24 3.62
W/ INTER 3B / 204K 70.71 52.30 37.30 34.34 33.89 78.25 26.94 85.62 51.10 64.52 50.73 53.24 3.30
W/ VAR 3B / 204K 71.78 50.36 37.07 34.58 33.90 78.88 27.01 85.52 51.45 64.94 51.94 53.38 3.38
W/ VAR & INTER 3B / 204K 72.60 51.98 37.25 34.31 33.95 77.83 26.84 85.58 50.93 64.97 51.18 53.40 3.47
ORACLE 3B / 204K 73.79 58.10 37.65 34.92 34.91 81.13 27.75 87.57 52.36 68.17 55.26 55.60 3.07
T0 11B / 0 80.83 70.12 43.56 38.68 41.26 90.02 33.58 92.40 59.94 61.45 56.58 60.76
GPT-3 175B / 0 63.50 46.40 34.60 35.40 34.50 91.00 78.90 83.20 70.20 65.40 45.92 59.00 -
Table 1: ROSPRrefers to our main proposed method, W/ INTER refers to applying interpolation of multiple source embedding
candidates, W/ VAR refers to retrieval through variance-based ranking, W/ VAR & INTER refers to applying both interpolation
and variance-based ranking where the interpolation weight is based on the variance-based ranking score, and ORACLE refers
performance when the most optimal source embedding is retrieved from the candidates, acting as an upper bound performance
for retrieval. FT refers to fine-tuned models on the target tasks. For FT models, we exclude StoryCloze due to the absence of
training instances. The best and second-best performance is shown in bold and
underline
respectively. Comparison with hard
prompt optimization techniques and visualization of the results is shown in Appendix Aand Appendix E, respectively.
is infeasible. Similar to our source prompt tun-
ing process, we train each target prompt for a sin-
gle epoch with a maximum of 5,000 training in-
stances. The first baseline model is naive prompt
tuning on the target tasks without any prompt re-
trieval, referred to as PT (Lester et al.,2021). The
second model is ATTEMPT (Asai et al.,2022),
which trains the target soft prompts through at-
tentional mixtures of source prompts. Because
StoryCloze (Mostafazadeh et al.,2016) does not
contain training instances, we exclude the dataset
for fine-tuning. More training details of fine-tuning
baseline are specified in Appendix C.
5 Experimental Results
ROSPRenhances the performance of T0 ef-
ficiently. Table 1shows the performance of
the 11 evaluation datasets. T0+ROSPRoutper-
forms T0 on 10 datasets among the 11 evaluation
datasets. Specifically, T0+ROSPRoutperforms
T0 on RTE (+6.99% points), CB (+4.12% points),
ANLI R1 (+3.24% points), and COPA (+2.87%
points). This shows that soft prompt retrieval as-
sists hard prompts for zero-shot task generaliza-
tion with a negligible number of additional param-
eters (0.007%)
7
. Also, while T0 outperforms
GPT-3 on 3 datasets (RTE, StoryCloze, WiC),
T0+ROSPRadditionally outperforms GPT-3 on
2 datasets (ANLI R1 and CB) and enlarging the
score gap for RTE, StoryCloze and WiC.
7
One exception is WSC, which is a binary classification
task (yes/no) predicting whether the reference of the pronoun
is correct. We observed that the evaluation data of this dataset
has unbalanced labels, containing over 60% of "No" labels.
This might be the reason why T0-11B underperforms T0-3B
only on this dataset (Sanh et al.,2021). Indeed, predicting
only "No" on this dataset outperforms T0-11B (63.46
>
61.45).
ROSPRalso outperforms finetuning baselines
even without utilizing any training instances of the
target task. We first observe that PT harms the
performance of the backbone model, which aligns
with the result of Liu et al. (2022a); Gu et al. (2022)
that prompt tuning is unstable when the training in-
stances or the training steps are small. By compar-
ing ATTEMPT with ROSPR, ROSPRoutperforms
ATTEMPT on 7 out of 10 tasks and 1.21% points
on the mean accuracy of 10 tasks. This shows that
ROSPRis more applicable for efficient adaptation
because it requires 3 times fewer additional param-
eters compared to ATTEMPT and also does not
require any further fine-tuning of the target task.
INTER and VAR enhance the performance
of ROSPR.We also analyze the effect of in-
troducing variants of ROSPR: interpolation of
soft prompts (INTER) and variance-based ranking
(VAR) in Table 1. First, applying INTER shows sim-
ilar accuracy compared to ROSPR. However, as
shown in the last column of Table 1, INTER reduces
the standard deviation of T0 and ROSPRby 8.84%
while improving the mean accuracy of T0, indicat-
ing increased robustness to different surface forms
of evaluation prompts. This indicates that interpo-
lation of multiple source embeddings outperforms
a single source embedding retrieval, aligning with
the result of Asai et al. (2022). Applying VAR with
T0+ROSPRimproves both zero-shot accuracy and
robustness of T0+ROSPR, showing that consider-
ing the answer choice distribution is beneficial for
zero-shot setting, aligned with results from Zhao
et al. (2021); Shi et al. (2022). Moreover, applying
both VAR+INTER results in the highest overall av-
erage accuracy, outperforming T0 by 2.18% points
by largely reducing the gap between larger LLMs.
摘要:

EfficientlyEnhancingZero-ShotPerformanceofInstructionFollowingModelviaRetrievalofSoftPromptSeonghyeonYeJoelJangDoyoungKimYongraeJoMinjoonSeoKAIST{seonghyeon.ye,joeljang,ikevin98,yongrae,minjoon}@kaist.ac.krAbstractEnhancingthezero-shotperformanceofinstruction-followingmodelsrequiresheavycomputation,...

展开>> 收起<<
Efficiently Enhancing Zero-Shot Performance of Instruction Following Model via Retrieval of Soft Prompt Seonghyeon Ye Joel Jang Doyoung Kim Yongrae Jo Minjoon Seo.pdf

共22页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:22 页 大小:2.28MB 格式:PDF 时间:2025-05-03

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 22
客服
关注