Logo

Lecture 9: Advanced Policy Gradients: Part 1

Return to Policy Gradients

So, why policy gradient? They converge, unlike QQ-function methods, and on-policy methods allow Monte Carlo advantage estimators (GAE with λ=1\lambda=1), which provides an unbiased estimator. Thus, policy gradients are stable, reliable methods when sample efficiency is not a concern (samples are cheap to generate).

Let's consider our usual policy gradient algorithm with GAE.

  1. Sample {τ(i)}\{ \tau^{(i)} \} from πθ(as)\pi_{\theta}(\mathbf{a}\mid \mathbf{s}) (run policy).
  2. Evaluate yi=r(st(i),at(i))+γV^ϕπ(st+1(i))y_{i}=r(\mathbf{s}_{t}^{(i)},\mathbf{a}_t^{(i)})+\gamma \hat{V}_{\phi}^{\pi}(\mathbf{s}_{t+1}^{(i)}).
  3. Refit V^ϕπ\hat{V}_{\phi}^{\pi} to targets {yt(i)}\{ y_{t}^{(i)} \}, minimizing L(ϕ)\mathcal{L}(\phi).
  4. Evaluate A^GAEπ(st(i),at(i))=t=t(γλ)ttδt(i)\hat{A}_{\text{GAE}}^{\pi}(\mathbf{s}_{t}^{(i)},\mathbf{a}_{t}^{(i)})=\sum_{t'=t}^{\infty}(\gamma\lambda)^{t'-t}\delta_{t'}^{(i)}.
  5. Compute θJ(θ)iθlogπθ(aisi)A^π(st(i),at(i))\nabla_{\theta}J(\theta)\approx \sum_{i}\nabla_{\theta}\log \pi_{\theta}(\mathbf{a}_{i}\mid \mathbf{s}_{i})\hat{A}^{\pi}(\mathbf{s}_{t}^{(i)},\mathbf{a}_{t}^{(i)}).
  6. Update θθ+αθJ(θ)\theta\leftarrow\theta+\alpha \nabla_{\theta}J(\theta).

As it is now, it is computationally intractable for practical purposes.

Really, it'd be nice to take multiple gradient steps for every sampling. In other words, repeat steps 5-6 some KK times for every execution of steps 1-4. However, this is problematic—the trajectories being used are produced by an old policy, not πθ\pi_{\theta}.

The primary goal of this lecture will be to derive an algorithm that can achieve this goal without issue.

Importance Sampling

Importance sampling is a technique that estimates an expected value under a distribution from which we did not sample on. This is exactly our problem :)

It's described mathematically as follows.

Exp(x)[f(x)]=p(x)f(x)dx=q(x)q(x)p(x)f(x)dx=q(x)p(x)q(x)f(x)dx=Exq(x)[p(x)q(x)f(x)]\begin{align*} \mathbb{E}_{x\sim p(x)}[f(x)] &= \int p(x)f(x) \,\mathrm{d}x \\ &= \int \frac{q(x)}{q(x)}p(x)f(x) \,\mathrm{d}x \\ &= \int q(x) \frac{p(x)}{q(x)}f(x) \,\mathrm{d}x \\ &= \mathbb{E}_{x\sim q(x)}\left[ \frac{p(x)}{q(x)}f(x) \right] \end{align*}

Thus, our objective function

J(θ)=Eτpθ(τ)[r(τ)]J(\theta)=\mathbb{E}_{\tau \sim p_{\theta}(\tau)}[r(\tau)]

may be written using samples from a different policy pˉ\bar{p} as

J(θ)=Eτpˉ(τ)[pθ(τ)pˉ(τ)r(τ)]J(\theta)=\mathbb{E}_{\tau \sim \bar{p}(\tau)}\left[ \frac{p_{\theta}(\tau)}{\bar{p}(\tau)}r(\tau) \right]

However, there's one critical issue with this formula. We do not assume knowledge of pθp_{\theta}, as pθ(τ)=p(s1)t=1Hπθ(atst)p(st+1st,at)p_{\theta}(\tau)=p(\mathbf{s}_{1})\prod_{t=1}^{H}\pi_{\theta}(\mathbf{a}_{t}\mid \mathbf{s}_{t})p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t}). In other words, it's dependent on state transition probabilities, which we may not know. However, consider that

