Logo

Lecture 13: Control as Variational Inference

Theory

Recall that our "control as inference" derivation from Lecture 12 used the following value functions.

Qt(st,at)=r(st,at)+logE[exp(Vt+1(st+1))]Vt(st)=logexp(Qt(st,at))dat\begin{align*} Q_{t}(\mathbf{s}_{t},\mathbf{a}_{t}) &= r(\mathbf{s}_{t},\mathbf{a}_{t})+\log \mathbb{E}[\exp(V_{t+1}(\mathbf{s}_{t+1}))] \\ V_{t}(\mathbf{s}_{t}) &= \log \int \exp(Q_{t}(\mathbf{s}_{t},\mathbf{a}_{t})) \,\mathrm{d}\mathbf{a}_{t} \end{align*}

recall that, with stochastic environment dynamics, the second term of the QQ-function expression encourages excessive optimism regarding states with very large rewards/values, even if they occur with very small state transition probabilities. We'd like to eliminate this excessive optimism.

Let's first take a step back and consider: why does this occur in the first place?

Our inference problem is p(s1:T,a1:TO1:T)p(\mathbf{s}_{1:T},\mathbf{a}_{1:T}\mid \mathcal{O}_{1:T}), i.e. what is the probability of the trajectory τ\tau given that the monkey is being optimal? Through marginalization over the time steps tt and conditioning on the state st\mathbf{s}_{t}, we produce the policy p(atst,O1:T)p(\mathbf{a}_{t}\mid \mathbf{s}_{t},\mathcal{O}_{1:T}), i.e. what is the probability of the action at\mathbf{a}_{t} being chosen in state st\mathbf{s}_{t} given the optimality of the monkey? However, we may note that Ot\mathcal{O}_{t} is really just an indicator (evidence) of whether or not at\mathbf{a}_{t} was high reward or not; thus, this is really asking, "given that you obtained high reward, what was your action probability?" This, intuitively, is a good question to ask—after understanding that we obtained a high reward, we'd like to maximize our action probability for this high reward.

However, when performing inference over the computational graph, the action probabilities are not the only factors; critically, the state transition probabilities are involved too. In fact, we can similarly marginalize over tt and condition to produce p(st+1st,at,O1:T)p(st+1st,at)p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t},\mathcal{O}_{1:T})\neq p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t}). That is, "given that you obtained high reward, what was your transition probability?" And this is a bad question to ask! This is like, say, training our model to make a decision on whether or not to buy a lottery ticket after they already observed winning the lottery ticket in the future. However, just because it won this time does not mean that buying a lottery ticket is a good action in general—in essence, the model is learning action optimality from the fact that it "lucked out" from hitting rare state transition probabilities, and believes it can optimize this state transition probability to always happen.

Still confused?

Courtesy of Gemini, who clarified the lecture's idea for me. As you noted, standard inference asks: "Given that I succeeded, what must have happened?" If you condition a probabilistic graphical model on achieving high reward (O1:T=1\mathcal{O}_{1:T}=1), the model uses Bayes' rule to update the probability of all latent variables that caused that reward. Because both the actions at\mathbf{a}_{t}​ and the state transitions st+1\mathbf{s}_{t+1} are latent variables leading to the reward, the model updates both.

This leads to the model becoming "delusional." If it is in a stochastic environment where taking a risky action usually results in death but has a 1% chance of hitting a jackpot, standard inference looks at the jackpot and assumes the 1% chance is actually a certainty. It essentially assumes it can control the environment's dice rolls.

In short, we'd like to phrase the question as "given that you obtained high reward, what was your action probability, given that your transition probability did not change?" More generally, we'd like to find a way to tell the model to maximize the probability of success without changing the environment dynamics. Or, mathematically, we'd like to find another distribution q(s1:T,a1:T)q(\mathbf{s}_{1:T},\mathbf{a}_{1:T}) that is close to the posterior p(s1:T,a1:TO1:T)p(\mathbf{s}_{1:T},\mathbf{a}_{1:T}\mid \mathcal{O}_{1:T}) but has static dynamics p(st+1st,at)p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t}) (i.e. the model should not believe it is capable of maximizing state transition probabilities like action probabilities).

