MOCODA Model-based Counterfactual Data Augmentation Silviu Pitis1Elliot Creager1Ajay Mandlekar2Animesh Garg12

2025-05-06 0 0 2.58MB 23 页 10玖币
侵权投诉
MOCODA: Model-based Counterfactual Data
Augmentation
Silviu Pitis1Elliot Creager1Ajay Mandlekar2Animesh Garg1,2
1University of Toronto and Vector Institute, 2NVIDIA
Abstract
The number of states in a dynamic process is exponential in the number of objects,
making reinforcement learning (RL) difficult in complex, multi-object domains.
For agents to scale to the real world, they will need to react to and reason about
unseen combinations of objects. We argue that the ability to recognize and use
local factorization in transition dynamics is a key element in unlocking the power
of multi-object reasoning. To this end, we show that (1) known local structure in
the environment transitions is sufficient for an exponential reduction in the sample
complexity of training a dynamics model, and (2) a locally factored dynamics
model provably generalizes out-of-distribution to unseen states and actions. Know-
ing the local structure also allows us to predict which unseen states and actions this
dynamics model will generalize to. We propose to leverage these observations in a
novel Model-based Counterfactual Data Augmentation (MOCODA) framework.
MOCODA applies a learned locally factored dynamics model to an augmented
distribution of states and actions to generate counterfactual transitions for RL.
MOCODA works with a broader set of local structures than prior work and al-
lows for direct control over the augmented training distribution. We show that
MOCODA enables RL agents to learn policies that generalize to unseen states
and actions. We use MOCODA to train an offline RL agent to solve an out-of-
distribution robotics manipulation task on which standard offline RL algorithms
fail.1
1 Introduction
Modern reinforcement learning (RL) algorithms have demonstrated remarkable success in several
different domains such as games [
42
,
53
] and robotic manipulation [
23
,
4
]. By repeatedly attempting
a single task through trial-and-error, these algorithms can learn to collect useful experience and
eventually solve the task of interest. However, designing agents that can generalize in off-task and
multi-task settings remains an open and challenging research question. This is especially true in the
offline and zero-shot settings, in which the training data might be unrelated to the target task, and
may lack sufficient coverage over possible states.
One way to enable generalization in such cases is through structured representations of states,
transition dynamics, or task spaces. These representations can be directly learned, sourced from
known or learned abstractions over the state space, or derived from causal knowledge of the world.
Symmetries present in such representations enable compositional generalization to new configurations
of states or tasks, either by building the structure into the function approximator or algorithm
[28, 58, 15, 43], or by using the structure for data augmentation [3, 33, 51].
In this paper, we extend past work on structure-driven data augmentation by using a locally factored
model of the transition dynamics to generate counterfactual training distributions. This enables
agents to generalize beyond the support of their original training distribution, including to novel
Correspondence to spitis@cs.toronto.edu
1Visualizations & code available at https://sites.google.com/view/mocoda-neurips-22/
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.11287v1 [cs.LG] 20 Oct 2022
Figure 1:
Out-of-Distribution Generalization using MOCODA
: A US driver can use MOCODA to quickly
adapt to driving in the left lane during a UK trip. Their prior experience
PEMP (τ)
(
top left
) contains mostly
right-driving experience (e.g.
1
,
2
) and a limited amount of left-driving experience after renting the car in the
UK (e.g.
3
). A locally factored model that captures the transition structure (
bottom left
) allows the agent to
accurately sample counterfactual experience from
PMOCODA (τ)
(
bottom center
), including novel left-lane city
driving maneuvers (e.g.
4
). This enables fast adaptation when learning an optimal policy for the new task (UK
driving). Our framework MOCODA draws single-step transition samples from
PMOCODA (τ)
given
PEMP (τ)
and
knowledge of the causal structure; several realizations of this framework are described in Section 4.
tasks where learning the optimal policy requires access to states never seen in the experience buffer.
Our key insight is that a learned dynamics model that accurately captures local causal structure (a
“locally factored” dynamics model) will predictably exhibit good generalization performance outside
the empirical training distribution. We propose Model-based Counterfactual Data Augmentation
(MOCODA), which generates an augmented state-action distribution where its locally factored
dynamics model is likely to perform well, then applies its dynamics model to generate new transition
data. By training the agent’s policy and value modules on this augmented dataset, they too learn to
generalize well out-of-distribution. To ground this in an example, we consider how a US driver might
use MOCODA to adapt to driving on the left side of the road while on vacation in the UK (Figure 1).
Given knowledge of the target task, we can even focus the augmented distribution on relevant areas
of the state-action space (e.g., states with the car on the left side of the road).
Our main contributions are:
A.
Our proposed method, MOCODA, leverages a masked dynamics model for data-augmentation
in locally-factored settings, which relaxes strong assumptions made by prior work on factored
MDPs and counterfactual data augmentation.
B.
MOCODA allows for direct control of the state-action distribution on which the agent trains; we
show that controlling this distribution in a task relevant way can lead to improved performance.
C.
We demonstrate “zero-shot” generalization of a policy trained with MOCODA to states that the
agent has never seen. With MOCODA, we train an offline RL agent to solve an out-of-distribution
robotics manipulation task on which standard offline RL algorithms fail.
2 Preliminaries
2.1 Background
We model the environment as an infinite-horizon, reward-free Markov Decision Process (MDP),
described by tuple
hS,A, P, γi
consisting of the state space, action space, transition function, and
discount factor, respectively [
52
,
57
]. We use lowercase for generic instances and uppercase for
2
Figure 2:
Locally Factored Dynamics:
The state-action space
S × A
is divided into local subsets,
L1,L2,L3
,
which each have their own factored causal structure,
GL
. The local transition model
PL
is factored according to
GL; e.g., in the example shown, PL(xt, yt, at) = [Px(xt), Py(yt, at)].
variables (e.g.,
srange(S)⊆ S
, though we also abuse notation and write
S∈ S
). A task is
defined as a tuple
hr, P0i
, where
r:S × A R
is a reward function and
P0
is an initial distribution
over
S
. The goal of the agent given a task is to learn a policy
π:S → A
that maximizes value
EP,π Ptγtr(st, at)
. Model-based RL is one approach to solving this problem, in which the agent
learns a model
Pθ
of the transition dynamics
P
. The model is “rolled out” to generate “imagined”
trajectories, which are used either for direct planning [
11
,
8
], or as training data for the agent’s policy
and value functions [56, 20].
Factored MDPs
. A factored MDP (FMDP) is a type of MDP that assumes a globally factored
transition model, which can be used to exponentially improve the sample complexity of RL [
16
,
24
,
45
]. In an FMDP, states and actions are described by a set of variables
{Xi}
, so that
S×A =
X1×X2×. . . ×Xn
, and each state variable
Xi∈ Xi(Xiis a subspace of S)
is dependent on
a subset of state-action variables (its “parents”
Pa(Xi)
) at the prior timestep,
XiPi(Pa(Xi))
.
We call a set
{Xj}
of state-action variables a “parent set” if there exists a state variable
Xi
such
that
{Xj}=Pa(Xi)
. We say that
Xi
is a “child” of its parent set
Pa(Xi)
. We refer to the tuple
hXi,Pa(Xi), Pi(·)ias a “causal mechanism”.
Local Causal Models
. Because the strict global factorization assumed by FMDPs is rare, recent
work on data augmentation for RL and object-oriented RL suggests that transition dynamics might be
better understood in a local sense, where all objects may interact with each other over time, but in a
locally sparse manner [
15
,
28
,
39
]. Our work uses an abridged version of the Local Causal Model
(LCM) framework [
51
], as follows: We assume the state-action space decomposes into a disjoint
union of local neighborhoods:
S×A =L1t L2t ··· t Ln
. A neighborhood
L
is associated with
its own transition function
PL
, which is factored according to its graphical model
GL
[
29
]. We
assume no two graphical models share the same structure
2
(i.e., the structure of
GL
uniquely identifies
L
). Then, analogously to FMDPs, if
(st, at)∈ L
, each state variable
Xi
t+1
at the next time step is
dependent on its parents
PaL(Xi
t+1)
at the prior timestep,
Xi
t+1 PL
i(PaL(Xi
t+1))
. We define mask
function
M:S × A → {Li}
that maps
(s, a)∈ L
to the adjacency matrix of
GL
. This formalism is
summarized in Figure 2, and differs from FMDPs in that each Lhas its own factorization.
Given knowledge of
M
, the Counterfactual Data Augmentation (CoDA) framework [
51
] allowed
agents to stitch together empirical samples from disconnected causal mechanisms to derive novel
transitions. It did this by swapping compatible components between the observed transitions to
create new ones, arguing that this procedure can generate exponentially more data samples as the
number of disconnected causal components grows. CoDA was shown to significantly improve sample
complexity in several settings, including the offline RL setting and a goal-conditioned robotics control
setting. Because CoDA relied on empirical samples of the causal mechanisms to generate data in
a model-free fashion, however, it required that the causal mechanisms be completely disentangled.
The proposed MOCODA leverages a dynamics model to improve upon model-free CoDA in several
respects: (a) by using a learned dynamics model, MOCODA works with overlapping parent sets,
(b) by modeling the parent distribution, MOCODA allows the agent to control the overall data
distribution, (c) MOCODA demonstrates zero-shot generalization to new areas of the state space,
allowing the agent to solve tasks that are entirely outside the original data distribution.
2
This assumption is a matter of convenience that makes counting local subspaces in Section 3 slightly easier
and simplifies our implementation of the locally factored dynamics model in Section 4. To accommodate cases
where subspaces with different dynamics share the same causal structure, one could identify local subspaces
using a latent variable rather than the mask itself, which we leave for future work.
3
2.2 Related Work
RL with Structured Dynamics
. A growing literature recognizes the advantages that structure can
provide in RL, including both improved sample efficiency [
37
,
5
,
19
] and generalization performance
[
62
,
59
,
54
]. Some of these works involve sparse interactions whose structure changes over time
[
15
,
28
], which is similar to and inspires the locally factored setup assumed by this paper. Most
existing work focuses on leveraging structure to improve the architecture and generalization capability
of the function approximator [
62
]. Although MOCODA also uses the structure for purposes of
improving the dynamics model, our proposed method is among the few existing works that also use
the structure for data augmentation [38, 40, 51].
Several past and concurrent works aim to tackle unsupervised object detection [
36
,
12
] (i.e., learning
an entity-oriented representation of states, which is a prerequisite for learning the dynamics factor-
ization) and learning the dynamics factorization [
27
,
60
]. These are both open problems that run
orthogonal to MOCODA. We expect that as solutions for unsupervised object detection and factored
dynamics discovery improve, MOCODA will find broader applicability.
RL with Causal Dynamics
. Adopting this formalism allows one to cast several important problems
within RL as questions of causal inference, such as off-policy evaluation [
7
,
44
], learning baselines
for model-free RL [
41
], and policy transfer [
25
]. Lu et al.
[38]
applied SCM dynamics to data
augmentation in continuous sample spaces, and discussed the conditions under which the generated
transitions are uniquely identifiable counterfactual samples. This approach models state and action
variables as unstructured vectors, emphasizing benefit in modeling action interventions for settings
such as clinical healthcare where exploratory policies cannot be directly deployed. We take a com-
plementary approach by modeling structure within state and action variables, and our augmentation
scheme involves sampling entire causal mechanisms (over multiple state or action dimensions) rather
than action vectors only. See Appendix F for a more detailed discussion of how MOCODA sampling
relates to causal inference and counterfactual reasoning.
3 Generalization Properties of Locally Factored Models
3.1 Sample Complexity of Training a Locally Factored Dynamics Model
In this subsection, we provide an original adaptation of an elementary result from model-based RL to
the locally factored setting, to show that factorization can exponentially improve sample complexity.
We note that several theoretical works have shown that the FMDP structure can be exploited to obtain
similarly strong sample complexity bounds in the FMDP setting. Our goal here is not to improve
upon these results, but to adapt a small part (model-based generalization) to the significantly more
general locally factored setting and show that local factorization is enough for (1) exponential gains
in sample complexity and (2) out-of-distribution generalization with respect to the empirical joint,
to a set of states and actions that may be exponentially larger than the empirical set. Note that the
following discussion applies to tabular RL, but we apply our method to continuous domains.
Notation
. We work with finite state and action spaces (
|S|,|A| <
) and assume that there are
m
local subspaces
L
of size
|L|
, such that
m|L| =|S||A|
. For each subspace
L
, we assume transitions
factor into
k
causal mechanisms
{Pi}
, each with the same number of possible children,
|ci|
, and the
same number of possible parents,
|Pai|
. Note
mΠi|ci|=|S|
(child sets are mutually exclusive) but
mΠi|Pai| ≥ |S||A| (parent sets may overlap).
Theorem 1. Let nbe the number of empirical samples used to train the model of each local causal
mechanism, PL
i,θ at each configuration of parents Pai=x. There exists constant csuch that, if
nck2|ci|log(|S||A|)
2,
then, with probability at least 1δ, we have:
max
(s,a)kP(s, a)Pθ(s, a)k1.
Sketch of Proof.
We apply a concentration inequality to bound the
`1
error for fixed parents and
extend this to a bound on the
`1
error for a fixed
(s, a)
pair. The conclusion follows by a union bound
across all states and actions. See Appendix A for details.
4
To compare to full-state dynamics modeling, we can translate the sample complexity from the per-
parent count
n
to a total count
N
. Recall
mΠi|ci|=|S|
, so that
|ci|= (|S|/m)1/k
, and
mΠi|Pai| ≥
|S||A|
. We assume a small constant overlap factor
v1
, so that
|Pai|=v(|S||A|/m)1/k
. We need
the total number of component visits to be
n|Pai|km
, for a total of
nv(|S||A|/m)1/km
state-action
visits, assuming that parent set visits are allocated evenly, and noting that each state-action visit
provides kparent set visits. This gives:
Corollary 1. To bound the error as above, we need to have
Ncmk2(|S|2|A|/m2)1/k log(|S||A|)
2,
total train samples, where we have absorbed the overlap factor vinto constant c.
Comparing this to the analogous bound for full-state model learning (Agarwal et al. [1], Prop. 2.1):
Nc|S|2|A|log(|S||A|)
2,
we see that we have gone from super-linear
O(|S|2|A|log(|S||A|))
sample complexity in terms of
|S||A|, to the exponentially smaller O(mk2(|S|2|A|/m2)1/k log(|S||A|)).
This result implies that for large enough
|S||A|
our model must generalize to unseen states and
actions, since the number of samples needed (
N
) is exponentially smaller than the size of the
state-action space (|S||A|). In contrast, if it did not, then sample complexity would be Ω(|S||A|).
Remark 3.1.
The global factorization property of FMDPs is a strict assumption that rarely holds in
reality. Although local factorization is broadly applicable and significantly more realistic than the
FMDP setting, it is not without cost. In FMDPs, we have a single subspace (
m= 1
). In the locally
factored case, the number of subspaces
m
is likely to grow exponentially with the number of factors
k
, as there are exponentially many ways that
k
factors can interact. To be more precise, there are
k2k
possible bipartite graphs from
k
nodes to
k
nodes. Nevertheless, by comparing bases (
2 |S||A|
),
we see that we still obtain exponential gains in sample complexity from the locally factored approach.
3.2 Training Value Functions and Policies for Out-of-Distribution Generalization
In the previous subsection, we saw that a locally factored dynamics model provably generalizes
outside of the empirical joint distribution. A natural question is whether such local factorization can
be leveraged to obtain similar results for value functions and policies?
We will show that the answer is yes, but perhaps counter-intuitively, it is not achieved by directly
training the value function and policy on the empirical distribution, as is the case for the dynamics
model. The difference arises because learned value functions, and consequently learned policies,
involve the long horizon prediction
EP,π P
t=0 γtr(st, at)
, which may not benefit from the local
sparsity of
GL
. When compounded over time, sparse local structures can quickly produce an entangled
long horizon structure (cf. the “butterfly effect”). Intuitively, even if several pool balls are far apart
and locally disentangled, future collisions are central to planning and the optimal policy depends on
the relative positions of all balls. This applies even if rewards are factored (e.g., rewards in most pool
variants) [54].
We note that, although temporal entanglement may be exponential in the branching factor of the
unrolled causal graph, it’s possible for the long horizon structure to stay sparse (e.g.,
k
independent
factors that never interact, or long-horizon disentanglement between descision relevant and decision
irrelevant variables [
19
]). It’s also possible that other regularities in the data will allow for good
out-of-distribution generalization. Thus, we cannot claim that value functions and policies will
never generalize well out-of-distribution (see Veerapaneni et al.
[58]
for an example when they do).
Nevertheless, we hypothesize that exponentially fast entanglement does occur in complex natural
systems, making direct generalization of long horizon predictions difficult.
Out-of-distribution generalization of the policy and value function can be achieved, however, by
leveraging the generalization properties of a locally factored dynamics model. We propose to do this
by generating out-of-distribution states and actions (the “parent distribution”), and then applying our
learned dynamics model to generate transitions that are used to train the policy and value function.
We call this process Model-based Counterfactual Data Augmentation (MOCODA).
5
摘要:

MOCODA:Model-basedCounterfactualDataAugmentationSilviuPitis1ElliotCreager1AjayMandlekar2AnimeshGarg1;21UniversityofTorontoandVectorInstitute,2NVIDIAAbstractThenumberofstatesinadynamicprocessisexponentialinthenumberofobjects,makingreinforcementlearning(RL)difcultincomplex,multi-objectdomains.Forage...

展开>> 收起<<
MOCODA Model-based Counterfactual Data Augmentation Silviu Pitis1Elliot Creager1Ajay Mandlekar2Animesh Garg12.pdf

共23页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!

相关推荐

分类:图书资源 价格:10玖币 属性:23 页 大小:2.58MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 23
客服
关注