pθ(τ)pˉ(τ)=p(s1)t=1Hπθ(atst)p(st+1st,at)p(s1)t=1Hπˉ(atst)p(st+1st,at)=t=1Hπθ(atst)πˉ(atst)\frac{p_{\theta}(\tau)}{\bar{p}(\tau)} = \frac{\cancel{ p(\mathbf{s}_{1}) }\prod_{t=1}^{H} \pi_{\theta}(\mathbf{a}_{t}\mid \mathbf{s}_{t})\cancel{ p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t}) }}{\cancel{ p(\mathbf{s}_{1}) }\prod_{t=1}^{H} \bar{\pi}(\mathbf{a}_{t}\mid \mathbf{s}_{t})\cancel{ p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t}) }} = \prod_{t=1}^{H} \frac{\pi_{\theta}(\mathbf{a}_{t}\mid \mathbf{s}_{t})}{\bar{\pi}(\mathbf{a}_{t}\mid \mathbf{s}_{t})}

This ratio may be computed using only knowledge of our policies!

Let's look at the gradients for a policy gradient algorithm with importance sampling. Define

J(θ)=Eτpθ(τ)[pθ(τ)pθ(τ)r(τ)]J(\theta')=\mathbb{E}_{\tau \sim p_{\theta}(\tau)}\left[ \frac{p_{\theta'}(\tau)}{p_{\theta}(\tau)}r(\tau) \right]

Then,