Well, we can actually apply variational inference to this problem! Let our observed variables be x=O1:T\mathbf{x}=\mathcal{O}_{1:T} and our latent variables be z=(s1:T,a1:T)\mathbf{z}=(\mathbf{s}_{1:T},\mathbf{a}_{1:T}). We'd like to find a q(z)q(\mathbf{z}) that approximates p(zx)p(\mathbf{z}\mid \mathbf{x}); in particular, we will choose to define

q(s1:T,a1:T)=p(s1)tp(st+1st,at)q(atst)q(\mathbf{s}_{1:T},\mathbf{a}_{1:T})=p(\mathbf{s}_{1})\prod_{t}p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t})q(\mathbf{a}_{t}\mid \mathbf{s}_{t})

Notably, it's very unusual to include p(s1)p(\mathbf{s}_{1}) and p(st+1st,at)p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t}) in the definition of qq, as, in variational inference, it's typically desirable to make qq a simple distribution (for a tractable approximation), and we have no guarantees about the simplicity of those distributions derived from the computational graph. We will immediately see why this is useful, though (induces a lot of cancellation below).

control-vi-graphs.png

Then, according to the ELBO (evidence lower bound)

logp(x)Ezq(z)[logp(x,z)logq(z)]logp(O1:T)E(s1:T,a1:T)q[(logp(s1)+t=1Tlogp(st+1st,at)+t=1Tlogp(Otst,at))(logp(s1)+t=1Tlogp(st+1st,at)+t=1Tlogq(atst))]=E(s1:T,a1:T)q[t=1Tlogp(Otst,at)logq(atst)]=E(s1:T,a1:T)q[tr(st,at)logq(atst)]\begin{align*} \log p(x) &\geq \mathbb{E}_{\mathbf{z}\sim q(\mathbf{z})}[\log p(\mathbf{x},\mathbf{z})-\log q(\mathbf{z})] \\ \log p(\mathcal{O}_{1:T}) &\geq \mathbb{E}_{(\mathbf{s}_{1:T},\mathbf{a}_{1:T})\sim q}\Bigg[ \left( \log p(\mathbf{s}_{1})+\sum_{t=1}^{T} \log p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t})+\sum_{t=1}^{T} \log p(\mathcal{O}_{t}\mid \mathbf{s}_{t},\mathbf{a}_{t}) \right) \\ &\qquad \qquad \qquad- \left(\log p(\mathbf{s}_{1})+\sum_{t=1}^{T}\log p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t})+\sum_{t=1}^{T}\log q(\mathbf{a}_{t}\mid \mathbf{s}_{t}) \right)\Bigg] \\ &= \mathbb{E}_{(\mathbf{s}_{1:T},\mathbf{a}_{1:T})\sim q}\left[\sum_{t=1}^{T}\log p(\mathcal{O}_{t}\mid \mathbf{s}_{t},\mathbf{a}_{t})-\log q(\mathbf{a}_{t}\mid \mathbf{s}_{t})\right] \\ &= \mathbb{E}_{(\mathbf{s}_{1:T},\mathbf{a}_{1:T})\sim q}\left[\sum_{t}r(\mathbf{s}_{t},\mathbf{a}_{t})-\log q(\mathbf{a}_{t}\mid \mathbf{s}_{t})\right] \end{align*}

For the last step, recall that we defined p(Otst,at)=exp(r(st,at))p(\mathcal{O}_{t}\mid \mathbf{s}_{t},\mathbf{a}_{t})=\exp(r(\mathbf{s}_{t},\mathbf{a}_{t})) last lecture. This allows us to subsequently conclude that

logp(O1:T)tE(st,at)q[r(st,at)+H(q(atst))]\log p(\mathcal{O}_{1:T}) \geq \sum_{t}\mathbb{E}_{(\mathbf{s}_{t},\mathbf{a}_{t})\sim q}[r(\mathbf{s}_{t},\mathbf{a}_{t})+\mathcal{H}(q(\mathbf{a}_{t}\mid \mathbf{s}_{t}))]

