SAGDA Achieving O2Communication Complexity in Federated Min-Max Learning Haibo Yang

2025-05-03 0 0 1.2MB 30 页 10玖币
侵权投诉
SAGDA: Achieving O(2)Communication
Complexity in Federated Min-Max Learning
Haibo Yang
Dept. of ECE
The Ohio State University
Columbus, OH 43210
yang.5952@osu.edu
Zhuqing Liu
Dept. of ECE
The Ohio State University
Columbus, OH 43210
liu.9384@osu.edu
Xin Zhang
Dept. of Statistics
Iowa State University
Ames, IA 50010
xinzhang@iastate.edu
Jia Liu
Dept. of ECE
The Ohio State University
Columbus, OH 43210
liu@ece.osu.edu
Abstract
Federated min-max learning has received increasing attention in recent years thanks
to its wide range of applications in various learning paradigms. Similar to the
conventional federated learning for empirical risk minimization problems, commu-
nication complexity also emerges as one of the most critical concerns that affects
the future prospect of federated min-max learning. To lower the communication
complexity of federated min-max learning, a natural approach is to utilize the
idea of infrequent communications (through multiple local updates) same as in
conventional federated learning. However, due to the more complicated inter-outer
problem structure in federated min-max learning, theoretical understandings of
communication complexity for federated min-max learning with infrequent commu-
nications remain very limited in the literature. This is particularly true for settings
with non-i.i.d. datasets and partial client participation. To address this challenge,
in this paper, we propose a new algorithmic framework called stochastic sampling
averaging gradient descent ascent (
SAGDA
), which i) assembles stochastic gradient
estimators from randomly sampled clients as control variates and ii) leverages two
learning rates on both server and client sides. We show that
SAGDA
achieves a
linear speedup in terms of both the number of clients and local update steps, which
yields an
O(2)
communication complexity that is orders of magnitude lower than
the state of the art. Interestingly, by noting that the standard federated stochastic
gradient descent ascent (FSGDA) is in fact a control-variate-free special version of
SAGDA
, we immediately arrive at an
O(2)
communication complexity result
for FSGDA. Therefore, through the lens of
SAGDA
, we also advance the current
understanding on communication complexity of the standard FSGDA method for
federated min-max learning.
1 Introduction
Recently, min-max optimization has drawn considerable attention from the machine learning com-
munity. Compared with conventional minimization problems (e.g., empirical risk minimization),
min-max optimization has a richer mathematical structure, thus being able to model more sophis-
ticated learning problems that emerge from ever-emerging applications. In particular, the subclass
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.00611v2 [cs.LG] 26 Dec 2022
of nonconvex-concave and nonconvex-PL (Polyak-Łojasiewicz) min-max problems has important
applications in, e.g., AUC (area under the ROC curve) maximization [1,2], adversarial and robust
learning [3, 4], and generative adversarial network (GAN) [5]. The versatility of min-max optimiza-
tion thus sparks intense research on developing efficient min-max algorithms. In the literature, the
family of primal-dual stochastic gradient methods is one of the most popular and efficient approaches.
For example, the stochastic gradient descent ascent (SGDA) method in this family has been shown
effective in centralized (single-machine) learning, both theoretically and empirically. However, as
over-parameterized models (e.g., deep neural networks) being more and more prevalent, learning
on a single machine becomes increasingly inefficient. To address challenge, large-scale distributed
learning emerges as an effective mechanism to accelerate training and has achieved astonishing
successes in recent years. Moreover, as more stringent data privacy requirements arise in recent
years, centralized learning becomes increasingly infeasible due to the prohibition of data collection.
This also motivates the need for distributed learning without sharing raw data. Consequently, there
is a growing need for distributed/federated min-max optimization, such as federated deep AUC
maximization [6, 7], federated adversarial training [8] and distributed/federated GAN [9–11].
Similar to conventional federated learning for minimization problems, federated min-max learning
enjoys benefits of parallelism and privacy, but suffers from high communication costs. One effective
approach to reduce communication costs is to utilize infrequent communications. For example, in
conventional federated learning for minimization problems, the FedAvg algorithm [12] allows each
client performs multiple stochastic gradient descent (SGD) steps to update the local model between
two successive communication rounds. Then, local models are sent to and averaged periodically
at the server through communications. Although infrequent communication may introduce extra
noises due to data heterogeneity, FedAvg can still achieve the same convergence rate as distributed
SGD, while having a significant lower communication complexity. Inspired by the theoretical
and empirical success of FedAvg, a natural idea to lower the communication costs of federated
min-max optimization is to utilize infrequent communication in the federated version of SGDA.
Despite the simplicity of this idea, existing works can only show unsatisfactory convergence rates
(
O(1/mT )
[13] and
O(1/(mKT )1/3)
[14]) for solving non-convex-strongly-concave or non-
convex-PL by federated SGDA with infrequent communication (
m
is the number of clients,
K
is
the number of local steps, and T is the number of communication rounds). These convergence
rates do not match with that of the FedAvg method. These unsatisfactory results are due to the fact
that federated min-max optimization not only needs to address the same challenges in conventional
federated learning (e.g., data heterogeneity and partial client participation), but also handle the more
complicated inter-outer problem structure. Thus, a fundamental question in federated min-max
optimization is: Can a federated SGDA-type method with infrequent communication provably achieve
the same convergence rate and even the highly desirable linear speedup effect for federated min-max
problems?
In this paper, we answer this question affirmatively. The main contributions of this paper are
summarized as follows:
We propose a new algorithmic framework called
SAGDA
(stochastic sampling averaging gradient
descent ascent), which assembles stochastic gradient estimators as control variates and leverages
two learning rates on both server and client sides. With these techniques,
SAGDA
relaxes the
restricted “bounded gradient dissimilarity” assumption, while still achieving the same convergence
rate with low communication complexity. We show that
SAGDA
achieves the highly desirable
linear speedup in terms of both the number of clients (even with partial client participation) and
local update steps, which yields an
O(2)
communication complexity that is orders of magnitude
lower than the state of the art in the literature of federated min-max optimization.
Interestingly, by noting that the standard federated stochastic gradient descent ascent (FSGDA)
is in fact a “control-variant-free” special version of our
SAGDA
algorithm, we can conclude
from our theoretical analysis of
SAGDA
that FSGDA achieves an
O(1/mKT )
convergence
rate for non-convex-PL problems with full client participation, which further implies the highly
desirable linear speedup effect. This improves the state-of-the-art result of FSGDA by a factor of
O(1/(mKT )1/6)
[14] and matches the optimal convergence rate of non-convex FL. Therefore,
through the lens of
SAGDA
, we also advance the current understanding on the communication
complexity of the standard FSGDA method for federated min-max learning.
2
Table 1: Number of communication rounds and stochastic gradients per client to reach
-stationary
point (
k∇Φk ≤
) for federated non-convex-PL min-max learning, denoted as communication and
client sample complexity. We omit the higher orders. Here
m
is the number of clients. BGD means
bounded gradient dissimilarity, which requires bounded data heterogeneity.
SAGDA
supports client
sampling and does not require BGD assumption.
Methods BGD
Assumption
Client
Sampling?
Per-Client Sample
Communication
Complexity Complexity
SGDA − − 4
Local SGDA [14] 4 7 max{4,1
m2/36} O((1/m)6)
(Momentum)
Local SGDA [15] 4 7 O((1/m)4)O(3)
CD-MAGE [13] 4 4 O((1/m)4)O((1/m)4)
SAGDA (Cor. 1)
Not needed
4O((1/m)4)O(2)
FSGDA (Cor. 2 3) 4 4 O((1/m)4)O(2)
The rest of the paper is organized as follows. In Section 2, we review related work. In Section 3,
we first introduce
SAGDA
And its convergence analysis, and then build the connection between
SAGDA
and FSGDA. We present numerical results in Section 4 and conclude the work in Section 5.
Due to space limitation, we relegate all proofs and some experiments to the supplementary material.
2 Related work
1) Federated Learning: In federated learning (FL), the seminal federated averaging (FedAvg) [16]
algorithm was first proposed as a heuristic to improve communication efficiency and data privacy,
but later theoretically confirmed to achieve a highly desirable
O(1/mKT )
convergence rate
in FL (implying linear convergence speedup as the number of clients
m
increases). Since then,
many follow-up works have been proposed to achieve the
O(1/mKT )
convergence rate for i.i.d.
datasets [17
23] and non-i.i.d. datasets [24
33]. For a comprehensive survey on FL convergence rate
order, we refer readers to Section 3 in [34].
2) Min-max Optimization:
Min-max optimization has a long history dating back to at least [35,36].
For non-convex-strongly-concave min-max problems, a simple approach is the stochastic gradient
descent ascent (SGDA), which performs stochastic gradient descent on primal variables and stochastic
gradient ascent on dual variables, respectively. It is well-known that SGDA achieves an
O(1/T)
convergence rate [37,38] for non-convex-strongly-concave min-max problems, matching that of SGD
in non-convex optimization. However, in the federated non-convex-strongly-concave setting, studies
in [13] and [14] only proved
O(1/mT )
and
O(1/(mKT )1/3)
convergence rates, respectively. So
far, it remains unknown whether federated SGDA could achieve the same desirable convergence
rate of
O(1/mKT )
as FedAvg. In this paper, we show that our
SAGDA
algorithm and FSGDA
(implied by
SAGDA
) indeed achieve the
O(1/mKT )
convergence rate, matching that of FedAvg.
3 Problem statement and algorithm design
We consider a general min-max optimization problem in federated learning setting as follows:
min
xRdmax
yRdf(x,y) := min
xRdmax
yRd
1
MX
i[M]
fi(x,y),(1)
where
fi(x,y) := EξiDi[f(x,y, ξi)]
is the local loss function associated with a local data distribu-
tion
Di
and
M
is the number of workers. Similar to FL, these exist two main challenges in federated
min-max optimization: 1) datasets are generated locally at the clients and generally non-i.i.d., i.e.,
Di6=Dj
, for
i6=j
; 2) potentially only a subset of clients may participate in each communication
round, leading to partial client participation.
3
In this paper, we focus on general non-convex-PL min-max problems. Before presenting the algo-
rithms and their convergence analysis, we first state several assumptions.
Assumption 1.
(Lipschitz Smooth)
fi(x,y)
is
Lf
-smooth, i.e., there exists a constant
Lf>0
, so that
k∇fi(x1,y1)fi(x2,y2)k2L2
fkx1x2k2+ky1y2k2
,
x1,x2,y1,y2Rd
,
i[M]
.
Assumption 2.
(Polyak-Łojasiewicz (PL) Condition) There exists a constant
µ > 0
such that
x,y
,
k∇yf(x,y)k22µmax
z(f(x,z)f(x,y)) .
Further, we assume the stochastic gradients with respect to
x
and
y
in each local update step at each
client are unbiased and have bounded variances.
Assumption 3.
(Unbiased Local Stochastic Gradient) Let
ξi
be a random local data sample at client
i
. The local stochastic gradients with respect to
x
and
y
are unbiased and have bounded variances:
E[fi(x,y, ξi)] = fi(x,y),Ek∇xfi(x,y, ξi)− ∇xfi(x,y)k2σ2
x,
Ek∇yfi(x,y, ξi)− ∇yfi(x,y)k2σ2
y,
where the expectation is taken over local distribution Di.
To analyze the convergence performance of min-max algorithms, we define a surrogate function
Φ
for
the global minimization as follows:
Φ(x) := maxyf(·,y).
We will use
Φ
as a metric to measure the
performance of an algorithm on min-max problems, and the goal is to find an approximate stationary
point of
Φ
efficiently. Then, we can conclude from previous works (see Lemma A.5 [39] or Lemma
4.3 [38]) that Φis L-smooth, where L:= Lf+L2
f.
Definition 1
(Stationarity)
.
For a differentiable function
Φ
,
z
is an
-stationary point if
k∇Φ(z)k ≤
.
Definition 2
(Complexity)
.
The communication and client sample complexity are defined as the total
number of rounds and stochastic gradients per client to achieve an -stationary point, respectively.
3.1 The Stochastic Averaging Gradient Descent Ascent (SAGDA) Algorithm
To solve Problem
(1)
, FedAvg could be naturally extended to federated min-max problems by applying
SGDA with multiple local update steps in primal and dual variables respectively. However, current
results [13
15] show that there exists two limitations: 1) limited data heterogeneity is often assumed,
e.g., bounded gradient dissimilarity assumption; 2) communication complexity is unsatisfactory. In
this paper, we propose the
SAGDA
(stochastic sampling averaging gradient descent ascent) algorithm
by utilizing the assembly of stochastic gradients from (randomly sampled) clients as control variates
to mitigate the effect of data heterogeneity in federated min-max problems. As will be shown later,
SAGDA is able to achieve better communication complexity under arbitrary data heterogeneity.
As illustrated in Algorithm 1, SAGDA contains the following two stages:
1.
On the Server Side: In each communication round, the server initializes the global model
(xt,yt)
at
t= 0
or updates the global model accrodingly when
t > 0
(Line 3). Specifically, for
t > 0
,
upon the reception of all returned parameters from round
t1
, the server aggregates them using
global learning rates
ηx,g
and
ηy,g
for
x
and
y
, respectively. Then server samples a subset of
clients
St
to participate in the training and broadcast the current global model
(xt,yt)
to these
clients (Line 4). Here, we follow the same common assumption on client participation as in FL:
the clients are uniformly sampled without replacement and a fixed-size subset (i.e.,
|St|=m
)
is chosen in each communication round. A key step here is to construct the control variates
(
¯
vx,¯
vy,vx,i,vy,i
) for server and client. Afterwards, the primal and dual variables alongside their
control variates are transmitted to each participated client iSt(Line 7).
2.
On the Client Side: Upon receiving the latest global model
(xt,yt)
, each client synchronizes its
local model (Line 10). Then, each client performs
K
local updates for
x
and
y
simultaneously
(Line 11). Upon the completion of local computations, the new local model is sent to the server.
We provide two options in
SAGDA
. First, in each communication round, client and server need to
respectively obtain control variates (
vx,i,vy,i
) and (
¯
vx,¯
vy
) for “variance reduction" purpose in primal
variable
x
and dual variable
y
(Lines 5 and 6). Option I requires each client to maintain the control
variates (
vx,i,vy,i
) across rounds locally (Line 12). As a result, (
¯
vx,¯
vy
) are constructed iteratively
(Line 5). In Option II, (
vx,i,vy,i
) are instantly calcuated by another round of communication, and
4
Algorithm 1 The Stochastic Averaging Gradient Descent Ascent (SAGDA) Algorithm.
1: for t= 0,··· , T 1do
2: for Server do
3: Initialize x0,y0for t= 0, or update global model from previous round for t > 0:
xt=xt1+ηx,g 1
mPiSt1xK+1
t1,i xt1,
yt=yt1+ηy,g 1
mPiSt1yK+1
t1,i yt1.
4: Randomly samples a subset Stof clients with |St|=m.
5: Option I: Construct sampling averaging ¯
vx,¯
vyfrom the return in the previous round:
¯
vx=¯
vx+1
MPiSt1vx,i,¯
vy=¯
vy+1
MPiSt1vy,i.
6: Option II:
The server sends current parameters
zt:= (xt,yt)
to clients in
St
and collects
stochastic gradients:
vx,i =xfi(zt, ξt,i),vy,i =yfi(zt, ξt,i),
¯
vx=1
mPiStvx,i,¯
vy=1
mPiStvy,i.
7: Send (xt,yt)and (¯
vx,¯
vy)to each client iSt.
8: end for
9: for Each client iStdo
10: Synchronization: x1
t,i =xt,y1
t,i =ytand receiving ¯
vx,t,¯
vy,t.
11: Local updates (k[K]):
xk+1
t,i =xk
t,i ηx,lvk
x,i (cf. Eq. (2) for vk
x,i);
yk+1
t,i =yk
t,i +ηy,lvk
y,i (cf. Eq. (3) for vk
y,i);
12: Option I:
Calculate: v0
x,i =xfi(zt, ξt,i),v0
y,i =yfi(zt, ξt,i).
Send xK+1
t,i ,yK+1
t,i and (∆vx,i,vy,i) = v0
x,i vx,i,v0
y,i vy,ito server.
Assign: vx,i =v0
x,i,vy,i =v0
y,i.
13: Option II: Send xK+1
t,i ,yK+1
t,i to server.
14: end for
15: end for
then (
vx,i,vy,i
) are constructed accrodingly (Line 6). We note that Option I needs client to be stateful
and thus being more challenging to implement in cross-device FL [34], while Option II may incur
extra communication overhead due to the need for one more communication session, although the
total communication size remains the same. In the local computation phase, each participated client
performs steps (Line 11) based on Eq.
(2)
and
(3)
, which can be interpreted as “variance reduction."
Here, we use zj
t,i := (xj
t,i,yj
t,i)for notational simplicity.
vk
x,i =xfi(zk
t,i, ξk
t,i)vx,i +¯
vx,(2)
vk
y,i =yfi(zk
t,i, ξk
t,i)vy,i +¯
vy.(3)
In classic variance reduction methods, the key idea is to utilize a full gradient (or approximation) to
reduce the stochastic gradient variance at the expense of high computation complexity compared to
SGD. Note that, in federated learning, the gradient dissimilarity (due to data heterogeneity) is a crtical
challenge and more problematic than stochastic gradient variance. Therefore, we calculate a 2-tuple
(
¯
vx,t,¯
vy,t
) of stochastic gradients from all clients as control variates to mitigate the potential gradient
deviation due to data heterogeneity. Note that
SAGDA
does not require a full gradient calculation for
each client. With the help from the local steps in
(2)
and
(3)
, each client no longer generate large
deviation in local updates even with arbitrary data heterogeneity. The reason is that, for small local
learning rates, the local steps in each client could be approximated by
xfi(zj
t,i, ξk
t,i)vx,i =vk
x,i ¯
vx,
yfi(zk
t,i, ξk
t,i)vy,i =vk
y,i ¯
vy.
In other words,
SAGDA
mimics mini-batch SGDA in the centralized learning by using an approxi-
mation of mini-batch stochastic gradient for the updates. As a result,
SAGDA
is able to provide a
5
摘要:

SAGDA:AchievingO(2)CommunicationComplexityinFederatedMin-MaxLearningHaiboYangDept.ofECETheOhioStateUniversityColumbus,OH43210yang.5952@osu.eduZhuqingLiuDept.ofECETheOhioStateUniversityColumbus,OH43210liu.9384@osu.eduXinZhangDept.ofStatisticsIowaStateUniversityAmes,IA50010xinzhang@iastate.eduJiaLiuD...

展开>> 收起<<
SAGDA Achieving O2Communication Complexity in Federated Min-Max Learning Haibo Yang.pdf

共30页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:30 页 大小:1.2MB 格式:PDF 时间:2025-05-03

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 30
客服
关注