Lecture 11: Variational Inference
"Variational inference is an optimization-based framework for training models with latent variables."
Latent Variable Models
Latent variables are learned as mixture elements that turn probabilistic models into latent variable models whose distribution is represented as a mixture of multiple distributions, i.e.
p(x)p(y∣x)=z∑p(x∣z)p(z)=z∑p(y∣x,z)p(z)
One possible implementation of a continuous latent variable model p(x)=∫p(x∣z)p(z)dz is the nonlinear composition of two "easy" distributions, e.g.
p(x∣z)p(z)=N(μnn(z),αnn(z))=N(0,1)
where the nn subscript implies a neural network. This is nice because it's easy to use our "easy" distributions.
Sidenote: this is applicable in RL for
- Conditional latent variable models for multi-modal policies
- Latent variable models for model-based RL (where you learn the MDP of the environment itself, more on this later in the course)
So how do we train latent variable models? Well, using maximum likelihood gives us
p(x)θ=∫p(x∣z)p(z)dz←θargmax N1i∑log(∫p(x∣z)p(z)dz)
This is intractable though—the maximum likelihood solution is difficult to compute because of the integral! Moreover, the integral can't be approximated with sampling because it lies inside a log, and therefore linearity of expectation no longer applies D:
(Had we tried to approximate with Monte Carlo sampling, the estimator would be biased).
Instead, we can use expected log-likelihood.
θ←θargmax N1i∑Ez∼p(z∣xi)[logpθ(xi,z)]
in essence, we "guess" the most likely value of the latent variable z for xi by sampling from a posterior p(z∣xi) that we construct, and then pretend our guess is correct. But, we still need to compute this posterior.
This problem of calculating p(z∣xi) is known as probabilistic inference.
Sidenote: probabilistic inference is also seen in RL for
- Modeling human behavior
- Generative models + variational inference for exploration
- Reinforcement learning from human feedback (RLHF) (predict reward based on model output)
Variational Inference
Performing probabilistic inference exactly is extremely hard; however, for practical purposes, there are approximate methods that work well.
First, let's pretend our posterior is extremely simple. This assumption is incorrect, but very convenient! (We'll see soon how we may quantify the extent to which this assumption is wrong). For instance, we can approximate p(z∣xi)≈qi=N(μi,σi). Critically, this approximation allows us to tractably produce a lower bound on logp(xi).
We note that
logp(xi)=log∫p(xi∣z)p(z)dz=log∫p(xi∣z)p(z)qi(z)qi(z)dz=logEz∼qi(z)[qi(z)p(xi∣z)p(z)]
(Really, this is just an application of Bayes' rule).
We can then apply what is known as Jensen's inequality, which states that
f(E[y])≥E[f(y)]
for any concave function f. Notably, log is a concave function. Using this, we may conclude
logp(xi)=logEz∼qi(z)[qi(z)p(xi∣z)p(z)]≥Ez∼qi(z)[logqi(z)p(xi∣z)p(z)]≥Ez∼qi(z)[logp(xi∣z)+logp(z)]−Ez∼qi(z)[logqi(z)]≥Ez∼qi(z)[logp(xi,z)]+H(qi)
where H is the entropy of the distribution. Notably, Ez∼qi(z)[logp(xi,z)] is precisely our expected log-likelihood term!
Note that the definition of entropy is
H(p)=−Ex∼p(x)[logp(x)]=−∫p(x)logp(x)dx
Intuitively, entropy describes...
- how random the random variable is?
- how large is the log probability in expectation under itself?
So what exactly do the two terms in the expression Ez∼qi(z)[logp(xi,z)]+H(qi) do?
- The first term, our expected log-likelihood, induces the model to maximize probability at the peaks of the distribution.
- The second term, entropy, induces the model to distribute its probability over a greater area, i.e. it encourages increased entropy in the model.
Now, note the following representation of KL divergence.
DKL(q∥p)=Ex∼q(x)[logp(x)q(x)]=−Ex∼q(x)[logp(x)]−H(q)
that looks suspiciously similar to our expression above...
Intuitively, KL divergence describes...
- how different are the two distributions?
- how small is the expected log probability of one distribution under another, minus entropy?
Thus, our expression Ez∼qi(z)[logp(xi,z)]+H(qi) is really very similar to KL divergence, and is a measure of the divergence between the distributions logp(xi,z) and qi(x). Notably, maximizing this expression, and therefore logp(x), is equivalent to minimizing the KL divergence, or minimizing the distance between the two distributions.
Let's now define Li(p,qi)=Ez∼qi(z)[logp(xi,z)]+H(qi). We previously alluded to how we may quantify how incorrect our initial assumption was. So, how do we measure how good a choice of qi is? The KL divergence DKL(qi(z)∥p(z∣xi)). Consider that
DKL(qi(z)∥p(z∣xi))=Ez∼qi(z)[logp(z∣xi)qi(z)]=Ez∼qi(z)[logp(xi,z)qi(z)p(xi)]=−Ez∼qi(z)[logp(xi,z)]+Ez∼qi(z)[logqi(z)]+Ez∼qi(z)[logp(xi)]=−Ez∼qi(z)[logp(xi,z)]−H(qi)+logp(xi)=−Li(p,qi)+logp(xi)
and therefore
logp(xi)=DKL(qi(z)∥p(z∣xi))+Li(p,qi)
In other words, qi produces a lower bound Li(p,qi)≤logp(xi) where the error (distance between lower bound and the actual value of logp(xi)) is precisely the KL divergence between qi(z) and the distribution it attempts to approximate, p(z∣xi).
Now, we may note that, in the expression for the KL divergence,
DKL(qi(z)∥p(z∣xi))=−Li(p,qi)+logp(xi)
logp(xi) is independent of qi. Therefore, if we maximize Li(p,qi) with respect to qi, we minimize KL divergence! Thus,
- Maximizing L with respect to p increases the log-likelihood
- Maximizing L with respect to qi tightens the lower bound
Thus, we can actually select qi intelligently by iteratively maximizing Li w.r.t qi.
Here's the final idea we derive from the above deductions.
-
For each xi or mini-batch:
- Sample z∼qi(z) and approximating ∇θLi(p,qi)≈∇θlogpθ(xi∣z).
- θ←θ+α∇θLi(p,qi).
- Update qi to maximize Li(p,qi).
Where the optimization step w.r.t qi may be done with gradient ascent or another optimization method. For instance, assuming qi(z)=N(μi,σi), we use the gradients ∇μiLi(p,qi) and ∇σiLi(p,qi) to perform gradient ascent on μi and σi.
There's just one problem... there are way too many parameters! In fact, it scales with the size of the dataset, i.e. ∣θ∣+(∣μi∣+∣σi∣)×N. We'll address this issue next lecture; in particular, how deep learning makes this tractable :)
tip
The Expectation-Maximization algorithm is actually a special case of variational inference! The E step corresponds to updating qi, and the M step corresponds to updating θ.