θJ(θ)=Eτpθ(τ)[θpθ(τ)pθ(τ)r(τ)]=Eτpθ(τ)[pθ(τ)pθ(τ)θlogpθ(τ)r(τ)]\nabla_{\theta'}J(\theta')=\mathbb{E}_{\tau \sim p_{\theta}(\tau)}\left[ \frac{\nabla_{\theta'}p_{\theta'}(\tau)}{p_{\theta}(\tau)}r(\tau) \right]=\mathbb{E}_{\tau \sim p_{\theta}(\tau)}\left[ \frac{p_{\theta'}(\tau)}{p_{\theta}(\tau)}\nabla_{\theta'}\log p_{\theta'}(\tau)r(\tau) \right]

where we once again use the identity that pθ(τ)θlogpθ(τ)=θpθ(τ)p_{\theta}(\tau)\nabla_{\theta}\log p_{\theta}(\tau)=\nabla_{\theta}p_{\theta}(\tau), only for θ=θ\theta=\theta' here. Note that this similar to the on-policy policy gradient expression, except with an extra term pθ(τ)pθ(τ)\frac{p_{\theta'}(\tau)}{p_{\theta}(\tau)} to account for the difference in policy.

We can substitute some terms in now (recall from policy gradient) to produce

θJ(θ)=Eτpθ(τ)[(t=1Hπθ(atst)πθ(atst))(t=1Hθlogπθ(atst))(t=1Hr(st,at))]\nabla_{\theta'}J(\theta') = \mathbb{E}_{\tau \sim p_{\theta}(\tau)}\left[ \left( \prod_{t=1}^{H} \frac{\pi_{\theta'}(\mathbf{a}_{t}\mid \mathbf{s}_{t})}{\pi_{\theta}(\mathbf{a}_{t}\mid \mathbf{s}_{t})} \right)\left( \sum_{t=1}^{H} \nabla_{\theta'}\log \pi_{\theta'}(\mathbf{a}_{t}\mid \mathbf{s}_{t}) \right)\left( \sum_{t=1}^{H} r(\mathbf{s}_{t}, \mathbf{a}_{t}) \right) \right]

which may be rewritten with causality as

θJ(θ)=Eτpθ(τ)[t=1Hθlogπθ(atst)(t=1tπθ(atst)πθ(atst))A(t=tHr(st,at)(t=ttπθ(atst)πθ(atst))B)]\nabla_{\theta'}J(\theta') = \mathbb{E}_{\tau \sim p_{\theta}(\tau)} \left[ \sum_{t=1}^{H} \nabla_{\theta'}\log \pi_{\theta'} (\mathbf{a}_{t}\mid \mathbf{s}_{t}) \underbrace{ \left( \prod_{t'=1}^{t} \frac{\pi_{\theta'}(\mathbf{a}_{t'}\mid \mathbf{s}_{t'})}{\pi_{\theta}(\mathbf{a}_{t'}\mid \mathbf{s}_{t'})} \right) }_{ A } \left( \sum_{t'=t}^{H} r( \mathbf{s}_{t'},\mathbf{a}_{t'}) \underbrace{ \cancel{ \left( \prod_{t''=t}^{t'} \frac{\pi_{\theta'}(\mathbf{a}_{t''}\mid \mathbf{s}_{t''})}{\pi_{\theta}(\mathbf{a}_{t''}\mid \mathbf{s}_{t''})} \right) } }_{ B } \right) \right]

where term AA implements causality to ensure future actions don't affect the current weight, and term BB represents the probability that this reward r(st,at)r(\mathbf{s}_{t'},\mathbf{a}_{t'}) represents the current policy. These are essentially the "weights" assigned to each reward.

info

Term BB is typically eliminated for, intuitively, the same motivation as the ideas from the [[Lecture 8#Double QQ-Learning|double Q-learning]] section, i.e. maximizing over an old policy is actually more effective. We explore this in more detail next lecture.

We'll focus more on term AA for the above reason. It turns out that term AA can actually prove very problematic—as a product over such long time intervals, with each term likely <1<1 since the trajectory produced by the old policy is likely to have a greater probability under the old policy compared to the new policy, it will vanish at an exponential rate as tt increases.

Now, let's rewrite this expectation in terms of sampling.

θJ(θ)1Ni=1Nt=1Hπθ(st(i),at(i))πθ(st(i),at(i))θlogπθ(at(i)st(i))A^t(i)\nabla_{\theta'}J(\theta') \approx \frac{1}{N}\sum_{i=1}^{N} \sum_{t=1}^{H} \frac{\pi_{\theta'}(\mathbf{s}_{t}^{(i)},\mathbf{a}_{t}^{(i)})}{\pi_{\theta}(\mathbf{s}_{t}^{(i)},\mathbf{a}_{t}^{(i)})}\nabla_{\theta'}\log \pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})\hat{A}_{t}^{(i)}

where πθ(st(i),at(i))\pi_{\theta}(\mathbf{s}_{t}^{(i)},\mathbf{a}_{t}^{(i)}) denotes the probability of the state-action pair (st(i),at(i))(\mathbf{s}_{t}^{(i)},\mathbf{a}_{t}^{(i)}) occurring under the policy πθ\pi_{\theta}, over the entire state-action space. This is known as a state-action marginal distribution, and is essentially the same expression as the previous importance sampling ratio product in expectation.

Subsequently, we can expand the state-action marginal

θJ(θ)1Ni=1Nt=1Hπθ(st(i))πθ(st(i))πθ(at(i)st(i))πθ(at(i)st(i))θlogπθ(at(i)st(i))A^t(i)\nabla_{\theta'}J(\theta') \approx \frac{1}{N}\sum_{i=1}^{N} \sum_{t=1}^{H} \cancel{ \frac{\pi_{\theta'}(\mathbf{s}_{t}^{(i)})}{\pi_{\theta}(\mathbf{s}_{t}^{(i)})} } \cdot\frac{\pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid\mathbf{s}_{t}^{(i)})}{\pi_{\theta}(\mathbf{a}_{t}^{(i)}\mid\mathbf{s}_{t}^{(i)})}\nabla_{\theta'}\log \pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})\hat{A}_{t}^{(i)}

and then essentially ignore the relative probabilities of state st(i)\mathbf{s}_{t}^{(i)} occurring under the different policies because we assume that θ\theta' is very similar to θ\theta, and therefore the state distribution will have shifted very little. In other words, we ignore the importance sampling ratios of the previous steps taken on the trajectory to arrive at this state because we assume the probability of arriving at this state, in expectation, should not change much between policies. Therefore, we only care about the importance sampling ratio of the action probabilities for the current state. We will explore this detail more precisely next lecture. Thus,

θJ(θ)1Ni=1Nt=1Hπθ(at(i)st(i))πθ(at(i)st(i))θlogπθ(at(i)st(i))A^t(i)\nabla_{\theta'}J(\theta')\approx \frac{1}{N}\sum_{i=1}^{N} \sum_{t=1}^{H} \frac{\pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})}{\pi_{\theta}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})}\nabla_{\theta'}\log \pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})\hat{A}_{t}^{(i)}

Now, finally, we can reconstruct our multi-step policy gradient algorithm with importance sampling.

  1. Sample {τ(i)}\{ \tau^{(i)} \} from πθ(as)\pi_{\theta}(\mathbf{a}\mid \mathbf{s}) (run policy).
  2. Evaluate yi=r(st(i),at(i))+γV^ϕπ(st+1(i))y_{i}=r(\mathbf{s}_{t}^{(i)},\mathbf{a}_t^{(i)})+\gamma \hat{V}_{\phi}^{\pi}(\mathbf{s}_{t+1}^{(i)}).
  3. Refit V^ϕπ\hat{V}_{\phi}^{\pi} to targets {yt(i)}\{ y_{t}^{(i)} \}, minimizing L(ϕ)\mathcal{L}(\phi).
  4. Evaluate A^GAEπ(st(i),at(i))=t=t(γλ)ttδt(i)\hat{A}_{\text{GAE}}^{\pi}(\mathbf{s}_{t}^{(i)},\mathbf{a}_{t}^{(i)})=\sum_{t'=t}^{\infty}(\gamma\lambda)^{t'-t}\delta_{t'}^{(i)}.
  5. θθ\theta'\leftarrow\theta
    6. θJ(θ)1Ni=1Nt=1Hπθ(at(i)st(i))πθ(at(i)st(i))θlogπθ(at(i)st(i))A^t(i)\nabla_{\theta'}J(\theta')\approx \frac{1}{N}\sum_{i=1}^{N} \sum_{t=1}^{H} \frac{\pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})}{\pi_{\theta}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})}\nabla_{\theta'}\log \pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})\hat{A}_{t}^{(i)}.
    7. θθ+αθJ(θ)\theta'\leftarrow\theta'+\alpha \nabla_{\theta'}J(\theta').
  6. θθ\theta\leftarrow\theta'.

where the inner loop is repeated some KK times.

Practical Importance Sampling

Clipped Importance Weights

One issue is that, once the importance weights πθ(at(i)st(i))πθ(at(i)st(i))\frac{\pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})}{\pi_{\theta}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})} stray too far from 1.01.0, the estimator is encouraged to over-prioritize good samples and over-penalize bad samples, causing high variance. The below figure shows what might happen.

