
features. We propose the test-time adaptation as the unsupervised knowledge distillation [
34
] to learn
the knowledge from MoE. Therefore, we treat
M
as the teacher and aim to distill its knowledge to a
student prediction network
f(·;θ)
to achieve adaptation. To do so, we sample a batch of unlabeled
x
from a target domain, and pass it to
M
to query their domain-specific knowledge
{Mi(x)}N
i=1
. That
knowledge is then forwarded to a knowledge aggregator
A(·;φ)
. The aggregator is learned to capture
the interconnection among domain knowledge and aggregate the information from MoE. The output of
A(·;φ)
is treated as the supervision signal to update
f(x;θ)
. Once the adapted
θ0
is obtained,
f(·;θ0)
is used to make predictions for the rest of the data in that domain. The overall framework follows
the effective few-shot learning paradigm where
x
is treated as an unlabeled support set [
74
,
65
,
25
].
Algorithm 1 Training for Meta-DMoE
Require: {DSi}N
i=1: data of source domains; α, β: learning rates; B: meta batch size
1: // Pretrain domain-specific MoE models
2: for i=1,...,Ndo
3: Train the domain-specific model Miusing DSi.
4: end for
5: // Meta-train aggregator A(·;φ)and student model f(·, θe;θc)
6: Initialize:φ,θe,θc
7: while not converged do
8: Sample a batch of Bsource domains {DSb}B, reset batch loss LB= 0
9: for each DSbdo
10: Sample support and query set: (xSU ), (xQ,yQ)∼ DSb
11: M0
e(xSU ;φ) = {Mi
e(xSU ;φ)}N
i=1, mask Mi
e(xSU ;φ)with 0if b=
i
12: Perform adaptation via knowledge distillation from MoE:
13: θ0
e=θe−α∇θe
A(M0
e(xSU ;φ)) −f(xSU ;θe)
2
14: Evaluate the adapted θ0
eusing the query set and accumulate the loss:
15: LB=LB+LCE (yQ, f(xQ;θ0
e, θc))
16: end for
17: Update φ,θe,θcfor the current meta batch:
18: (φ, θe, θc)←(φ, θe, θc)−β∇(φ,θe,θc)LB
19: end while
Training Meta-DMoE.
Properly
training
(θ, φ)
is critical to im-
prove the generalization on un-
seen domains. In our framework,
A(·, φ)
acts as a mechanism that
explores and mixes the knowledge
from multiple source domains.
Conventional knowledge distilla-
tion process requires large num-
bers of data samples and learn-
ing iterations [
34
,
2
]. The repeti-
tive large-scale training is inappli-
cable in real-world applications.
To mitigate the aforementioned
challenges, we follow the meta-
learning paradigm [
25
]. Such bi-
level optimization enforces the
A(·, φ)
to learn beyond any spe-
cific knowledge [
85
] and allows
the student prediction network
f(·;θ)
to achieve fast adaptation.
Specifically, We first split the data samples in each source domain
DSi
into disjoint support and
query sets. The unlabeled support set (
xSU
) is used to perform adaptation via knowledge distillation,
while the labeled query set (
xQ
,
yQ
) is used to evaluate the adapted parameters to explicitly test the
generalization on unseen data.
The student prediction network
f(·;θ)
can be decoupled as a feature extractor
θe
and classifier
θc
.
Unsupervised knowledge distillation can be achieved via the softened output [
34
] or intermediate
features [
84
] from
M
. The former one allows the whole student network
θ= (θe, θc)
to be adaptive,
while the latter one allows partial or complete
θe
to adapt to
x
, depending on the features utilized.
We follow [
56
] to only adapt
θe
in the inner loop while keeping the
θc
fixed. Thus, the adaptation
process is achieved by distilling the knowledge via the aggregated features:
DIST (xSU ,Me, φ, θe) = θ0
e=θe−α∇θekA(Me(xSU ); φ)−f(xSU ;θe)k2,(2)
where
α
denotes the adaptation learning rate,
Me
is the feature extractor of MoE models, which
extracts the features before the classifier, and
k·k2
measures the
L2
distance. The goal is to obtain an
updated
θ0
e
such that the extracted features of
f(xSU ;θ0
e)
is closer to the aggregated features. The
overall learning objective of Meta-DMoE is to minimize the following expected loss:
arg min
θe,θc,φ X
DSj∈DS
X
(xSU )∈DSj
(xQ,yQ)∈DSj
LCE (yQ, f(xQ;θ0
e, θc)),where θ0
e=DIST (xSU ,Me, φ, θe),
(3)
where
LCE
is the cross-entropy loss. Alg. 1demonstrates our full training procedure. To smooth the
meta gradient and stabilize the training, we process a batch of episodes before each meta-update.
Since the training domains overlap for the MoE and meta-training, we simulate the test-time out-of-
distribution by excluding the corresponding expert model in each episode. To do so, we multiply
the features by
0
to mask them out.
M0
e
in L11 of Alg. 1denotes such operation. Therefore, the
adaptation is enforced to use the knowledge that is aggregated from other domains.
5