Meta-Learning with Self-Improving Momentum Target Jihoon Tack1 Jongjin Park1 Hankook Lee1 Jaeho Lee2 Jinwoo Shin1

2025-05-02 0 0 740.37KB 20 页 10玖币
侵权投诉
Meta-Learning with
Self-Improving Momentum Target
Jihoon Tack1, Jongjin Park1, Hankook Lee1, Jaeho Lee2, Jinwoo Shin1
1Korea Advanced Institute of Science and Technology (KAIST)
2Pohang University of Science and Technology (POSTECH)
{jihoontack,jongjin.park,hankook.lee,jinwoos}@kaist.ac.kr
jaeho.lee@postech.ac.kr
Abstract
The idea of using a separately trained target model (or teacher) to improve the
performance of the student model has been increasingly popular in various machine
learning domains, and meta-learning is no exception; a recent discovery shows
that utilizing task-wise target models can significantly boost the generalization
performance. However, obtaining a target model for each task can be highly expen-
sive, especially when the number of tasks for meta-learning is large. To tackle this
issue, we propose a simple yet effective method, coined Self-improving Momentum
Target (SiMT). SiMT generates the target model by adapting from the tempo-
ral ensemble of the meta-learner, i.e., the momentum network. This momentum
network and its task-specific adaptations enjoy a favorable generalization perfor-
mance, enabling self-improving of the meta-learner through knowledge distillation.
Moreover, we found that perturbing parameters of the meta-learner, e.g., dropout,
further stabilize this self-improving process by preventing fast convergence of the
distillation loss during meta-training. Our experimental results demonstrate that
SiMT brings a significant performance gain when combined with a wide range
of meta-learning methods under various applications, including few-shot regres-
sion, few-shot classification, and meta-reinforcement learning. Code is available at
https://github.com/jihoontack/SiMT.
1 Introduction
Meta-learning [
51
] is the art of extracting and utilizing the knowledge from the distribution of tasks
to better solve a relevant task. This problem is typically approached by training a meta-model that can
transfer its knowledge to a task-specific solver, where the performance of the meta-model is evaluated
on the basis of how well each solver performs on the corresponding task. To learn such meta-model,
one should be able to (a) train an appropriate solver for each task utilizing the knowledge transferred
from the meta-model, and (b) accurately evaluate the performance of the solver. A standard way to do
this is the so-called
S
/
Q
(support/query) protocol [
55
,
34
]: for (a), use a set of support set samples to
train the solver; for (b), use another set of samples, called query set samples to evaluate the solver1.
Recently, however, an alternative paradigm—called
S
/
T
(support/target) protocol—has received
much attention [
58
,
62
,
32
]. The approach assumes that the meta-learner has an access to task-specific
target models, i.e., an expert model for each given task, and uses these models to evaluate task-specific
solvers by measuring the discrepancy of the solvers from the target models. Intriguingly, it has been
observed that such knowledge distillation procedure [
43
,
21
] helps to improve the meta-generalization
performance [
62
], in a similar way that such teacher-student framework helps to avoid overfitting
under non-meta-learning contexts [30,24].
1
We give an overview of terminologies used in the paper to guide readers new to this field (see Appendix A).
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.05185v1 [cs.LG] 11 Oct 2022
Distillation
Temporal
ensemble
Momentum targetMomentum network
Meta-model Task-specific solver Perturbed solver
Self-improving process
Dropout
Figure 1: An overview of the proposed Self-improving Momentum Target (SiMT): the momentum
network efficiently generates the target model, and by distilling knowledge to the task-specific solver,
it forms a self-improving process. Sand Qdenote the support and query datasets, respectively.
Despite such advantage, the
S
/
T
protocol is difficult to be used in practice, as training target models
for each task usually requires an excessive computation, especially when the number of tasks is large.
Prior works aim to alleviate this issue by generating target models in a compute-efficient manner.
For instance, Lu et al.
[32]
consider the case where the learner has an access to a model pre-trained
on a global data domain that covers most tasks (to be meta-trained upon), and propose to generate
task-wise target models by simply fine-tuning the model for each task. However, the method still
requires to compute for fine-tuning on a large number of tasks, and more importantly, is hard to be
used when there is no effective pre-trained model available, e.g., a globally pre-trained model is
usually not available in reinforcement learning, as collecting “global” data is a nontrivial task [9].
In this paper, we ask whether we can generate the task-specific target models by (somewhat ironically)
using meta-learning. We draw inspiration from recent observations in semi/self-supervised learning
literature [
50
,
16
,
5
] that the temporal ensemble of a model, i.e., the momentum network [
27
], can
be an effective teacher of the original model. It turns out that a similar phenomenon happens in
the meta-learning scenario: one can construct a momentum network of the meta-model, whose
task-specific adaptation is an effective target model from which the task-specific knowledge can be
distilled to train the original meta-model.
Contribution.
We establish a novel framework, coined Meta-Learning with Self-improving Momen-
tum Target (SiMT), which brings the benefit of the
S
/
T
protocol to the
S
/
Q
-like scenario where
task-specific target models are not available (but have access to query data). The overview of SiMT is
illustrated in Figure 1. In a nutshell, SiMT is comprised of two (iterative) steps:
Momentum target: We generate the target model by adapting from the momentum network, which
shows better adaptation performance than the meta-model itself. In this regard, generating the
target model becomes highly efficient, e.g., one single forward is required when obtaining the
momentum target for ProtoNet [45].
Self-improving process: The meta-model enables to improve through the knowledge distillation
from the momentum target, and this recursively improves the momentum network by the temporal
ensemble. Furthermore, we find that perturbing parameters of the task-specific solver of the
meta-model, e.g., dropout [
47
], further stabilizes this self-improving process by preventing fast
convergence of the distillation loss during meta-training.
We verify the effectiveness of SiMT under various applications of meta-learning, including few-
shot regression, few-shot classification, and meta-reinforcement learning (meta-RL). Overall, our
experimental results show that incorporating the proposed method can consistently and significantly
improve the baseline meta-learning methods [
10
,
31
,
36
,
45
]. In particular, our method improves the
few-shot classification accuracy of Conv4 [
55
] trained with MAML [
10
] on mini-ImageNet [
55
] from
47.33%
51.49% for 1-shot, and from 63.27%
68.74% for 5-shot, respectively. Moreover, we
show that our framework could even notably improve on the few-shot regression and meta-RL tasks,
which supports that our proposed method is indeed domain-agnostic.
2
2 Related work
Learning from target models.
Learning from an expert model, i.e., the target model, has shown its
effectiveness across various domains [
30
,
35
,
65
,
52
]. As a follow-up, recent papers demonstrate that
meta-learning can also be the case [
58
,
62
]. However, training independent task-specific target models
is highly expensive due to the large space of task distribution in meta-learning. To this end, recent
work suggests pre-training a global encoder on the whole meta-training set and finetune target models
on each task [
32
]; however, they are limited to specific domains and still require some computations,
e.g., they take more than 6.5 GPU hours to pre-train only 10% of target models while ours require 2
GPU hours for the entire meta-learning process (ProtoNet [
45
] of ResNet-12 [
34
]) on the same GPU.
Another recent relevant work is bootstrapped meta-learning [
11
], which generates the target model
from the meta-model by further updating the parameters of the task-specific solver for some number
of steps with the query dataset. While the bootstrapped target models can be obtained efficiently, their
approach is specialized in gradient-based meta-learning schemes, e.g., MAML [
10
]. In this paper, we
suggest an efficient and more generic way to generate the target model during the meta-training.
Learning with momentum networks.
The idea of temporal ensembling, i.e., the momentum net-
work, has become an essential component of the recent semi/self-supervised learning algorithms [
3
,
5
].
For example, Mean Teacher [
50
] first showed that the momentum network improves the performance
of semi-supervised image classification, and recent advanced approaches [
2
,
46
] adopted this idea for
achieving state-of-the-art performances. Also, in self-supervised learning methods which enforce
invariance to data augmentation, such momentum networks are widely utilized as a target network
[
19
,
16
] to prevent collapse by providing smoother changes in the representations. In meta-learning,
a concurrent work [
6
] used stochastic weight averaging [
23
] (a similar approach to the momentum
network) to learn a low-rank representation. In this paper, we empirically demonstrate that the mo-
mentum network shows better adaptation performance compare to the original meta-model, which
motivates us to utilize it for generating the target model in a compute-efficient manner.
3 Problem setup and evaluation protocols
In this section, we formally describe the meta-learning setup under consideration, and
S
/
Q
and
S/Tprotocols studied in prior works.
Problem setup: Meta-learning.
Let
p(τ)
be a distribution of tasks. The goal of meta-learning is
to train a meta-model
fθ
, parameterized by the meta-model parameter
θ
, which can transfer its
knowledge to help to train a solver for a new task. More formally, we consider some adaptation
subroutine
Adapt(·,·)
which uses both information transferred from
θ
and the task-specific dataset
(which we call support set)
Sτ
to output a task-specific solver as
φτ=Adapt(θ, Sτ)
. For example,
the model-agnostic meta-learning algorithm (MAML; [
10
]) uses the adaptation subroutine of taking
a fixed number of SGD on
Sτ
, starting from the initial parameter
θ
. In this paper, we aim to give a
general meta-learning framework that can be used in conjunction with any adaptation subroutine,
instead of designing a method specialized for a specific one.
The objective is to learn a nice meta-model parameter
θ
from a set of tasks sampled from
p(τ)
(or
sometimes the task distribution itself), such that the expected loss of the task-specific adaptations
is small, i.e.,
minθEτp(τ)[`τ(Adapt(θ, Sτ))]
, where
`τ(·)
denotes the test loss on task
τ
. To train
such meta-model, we need a mechanism to evaluate and optimize
θ
(e.g., via gradient descent). For
this purpose, existing approaches take one of two approaches: the
S
/
Q
protocol or the
S
/
T
protocol.
S/Qprotocol.
The majority of the existing meta-learning frameworks (e.g., [
55
,
34
]) splits the
task-specific training data into two, and use them for different purposes. One is the support set
Sτ
which is used to perform the adaptation subroutine. Another is the query set
Qτ
which is used for
evaluating the performance of the adapted parameter and compute the gradient with respect to
θ
. In
other words, given the task datasets (S1,Q1),(S2,Q2),...,(SN,QN),2the S/Qprotocol solves
min
θ
1
N
N
X
i=1
LAdapt(θ, Sτi),Qτi,(1)
where L(φ, Q)denotes the empirical loss of a solver φon the dataset Q.
2
Here, while we assumed a static batch of tasks for notational simplicity, the expression is readily extendible
to the case of a stream of tasks drawn from p(τ).
3
S/Tprotocol.
Another line of work considers the scenario where the meta-learner additionally has
an access to a set of target models
φtarget
for each training task [
58
,
32
]. In such case, one can
use a teacher-student framework to regularize the adapted solver to behave similarly (or have low
prediction discrepancy, equivalently) to the target model. Here, a typical practice is to not split each
task dataset and measure the discrepancy using the support dataset that is used for the adaptation
[
32
]. In other words, given the task datasets
S1,S2,...,SN
and the corresponding target models
φτ1
target, φτ2
target, . . . , φτN
target,the S/Tprotocol updates the meta-model by solving
min
θ
1
N
N
X
i=1
LteachAdapt(θ, Sτi), φτi
target,Sτi,(2)
where
Lteach(φ, φtarget,S)
denotes a discrepancy measure between the adapted model
φ
and the
target model φtarget, measured using the dataset S.
4 Meta-learning with self-improving momentum target
In this section, we develop a compute-efficient framework which bring the benefits of
S
/
T
protocol
to the settings where we do not have access to target-specific tasks or a general pretrained model, as
in general
S
/
Q
-like setups. In a nutshell, our framework iteratively generates a meta-target model
which generalizes well when adapted to the target tasks, by constructing a momentum network [
50
]
of the meta-model itself. The meta-model is then trained, using both the knowledge transferred from
the momentum target and the knowledge freshly learned from the query sets. We first briefly describe
our meta-model update protocol (Section 4.1), and then the core component, coined Self-Improving
Momentum Target (SiMT), which efficiently generates the target model for each task (Section 4.2).
4.1 Meta-model update with a S/Q-S/Thybrid loss
To update the meta-model, we use a hybrid loss function of the
S
/
Q
protocol
(1)
and the
S
/
T
protocol
(2)
. Formally, let
(S1,Q1),(S2,Q2),...,(SN,QN)
be given task datasets with support-query split,
and let
φτ1
target, φτ2
target, . . . , φτN
target
be task-specific target models generated by our target generation
procedure (which will be explained with more detail in Section 4.2). We train the meta-model as
min
θ
1
N
N
X
i=1 (1 λ)· L(Adapt(θ, Sτi),Qτi) + λ· Lteach(Adapt(θ, Sτi), φτi
target,Qτi),(3)
where
λ[0,1)
is the weight hyperparameter. We note two things about Eq. 3. First, while we are
training using the target model, we also use a
S
/
Q
loss term. This is because our method trains the
meta-target model and the meta-model simultaneously from scratch, instead of requiring fully-trained
target models. Second, unlike in the
S
/
T
protocol, we evaluate the discrepancy
Lteach
using the
query set
Qτi
instead of the support set, to improve the generalization performance of the student
model. In particular, the predictions of adapted models on query set samples are softer (i.e., having
less confidence) than on support set samples, and such soft predictions are known to be beneficial on
the generalization performance of the student model in the knowledge distillation literature [
64
,
49
].
4.2 SiMT: Self-improving momentum target
We now describe the algorithm we propose, SiMT (Algorithm 1), to generate the target model in
a compute-efficient manner. In a nutshell, SiMT is comprised of two iterative steps: momentum
target and self-improving process. To efficiently generate a target model, SiMT utilizes the temporal
ensemble of the network, i.e., the momentum network, then distills the knowledge of the generated
target model into the task-specific solver of the meta-model to form a self-improving process.
Momentum target.
For the compute-efficient generation of target models, we utilize the momentum
network
θmoment
of the meta-model. Specifically, after every meta-model training iteration, we compute
the exponential moving average of the meta-model parameter θas
θmoment η·θmoment + (1 η)·θ, (4)
where
η[0,1)
is the momentum coefficient. We find that
θmoment
can adapt better than the meta-
model
θ
itself and observe that the loss landscape has flatter minima (see Section 5.5), which can
4
摘要:

Meta-LearningwithSelf-ImprovingMomentumTargetJihoonTack1,JongjinPark1,HankookLee1,JaehoLee2,JinwooShin11KoreaAdvancedInstituteofScienceandTechnology(KAIST)2PohangUniversityofScienceandTechnology(POSTECH){jihoontack,jongjin.park,hankook.lee,jinwoos}@kaist.ac.krjaeho.lee@postech.ac.krAbstractTheideaof...

展开>> 收起<<
Meta-Learning with Self-Improving Momentum Target Jihoon Tack1 Jongjin Park1 Hankook Lee1 Jaeho Lee2 Jinwoo Shin1.pdf

共20页,预览4页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:20 页 大小:740.37KB 格式:PDF 时间:2025-05-02

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 20
客服
关注