Domain Discrepancy Aware Distillation for Model Aggregation in Federated Learning Shangchao Su Bin Li Xiangyang Xue

2025-04-27 0 0 4.83MB 9 页 10玖币
侵权投诉
Domain Discrepancy Aware Distillation for Model Aggregation in Federated
Learning
Shangchao Su, Bin Li, Xiangyang Xue
Fudan University
Abstract
Knowledge distillation has recently become popular as a
method of model aggregation on the server for federated learn-
ing. It is generally assumed that there are abundant public
unlabeled data on the server. However, in reality, there exists a
domain discrepancy between the datasets of the server domain
and a client domain, which limits the performance of knowl-
edge distillation. How to improve the aggregation under such
a domain discrepancy setting is still an open problem. In this
paper, we first analyze the generalization bound of the aggrega-
tion model produced from knowledge distillation for the client
domains, and then describe two challenges, server-to-client
discrepancy and client-to-client discrepancy, brought to the ag-
gregation model by the domain discrepancies. Following our
analysis, we propose an adaptive knowledge aggregation algo-
rithm FedD3A based on domain discrepancy aware distillation
to lower the bound. FedD3A performs adaptive weighting at
the sample level in each round of FL. For each sample in the
server domain, only the client models of its similar domains
will be selected for playing the teacher role. To achieve this,
we show that the discrepancy between the server-side sample
and the client domain can be approximately measured using a
subspace projection matrix calculated on each client without
accessing its raw data. The server can thus leverage the pro-
jection matrices from multiple clients to assign weights to the
corresponding teacher models for each server-side sample. We
validate FedD3A on two popular cross-domain datasets and
show that it outperforms the compared competitors in both
cross-silo and cross-device FL settings.
1 Introduction
Federated Learning (FL) (McMahan et al. 2017; Yang et al.
2019) is proposed and achieves rapid development for the
purpose of privacy protection. It allows multiple clients to
conduct joint machine learning training in which the client
datasets are not shared. In each round of communication, the
server sends the global model (a random initialization model
for the first round) to the clients. Then each client utilizes its
local data to train a local model based on the global model,
and the server gathers the updated local models of the clients
for aggregation. The aggregated model will eventually be
able to handle the data of multiple clients after going through
several rounds of communication.
FedAvg (McMahan et al. 2017) is a traditional FL algo-
rithm that obtains the aggregated model by directly averaging
the local model parameters on the server. However, parameter
averaging based aggregation methods can only be applied
when the local models have exactly the same structure, which
limits the application scenarios of FL. In some scenarios, such
as automatic driving, the server often has a large amount of
unlabeled data. Therefore, some studies (Guha, Talwalkar,
and Smith 2019; Lin et al. 2020; Gong et al. 2021; Sturluson
et al. 2021) replace parameter averaging with knowledge dis-
tillation (KD). They propose to use the public unlabeled data
on the server, and leverage the ensemble output of the local
models as a teacher to guide student’s (aggregated model)
learning. This solution can achieve an aggregated model
from heterogeneous local models, and also reduce the num-
ber of communication rounds required for federated learning,
because the server’s training process can accelerate the con-
vergence of the aggregated model.
However, the existing work based on KD does not consider
the impact of the discrepancy between different domains. As
we will point out in Section 3, when there is a discrepancy
between the public unlabeled data available on the server
and the data in the client domains, the performance of the
student (aggregated) model will drop significantly. Some
works (Zhang et al. 2021; Zhang and Yuan 2021) employ
the data-free distillation by restoring pseudo samples on the
server from the local models that fit the original domain
data distributions. To use the data-free method in real natural
images, the running mean and variance of the batch normal-
ization layers must be obtained. However, some studies (Yin
et al. 2020, 2021) show that the mean and variance may leak
the user’s training data. In addition, the data-free method
requires a large amount of additional calculation cost. There-
fore, in this paper, we focus on the setting that the server
has public unlabeled data. Now the problem is: how can we
extract as much knowledge as possible from available unla-
beled data on the server to maximize the performance of the
aggregated model obtained through distillation?
To answer this question, we define a domain as a pair
(D, f)
consisting of a distribution
D
on the input data
X
and a labeling function
f
, and then analyze the generaliza-
tion error of the aggregated model obtained by KD in the
client domains. We find that when the teacher models (lo-
cal models) originate from different domains, there are two
main factors that affect the performance of the aggregated
model: 1) Server-to-Client (S2C) discrepancy. When the do-
arXiv:2210.02190v1 [cs.LG] 4 Oct 2022
main discrepancy between the server domain and each client
domain increases, the performance of the aggregated model
on the client side will decrease. 2) Client-to-Clients (C2C)
discrepancy. There is a knowledge conflict among multiple
client models. Since the teacher models are learned from
different client domains, the sample in the server domain is
likely to obtain different predictions from different teacher
models, which may severely interfere with each other. Based
on these two factors, we can see that if we can reasonably
take into account the similarity between the server domain
and different client domains, and assign the most appropriate
teacher models to different samples in the server domain, the
generalization error upper bound can be reduced.
Motivated by our analysis, we propose an adaptive knowl-
edge aggregation method, Domain Discrepancy Aware Dis-
tillation, called FedD3A. During distillation, we assign in-
dependent teacher weights to a server-side sample based on
how similar this sample is to the client domains. Then we
can reduce the knowledge conflict between different teacher
models. In each round, the client extracts the features of the
local data using the backbone of the global model, and then
calculates the subspace projection matrix of the local feature
space. The server obtains the projection matrices from the
clients, and calculates the angles between the server data
and the local feature space to measure similarity without ac-
cessing client data features. Overall, our contributions are as
follows:
By analyzing the generalization error of the aggregated
model in the client domains, we point out two possible
reasons why distillation-based model aggregation perfor-
mance drops when there is a discrepancy between the
server unlabeled data and the client domain data, namely
Server-to-Client (S2C) discrepancy and Client-to-Clients
(C2C) discrepancy.
Motivated by our analysis, we propose an aggregation
method FedD3A based on domain discrepancy aware dis-
tillation, which further exploits the potential of abundant
unlabeled data on the server.
To validate our method, we conduct extensive experiments
on several datasets. The results show that compared with
baselines, our method has a significant improvement in
both cross-silo and cross-device FL settings.
2 Related Work
Federated Learning.
FedAvg (McMahan et al. 2017) pro-
poses the federated averaging method. In each round of com-
munication, a group of clients is randomly selected, the initial
model is sent to all clients for training, and then the models
trained by the clients are collected by the server. The aggre-
gated model is obtained by averaging the model parameters
of the clients. Some works (Li et al. 2019; Sahu et al. 2018)
demonstrate the convergence of FedAvg and point out that
the performance of FedAvg will degrade when different client
datasets are non-iid distributed. A lot of research has tried to
solve the non-iid problem encountered by FedAvg. One type
of work (Li et al. 2020a; Karimireddy et al. 2020; Reddi et al.
2020; Wang et al. 2020b,a; Singh and Jaggi 2020; Su, Li, and
Xue 2022) attempts to improve the fitting ability of the global
aggregated model. The other type of work (Dinh, Tran, and
Nguyen 2020; Fallah, Mokhtari, and Ozdaglar 2020; Hanzely
et al. 2020; Li et al. 2021) seeks to establish the personal-
ized federated learning (pFL), in which the clients can train
different models with different parameters.
Knowledge Distillation.
Knowledge distillation (Hinton,
Vinyals, and Dean 2015) is a knowledge transfer approach
and is initially proposed for model compression. Usually, a
larger model is used as the teacher model, and the knowl-
edge of the teacher model is transferred to the student model
by letting the smaller student model learn the output of the
teacher model. The techniques of KD are mainly divided into
logits-based distillation (Hinton, Vinyals, and Dean 2015; Li
et al. 2017), feature-based distillation (Romero et al. 2014;
Huang and Wang 2017; Yim et al. 2017), and relation-based
distillation (Park et al. 2019; Liu et al. 2019; Tung and Mori
2019). Some works (Du et al. 2020; Shen, He, and Xue 2019)
have also conducted research on multi-teacher distillation.
(Du et al. 2020) tries to make the gradient of the student
model close to that of all teacher models by multi-objective
optimization. (Shen, He, and Xue 2019) uses adversarial
learning to force students to learn intermediate features sim-
ilar to multiple teachers. There are some studies (Yin et al.
2020; Lopes, Fenu, and Starner 2017; Chawla et al. 2021) on
data-free distillation, which use pseudo samples generated
by the teacher model to replace real data.
Federated Learning with Knowledge Distillation.
There are several ways of applying KD in FL: 1) The first is
to perform distillation on the clients (Yao et al. 2021; Wu et al.
2022; Zhu, Hong, and Zhou 2021), treating the global aggre-
gated model as the teacher. 2) In (Gong et al. 2021, 2022; Sun
and Lyu 2020), the server and all the clients share a public
unlabeled dataset. The predictions of different models on the
dataset are transmitted among all parties to perform distilla-
tion. This is often used for personalized federated learning.
3) The third way is to directly use the client models as the
teachers. (Lin et al. 2020; Guha, Talwalkar, and Smith 2019)
suppose the server has unlabeled data and use an ensemble
of multiple client models as the teacher. The average output
of the teacher model is then used to calculate the distillation
loss. (Sturluson et al. 2021) proposes using the median-based
scores instead of the average logits of teacher outputs for
distillation. In addition, the date-free method has also been
applied on the server, (Zhang et al. 2021; Zhang and Yuan
2021) try to learn a generative model based on the ensem-
ble of client models. However, high-quality pseudo samples
rely on the running mean and variance carried by BN layers
in the client models, which may reveal privacy (Yin et al.
2020, 2021). The data-free method requires a large amount
of computational cost, and there is no evidence that it can be
applied to tasks other than image classification. Nevertheless,
data-free method is orthogonal to FedD3A, and all pseudo
samples can be used as our training data.
3 Proposed Method
Notations and Analysis
Following the domain adaptation (Ben-David et al. 2010)
field, during the analysis, we consider the binary classifica-
摘要:

DomainDiscrepancyAwareDistillationforModelAggregationinFederatedLearningShangchaoSu,BinLi,XiangyangXueFudanUniversityAbstractKnowledgedistillationhasrecentlybecomepopularasamethodofmodelaggregationontheserverforfederatedlearn-ing.Itisgenerallyassumedthatthereareabundantpublicunlabeleddataontheserver...

展开>> 收起<<
Domain Discrepancy Aware Distillation for Model Aggregation in Federated Learning Shangchao Su Bin Li Xiangyang Xue.pdf

共9页,预览2页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:9 页 大小:4.83MB 格式:PDF 时间:2025-04-27

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 9
客服
关注