Variational Inference

This article contains random notes about VI-related methods (not all about VI though)

Amortized Variational Inference (VAE example)

Suppose we have a generative model with $p_\theta(z,x)=p_\theta(z)p_\theta(x|z)$ with N observed instance $x_1,x_2,...,x_N$. Instead of having N variational distributions $q_1(z),q_2(z),...,q_N(z)$ for every instance, we can use an amortized infernece: $q_\phi(z|x)$ (e.g. there is a neural network parameterized by $\phi$ that takes $x$ as an input and produce the distribution parameters for $z$)

The goal of optimization is, given a true data distribution $p(x)$, for $p_\theta(x)$ and $q_\phi(z|x)$ to approximate $p(x)$ and $p_\theta(z|x)$ respectively.

The ELBO of $p_\theta(x)$ in this model is: \begin{equation} ELBO(x) = \log p_\theta(x)-\mathrm{KL}[q_\phi(z|x)||p_\theta(z|x)]= \mathbb{E}_{q_\phi(z|x)}[p_\theta(x|z)] + \mathrm{KL}[q_\phi(z|x)||p_\theta(z)] \end{equation}

Maximizing $\mathbb{E}_{p(x)}[ELBO(x)]$ achieves our goal, since \begin{equation} \begin{split} \mathbb{E}_{p(x)}[ELBO(x)]&=\mathbb{E}_{p(x)}[\log p_\theta(x)-\log p(x) + \log p(x) -\mathrm{KL}[q_\phi(z|x)||p_\theta(z|x)]] \\ &=\mathbb{E}_{p(x)}[\log p(x)]-\mathrm{KL}[p(x)||p_\theta(x)] - \mathbb{E}_{p(x)}[\mathrm{KL}[q_\phi(z|x)||p_\theta(z|x)]] \end{split} \end{equation}

Therefore, the optimization problem is to maximize $\mathbb{E}_{p(x)}[\mathbb{E}_{q_\phi(z|x)}[p_\theta(x|z)] + \mathrm{KL}[q_\phi(z|x)||p_\theta(z)]]$ w.r.t. $\theta$ and $\phi$. In practice, we use minibatch of data to approximate this expectation.

After training, we can use the generative model $p_\theta(z,x)$ to generate new data (e.g. images).

Gradient estimation

In order to train the network described above, we need to estimate the gradient (w.r.t. $\phi$) of a term of the form $\mathbb{E}_{q_\phi(z)}[f(z)]$

1. If $q_\phi(z)$ is parameterizable, i.e. we can rewrite $z=g(\epsilon,\phi), \epsilon \sim r(\epsilon)$.

Then, we can rewrite $\mathbb{E}_{q_\phi(z)}[f(z)]$ to $\mathbb{E}_{r(\epsilon)}[f(g(\epsilon,\phi))]$ and the gradient can be computed easily using MC sampling (since it is a deterministic function of $\phi$)

2. If $q_\phi(z)$ is not parameterizable, we can use the REINFORCE estimator.

Using the log-derivative trick, we have $\frac{\partial q_\phi(z)}{\partial \phi} = q_\phi(z)\frac{\partial \log q_\phi(z)}{\partial \phi}$

\begin{equation} \begin{split} \frac{\partial \mathbb{E}_{q_\phi(z)}[f(z)]}{\partial \phi} &= \frac{\partial \int q_\phi(z)f(z)dz}{\partial \phi}\\ &= \int \frac{\partial q_\phi(z)f(z)}{\partial \phi}dz\\ &= \int f(z)q_\phi(z)\frac{\partial \log q_\phi(z)}{\partial \phi}dz\\ &= \mathbb{E}_{q_\phi(z)}\left[f(z)\frac{\partial \log q_\phi(z)}{\partial \phi}\right] \end{split} \end{equation}

And this can be approximated with MC sampling.

Tighter bound with Importance Sampling

Staring from an unbiased estimator of $p_\theta(x)=\mathbb{E}_{p_\theta(x)}[p_\theta(x|z)]$ by using importance sampling with a proposal $q_\phi(z|x)$: \begin{equation} \frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)}, \quad z^1,z^2,...,z^K \sim q_\phi(z|x) \end{equation}