This looks pretty similar to what we've seen before; in fact, it's just our regular RL policy gradient objective plus the entropy of the action distribution! Crucially, though, the only thing within the expectation that is not fixed relative to qq (since the expectation is computed over qq) is the action probabilities! Our choice to hardcode the environment dynamics p(st)p(\mathbf{s}_{t}) and p(st+1st,at)p(\mathbf{s}_{t+1}\mid \mathbf{s}_{t},\mathbf{a}_{t}) into q(τ)q(\tau) cancels with the existing probabilities in p(τ,O1:T)p(\tau,\mathcal{O}_{1:T}). Thus, our model won't try and modify our state transition probabilities to maximize p(O1:T)p(\mathcal{O}_{1:T}); it's forced to optimize only q(atst)q(\mathbf{a}_{t}\mid \mathbf{s}_{t}), which is the only aspect the model actually controls.

Just one final note to close this off: we still need to solve for q(atst)q(\mathbf{a}_{t}\mid \mathbf{s}_{t}) for all tt. We can recurse backwards, like we did for backwards messages. We first compute the base case

q(aTsT)=argmaxq EsTq(sT)[EaTq(aTsT)[r(sT,aT)]+H(q(aTsT))]=exp(r(sT,aT))exp(r(sT,a))da,=exp(Q(sT,aT)V(sT))\begin{align*} q(\mathbf{a}_{T}\mid \mathbf{s}_{T}) &= \underset{q}{\arg\max}\ \mathbb{E}_{\mathbf{s}_{T}\sim q(\mathbf{s}_{T})}[\mathbb{E}_{\mathbf{a}_{T}\sim q(\mathbf{a}_{T}\mid \mathbf{s}_{T})}[r(\mathbf{s}_{T},\mathbf{a}_{T})]+\mathcal{H}(q(\mathbf{a}_{T}\mid \mathbf{s}_{T}))] \\ &= \frac{\exp(r(\mathbf{s}_{T},\mathbf{a}_{T}))}{\int \exp(r(\mathbf{s}_{T},\mathbf{a})) \,\mathrm{d}a, } = \exp(Q(\mathbf{s}_{T},\mathbf{a}_{T})-V(\mathbf{s}_{T})) \end{align*}

and then, with many steps of mathematical derivation, we can derive

q(atst)=exp(Q(st,at)V(st))q(\mathbf{a}_{t}\mid \mathbf{s}_{t})=\exp(Q(\mathbf{s}_{t},\mathbf{a}_{t})-V(\mathbf{s}_{t}))

where Q(st,at)=r(st,at)+E[Vt+1(st+1)]Q(\mathbf{s}_{t},\mathbf{a}_{t})=r(\mathbf{s}_{t},\mathbf{a}_{t})+\mathbb{E}[V_{t+1}(\mathbf{s}_{t+1})]. Notably, it is now a regular Bellman backup equation; it's no longer optimistic.

And thus, we can compute our value functions backwards through time with our recursive expressions. For t=T1t=T-1 to 11,

Qt(st,at)=r(st,at)+E[Vt+1(st+1)]Vt(st)=logexp(Qt(st,at))dat\begin{align*} Q_{t}(\mathbf{s}_{t},\mathbf{a}_{t}) &= r(\mathbf{s}_{t},\mathbf{a}_{t})+\mathbb{E}[V_{t+1}(\mathbf{s}_{t+1})] \\ V_{t}(\mathbf{s}_{t}) &= \log \int \exp(Q_{t}(\mathbf{s}_{t},\mathbf{a}_{t})) \,\mathrm{d}\mathbf{a}_{t} \end{align*}

