Meta-DMoE Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts Tao Zhong1 Zhixiang Chi2 Li Gu2 Yang Wang23 Yuanhao Yu2 Jin Tang2

2025-05-02 0 0 2.89MB 21 页 10玖币
侵权投诉
Meta-DMoE: Adapting to Domain Shift by
Meta-Distillation from Mixture-of-Experts
Tao Zhong1, Zhixiang Chi2, Li Gu2, Yang Wang2,3, Yuanhao Yu2, Jin Tang2
1University of Toronto, 2Huawei Noah’s Ark Lab, 3Concordia University
tao.zhong@mail.utoronto.ca yang.wang@concordia.ca
{zhixiang.chi, li.gu, yuanhao.yu, tangjin}@huawei.com
Abstract
In this paper, we tackle the problem of domain shift. Most existing methods
perform training on multiple source domains using a single model, and the same
trained model is used on all unseen target domains. Such solutions are sub-optimal
as each target domain exhibits its own specialty, which is not adapted. Further-
more, expecting single-model training to learn extensive knowledge from multiple
source domains is counterintuitive. The model is more biased toward learning
only domain-invariant features and may result in negative knowledge transfer.
In this work, we propose a novel framework for unsupervised test-time adapta-
tion, which is formulated as a knowledge distillation process to address domain
shift. Specifically, we incorporate Mixture-of-Experts (MoE) as teachers, where
each expert is separately trained on different source domains to maximize their
specialty. Given a test-time target domain, a small set of unlabeled data is sam-
pled to query the knowledge from MoE. As the source domains are correlated
to the target domains, a transformer-based aggregator then combines the domain
knowledge by examining the interconnection among them. The output is treated
as a supervision signal to adapt a student prediction network toward the target
domain. We further employ meta-learning to enforce the aggregator to distill
positive knowledge and the student network to achieve fast adaptation. Extensive
experiments demonstrate that the proposed method outperforms the state-of-the-art
and validates the effectiveness of each proposed component. Our code is available
at https://github.com/n3il666/Meta-DMoE.
1 Introduction
The emergence of deep models has achieved superior performance [
32
,
40
,
47
]. Such unprecedented
success is built on the strong assumption that the training and testing data are highly correlated
(i.e., they are both sampled from the same data distribution). However, the assumption typically
does not hold in real-world settings as the training data is infeasible to cover all the ever-changing
deployment environments [
39
]. Reducing such distribution correlation is known as distribution shift,
which significantly hampers the performance of deep models. Humans are more robust against the
distribution shift, but artificial learning-based systems suffer more from performance degradation.
One line of research aims to mitigate the distribution shift by exploiting some unlabeled data from
a target domain, which is known as unsupervised domain adaptation (UDA) [
24
,
51
,
26
]. The
unlabeled data is an estimation of the target distribution [
86
]. Therefore, UDA normally adapts
to the target domain by transferring the source knowledge via a common feature space with less
effect from domain discrepancy [
79
,
50
]. However, UDA is less applicable for real-world scenarios
equal contribution
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.03885v2 [cs.LG] 11 Jan 2023
as repetitive large-scale training is required for every target domain. In addition, collecting data
samples from a target domain in advance might be unavailable as the target distribution could be
unknown during training. Domain generalization (DG) [
54
,
28
,
6
] is an alternative line of research
but more challenging as it assumes the prior knowledge of the target domains is unknown. DG
methods leverage multiple source domains for training and directly use the trained model on all
unseen domains. As the domain-specific information for the target domains is not adapted, a generic
model is sub-optimal [68,17].
Test-time adaptation with DG allows the model to exploit the unlabeled data during testing to
overcome the limitation of using a flawed generic model for all unseen target domains. In ARM [
86
],
meta-learning [
25
] is utilized for training the model as an initialization such that it can be adapted
using the unlabeled data from the unseen target domain before making the final inference. However,
we observed that ARM only trains a single model, which is counterintuitive for the multi-source
domain setting. There is a certain amount of correlations among the source domains while each
of them also exhibits its own specific knowledge. When the number of source domains rises, data
complexity dramatically increases, which impedes thorough exploration of the dataset. Furthermore,
real-world domains are not always balanced in data scales [
39
]. Therefore, the single-model training
is more biased toward the domain-invariant features and dominant domains instead of the domain-
specific features [12].
In this work, we propose to formulate the test-time adaptation as the process of knowledge distil-
lation [
34
] from multiple source domains. Concretely, we propose to incorporate the concept of
Mixture-of-Experts (MoE), which is a natural fit for the multi-source domain settings. The MoE
models are treated as a teacher and separately trained on the corresponding domain to maximize
their domain specialty. Given a new target domain, a few unlabeled data are collected to query the
features from expert models. A transformer-based knowledge aggregator is proposed to examine
the interconnection among queried knowledge and aggregate the correlated information toward the
target domain. The output is then treated as a supervision signal to update a student prediction
network to adapt to the target domain. The adapted student is then used for subsequent inference. We
employ bi-level optimization as meta-learning to train the aggregator at the meta-level to improve
generalization. The student network is also meta-trained to achieve fast adaptation via a few samples.
Furthermore, we simulate the test-time out-of-distribution scenarios during training to align the
training objective with the evaluation protocol.
The proposed method also provides additional advantages over ARM: 1) Our method provides a
larger model capability to improve the generalization power; 2) Despite the higher computational
cost, only the adapted student network is kept for inference, while the MoE models are discarded
after adaptation. Therefore, our method is more flexible in designing the architectures for the teacher
or student models. (e.g., designing compact models for the power-constrained environment); 3) Our
method does not need to access the raw data of source domains but only needs their trained models.
So, we can take advantage of private domains in a real-world setting where their data is inaccessible.
We name our method as
Meta
-
D
istillation of
MoE
(Meta-DMoE). Our contributions are as follows:
We propose a novel unsupervised test-time adaptation framework that is tailored for multiple
sources domain settings. Our framework employs the concept of MoE to allow each expert
model to explore each source domain thoroughly. We formulate the adaptation process as
knowledge distillation via aggregating the positive knowledge retrieved from MoE.
The alignment between training and evaluation objectives via meta-learning improves the
adaptation, hence the test-time generalization.
We conduct extensive experiments to show the superiority of the proposed method among
the state-of-the-arts and validate the effectiveness of each component of Meta-DMoE.
We validate that our method is more flexible in real-world settings where computational
power and data privacy are the concerns.
2 Related work
Domain shift.
Unsupervised Domain Adaptation (UDA) has been popular to address domain shift by
transferring the knowledge from the labeled source domain to the unlabeled target domain [
48
,
41
,
81
].
It is achieved by learning domain-invariant features via minimizing statistical discrepancy across
2
domains [
5
,
58
,
70
]. Adversarial learning is also applied to develop indistinguishable feature
space [
26
,
51
,
57
]. The first limitation of UDA is the assumption of the co-existence of source and
target data, which is inapplicable when the target domain is unknown in advance. Furthermore,
most of the algorithms focus on unrealistic single-source-single-target adaptation as source data
normally come from multiple domains. Splitting the source data into various distinct domains and
exploring the unique characteristics of each domain and the dependencies among them strengthen the
robustness [
88
,
76
,
78
]. Domain generalization (DG) is another line of research to alleviate the domain
shift. DG aims to train a model on multiple source domains without accessing any prior information
of the target domain and expects it to perform well on unseen target domains. [
28
,
45
,
53
] aim to
learn the domain-invariant feature representation. [
63
,
75
] exploit data augmentation strategies in
data or feature space. A concurrent work proposed bidirectional learning to mitigate domain shift [
14
].
However, deploying the generic model to all unseen target domains fails to explore domain specialty
and yields sub-optimal solutions. In contrast, our method further exploits the unlabeled target data
and updates the trained model to each specific unseen target domain at test time.
Test-time adaptation (TTA).
TTA constructs supervision signals from unlabeled data to update the
generic model before inference. Sun et al. [
68
] uses rotation prediction to update the model during
inference. Chi et al. [
17
] and Li et al. [
46
] reconstruct the input images to achieve internal-learning to
restore the blurry images and estimate the human pose. ARM [
86
] incorporates test-time adaptation
with DG, which meta-learns a model that is capable of adapting to unseen target domains before
making an inference. Instead of adapting to every data sample, our method only updates once for
each target domain using a fixed number of examples.
Meta-learning.
The existing meta-learning methods can be categorised as model-based [
62
,
59
,
8
],
metric-based [
65
,
30
], and optimization-based [
25
]. Meta-learning aims to learn the learning process
by episodic learning, which is based on bi-level optimization ([
13
] provides a comprehensive survey).
One of the advantages of bi-level optimization is to improve the training with conflicting learning
objectives. Utilizing such a paradigm, [
16
,
85
] successfully reduce the forgetting issue and improve
adaptation for continual learning [
49
]. In our method, we incorporate meta-learning with knowledge
distillation by jointly learning a student model initialization and a knowledge aggregator for fast
adaptation.
Mixture-of-experts.
The goal of MoE [
37
] is to decompose the whole training set into many subsets,
which are independently learned by different models. It has been successfully applied in image
recognition models to improve the accuracy [
1
]. MoE is also popular in scaling up the architectures.
As each expert is independently trained, sparse selection methods are developed to select a subset
of the MoE during inference to increase the network capacity [
42
,
23
,
29
]. In contrast, our method
utilizes all the experts to extract and combine the knowledge for positive knowledge transfer.
3 Preliminaries
In this section, we describe the problem setting and discuss the adaptive model. We mainly follow
the test-time unsupervised adaptation as in [
86
]. Specifically, we define a set of
N
source domains
DS={DSi}N
i=1
and
M
target domains
DT={DTj}M
j=1
. The exact definition of a domain varies
and depends on the applications or data collection methods. It could be a specific dataset, user, or
location. Let
x∈ X
and
y∈ Y
denote the input and the corresponding label, respectively. Each of the
source domains contains data in the form of input-output pairs:
DSi={(xz
S, yz
S)}Zi
z=1
. In contrast,
each of the target domains contains only unlabeled data:
DTj={(xk
T)}Kj
k=1
. For well-designed
datasets (e.g. [
33
,
20
]), all the source or target domains have the same number of data samples. Such
condition is not ubiquitous for real-world scenarios (i.e.
Zi16=Zi2
if
i16=i2
and
Kj16=Kj2
if
j16=j2
) where data imbalance always exists [
39
]. It further challenges the generalization with a
broader range of real-world distribution shifts instead of finite synthetic ones. Generic domain shift
tasks focus on the out-of-distribution setting where the source and target domains are non-overlapping
(i.e. DS∩ DT=), but the label spaces of both domains are the same (i.e. YS=YT).
Conventional DG methods perform training on
DS
and make a minimal assumption on the testing
scenarios [
67
,
3
,
35
]. Therefore, the same generic model is directly applied to all target domains
DT
, which leads to sub-optimal solutions [
68
]. In fact, for each
DTj
, some unlabeled data are
readily available which provides certain prior knowledge for that target distribution. Adaptive
Risk Minimization (ARM) [
86
] assumes that a batch of unlabeled input data
x
approximate the
3
...
Domain expert models
Aggregator
Support set
Distill
knowledge
Query set
Meta-update
Inner loop, adaptation
Outer loop, update meta parameters
Figure 1: Overview of the training of Meta-DMoE. We first sample disjoint support set
xSU
and
query set
(xQ,yQ)
from a training domain.
xSU
is sent to the expert models
M
to query their
domain-specific knowledge. An aggregator
A(·;φ)
then combines the information and generates a
supervision signal to update the
f(·;θ)
via knowledge distillation. The updated
f(·;θ0)
is evaluated
using the labeled query set to update the meta-parameters.
input distribution
px
which provides useful information about
py|x
. Based on the assumption, an
unsupervised test-time adaptation [
59
,
27
] is proposed. The fundamental concept is to adapt the
model to the specific domain using
x
. Overall, ARM aims to minimize the following objective
L(·,·)
over all training domains:
X
DSj∈DS
X
(x,y)∈DSjL(y, f(x;θ0)),where θ0=h(x, θ;φ).(1)
y
is the labels for
x
.
f(x;θ)
denotes the prediction model parameterized by
θ
.
h(·;φ)
is an adaptation
function parameterized by
φ
. It receives the original
θ
of
f
and the unlabeled data
x
to adapt
θ
to
θ0
.
The goal of ARM is to learn both
(θ, φ)
. To mimic the test-time adaptation (i.e., adapt before
prediction), it follows the episodic learning as in meta-learning [
25
]. Specifically, each episode
processes a domain by performing unsupervised adaptation using
x
and
h(·;φ)
in the inner loop
to obtain
f(·;θ0
). The outer loop evaluates the adapted
f(·;θ0
) using the true label to perform a
meta-update. ARM is a general framework that can be incorporated with existing meta-learning
approaches with different forms of adaptation module h(·;·)[25,27].
However, several shortcomings are observed with respect to the generalization. The episodic learning
processes one domain at a time, which has clear boundaries among the domains. The overall setting
is equivalent to the multi-source domain setting, which is proven to be more effective than learning
from a single domain [
53
,
87
] as most of the domains are correlated to each other [
2
]. However, it is
counterintuitive to learn all the domain knowledge in one single model as each domain has specialized
semantics or low-level features [
64
]. Therefore, the single-model method in ARM is sub-optimal
due to: 1) some domains may contain competitive information, which leads to negative knowledge
transfer [
66
]. It may tend to learn the ambiguous feature representations instead of capturing all the
domain-specific information [
80
]; 2) not all the domains are equally important [
76
], and the learning
might be biased as data in different domains are imbalanced in real-world applications [39].
4 Proposed approach
In this section, we explicitly formulate the test-time adaptation as a knowledge transfer process to
distill the knowledge from MoE. The proposed method is learned via meta-learning to mimic the
test-time out-of-distribution scenarios and ensure positive knowledge transfer.
4.1 Meta-distillation from mixture-of-experts
Overview.
Fig. 1shows the method overview. We wish to explicitly transfer useful knowledge
from various source domains to achieve generalization on unseen target domains. Concretely, we
define MoE as
M={Mi}N
i=1
to represent the domain-specific models. Each
Mi
is separately
trained using standard supervised learning on the source domain
DSi
to learn its discriminative
4
features. We propose the test-time adaptation as the unsupervised knowledge distillation [
34
] to learn
the knowledge from MoE. Therefore, we treat
M
as the teacher and aim to distill its knowledge to a
student prediction network
f(·;θ)
to achieve adaptation. To do so, we sample a batch of unlabeled
x
from a target domain, and pass it to
M
to query their domain-specific knowledge
{Mi(x)}N
i=1
. That
knowledge is then forwarded to a knowledge aggregator
A(·;φ)
. The aggregator is learned to capture
the interconnection among domain knowledge and aggregate the information from MoE. The output of
A(·;φ)
is treated as the supervision signal to update
f(x;θ)
. Once the adapted
θ0
is obtained,
f(·;θ0)
is used to make predictions for the rest of the data in that domain. The overall framework follows
the effective few-shot learning paradigm where
x
is treated as an unlabeled support set [
74
,
65
,
25
].
Algorithm 1 Training for Meta-DMoE
Require: {DSi}N
i=1: data of source domains; α, β: learning rates; B: meta batch size
1: // Pretrain domain-specific MoE models
2: for i=1,...,Ndo
3: Train the domain-specific model Miusing DSi.
4: end for
5: // Meta-train aggregator A(·;φ)and student model f(·, θe;θc)
6: Initialize:φ,θe,θc
7: while not converged do
8: Sample a batch of Bsource domains {DSb}B, reset batch loss LB= 0
9: for each DSbdo
10: Sample support and query set: (xSU ), (xQ,yQ)∼ DSb
11: M0
e(xSU ;φ) = {Mi
e(xSU ;φ)}N
i=1, mask Mi
e(xSU ;φ)with 0if b=
i
12: Perform adaptation via knowledge distillation from MoE:
13: θ0
e=θeαθe
A(M0
e(xSU ;φ)) f(xSU ;θe)
2
14: Evaluate the adapted θ0
eusing the query set and accumulate the loss:
15: LB=LB+LCE (yQ, f(xQ;θ0
e, θc))
16: end for
17: Update φ,θe,θcfor the current meta batch:
18: (φ, θe, θc)(φ, θe, θc)β(φ,θec)LB
19: end while
Training Meta-DMoE.
Properly
training
(θ, φ)
is critical to im-
prove the generalization on un-
seen domains. In our framework,
A(·, φ)
acts as a mechanism that
explores and mixes the knowledge
from multiple source domains.
Conventional knowledge distilla-
tion process requires large num-
bers of data samples and learn-
ing iterations [
34
,
2
]. The repeti-
tive large-scale training is inappli-
cable in real-world applications.
To mitigate the aforementioned
challenges, we follow the meta-
learning paradigm [
25
]. Such bi-
level optimization enforces the
A(·, φ)
to learn beyond any spe-
cific knowledge [
85
] and allows
the student prediction network
f(·;θ)
to achieve fast adaptation.
Specifically, We first split the data samples in each source domain
DSi
into disjoint support and
query sets. The unlabeled support set (
xSU
) is used to perform adaptation via knowledge distillation,
while the labeled query set (
xQ
,
yQ
) is used to evaluate the adapted parameters to explicitly test the
generalization on unseen data.
The student prediction network
f(·;θ)
can be decoupled as a feature extractor
θe
and classifier
θc
.
Unsupervised knowledge distillation can be achieved via the softened output [
34
] or intermediate
features [
84
] from
M
. The former one allows the whole student network
θ= (θe, θc)
to be adaptive,
while the latter one allows partial or complete
θe
to adapt to
x
, depending on the features utilized.
We follow [
56
] to only adapt
θe
in the inner loop while keeping the
θc
fixed. Thus, the adaptation
process is achieved by distilling the knowledge via the aggregated features:
DIST (xSU ,Me, φ, θe) = θ0
e=θeαθekA(Me(xSU ); φ)f(xSU ;θe)k2,(2)
where
α
denotes the adaptation learning rate,
Me
is the feature extractor of MoE models, which
extracts the features before the classifier, and
k·k2
measures the
L2
distance. The goal is to obtain an
updated
θ0
e
such that the extracted features of
f(xSU ;θ0
e)
is closer to the aggregated features. The
overall learning objective of Meta-DMoE is to minimize the following expected loss:
arg min
θecX
DSj∈DS
X
(xSU )∈DSj
(xQ,yQ)∈DSj
LCE (yQ, f(xQ;θ0
e, θc)),where θ0
e=DIST (xSU ,Me, φ, θe),
(3)
where
LCE
is the cross-entropy loss. Alg. 1demonstrates our full training procedure. To smooth the
meta gradient and stabilize the training, we process a batch of episodes before each meta-update.
Since the training domains overlap for the MoE and meta-training, we simulate the test-time out-of-
distribution by excluding the corresponding expert model in each episode. To do so, we multiply
the features by
0
to mask them out.
M0
e
in L11 of Alg. 1denotes such operation. Therefore, the
adaptation is enforced to use the knowledge that is aggregated from other domains.
5
摘要:

Meta-DMoE:AdaptingtoDomainShiftbyMeta-DistillationfromMixture-of-ExpertsTaoZhong1,ZhixiangChi2,LiGu2,YangWang2;3,YuanhaoYu2,JinTang21UniversityofToronto,2HuaweiNoah'sArkLab,3ConcordiaUniversitytao.zhong@mail.utoronto.cayang.wang@concordia.ca{zhixiang.chi,li.gu,yuanhao.yu,tangjin}@huawei.comAbstra...

展开>> 收起<<
Meta-DMoE Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts Tao Zhong1 Zhixiang Chi2 Li Gu2 Yang Wang23 Yuanhao Yu2 Jin Tang2.pdf

共21页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:21 页 大小:2.89MB 格式:PDF 时间:2025-05-02

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 21
客服
关注