A Unified Framework for Alternating Offline Model Training and Policy Learning Shentao Yang1 Shujian Zhang1 Yihao Feng2 Mingyuan Zhou1

2025-04-27 0 0 1.02MB 30 页 10玖币
侵权投诉
A Unified Framework for Alternating Offline
Model Training and Policy Learning
Shentao Yang1, Shujian Zhang1, Yihao Feng2, Mingyuan Zhou1
1The University of Texas at Austin 2Salesforce Research
shentao.yang@mccombs.utexas.edu szhang19@utexas.edu
yihaof@salesforce.com mingyuan.zhou@mccombs.utexas.edu
Abstract
In offline model-based reinforcement learning (offline MBRL), we learn a dynamic
model from historically collected data, and subsequently utilize the learned model
and fixed datasets for policy learning, without further interacting with the envi-
ronment. Offline MBRL algorithms can improve the efficiency and stability of
policy learning over the model-free algorithms. However, in most of the existing
offline MBRL algorithms, the learning objectives for the dynamic models and the
policies are isolated from each other. Such an objective mismatch may lead to
inferior performance of the learned agents. In this paper, we address this issue
by developing an iterative offline MBRL framework, where we maximize a lower
bound of the true expected return, by alternating between dynamic-model training
and policy learning. With the proposed unified model-policy learning framework,
we achieve competitive performance on a wide range of continuous-control offline
reinforcement learning datasets. Source code is publicly released.
1 Introduction
Offline reinforcement learning (offline RL) [
1
,
2
], where the agents are trained from static and
pre-collected datasets, avoids direct interactions with the underlying real-world environment during
the learning process. Unlike traditional online RL, whose success largely depends on simulator-based
trial-and-error [e.g.,
3
,
4
], offline RL enables training policies for real-world applications, where it is
infeasible or even risky to collect online experimental data, such as robotics, advertisement, or dialog
systems [e.g.,
5
8
]. Though promising, it remains challenging to train agents under the offline setting,
due to the discrepancy between the distribution of the offline data and the state-action distribution
induced by the current learning policy. With such a discrepancy, directly transferring standard online
off-policy RL methods [e.g.,
9
11
] to the offline setting tends to be problematic [
12
,
13
], especially
when the offline data cannot sufficiently cover the state-action space [
14
]. To tackle this issue in offline
RL, recent works [e.g.,
15
17
] propose to approximate the policy-induced state-action distribution by
leveraging a learned dynamic model to draw imaginary rollouts. These additional synthetic rollouts
help mitigate the distributional discrepancy and stabilize the policy-learning algorithms under the
offline setting.
Most of the prior offline model-based RL (MBRL) methods [e.g.,
16
,
18
21
], however, first pretrain
a one-step forward dynamic model via maximum likelihood estimation (MLE) on the offline dataset,
and then use the learned model to train the policy, without further improving the dynamic model
during the policy learning process. As a result, the objective function used for model training (e.g.,
MLE) and the objective of model utilization are unrelated with each other. Specifically, the model is
trained to be “simply a mimic of the world,” but is used to improve the performance of the learned
policy [
22
24
]. Though such a training paradigm is historically rooted [
25
,
26
], this issue of objective
mismatch in the model training and model utilization has been identified as problematic in recent
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.05922v1 [cs.LG] 12 Oct 2022
works [
23
,
24
]. In offline MBRL, this issue is exacerbated, since the learned model can hardly be
globally accurate, due to the limited amount of offline data and the complexity of the control tasks.
Motivated by the objective-mismatch issue, we develop an iterative offline MBRL method, alternating
between training the dynamic model and the policy to maximize a lower bound of the true expected
return. This lower bound, leading to a weighted MLE objective for the dynamic-model training, is
relaxed to a tractable regularized objective for the policy learning. To train the dynamic model by
the proposed objective, we need to estimate the marginal importance weights (MIW) between the
offline-data distribution and the stationary state-action distribution of the current policy [
27
,
28
]. This
estimation tends to be unstable by standard approaches [e.g.,
29
,
30
], which require saddle-point
optimization. Instead, we propose a simple yet stable fixed-point-style method for MIW estimation,
which can be directly incorporated into our alternating training framework. With these considerations,
our method, offline Alternating Model-Policy Learning (AMPL), performs competitively on a wide
range of continuous-control offline RL datasets in the D4RL benchmark [
31
]. These empirical results
and ablation study show the efficacy of our proposed algorithmic designs.
2 Background
Markov decision process and offline RL.
A Markov decision process (MDP) is denoted by
M=
(S,A, P, r, γ, µ0)
, where
S
is the state space,
A
the action space,
P(s0|s, a) : S×S×A[0,1]
the environmental dynamic,
r(s, a) : S×A[rmax, rmax]
the reward function,
γ[0,1)
the
discount factor, and µ0(s) : S[0,1] the initial state-distribution.
For any policy
π(a|s)
, we denote its state-action distribution at timestep
t0
as
dP
π,t(s, a),
Pr (st=s, at=a|s0µ0, atπ, st+1 P, t0) .
The (discounted) stationary state-action
distribution of πis denoted as dP
π(s, a),(1 γ)P
t=0 γtdP
π,t(s, a).
Denote
QP
π(s, a) = Eπ,P [P
t=0 γtr(st, at)|s0=s, a0=a]
as the action-value function of policy
πunder the dynamic P. The goal of RL is to find a policy πmaximizing the expected return
J(π, P ),(1 γ)Esµ0,aπ(·|s)QP
π(s, a)=E(s,a)dP
π,γ [r(s, a)] .(1)
In offline RL, the policy
π
and critic
QP
π
are typically approximated by parametric functions
πφ
and
Qθ, respectively, with parameters φand θ. The critic Qθis trained by the Bellman backup
arg minθE(s,a,r,s0)∼Denv hQθ(s, a)r(s, a) + γEa0πφ(· | s0)[Qθ0(s0, a0)]2i,(2)
where
Qθ0
is the target network [
12
,
13
]. The actor
πφ
is trained in the policy improvement step by
arg maxφEs∼Denv, aπφ(· | s)[Qθ(s, a)] ,(3)
where Denv denotes the offline dataset drawn from dP
πb[2,32], with πbbeing the behavior policy.
Offline model-based RL.
In offline model-based RL algorithms, the true environmental dynamic
P
is typically approximated by a parametric function
b
P(s0|s, a)
in some function class
P
. With
the offline dataset Denv,b
Pis trained via the MLE [15,16,18] as
arg max b
P∈P E(s,a,s0)∼Denv hlog b
P(s0|s, a)i.(4)
Similarly, the reward function can be approximated by a parametric model
br
if assumed unknown.
With
b
P
and
br
, the true MDP
M
can be approximated by
c
M= (S,A,b
P , br, γ, µ0)
. We further define
dP
π(s, a)
as the stationary state-action distribution induced by
π
on
P
(or MDP
M
), and
d
b
P
π(s, a)
as that on the learned dynamic
b
P
(or MDP
c
M
). We approximate
dP
πφ
by simulating
πφ
on
c
M
for a
short horizon
h
starting from state
s∈ Denv
, as in prior work [e.g.,
16
,
18
,
19
,
21
]. The resulting
transitions are stored in a replay buffer
Dmodel
, constructed similar to the off-policy RL [
33
,
9
]. To
better approximate
dP
πφ
, sampling from
Denv
in Eqs. (2) and (3) is commonly replaced by sampling
from the augmented dataset
D=fDenv + (1 f)Dmodel, f [0,1]
, denoting sampling from
Denv
and Dmodel with probabilities fand 1f, respectively. We follow Yu et al. [18] to use f= 0.5.
2
3 Offline alternating model-policy learning
Our goal is to derive the objectives for both dynamic-model training and policy learning from a
principled perspective. A natural idea is to build a tractable lower bound for
J(π, P )
, the expected
return of the policy
π
under the true dynamic
P
, and then alternate between training the policy
π
and the dynamic model
b
P
to maximize this lower bound. Indeed, we can construct a lower bound as
J(π, P )J(π, b
P)− |J(π, P )J(π, b
P)|,(5)
where
J(π, b
P)
is the expected return of policy
π
under the learned model
b
P
. From the right hand
side (RHS) of Eq. (5), if the policy evaluation error
|J(π, P )J(π, b
P)|
is small,
J(π, b
P)
will be a
good proxy for the true expected return
J(π, P )
. We can empirically estimate
J(π, b
P)
via
b
P
and
π
.
Further, if a tractable upper bound for
|J(π, P )J(π, b
P)|
can be constructed, it can serve as a
unified training objective for both dynamic model
b
P
and policy
π
. We can then alternate between
optimizing the dynamic model
b
P
and the policy
π
to maximize the lower bound of
J(π, P )
,i.e.,
simultaneously minimizing the upper bound of the evaluation error
|J(π, P )J(π, b
P)|
. This gives
us an iterative, maximization-maximization algorithm for model and policy learning.
The following theorem indicates a tractable upper bound for
|J(π, P )J(π, b
P)|
, which can be
subsequently relaxed for model training and policy learning.
Theorem 1.
Let
P
be the true dynamic and
b
P
be the approximate dynamic model. Suppose the
reward function |r(s, a)| ≤ rmax, then we have
J(π, P )J(π, b
P)γ·rmax
2(1 γ)·qDπ(P,b
P),
with Dπ(P,b
P),E(s,a)dP
πbhω(s, a)KL P(s0|s, a)πb(a0|s0)|| b
P(s0|s, a)π(a0|s0)i,
where
πb
is the behavior policy,
dP
πb
is the offline-data distribution, and
ω(s, a),dP
π,γ (s,a)
dP
πb(s,a)
is
the marginal importance weight (MIW) between the offline-data distribution and the stationary
state-action distribution of the policy π[27,29].
Detailed proof of Theorem 1can be found in Appendix B.2.
The KL term in Dπ(P,b
P)indicates the following two principles for model and policy learning:
?
For the dynamic model,
s0b
P(·|s, a)
should be close to the true next state
˜s0P(·|s, a)
,
with
(s, a)
pairs drawn from the stationary state-action distribution of the policy
π
. Since we cannot
directly draw samples from
dP
π
, we reweight the offline data with
ω(s, a)
. This leads to a weighted
KL minimization objective for the model training.
?
For the policy
π
, the KL term indicates a regularization term, that the tuple
(s0, a0)
from
the joint conditional distribution
b
P(s0|s, a)π(a0|s0)
should be close to the tuple
(˜s0,˜a0)
from
P(˜s0|s, a)πb(˜a0|˜s0).(˜s0,˜a0)is simply a sample from the offline dataset.
Based on the above observations, we can fixed
π
and train the dynamic model
b
P
by minimizing
Dπ(P,b
P)
w.r.t.
b
P
. Similarly, we can fix the dynamic model
b
P
and learn a policy
π
to maximize the
lower bound of
J(π, P )
. This alternating training scheme provides a unified approach for model
and policy learning. In the following sections, we discuss how to optimize the dynamic model
b
P
, the
policy π, and the MIW ωunder our alternating training farmework.
3.1 Dynamic model training
Expanding the KL term in Dπ(P,b
P), we have
Dπ(P,b
P) =
,1
z }| {
E(s,a,s0,a0)dP
πb[ω(s, a) (log P(s0|s, a) + log πb(a0|s0)log π(a0|s0))]
E(s,a,s0)dP
πbhω(s, a) log b
P(s0|s, a)i,
3
where the tuple
(s, a, s0, a0)
is simply two consecutive state-action pairs in the offline dataset. Further,
if the policy
π
is fixed, the term
1
is a constant w.r.t.
b
P
. Thus, given the MIW
ω
, we can optimize
b
P
by minimizing the following loss
`(b
P),E(s,a,s0)dP
πbhω(s, a) log b
P(s0|s, a)i,(6)
which is an MLE objective weighted by ω(s, a). We discuss how to estimate ω(s, a)in Section 3.3.
3.2 Policy learning
The lower bound for J(π, P )implied by Theorem 1is
J(π, b
P)γ·rmax
2(1γ)·qDπ(P,b
P),(7)
where
J(π, b
P)
can be estimated via the action-value function similar to standard offline MBRL
algorithms [e.g.,
16
,
18
,
19
]. Thus, when the dynamic model
b
P
is fixed, the main difficulty is to
estimate the regularizer Dπ(P,b
P)for the policy π.
When the policy
π
is Gaussian, direct estimation of
Dπ(P,b
P)
is possible. Empirically, however,
it is helpful to learn the policy
π
in the class of implicit distribution, which is a richer distribution
class and can better maximize the action-value function. Specifically, given a noise distribution
pz(z)
,
action a=πφ(s, z)with zpz(·), where πφis a deterministic network.
Unfortunately, we can not directly estimate the KL term in
Dπ(P,b
P)
if
π
is an implicit policy,
since we can only draw samples from
π
but the density is unknown. A potential solution is to
use the dual representation of KL divergence
KL(p||q) = supTEp[T]log(Eq[eT])
[
34
], which
can be estimated with samples from the distributions
p
and
q
. However, the exponential function
therein makes the estimation unstable in practice [
30
]. We instead use the dual representation of
the Jensen–Shannon divergence (JSD) to approximate
Dπ(P,b
P)
, which can be approximately
minimized using the GAN structure [
35
,
36
]. Our framework can thus utilize the many stabilization
techniques developed in the GAN community (Appendix E.2.1).
Besides, we remove the MIW
ω(s, a)
during the policy training since we do not observe its empirical
benefits, which will be discussed in Section 4.2. Further applying the replacement of KL with JSD
and ignoring the
·
for numerical stability, we get an approximated new regularization for policy
π
:
e
Dπ(P,b
P),JSD P(s0|s, a)πb(a0|s0)dP
πb(s, a)|| b
P(s0|s, a)π(a0|s0)dP
πb(s)π(a|s).(8)
Informally speaking, Eq. (8) regularizes the imaginary rollouts of
π
on
b
P
towards state-action pairs
from the offline dataset. Intuitively,
e
Dπ(P,b
P)
is a more effective regularizer for policy training than
the original
Dπ(P,b
P)
, since
e
Dπ(P,b
P)
regularizes action choices at both
s
and
s0
. Appendix B.3
discusses how we move from Dπ(P,b
P)to e
Dπ(P,b
P)in detail.
3.3 Marginal importance weight training
A number of methods have been recently proposed to estimate the marginal importance weight
ω
[
27
,
29
,
30
]. These methods typically require solving a complex saddle-point optimization, casting
doubts on their training stability especially when combined with policy learning on continuous-control
offline MBRL problems. In this section, we mimic the Bellman backup to derive a fixed-point-style
method for estimating the MIW.
Denote the true MIW as
ω(s, a),dP
π,γ (s,a)
dP
πb(s,a)
, we have
dP
πb(s, a)·ω(s, a) = dP
π(s, a)
. Expanding
the RHS, s0, a0,
dP
πb(s0, a0)ω(s0, a0) = γX
s,a
π(a0|s0)P(s0|s, a)ω(s, a)dP
πb(s, a) + (1 γ)µ0(s0)π(a0|s0).(9)
The derivation is deferred to Appendix B.4. Therefore, a “Bellman equation” for ω(s0, a0)is
ω(s0, a0) = Tω(s0, a0),
Tω(s0, a0),γPs,a π(a0|s0)P(s0|s, a)ω(s, a)dP
πb(s, a) + (1 γ)µ0(s0)π(a0|s0)
dP
πb(s0, a0).
4
Here
T
can be viewed as the “Bellman operator” for
ω
. The update iterate defined by
T
has the
following convergence property, which is proved in Appendix B.5.
Proposition 2.
On finite state-action space, if the current policy
π
is close to the behavior policy
πb
,
then the iterate for ωdefined by Tconverges geometrically.
The assumption that
π
is close to
πb
coincides with the regularization term in the policy-learning
objective discussed in Section 3.2.
Unfortunately, the RHS of Eq. (9) is not estimable since we do not know the density values therein.
We therefore multiply both sides of Eq. (9) by some test function and subsequently sum over
(s0, a0)
on both sides to get a tractable objective that only requires samples from the offline dataset and
the initial state-distribution
µ0
. It is desired to choose a test function that can better distinguish
the difference between the left-hand side (LHS) and the RHS of Eq. (9). A potential choice is the
action-value function of the policy
π
, due to some primal-dual relationship between the stationary
state-action density-(ratio) (
dP
π(s, a)
or
ω(s, a)
) and the action-value function [
37
39
]. A detailed
discussion on the choice of the test function is provided in Appendix C.
Practically we use
Q
b
P
π
as the test function. Note that multiplying both sides of Eq. (9) by the same
Q
b
P
π
does not undermine the convergence property, under mild conditions on Q. Mimicking the Bellman
backup to sum over
(s0, a0)
on both sides, with the notation
dP
πb(s, a, s0) = dP
πb(s, a)P(s0|s, a)
,
`1(ω)
z }| {
E(s,a)dP
πbhω(s, a)·Q
b
P
π(s, a)i=
`2(ω)
z }| {
γE(s,a,s0)dP
πb
a0π(·|s0)hω(s, a)·Q
b
P
π(s0, a0)i+ (1 γ)Esµ0(·)
aπ(·| s)hQ
b
P
π(s, a)i.(10)
Thus for a given
ω
, we can optimize
ω
by minimizing the difference between the RHS and the LHS of
Eq. (10). For training stability, we use a target network ω0(s, a)for the RHS, and the final objective
for learning ωis
(`1(ω)`2(ω0))2,(11)
where the target network
ω0(s, a)
is soft-updated after each gradient step, motivated by
Qθ0
j
and
πφ0
.
Our proposed training method is closely related to VPM [
40
]. By using the MIW
ω(s, a)
itself as the
test function, VPM leverages the variational power iteration to train MIW iteratively. Instead, our
approach uses the current action-value function as the test function, motivated by the primal-dual
relationship between the MIW and the action-value function in off-policy evaluation. We compare the
empirical performance of several alternative approaches in Section 4.2 and in Table 3of Appendix A.
3.4 Practical implementation
In this section we briefly discuss some implementation details of our offline Alternating Model-Policy
Learning (AMPL) method, whose main steps are in Algorithm 1. Further details are in Appendix E.1.
Dynamic model training.
We adopt common practice in offline MBRL [e.g.,
16
,
18
] to use an
ensemble of Gaussian probabilistic networks
b
P(·|s, a)
and
ˆr(s, a)
to parameterize the stochastic
transition and reward. We initialize the dynamic model by standard MLE training, and periodically
update the model by minimizing Eq. (6).
Critic training. We use the conservative target in the offline RL literature [e.g.,12,13]:
e
Q(s, a),r(s, a) + γEa0πφ0(·|s0)[cminj=1,2Qθ0
j(s0, a0) + (1 c) maxj=1,2Qθ0
j(s0, a0)] ,(12)
where we set
c= 0.75
. With mini-batch
B
sampled from the augmented dataset
D
, both critic
networks are trained as
j= 1,2,arg minθj
1
|B| P(s,a)∈B Huber(Qθj(s, a),e
Q(s, a)) ,(13)
where the Huber loss Huber(·)is used in lieu of the classical MSE for training stability [41].
Estimating e
Dπ(P,b
P).
In
e
Dπ(P,b
P)
, using the notations of GAN, we denote the sample from
the left distribution of JSD in Eq. (8) by
Btrue
(i.e., “true” sample), and the sample from the right
distribution of JSD by Bfake (i.e., “fake” sample).
5
摘要:

AUniedFrameworkforAlternatingOfineModelTrainingandPolicyLearningShentaoYang1,ShujianZhang1,YihaoFeng2,MingyuanZhou11TheUniversityofTexasatAustin2SalesforceResearchshentao.yang@mccombs.utexas.eduszhang19@utexas.eduyihaof@salesforce.commingyuan.zhou@mccombs.utexas.eduAbstractInofinemodel-basedreinf...

展开>> 收起<<
A Unified Framework for Alternating Offline Model Training and Policy Learning Shentao Yang1 Shujian Zhang1 Yihao Feng2 Mingyuan Zhou1.pdf

共30页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:30 页 大小:1.02MB 格式:PDF 时间:2025-04-27

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 30
客服
关注