Since this is an unbiased estimator of $p_\theta(x)$, we have: \begin{equation} p(x) = \mathbb{E}_{z^{1..K}\sim q_\phi(z|x)}\left[\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)}\right] \end{equation}

Using Jensen Inequality, we have: \begin{equation} \begin{split} \log p(x) &= \log \mathbb{E}_{z^{1..K}\sim q_\phi(z|x)}\left[\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)}\right]\\ &\geq \mathbb{E}_{z^{1..K}\sim q_\phi(z|x)}\left[\log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)}\right]\\ &= ELBO \end{split} \end{equation}

By setting $K=1$, we get the normal ELBO of variational inference described in the above section. The bound gets tighter when we increase $K$.

Now we will estimate the gradient w.r.t. $\phi$ of this ELBO (note that the gradient w.r.t. $\theta$ is easy to compute). \begin{equation} \begin{split} &\nabla_\phi\mathbb{E}_{z^{1..K}\sim q_\phi(z|x)}\left[\log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)}\right]\\ &= \nabla_\phi\int \log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)} \prod_{i=1}^K q_\phi(z^i|x) dz^{12..K}\\ &= \int \nabla_\phi\left[\log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)} \prod_{i=1}^K q_\phi(z^i|x)\right] dz^{12..K}\\ &= \int \nabla_\phi\left[\log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)}\right] \prod_{i=1}^K q_\phi(z^i|x) + \log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)} \nabla_\phi\prod_{i=1}^K q_\phi(z^i|x) \; dz^{12..K}\\ &= \int \nabla_\phi\left[\log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)}\right] \prod_{i=1}^K q_\phi(z^i|x) + \log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)} \prod_{i=1}^K q_\phi(z^i|x) \nabla_\phi\log\prod_{i=1}^K q_\phi(z^i|x) \; dz^{12..K}\\ &= \mathbb{E}_{z^{1..K}\sim q_\phi(z|x)}\left[\nabla_\phi\log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)} + \log\frac{1}{K}\sum_{i=1}^K \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)} \nabla_\phi\log\prod_{i=1}^K q_\phi(z^i|x)\right] \end{split} \end{equation}

Then, the REINFORCE estimator of the gradient is \begin{equation} \nabla_\phi\log\frac{1}{K}\sum_{i=1}^K f_{\phi,\theta}(x,z^i) + \log\frac{1}{K}\sum_{i=1}^K f_{\phi,\theta}(x,z^i) \nabla_\phi\log\prod_{i=1}^K q_\phi(z^i|x) \end{equation} where $f_{\phi,\theta}(x,z^i) = \frac{p_\theta(x,z^i)}{q_\phi(z^i|x)}$

The first term: \begin{equation} \begin{split} &\nabla_\phi\log\frac{1}{K}\sum_{i=1}^K f_{\phi,\theta}(x,z^i)\\ &=\frac{1}{\sum_{i=1}^K f_{\phi,\theta}(x,z^i)}\sum_{i=1}^K \nabla_\phi f_{\phi,\theta}(x,z^i)\\ &=\sum_{i=1}^K w^i\nabla_\phi \log f_{\phi,\theta}(x,z^i) \end{split} \end{equation} where $w^i=\frac{f_{\phi,\theta}(x,z^i)}{\sum_{k=1}^K f_{\phi,\theta}(x,z^k)}$

The second term: \begin{equation} \begin{split} &\log\frac{1}{K}\sum_{i=1}^K f_{\phi,\theta}(x,z^i) \sum_{i=1}^K\nabla_\phi\log q_\phi(z^i|x)\\ &=\sum_{k=1}^K\left[\log\frac{1}{K}\sum_{i=1}^K f_{\phi,\theta}(x,z^i) \nabla_\phi\log q_\phi(z^k|x)\right] \end{split} \end{equation}

The first term behaves nicely since it is a linear combination of the gradient of individual sample, with non-negative weights which sum up to 1. Therefore, the term is bounded by the bigest (and smallest) gradient componient. However, in the second term, the weight for individual-sample gradient is a (unbounded) constant. Therefore, this term is potentially unbounded with high variance; and additional measures should be taken to reduce the variance.

In [ ]: