On-Demand Sampling Learning Optimally from Multiple Distributions Nika Haghtalab Michael I. Jordan and Eric Zhao

2025-05-02 0 0 2.1MB 28 页 10玖币
侵权投诉
On-Demand Sampling: Learning Optimally from Multiple
Distributions
Nika Haghtalab, Michael I. Jordan, and Eric Zhao
University of California, Berkeley
{nika,jordan,eric.zh}@berkeley.edu
Abstract
Social and real-world considerations such as robustness, fairness, social welfare and multi-agent tradeoffs
have given rise to multi-distribution learning paradigms, such as collaborative [
9
], group distributionally
robust [
50
], and fair federated learning [
39
]. In each of these settings, a learner seeks to uniformly
minimize its expected loss over
n
predefined data distributions, while using as few samples as possible. In
this paper, we establish the optimal sample complexity of these learning paradigms and give algorithms
that meet this sample complexity. Importantly, our sample complexity bounds exceed that of learning
a single distribution by only an additive factor of
nlog(n)
ε2
. This improves upon the best known sample
complexity bounds for fair federated learning (by Mohri et al.
[39]
) and collaborative learning (by Nguyen
and Zakynthinou
[42]
) by multiplicative factors of
n
and
log(n)
ε3
, respectively. We also provide the first
sample complexity bounds for the group DRO objective of Sagawa et al.
[50]
. To guarantee these optimal
sample complexity bounds, our algorithms learn to sample from data distributions on demand. Our
algorithm design and analysis are enabled by our extensions of online learning techniques for solving
stochastic zero-sum games. In particular, we contribute stochastic variants of no-regret dynamics that
can trade off between players’ differing sampling costs.
1 Introduction
Pervasive needs for robustness, fairness, and multi-agent collaboration in learning have given rise to multi-
distribution learning paradigms (e.g., [
9
,
50
,
39
,
18
]). In these settings, we seek to learn a model that performs
well on any distribution in a predefined set of interest. For fairness considerations, these distributions
may represent heterogeneous populations of different protected or socioeconomic attributes; in robustness
applications, they may capture a learner’s uncertainty regarding the true underlying task; and in multi-agent
collaborative or federated applications, they may represent agent-specific learning tasks. In these applications,
the performance and optimality of a model is measured by its worst test-time performance on a distribution
in the set. We are concerned with this fundamental problem of designing sample-efficient multi-distribution
learning algorithms.
The sample complexity of multi-distribution learning differs from that of learning a single distribution
in several ways. On one hand, varying numbers of samples are required when learning tasks of varying
difficulty. On the other hand, similarity or overlap among learning tasks may obviate the need to sample
from some distributions. This makes the use of a fixed per-distribution sample budget highly inefficient and
suggests that optimal multi-distribution learning algorithms should sample on demand. That is, algorithms
should take additional samples whenever they need them and from whichever data distribution they want
them. On-demand sampling is especially appropriate when some population data is scarce (as in fairness
mechanisms in which samples are amended [
46
]); when the designer can actively perturb datasets towards
rare or atypical instances (such as in robustness applications [
29
,
59
]); or when sample sets represent agents’
contributions to an interactive multi-agent system [39, 10].
Authors are ordered alphabetically. Correspondence to eric.zh@berkeley.edu.
1
arXiv:2210.12529v3 [cs.LG] 2 Apr 2024
Problem Sample Complexity Thm Best Previous Result
Collab. Learning UB ε2(log(|H|) + nlog(n/δ)) [5.1] ε5log( 1
ε) log(n/δ)(log(|H|) + n)[42]
Collab. Learning LB ε2(log(|H|) + nlog(n/δ)) [5.3] ε1(log(|H|) + nlog(n/δ)) [9]
GDRO/AFL UB ε2(log(|H|) + nlog(n/δ)) [5.1] ε2(nlog(|H|) + nlog(n/δ)) [39]
GDRO/AFL UB ε2(DH+nlog(n/δ)) [6.1] N/A
(Training error convg.) ε2(DH+nlog(n/δ)) [6.2] ε2n(log(n) + DH)(expected convergence only) [50]
Table 1: This table lists upper (UB) and lower bounds (LB) on the sample complexity of learning a model class
H
on
n
distributions. For the collaborative learning and agnostic federated learning (AFL) settings, the sample
complexity upper bounds refer to the problem of learning a (potentially randomized) model whose expected loss
on each distribution is at most
OPT
+
ε
, where
OPT
is the best possible such guarantee. For the GDRO setting,
sample complexity refers to learning a deterministic model with expected losses of at most
OPT
+
ε
, from a convex
compact model space
H
with a Bregman radius of
DH
. Sample complexity bounds for collaborative and agnostic
federated learning in existing works extend to VC dimension and Rademacher complexity. Our results also extend to
VC dimension under some assumptions.
Blum et al.
[9]
demonstrated the benefit of on-demand sampling in the collaborative learning setting, when
all data distributions are realizable with respect to the same target classifier. This line of work established that
learning
n
distributions with on-demand sampling requires a factor of
e
O
(
log
(
n
)) times the sample complexity
of learning a single realizable distribution [
9
,
13
,
42
], whereas relying on batched uniform convergence takes
e
(
n
)times more samples than learning a single distribution [
9
]. However, beyond the realizable setting,
the best known multi-distribution learning results fall short of this promise: existing on-demand sample
complexity bounds for agnostic collaborative learning have highly suboptimal dependence on
ε
, requiring
e
O
(
log
(
n
)
3
)times the sample complexity of agnostically learning a single distribution [
42
]. On the other
hand, agnostic fair federated learning bounds [
39
] have been studied only for algorithms that sample in one
large batch and thus require
e
(
n
)times the sample complexity of a single learning task. Moreover, the
test-time performance of some key multi-distribution learning methods, such as group distributionally robust
optimization [50], have not been studied from a provable or mathematical perspective before.
In this paper, we give a general framework for obtaining optimal and on-demand sample complexity for
three multi-distribution learning settings. Table 1 summarizes our results. All three of these settings consider
a set
D
of
n
data distributions and a model class
H
, evaluating the performance of a model
h
by its worst-case
expected loss,
maxD∈D RD
(
h
). As a benchmark, they consider the worst-case expected loss of the best model,
i.e.,
OPT
=
minh∈H maxD∈D RD
(
h
). Notably, all of our sample complexity upper bounds demonstrate
only an additive increase of
ε2nlog
(
n/δ
)over the sample complexity of a single learning task, compared to
the multiplicative factor increase required by existing works.
-
Collaborative learning of Blum et al.
[9]
:For agnostic collaborative learning, our Theorem 5.1 gives a
randomized and a deterministic model that achieves performance guarantees of
OPT
+
ε
and 2
OPT
+
ε
,
respectively. Our algorithms have an optimal sample complexity of
O
(
1
ε2
(
log
(
|H|
) +
nlog
(
n/δ
))). This
improves upon the work of Nguyen and Zakynthinou
[42]
in two ways. First, it provides risk bounds of
OPT
+
ε
for randomized classifiers, where only 2
OPT
+
ε
was established previously. Second, it improves
the upper bound of Nguyen and Zakynthinou
[42]
by a multiplicative factor of
log
(
n
)
3
. In Theorem 5.3,
we give a matching lower bound on this sample complexity, thereby establishing the optimality of our
algorithms.
-
Group distributionally robust learning (group DRO) of Sagawa et al.
[50]
:For group DRO, we consider
a convex and compact model space
H
. Our Theorem 6.1 studies a model that achieves an
OPT
+
ε
guarantee on the worst-case test-time performance of the model with an on-demand sample complexity of
O1
ε2(DH+nlog(n/δ))
. Our results also imply a high-probability bound for the convergence of group
2
DRO training error that improves upon the (expected) convergence guarantees of Sagawa et al.
[50]
by a
factor of n.
-
Agnostic federated learning of [
39
]: For agnostic federated learning, we consider a finite class of hypotheses.
Our Theorems 5.1 and 6.1 show that on-demand sampling can accelerate the generalization of agnostic
federated learning by a factor of
n
compared to batch results established by Mohri et al.
[39]
. Our results
also imply matching high-probability bounds with respect to Mohri et al.
[39]
on the convergence of the
training error in the batched setting.
To achieve these results, we frame multi-distribution learning as a stochastic zero-sum game: a maximizing
player chooses a weight vector over data distributions
D
and a minimizing player chooses a weight vector over
hypotheses
H
. These two players require different numbers of datapoints in order to estimate their respective
payoff vectors. We therefore solve the game using no-regret dynamics, utilizing stochastic mirror descent to
optimally trade off the players’ asymmetric needs for datapoints. In Section 3, we give an overview of this
approach and its technical challenges and contributions. Our results also extend directly to settings with not
only multiple data distributions but also multiple loss functions.
1.1 Related Work
There are many lines of work that study multi-distribution learning but which have evolved independently in
separate communities.
Collaborative and agnostic federated learning. Blum, Haghtalab, Procaccia, and Qiao
[9]
posed the
first fully general description of multi-distribution learning, motivated by the application of collaborative
PAC learning. The field of collaborative learning is concerned with the learning of a shared machine learning
model by multiple stakeholders that each desire a model with low error on their own data distribution. The
line of work studies on-demand sample complexity bounds for the setting where stakeholders collect data so
as to minimize the error of the worst-off stakeholder [
9
,
42
,
13
,
11
]. This setting, stated in its full generality,
yields the multi-distribution learning problem as presented in this paper. Blum et al.
[9]
established a
log
(
n
)
factor blowup for the realizable case. For the general agnostic setting the best existing sample complexity
requires a factor
log
(
n
)
3
blowup [
42
]. In comparison, our work establishes a tight additive increase in the
sample complexity (which is comparable to
log
(
n
)multiplicative factor blowup with no dependence on
ε
). A
related line of work concerns the strategic considerations of collaborative learning and seeks incentive-aware
mechanisms for collecting data in the collaborative learning setting [10].
The field of federated learning focuses on a related motivating application where the goal is to learn a
model from data dispersed across multiple devices but where querying data from each device is expensive [
38
].
The agnostic federated learning framework of Mohri, Sivek, and Suresh
[39]
poses (a variant of) the multi-
distribution learning objective as a target for federated learning algorithms, and studies it in the offline setting
with a data-dependent analysis. Their results involve a blowup by a factor nfor the sample complexity.
Group distributionally robust optimization (Group DRO). Multi-distribution learning also arises in
distributionally robust optimization [
8
] under the name of Group DRO, a class of DRO problems where the
distributional uncertainty set is finite [
24
]. The group DRO literature is motivated by applications where
the distributions correspond to deployment domains or protected demographics that a machine learning
model should avoid spuriously linking to labels [
24
,
50
,
51
]. Although Group DRO—like collaborative
learning—is mathematically an instance of multi-distribution learning, prior work on Group DRO focuses on
the convergence of training error in offline settings, with a particular focus on deep learning applications. As
we discuss later, theoretical aspects of on-demand multi-distribution learning can translate into actionable
insights for Group DRO applications.
Multi-group fairness. Multi-distribution learning is also related to the fields of multi-group learning [
49
,
53
]
and multi-group fairness [
19
,
27
]. These works study offline learning settings with a single distribution
D
and
implicitly consider distribution
Di
to be the conditional distribution on a subset of the support representing
group
i
. In these settings, the learner does not have explicit access to oracles that sample from distributions
3
D1, . . . , Dn
and instead uses rejection sampling to collect data from
D1, . . . , Dn
. As a result, they experience
a sub-optimal sample complexity blowup by a factor
n
. This blowup may not be obvious upon first glance, as
these works provide theoretical guarantees for each group in terms of the number of datapoints from that
group. Multi-group learning [
49
,
53
] considers a similar problem to multi-distribution learning; by assuming
that there exists a hypothesis that is simultaneously
ε
-optimal on every distribution (an assumption not
made in our setting) they compare their learned hypothesis against the best hypothesis for each individual
distribution.
Multi-source domain adaptation. Multi-source domain adaptation, or multi-task learning, is another
related line of work that is concerned with using data from multiple different training distributions to
learn some target distribution, under the assumption that the training and target distributions share
some task relatedness [
7
,
36
]. Multi-distribution learning can be framed similarly as using a finite set of
training distributions to simultaneously learn the convex hull of the training distributions. Interestingly, the
requirement in the multi-distribution setting of learning the entire convex hull obviates the need for the
task-relatedness assumptions of multi-source learning.
Stochastic game equilibria. Our approach relates to a line of research on using online algorithms to
find min-max equilibria by playing no-regret algorithms against one another [
48
,
21
,
45
,
14
,
15
]. Online
mirror descent (OMD) is a well-studied family of methods that can find approximate minima of convex
functions, and also find approximate min-max equilibria of convex-concave games, with high probability,
using noisy first-order information [
47
,
40
,
23
,
6
]. We bring these online learning tools to bear on the problem
of finding saddle points in robust optimization formulations. The primary technical difference between
multi-distribution learning and traditional saddle-point optimization problems is that we have sample access
to data distributions instead of noisy local gradients.
Other paradigms. Several other machine learning paradigms also consider learning from multiple distribu-
tions. Notably, distributed learning (e.g., [
44
,
12
,
5
,
16
,
52
]) and federated learning (e.g., [
32
,
31
,
38
]) consider
learning from data that is spread across multiple sources or devices. Classically, both of these settings have
focused on minimizing the training or testing error averaged over these devices. The literature in these fields
has primarily focused on methods for minimizing the average loss using communication-efficient, private,
and robust-to-dropout training methods. However, optimizing average performance produces models that
can significantly underperform on some data sources, especially when the data is heterogeneously spread
across the sources. In comparison, multi-distribution learning paradigms such as collaborative learning [
9
],
agnostic federated learning [
39
], and Group DRO [
50
] learn models that perform well across any one of the
data sources.
Subsequent work. Haghtalab et al.
[22]
formalized multicalibration as a type of multi-distribution
learning, building on the framework presented in this manuscript. Their work improves upon state-of-art
multicalibration algorithms by implementing multi-distribution learning game dynamics using online learning
algorithms that leverage the structure of calibration losses. Zhang et al.
[61]
extended the discussion on the
sample complexity of Group DRO to settings with data budgets. They also noted an erroneous bandit-to-
full-information reduction in an earlier version of this manuscript, which we corrected in a previous version
(arXiv V2) with a minor change that employs Exp3 [
41
] or ELP [
1
] in place of our earlier reduction. Awasthi
et al.
[4]
presented steps towards answering the sample complexity of multi-distribution learning with VC
classes. This open problem was recently settled up to log factors by Zhang et al. [62], Peng [43].
2 Preliminaries
Throughout this manuscript, we use the shorthands
x(1:T):
=
x(1), . . . , x(T)
and
f
(
·, b
)
:
=
a7→ f
(
a, b
). We
write ∆(
A
)to denote the set of probability distributions supported on a set
A
and
d
to denote the
probability simplex in
Rd1
. We use
∥·∥
to denote the dual of the norm
∥·∥
and
eiRn
to denote the
i
th standard basis vector. Given a data distribution
D
supported on the space of datapoints
Z
, hypothesis
4
class
H
, and a loss function
:
H × Z
[0
,
1], we denote the expected loss (risk) of a hypothesis
h∈ H
by
RD,ℓ(h):=EzD[(h, z)], writing RD(h)if is clear from context.
2.1 Multi-Distribution Learning
The goal of multi-distribution learning is finding a hypothesis that uniformly minimizes expected loss across
multiple data distributions and loss functions. Importantly, we make no assumptions on the relationships
between the data distributions; for example, we do not assume the existence of a hypothesis that is
simultaneously optimal for every distribution. Formally, given a set of data distributions
D
=
{Di}n
i=1
, losses
L
=
{j}m
j=1
, and a hypothesis class
H
, we say a hypothesis
h
is
ε
-optimal for the multi-distribution learning
problem (D,L,H)if
max
D∈D max
∈L RD,ℓ(h)OPT + ε, where OPT := min
h∈H max
D∈D max
∈L RD,ℓ(h).(1)
Throughout this manuscript, we will often assume we are working with smooth and convex loss functions.
Formally, we say a multi-distribution learning problem (
D,L,H
)has smooth convex losses if two conditions
are met. First,
H
is parameterized by a convex compact Euclidean parameter space Θsuch that
H
=
{hθ}θΘ
.
Second, for the same parameter space Θ, for every loss function
∈ L
and datapoint
z∈ Z
, the mapping
f
: Θ
[0
,
1] defined as
f
(
θ
) =
(
hθ, z
)is convex and 1-smooth; i.e.,
∥∇θf(θ)∥ ≤
1for all
θ
Θ. We remark
that the assumption of smooth convex losses is a weak assumption. In fact, we will observe that our results
on smooth convex losses easily extend to bounded non-smooth non-convex losses when the hypothesis class
H
is finite or combinatorially bounded, such as when Hhas finite VC dimension or Littlestone dimension [33].
Sample complexity. We are interested in the design of multi-distribution learning algorithms that have
sample access to the distributions
D1, . . . , Dn
and only take a small number of samples from these distributions
overall. We formalize this access by defining a set of example oracles,
EX
(
D1
)
,...,EX
(
Dn
), where each
EX
(
Di
)returns i.i.d. samples from
Di
. We can then define the sample complexity of a multi-distribution
learning algorithm by the cumulative number of calls it makes to these example oracles in order to find a
solution.
We note that a multi-distribution learning algorithm may make these example oracle calls in an adaptive
fashion; i.e., choosing which example oracle to call based on the datapoints it received from previous oracle
calls. As first noted by Blum et al.
[9]
, this ability to query for samples on-demand is critical for achieving
efficient multi-distribution learning sample complexities. We also note that multi-distribution algorithms
can use a set of example oracles to sample from any mixture distribution
q
D
; e.g., by first sampling a
supporting distribution Difrom the mixture distribution and then calling its example oracle EX(Di).
2.2 Instances of Multi-Distribution Learning
Multi-distribution learning unifies the problem formulations of collaborative learning [
9
], agnostic federated
learning [
39
], and group distributionally robust optimization (group DRO) [
50
]. These problems have each
spawned a line of highly influential works but were previously not recognized to be equivalent. We emphasize
our view that multi-distribution learning is a particularly useful level of generality at which to study these
problems, as it allows for their unified treatment both conceptually and algorithmically.
Collaborative learning. In the collaborative PAC learning model of Blum et al.
[9]
, and its agnostic
extensions by Nguyen and Zakynthinou
[42]
, the goal is to learn a hypothesis that guarantees small risk for
every distribution in a collection of distributions. These data distributions are usually interpreted as the
heterogeneous problem domains faced by multiple participants that are collaborating on data collection; the
goal of collaborative learning is to learn a machine learning model that all participants are satisfied with.
Collaborative learning is usually studied in a supervised learning setting where datapoints consist of a
feature-label pair, i.e.,
Z
=
X × Y
, and where hypothesis classes
H ⊂ YX
are either finite or combinatorially
bounded. Importantly, loss functions are assumed to be bounded in [0
,
1], but may be non-smooth and
non-convex. Formally, given a set of data distributions,
D:
=
{D1, . . . , Dn}
, supported on
X × Y
, a loss
5
摘要:

On-DemandSampling:LearningOptimallyfromMultipleDistributions∗NikaHaghtalab,MichaelI.Jordan,andEricZhaoUniversityofCalifornia,Berkeley{nika,jordan,eric.zh}@berkeley.eduAbstractSocialandreal-worldconsiderationssuchasrobustness,fairness,socialwelfareandmulti-agenttradeoffshavegivenrisetomulti-distribut...

展开>> 收起<<
On-Demand Sampling Learning Optimally from Multiple Distributions Nika Haghtalab Michael I. Jordan and Eric Zhao.pdf

共28页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:28 页 大小:2.1MB 格式:PDF 时间:2025-05-02

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 28
客服
关注