and we now have a soft value iteration algorithm:

  1. Set Q(s,a)r(s,a)+E[V(s)]Q(\mathbf{s},\mathbf{a})\leftarrow r(\mathbf{s},\mathbf{a})+\mathbb{E}[V(\mathbf{s}')].
  2. Set V(s)soft maxa Q(s,a)V(\mathbf{s})\leftarrow\text{soft max}_{\mathbf{a}}\ Q(\mathbf{s},\mathbf{a}).

Various modifications like discounting and temperature may also be added.

Maximum Entropy RL Algorithms

Why use soft max?

The easy answer is that it encourages more exploration, though this does not capture the entire impact of the soft max. The true answer is that it encourages robustness to inaccurately specified MDPs, i.e. the addition of the entropy term H(q)\mathcal{H}(q) encourages the model to seek high rewards while acting as randomly as possible, eventually placing the learned policy in an action space that is robust to perturbations. In reality, though, the real answer is that it just works well in practice ;)

QQ-Learning with soft optimality

Standard QQ-learning uses the hard max over the QQ values to determine the value function, i.e. V(s)=maxaQϕ(s,a)V(\mathbf{s}')=\max_{\mathbf{a}'}Q_{\phi}(\mathbf{s}',\mathbf{a}'). For soft QQ-learning, we just change the hard max to soft max. Ultimately, this not a very popular algorithm because, for small action spaces, the hard max works just fine.

Policy Gradient with soft optimality

Recall that the control with variational inference strategy ended up with a policy gradient objective with an added action entropy term. The modification to policy gradient is precisely the same, just define J(θ)=tE(st,at)πθ[r(st,at)+H(πθ(atst))]J(\theta)=\sum_{t}\mathbb{E}_{(\mathbf{s}_{t},\mathbf{a}_{t})\sim \pi_{\theta}}[r(\mathbf{s}_{t},\mathbf{a}_{t})+\mathcal{H}(\pi_{\theta}(\mathbf{a}_{t}\mid \mathbf{s}_{t}))]. This is very commonly applied, because it's simple to add (entropy usually has a closed form expression for most policy classes) and policy gradient often collapses to a deterministic, non-exploratory policy too soon during training.

Soft Actor-Critic

This is the most well-known application of soft optimality. We modify two things:

yi=r(si,ai)+γEaπθ(asi)[Q^ϕπ(si,a)+H(πθ(asi))new]J(θ)=iEaπθ(asi)[Q^ϕπ(si,a)+H(πθ(asi))new]\begin{align*} y_{i} &= r(\mathbf{s}_{i},\mathbf{a}_{i})+\gamma \mathbb{E}_{\mathbf{a}'\sim \pi_{\theta}(\mathbf{a}'\mid \mathbf{s}_{i}')}[\hat{Q}_{\phi}^{\pi}(\mathbf{s}_{i}',\mathbf{a}')+\underbrace{ \mathcal{H}(\pi_{\theta}(\mathbf{a}'\mid \mathbf{s}_{i}')) }_{ \text{new} }] \\ J(\theta) &= \sum_{i}\mathbb{E}_{\mathbf{a}\sim \pi_{\theta}(\mathbf{a}\mid \mathbf{s}_{i})}[\hat{Q}_{\phi}^{\pi}(\mathbf{s}_{i},\mathbf{a})+\underbrace{ \mathcal{H}(\pi_{\theta}(\mathbf{a}\mid \mathbf{s}_{i})) }_{ \text{new} }] \end{align*}

where you may recall that the yiy_{i} are the target values used to train the critic Q^ϕπ\hat{Q}_{\phi}^{\pi}. Note that, conventionally, the entropy for the QQ-function is usually estimated, i.e.

yir(si,ai)+γEaπθ(asi)[Q^ϕπ(si,a)logπθ(asi)]y_{i}\approx r(\mathbf{s}_{i},\mathbf{a}_{i})+\gamma \mathbb{E}_{\mathbf{a}'\sim \pi_{\theta}(\mathbf{a}'\mid \mathbf{s}_{i}')}[\hat{Q}_{\phi}^{\pi}(\mathbf{s}_{i}',\mathbf{a}')-\log \pi_{\theta}(\mathbf{a}'\mid \mathbf{s}_{i}')]

Also, it's very common to add a "temperature" β\beta as a coefficient to the entropy terms.

Inverse Reinforcement Learning

In standard imitation learning, e.g. behavioral cloning, the model attempts to copy the actions the expert takes, without any reasoning about the action outcomes. In contrast, humans, when learning, attempt to copy the intent of the expert, and consequently may take very different actions due to their reasoning about the expert's intent. Thus, it's potentially very useful to learn reward functions from demonstrations, rather than have the developer define a reward function for the model to optimize.

Our previous soft optimality or control as inference framework is very helpful for this. Let's use the same probabilistic model with the same optimality variable; but now, instead of learning a policy based on a reward function, we will attempt to learn a reward function given trajectories sampled from a soft-optimal policy. In other words, we have

p(Otst,at,ψ)=exp(rψ(st,at))p(\mathcal{O}_{t}\mid \mathbf{s}_{t},\mathbf{a}_{t},\psi)=\exp(r_{\psi}(\mathbf{s}_{t},\mathbf{a}_{t}))

and consequently

p(τO1:T,ψ)p(τ)exp(trψ(st,at))p(\tau \mid \mathcal{O}_{1:T},\psi) \propto p(\tau)\exp\left( \sum_{t}r_{\psi}(\mathbf{s}_{t},\mathbf{a}_{t}) \right)

Note that there's a hidden denominator involved with the expression as well that's not shown in the proportionality equation. In particular, the denominator is ZZ, the partition function Z=p(τ)exp(rψ(τ))dτZ=\int p(\tau)\exp(r_{\psi}(\tau)) \,\mathrm{d}\tau.

Maximum likelihood learning tells us that the optimal ψ\psi is simply

ψ^=argmaxψ 1Ni=1Nlogp(τiO1:T,ψ)\hat{\psi}=\underset{\psi}{\arg\max}\ \frac{1}{N}\sum_{i=1}^{N} \log p(\tau_{i}\mid \mathcal{O}_{1:T},\psi)

Also, we may note that, when taking the gradient of p(τO1:T,ψ)p(\tau \mid \mathcal{O}_{1:T},\psi) w.r.t ψ\psi, p(τ)p(\tau) is irrelevant since it represents the environment dynamics and is thus independent of the reward parameters. Therefore, we're really considering

L=1Ni=1Nrψ(τi)logZnormalizer\mathcal{L}=\frac{1}{N}\sum_{i=1}^{N} r_{\psi}(\tau_{i})-\underbrace{ \log Z }_{ \text{normalizer} }

where logZ\log Z shows up due to the hidden ZZ denominator, as aforementioned.

ψL=1Ni=1Nψrψ(τi)1Zp(τ)exp(rψ(τ))p(τO1:T,ψ)ψrψ(τ)dτ=Eτπ(τ)[ψrψ(τi)]Eτp(τO1:T,ψ)[ψrψ(τ)]\begin{align*} \nabla_{\psi}\mathcal{L} &= \frac{1}{N}\sum_{i=1}^{N} \nabla_{\psi}r_{\psi}(\tau_{i})-\underbrace{ \frac{1}{Z}\int p(\tau)\exp(r_{\psi}(\tau)) }_{ p(\tau \mid \mathcal{O_{1:T},\psi}) }\nabla_{\psi}r_{\psi}(\tau) \,\mathrm{d}\tau \\ &= \mathbb{E}_{\tau \sim \pi^{*}(\tau)}[\nabla_{\psi}r_{\psi}(\tau_{i})]-\mathbb{E}_{\tau \sim p(\tau \mid \mathcal{O}_{1:T},\psi)}[\nabla_{\psi}r_{\psi}(\tau)] \end{align*}

Notably, the gradient is zero when π(τ)\pi^{*}(\tau), the policy of our expert that we estimate with samples, is equivalent to p(τO1:T,ψ)p(\tau \mid \mathcal{O}_{1:T},\psi), the soft optimal policy produced by following our current estimated reward function.

Additionally, because the partition function ZZ integrates over all possible trajectories, it is of course intractable to compute directly. Therefore, we must estimate it instead—we can use variational inference, which dictates that we should train a maximum entropy RL agent to approximate the intractable optimal policy for the current reward function rψr_{\psi}. Only then may we sample from this policy to produce trajectories τ\tau for the second term Eτp(τO1:T,ψ)[ψrψ(τ)]\mathbb{E}_{\tau \sim p(\tau \mid \mathcal{O}_{1:T},\psi)}[\nabla_{\psi}r_{\psi}(\tau)].