Learning Robust Dynamics through Variational Sparse Gating Arnav Kumar Jain12 Shivakanth Sujit23 Shruti Joshi23 Vincent Michalski12

2025-05-02 0 0 2.55MB 27 页 10玖币
侵权投诉
Learning Robust Dynamics through
Variational Sparse Gating
Arnav Kumar Jain1,2,, Shivakanth Sujit2,3, Shruti Joshi2,3, Vincent Michalski1,2
Danijar Hafner4,5,Samira Ebrahimi-Kahou2,3,6
Abstract
Learning world models from their sensory inputs enables agents to plan for actions
by imagining their future outcomes. World models have previously been shown to
improve sample-efficiency in simulated environments with few objects, but have not
yet been applied successfully to environments with many objects. In environments
with many objects, often only a small number of them are moving or interacting
at the same time. In this paper, we investigate integrating this inductive bias of
sparse interactions into the latent dynamics of world models trained from pixels.
First, we introduce Variational Sparse Gating (VSG), a latent dynamics model that
updates its feature dimensions sparsely through stochastic binary gates. Moreover,
we propose a simplified architecture Simple Variational Sparse Gating (SVSG)
that removes the deterministic pathway of previous models, resulting in a fully
stochastic transition function that leverages the VSG mechanism. We evaluate the
two model architectures in the BringBackShapes (BBS) environment that features
a large number of moving objects and partial observability, demonstrating clear
improvements over prior models.
1 Introduction
Latent dynamics models are models that generate agent’s future states in the compact latent space
without feeding the high-dimensional observations back to the model. They have shown promising
results on various tasks like video prediction (Karl et al.,2016;Kalman,1960;Krishnan et al.,2015),
model-based Reinforcement Learning (RL) (Hafner et al.,2020;2021;2019;Ha and Schmidhuber,
2018), and robotics (Watter et al.,2015). Generating sequences in the compact latent space reduces
the accumulating errors leading to more accurate long-term predictions. Additionally, having lower
dimensionality leads to a lower memory footprint. Solving tasks in model-based RL involves learning
a world model (Ha and Schmidhuber,2018) that can predict outcomes of actions, followed by
using them to derive behaviors (Sutton,1991). Motivated by these benefits, the recently proposed
DreamerV1 (Hafner et al.,2020) and DreamerV2 (Hafner et al.,2021) agents achieved state-of-the-art
results on a wide range of visual control tasks.
Many complex tasks require reliable long-term prediction of dynamics. This is true especially in
partially observable environments where only a subspace is visible to the agent, and it is usually
required to accurately retain information over multiple time steps to solve the task. The Dreamer
agents (Hafner et al.,2020;2021) employ an Recurrent State-Space Model (RSSM) (Hafner et al.,
2019) comprising of a Recurrent Neural Network (RNN). Training RNNs for long sequences is chal-
lenging as they suffer from optimization problems like vanishing gradients (Hochreiter,1991;Bengio
et al.,1994). Different ways of applying sparse updates in RNNs have been investigated (Campos
et al.,2017;Neil et al.,2016;Goyal et al.,2019), enabling a subset of state dimensions to be constant
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
1
Université de Montréal,
2
Mila- Quebec AI Institute,
3
École de technologie supérieure,
4
University of Toronto,
5Google Brain, 6CIFAR. Correspondence to Arnav Kumar Jain <arnav-kumar.jain@mila.quebec>.
arXiv:2210.11698v1 [cs.LG] 21 Oct 2022
during the update. A sparse update prior can also be motivated by the fact that in the real world, many
factors of variation are constant over extended periods of time. For instance, several objects in a
physical simulation may be stationary until some force acts upon them. Additionally, this is useful in
the partially observable setting where the agent observes a constrained viewpoint and has to keep
track of objects that are not visible for many time steps. In this work, we introduce Variational Sparse
Gating (VSG), a stochastic gating mechanism that sparsely updates the latent states at each step.
Recurrent State-Space Model (RSSM) (Hafner et al.,2019) was introduced in PLaNet where the
model state was composed of two paths, an image representation path and a recurrent path. Dream-
erV1 (Hafner et al.,2020) and DreamerV2 (Hafner et al.,2021) utilized them to achieve state-
of-the-art results in continuous and discrete control tasks (Hafner et al.,2019). While the image
representation path which is stochastic accounts for multiple possible future states, the recurrent path
is deterministic to retain information over multiple time steps to facilitate gradient-based optimiza-
tion. (Hafner et al.,2019) showed that both components were important for solving tasks, where
the stochastic part was more important to account for partial observability of the initial states. By
leveraging the proposed gating mechanism (Variational Sparse Gating (VSG)), we demonstrate that a
purely stochastic model with a single component can achieve competitive results, and call it Simple
Variational Sparse Gating (SVSG). To the best of our knowledge, this is the first work that shows
that purely stochastic models achieve competitive performance on continuous control tasks when
compared to leading agents.
Existing benchmarks (Bellemare et al.,2013;Chevalier-Boisvert et al.,2018;Tassa et al.,2018) for RL
do not test the capability of agents in both partial observability and stochasticity. The Atari (Bellemare
et al.,2013) benchmark comprises of 55 games but most of the games are deterministic and a lot of
compute is required to train on them. Some tasks in the Atari and Minigrid benchmarks are partially-
observable but either lack stochasticity or are hard exploration tasks. Also, these benchmarks do not
allow for controlling the factors of variation. We developed a new partially-observable and stochastic
environment, called BringBackShapes (BBS), where the task is to push objects to a predefined goal
area. Solving tasks in BBS require agents to remember states of previously observed objects and
avoid noisy distractor objects. Furthermore, VSG and SVSG outperformed leading model-based and
model-free baselines. We also present studies with varying partial-observability and stochasticity to
demonstrate that the proposed agents have better memory for tracking observed objects and are more
robust to increasing levels of noise. Lastly, the proposed methods were also evaluated on existing
benchmarks - DeepMind Control (DMC) (Tassa et al.,2018), DMC with Natural Background (Zhang
et al.,2021;Nguyen et al.,2021b), and Atari (Bellemare et al.,2013). On the existing benchmarks,
the proposed method performed better on tasks with changing viewpoints and sparse rewards.
Our key contributions are summarized as follows:
Variational Sparse Gating
: We introduce Variational Sparse Gating (VSG), where the recurrent
states are sparsely updated through a stochastic gating mechanism. A comprehensive empirical
evaluation shows that VSG outperforms baselines on tasks requiring long-term memory.
Simple Variational Sparse Gating
: We also propose Simple Variational Sparse Gating (SVSG)
which has a purely stochastic state, and achieves competitive results on continuous control tasks
when compared with agents that also use a deterministic component.
BringBackShapes
: We developed the BringBackShapes (BBS) environment to evaluate agents
on partially-observable and stochastic settings where these variations can be controlled. Our
experiments show that the proposed agents are more robust to such variations.
2 Variational Sparse Gating
Reinforcement Learning
: The visual control task can be formulated as a Partially Observable
Markov Decision Process (POMDP) with discrete time steps
t[1; T]
. The agent selects action
atp(at|ot, a<t
) to interact with the environment and receives the next observation and scalar
reward
ot, rtp(ot, rt|o<t, r<t
), respectively, at each time step. The goal is to learn a policy that
maximizes the expected discounted sum of rewards Ep(PT
t=1 γtrt), where γis the discount factor.
Agent
: Agent is composed of a world model and a policy (Fig. 1). World models (Sec. 2.1) encode a
sequence of observations and actions into latent representations. The agents behavior (Appendix B)
is derived to maximize expected returns on the trajectories generated from the learned world model.
While training, the world model is learned with collected experience, the policy is improved on
2
𝑠1𝑠2𝑠3
𝑠0
𝑎1
Ƹ𝑟
1
𝑥1𝑥2𝑥3
𝑥1
𝑎0𝑎2
Ƹ𝑟
2
𝑥2
VSG
VSG
VSG
Ƹ𝑟
3
𝑥3
(a)
𝑠1
𝑠0
Ƹ𝑟
1
𝑥1
Ƹ𝑟
2
VSG
VSG
VSG
Ƹ𝑟
3
𝑎1𝑎2𝑎3
𝑣1𝑣2𝑣3
Ƹ𝑠2Ƹ𝑠3
(b)
Figure 1: (a) World Model: The VSG block takes the previous model state
st1
and action
at1
, and
outputs the updated model state at next step
st
, which is further used to reconstruct image
ˆxt
and
reward
ˆrt
. (b) Policy: Comprises of an actor to select optimal action
ˆat
and critic to predict value
ˆvt
beyond the planning horizon. The world model is unrolled using the prior model state
ˆst
which does
not contain information about image xt.
trajectories unrolled using the world model and new episodes are collected by deploying the policy
in the environment. An initial set of episodes are collected using a random policy. As training
progresses, new episodes are collected using the latest policy to further improve the world model.
2.1 World Model
World Models (Ha and Schmidhuber,2018) learn to mimic the environment using the collected
experience and facilitate deriving behaviours in the abstract latent space. Given an abstract state of the
world and an action, the model applies the learned transition dynamics to predict the resulting next
state and reward. RSSM (Hafner et al.,2019) was introduced in PlaNet, where the model state was
composed of two paths. The recurrent path consists of an RNN (See Figure 2[a]), and is motivated
with reliable long-term information preservation, while the image representation path samples from a
learned distribution to account for multiple possible futures (Babaeizadeh et al.,2017). In this work,
we introduce Variational Sparse Gating (VSG), where the recurrent path selectively updates a subset
of the latent states at each step using a stochastic gating network. Sparse updates enable the agent to
have long-term memory and learn robust representations to solve complex tasks.
Model Components
: The world model comprises of an image encoder, a VSG model, and predictors
for image, discount and reward. The image encoder generates representations
ot
for the observation
xt
using Convolutional Neural Networks (CNNs). The VSG model comprises of a recurrent model
equipped with the stochastic gating mechanism to get the recurrent state
ht
, and is used to compute
two stochastic image representation states. The posterior representation state
zt
is obtained using
the representation model and contains information about the current observation
xt
. The prior
state
ˆzt
is obtained from the transition predictor without observing the current observation
xt
. This is
useful while planning as sequences are generated in compact latent state, and the output from the
transition predictor is utilized. This also results in a lower memory footprint and enables predictions
of thousands of trajectories in parallel on a single GPU. The representation states are sampled from a
known distribution with learned parameters like Gaussian (Hafner et al.,2020) or Categorical (Hafner
et al.,2021). The concatenation of outputs from the recurrent and image representation models gives
the compact model state (
st= [ht, zt]
). The posterior model state is further used to reconstruct the
original image
ˆxt
, predict the reward
ˆrt
, and discount factor
ˆγt
. The discount factor helps to predict
the probability that an episode will end. The components of the world model are as follows:
Recurrent model: ht=fφ(ht1, zt1, at1)
Representation model: ztqφ(zt|ht, xt)
Transition predictor: ˆztpφ(ˆzt|ht)
Image predictor: ˆxtpφ(ˆxt|ht, zt)
Reward predictor: ˆrtpφ(ˆrt|ht, zt)
Discount predictor: ˆγtpφγt|ht, zt).
(1)
3