importance-sampling-variance.png

We can fix this by clipping the importance weights. Let w(τ)=pθ(τ)pθ(τ)w(\tau)=\frac{p_{\theta'}(\tau)}{p_{\theta}(\tau)}. We clip by defining

w(τ)=max{1ϵ,min{1+ϵ,pθ(τ)pθ(τ)}}w(\tau)=\max \left\{ 1-\epsilon,\min \left\{ 1+\epsilon, \frac{p_{\theta'}(\tau)}{p_{\theta}(\tau)} \right\} \right\}

in other words, 1pθ(τ)pθ(τ)ϵ\lVert 1-\frac{p_{\theta}(\tau)}{p_{\theta}(\tau)} \rVert\leq\epsilon, which prevents the algorithm from weighting good samples too heavily.

There's one more detail to consider though. Consider a good sample with w(τg)<1+ϵw(\tau_{\text{g}})<1+\epsilon, and a bad sample with w(τb)1+ϵw(\tau_{\text{b}})\geq 1+\epsilon, such that the two samples are similar enough that increasing w(τg)w(\tau_{\text{g}}) will increase w(τb)w(\tau_{\text{b}}), and vice versa. Notably, the estimator is only ever encouraged to increase w(τ)w(\tau) (and thus increase the probability of choosing τ\tau under the new policy θ\theta') for these samples because w(τb)w(\tau_{\text{b}}) will stay constant, but w(τg)w(\tau_{\text{g}}) will increase, based on the clipping. However, this may result in the algorithm making a change that substantially increases the probability of τb\tau_{\text{b}} under the new policy for only marginal improvements in the probability of τg\tau_{\text{g}}, because in the perspective of the algorithm with importance weight clipping, this is always an improvement. Thus, the clipping is modified a little bit to be

J(θ)=Eτpθ(τ)[min{w(τ)r(τ),wc(τ)r(τ)}]w(τ)=pθ(τ)pθ(τ)wc(τ)=max{1ϵ,min{1+ϵ,pθ(τ)pθ(τ)}}\begin{align*} J(\theta') &= \mathbb{E}_{\tau \sim p_{\theta}(\tau)}[\min \{ w(\tau)r(\tau),w_{c}(\tau)r(\tau) \}] \\ w(\tau) &= \frac{p_{\theta'}(\tau)}{p_{\theta}(\tau)} \\ w_{c}(\tau) &= \max \left\{ 1-\epsilon,\min \left\{ 1+\epsilon, \frac{p_{\theta'}(\tau)}{p_{\theta}(\tau)} \right\} \right\} \end{align*}

where, essentially, w(τ)w(\tau) is clipped for good samples, but is not clipped for bad samples. In other words, the bad samples can become arbitrarily bad, but the good samples are always bounded by how good they can be.

Proximal Policy Optimization

Subsequently, we can now implement proximal policy optimization (PPO), or policy gradient with clipped importance sampling.

  1. Sample {τ(i)}\{ \tau^{(i)} \} from πθ(as)\pi_{\theta}(\mathbf{a}\mid \mathbf{s}) (run policy).

  2. Evaluate yt(i)=r(st(i),at(i))+γV^ϕπ(st+1(i))y_{t}^{(i)}=r(\mathbf{s}_{t}^{(i)},\mathbf{a}_t^{(i)})+\gamma \hat{V}_{\phi}^{\pi}(\mathbf{s}_{t+1}^{(i)}).

  3. Refit V^ϕπ\hat{V}_{\phi}^{\pi} to targets {yt(i)}\{ y_{t}^{(i)} \}, minimizing L(ϕ)\mathcal{L}(\phi).

  4. Evaluate A^GAEπ(st(i),at(i))=t=t(γλ)ttδt(i)\hat{A}_{\text{GAE}}^{\pi}(\mathbf{s}_{t}^{(i)},\mathbf{a}_{t}^{(i)})=\sum_{t'=t}^{\infty}(\gamma\lambda)^{t'-t}\delta_{t'}^{(i)}.

  5. θθ\theta'\leftarrow\theta.

    1. θJ(θ)θLCLIP(θ)\nabla_{\theta'}J(\theta')\approx \nabla_{\theta'}\mathcal{L}_{\text{CLIP}}(\theta')
    2. θθ+αθJ(θ)\theta'\leftarrow\theta'+\alpha \nabla_{\theta'}J(\theta').
  6. θθ\theta\leftarrow\theta'.

where we define

LCLIP(θ)=i=1Nt=1Hmin{πθ(at(i)st(i))πθ(at(i)st)(i)A^t(i),clip(πθ(at(i)st(i))πθ(at(i)st(i)),1ϵ,1+ϵ)A^t(i)}\mathcal{L}_{\text{CLIP}}(\theta')=\sum_{i=1}^{N} \sum_{t=1}^{H} \min \left\{ \frac{\pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})}{\pi_{\theta}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t})^{(i)}}\hat{A}_{t}^{(i)},\text{clip}\left( \frac{\pi_{\theta'}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})}{\pi_{\theta}(\mathbf{a}_{t}^{(i)}\mid \mathbf{s}_{t}^{(i)})},1-\epsilon,1+\epsilon \right)\hat{A}_{t}^{(i)} \right\}

Other Practical Details