tanh

tanh


tanh

(a)

tanh

tanh


tanh

(b)

tanh

tanh


tanh

(c)
Figure 2: Architectures of (a) Recurrent State-Space Model (RSSM), (b) Variational Sparse Gating
(VSG), and (c) Simple Variational Sparse Gating (SVSG), respectively.
σ
and
tanh
denote the
sigmoid and tanh non-linear activations, respectively.
W
and
b
are the corresponding weights
and biases.
,
and
denote sampling, vector concatenation, and element-wise multiplication,
respectively.
M
computes
xt=ut˜xt+ (1 ut)xt1
, where
xt=ht
is used for RSSM and VSG,
and
xt=st
is used for SVSG.
B
denotes Bernoulli distribution.
fp
and
fq
denote the prior and
posterior distributions with learned parameters, respectively (See Appendix Ifor more details).
Neural Networks
: The representation model outputs the posterior image representation state
zt
conditioned on the image encoding
xt
and recurrent state
ht
. The transition predictor provides the
prior image representation state
ˆzt
. The image encoding
ot
is obtained by passing the image
xt
through CNN (LeCun et al.,1989) and Multi-layer Perceptron (MLP) layers. In VSG, we propose
to modify the Gated Recurrent Unit (GRU) used in RSSM to sparsely update the recurrent state at
each step. The model state
st
, which is a concatenation of recurrent and image representation states
is passed through several layers of MLP to predict the discount and reward, and transposed CNN
layers are used to reconstruct the image. The Exponential Linear Unit (ELU) activation is used for
training all the components of the world model (Clevert et al.,2015).
Sparse Gating
: In light of training RNNs to capture long-term dependencies, different ways of
applying sparse updates have been investigated (Campos et al.,2017;Neil et al.,2016;Goyal
et al.,2019), enabling a subset of state dimensions to be constant during the update. They were
found to alleviate the vanishing gradient problem by effectively reducing the number of sequential
operations (Campos et al.,2017). Discrete gates may also improve long-term memory by avoiding
the gradual change of state values introduced by repeated multiplication with continuous gate values
in standard recurrent architectures. Previous works on sparsely updating hidden states (Campos et al.,
2017;Neil et al.,2016) use a separate layer applied over the outputs of RNN, and do not modify the
RNN in itself. However, in this work, we modify the update gate in GRU (Cho et al.,2014) to take
binary values by sampling from a Bernoulli distribution (Fig. 2[b] shows the architecture).
The input
it
to the recurrent model contains information about the action and is obtained by concate-
nating the previous image representation state
zt1
and action
at
followed by passing them through
a MLP layer. Similar to GRU (Cho et al.,2014), there is a reset and update gate. The reset gate
vt
decides the extent of information flow from the previous recurrent state and inputs, and the update
gate
ut
tells which parts of the recurrent state will be updated. The update gate takes only binary
values, selecting whether the value will be updated or copied from previous time step. Binary values
are obtained by sampling from a Bernoulli distribution where the probability of sampling is obtained
using the previous recurrent state
ht1
and input
it
. Straight-through estimators (Bengio et al.,2013)
were used for propagating gradients backwards for training. The update equations are:
vt=σ(WT
v[ht1, it] + bv)
˜ut=σ(WT
u[ht1, it] + bu)
˜
ht= tanh(vt(WT
c[ht1, it] + bc))
utBernoulli(˜ut)
ht=ut˜
ht+ (1 ut)ht1,
(2)
where
denotes element-wise multiplication,
σ
and
tanh
are the sigmoid and hyperbolic tangent
activation function, and
W
and
b
denotes the weights and biases, respectively. To control the
sparsity of updates, we have used KL divergence between probability of sampling the update gate
˜ut
and a fixed prior probability κ, where κis a tunable hyperparameter.
Loss function
: The predictors for image and reward produces Gaussian distributions with unit
variance, whereas the discount predictor predicts a Bernoulli likelihood. The image representation
4
states are sampled from a Gaussian (Hafner et al.,2020) or a Categorical (Hafner et al.,2021)
distribution which are trained to maximize the likelihood of targets. In addition, there is a KL
Divergence term between the prior and posterior distributions and similar to DreamerV2 (Hafner
et al.,2021), we have also used KL balancing with a factor of 0.8. We have also added a sparsity loss
to regularize the number of updates in hidden state at each step. All the components of the world
model are optimized jointly using the loss function given by:
L(φ).
= Eqφ(z1:T|a1:T,x1:T)hPT
t=1 ln pφ(xt|ht, zt)
image log loss
ln pφ(rt|ht, zt)
reward log loss
ln pφ(γt|ht, zt)
discount log loss
+βKLqφ(zt|ht, xt)
pφ(zt|ht)
KL loss
+αKL˜ut
κ
sparsity loss i,
(3)
where
β
and
α
are the scale for KL losses of the latent codes and the sparse update gates, respectively.
3 Simple Variational Sparse Gating
Stochastic State-Space Model (SSM) were proposed in PLaNet (Hafner et al.,2019), where it was
discussed that it is not trivial to achieve competitive results without the deterministic recurrent path.
Having a deterministic component was motivated to allow the transition model to retain information
for multiple time steps as the stochastic component induces variance (Hafner et al.,2019). In this
work, we show that having a purely stochastic component achieves comparable performance with
DreamerV2 while significantly outperforming SSMs (refer to Appendix Hfor more details). We
introduce a simplified version of VSG, called Simple Variational Sparse Gating (SVSG) where the
world model has a model state with single path to preserve information over multiple steps and also
account for partial observability in future states (Fig. 2[c] presents the SVSG architecture).
Model Components
: In SVSG, there is no recurrent model and the posterior state
st
is obtained
using the representation model by conditioning on the previous state
st1
, input image
xt
and the
action
at
. Similar to VSG, there is a transition predictor that returns the prior state
ˆst
which does
not use the current image observation to imagine trajectories in the latent space. Both the modules
sparsely update the model state at each step using the stochastic gating mechanism proposed in VSG.
We have used a Gaussian distribution for the stochastic state with a learnable mean vector and a
learnable diagonal covariance matrix. Similar to VSG, the posterior state is used to reconstruct the
image, and predict the reward and discount factor. The components of world model in SVSG are:
Representation model: stqφ(st|st1, xt, at)
Transition predictor: ˆstpφ(ˆst|st1, at)
Image predictor: ˆxtpφ(ˆxt|st)
Reward predictor: ˆrtpφ(ˆrt|st)
Discount predictor: ˆγtpφγt|st).
(4)
The representation model
qφ
and transition predictor
pφ
are modified to output the posterior
st
and
prior
ˆst
states, respectively. The reset gate
vt
and the update gate
˜ut
is calculated using the previous
state
st1
and input
it
which has the information about the action
at
. The candidate state
˜st
at each
step is obtained using input
it
, reset gate
vt
and previous state
st1
. Similar to VSG, the update gate
ut
is sampled from a Bernoulli distribution to sparsely update the latent states at each step, given by:
vt=σ(WT
v[st1, it] + bv)
˜ut=σ(WT
u[st1, it] + bu)
˜st= tanh(vt(WT
c[st1, it] + bc))
utBernoulli(˜ut),
(5)
where
denotes the element-wise multiplication,
σ
and
tanh
are the sigmoid and hyperbolic tangent
activation functions, and Wand bdenote the weights and biases, respectively.
The candidate state
˜st
is feeded through MLP layers to get the prior and posterior distributions. The
image encoding
xt
was used to get posterior distribution, whereas the prior distribution was predicted
without it. The prior
ˆzt
and posterior
zt
candidate states are sampled from these distributions, where
5
摘要:

LearningRobustDynamicsthroughVariationalSparseGatingArnavKumarJain1;2;,ShivakanthSujit2;3,ShrutiJoshi2;3,VincentMichalski1;2DanijarHafner4;5,SamiraEbrahimi-Kahou2;3;6AbstractLearningworldmodelsfromtheirsensoryinputsenablesagentstoplanforactionsbyimaginingtheirfutureoutcomes.Worldmodelshavepreviousl...

展开>> 收起<<
Learning Robust Dynamics through Variational Sparse Gating Arnav Kumar Jain12 Shivakanth Sujit23 Shruti Joshi23 Vincent Michalski12.pdf

共27页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:27 页 大小:2.55MB 格式:PDF 时间:2025-05-02

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 27
客服
关注