Bayesian Inference (in Deep Learning)

In this note we explore several types of Bayesian Inference, and especially how they are applied in Deep Learning.

Assume that we have a dataset $D=\{x_i,y_i\}_{i=1}^{N}$ and the following generative model (hereunder denoted by $p$), where $\theta$ is typically parameters vector of a neural network.

Graphical Model

Assume that our data is generated from this model and we want to infer the posterior $p(\theta|D)$. We have $p(\theta|D)=\frac{p(D|\theta)p(\theta)}{p(D)}$, which is intractable. SOME (not all) of the inference methods are listed below

Variational Inference

The first method is variational inference. Suppose that we want to approximate $p(\theta|D)$ by $q_\phi(\theta)$, where $\phi$ defines the parameters of the variatonal distribution. We need to minimize the distance from $q_\phi(\theta)$ to $p(\theta|D)$, and a popular distance is the KL from $q$ to $p$.

\begin{equation} \begin{split} &\mathrm{KL}\left[q_\phi(\theta)||p(\theta|D)\right] = \mathbb{E}_{q_\phi(\theta)}\left[\log q_\phi(\theta)-\log p(\theta|D)\right]\\ &=\mathbb{E}_{q_\phi(\theta)}\left[\log q_\phi(\theta)-\log p(D|\theta)-\log p(\theta)+\log p(D)\right]\\ &=\mathbb{E}_{q_\phi(\theta)}\left[\log q_\phi(\theta)-\log p(\theta)\right] - \mathbb{E}_{q_\phi(\theta)}\left[\log p(D|\theta)\right] + constant \quad \text{since }p(D) \text{ is fixed w.r.t. }\phi\\ &=\mathrm{KL}\left[q_\phi(\theta)||p(\theta)\right] - \mathbb{E}_{q_\phi(\theta)}\left[\log p(D|\theta)\right] + constant \end{split} \end{equation}

Therefore, we need to maximize $\mathbb{E}_{q_\phi(\theta)}\left[\log p(D|\theta)\right]-\mathrm{KL}\left[q_\phi(\theta)||p(\theta)\right]$. The first term ensures that $q_\phi(\theta)$ explains the data well, while the second term encourages $q_\phi(\theta)$ to stay closed to the prior.

If the KL term can be computed/approximated in analytically, we can use that directly in our computation. Otherwise, we can resort to MC sampling to approximate $\mathbb{E}_{q_\phi(\theta)}\left[\log q(\theta)-\log p(\theta)\right]$.

If $q_\phi$ is reparameterizable, we can use the reparameterization trick to compute the gradient w.r.t. $\phi$. Otherwise, we can use other estimator like REINFORCE. For details, see my note on Variational Inference.

Monte Carlo Sampling

Another method is to use MC sampling to sample from $p(\theta|D)$ (even though we may not know the exact pdf of this distribution). Let's say if we are interested in prediction of a new point $x^*$, which is $\mathbb{E}_{p(\theta|D)}\left[p(y^*|x^*,\theta)\right]$, and if we can somehow sample from $p(\theta|D)$, we can use these samples to approximate the expectation.

Note: These MC methods typically suffer from the curse of dimensionality

Direct Sampling

If we can do direct sampling efficiently and effectively from the desired distribution (not applicable to posterior in neural networks, though), then we can use these samples to approximate the expectation: \begin{equation} \mathbb{E}_{p(x)}[f(x)] = \frac{1}{S}\sum_{i=1}^Sf(x_i), \quad x_i \sim p(x) \end{equation}

Rejection Sampling

Suppose $p(x)=p'(x)/Z$ where $Z$ is the normalizing constant and is hard to compute and $p'(x)$ is easy to compute (in the case of posterior for deep learning, $p'(x)$ is the product of likelihood and prior, Z is the $p(D)$).

Let $q(x)$ be an "arbitrary" proposal, which we know the form, the pdf, and is easy to sample from.

Choose a constant $M$ s.t. $\frac{p'(x)}{q(x)}<M \quad \forall x$

Then the rejection sampling algorithm is:

  • Sample $x$ from $q(x)$
  • With probability $\frac{p'(x)}{Mq(x)}$, accept $x$.

Let's call this sampling distribution $h(x)$. We have: \begin{align} h(x) = \frac{q(x)\left(p'(x)/Mq(x)\right)}{\int \left(q(x)\left(p'(x')/Mq(x')\right)\right) dx'} =\frac{p'(x)}{\int p'(x') dx'} =p(x) \end{align}

Note: The rejection sampling method can have very high rejection rate in practice, if $q(x)$ differs too much from $p(x)$. In case of high dimensionality, this is very ineffective.

Importance Sampling

Again, suppose $p(x)=p'(x)/Z$ where $Z$ is the normalizing constant. Let $q(x)$ be a proposal and $q$ dominates $p$ (i.e. $p(x)>0 \rightarrow q(x)>0$)

\begin{equation} \begin{split} \mathbb{E}_{p(x)}[f(x)] = \int f(x)p(x)dx =\int f(x)\frac{p(x)}{q(x)}q(x)dx =\mathbb{E}_{q(x)}\left[f(x)\frac{p(x)}{q(x)}\right] =\frac{1}{Z}\mathbb{E}_{q(x)}\left[f(x)\frac{p'(x)}{q(x)}\right] \end{split} \end{equation}

Let $x_1,x_2,...,x_S$ be samples from $q(x)$ and the weight $m_i=\frac{p'(x_i)}{q(x_i)}$, then since $Z = \int p'(x)dx = \mathbb{E}_{q(x)} \left[\frac{p'(x)}{q(x)}\right]$,

\begin{equation} \begin{split} Z &\approx \frac{\sum_{i=1}^{S} m_i}{S}\\ \mathbb{E}_{q(x)}\left[f(x)\frac{p'(x)}{q(x)}\right] &\approx \frac{\sum_{i=1}^S f(x_i)m_i}{S}\\ \\ \implies\mathbb{E}_{p(x)}[f(x)] &\approx \sum_{i=1}^S f(x_i)w_i \end{split} \end{equation}

where $w_i=\frac{m_i}{\sum_{i=1}^{S} m_i}$, which is the normalized version of $m_i$

Note: When $q(x)$ and $p(x)$ are very different (masses distributed in different areas), the approximation is very erroneous.

Weighted Resampling

Follow same sampling procedure as Importance Sampling until getting $x_1,x_2,...,x_S$ with normalized weights $w_1,w_2,...,w_S$, then subsample $x$ from $x_1,x_2,...,x_S$ with probability $w_1,w_2,...,w_S$

Markov Chain Monte Carlo Sampling

The idea behind MCMC sampling is that, (consider a discrete variable for now) if we have a Markov Chain and the chain has an equilibrium state which is the distribution of interest p(x), then if we keep sampling from that Markov Chain, we will eventually samples from the true distribution.

Suppose that we have $N$ states and the transition matrix of the Markov Chain is $T$ ($NxN$). If the chain is ergodic (aperiodic and positive recurrent), it is guaranteed to reach a stationary state.

Detailed balance condition: the probability of being in state $x$ and transitioning to state $x'$ must be equal to the probability of being in state $x'$ and transitioning to state $x$. This is a sufficient (but not necessary) condition for a MC to have a stationary state.

For continuous variable, we instead "use a transition kernel" $T$ and $p^{t+1}(x)=\int T(x',x)p^t(x')dx'$

High level idea: MCMC algorithms will generally show you a Markov Chain that will converge to a stationary state that is the desired distribution. One of the most popular algorithm is Metropolis-Hastings.

Metropolis-Hastings algorithm

  1. Initialization: an arbitrary starting point $x_1$ and a proposal $q(x'|x)$
  2. At step $t$:
    • Sample a point $x'$ from $q(x'|x^t)$.
    • $a = min\left(1,\frac{p'(x')q(x^t|x')}{p'(x^t)q(x'|x^t)}\right)$
    • With probability $a$ accept $x'$ ($x^{t+1}=x'$) or else reject it ($x^{t+1}=x^t$)

Note that this algorithm only requires the unnormalized pdf $p'(x)$. Also, if $q$ is symmetric ($q(x'|x)=q(x|x')$), then we only need to compute $p'(x')/p'(x)$.

We have: \begin{equation} \begin{split} p(x)T(x'|x)&=p(x)q(x'|x)min\left(1,\frac{p'(x')q(x|x')}{p'(x)q(x'|x)}\right)\\ &=\frac{1}{Z}min\left(p'(x)q(x'|x),p'(x')q(x|x')\right)\\ &=p(x')T(x|x')\\ \implies &\text{this MC sastifies the detailed balance condition}\\ \implies &\text{sampling distribution will converge to } p(x) \end{split} \end{equation}

Note: For advanced sampling methods such as HMC and LD, see my note on Advanced MCMC methods

Laplace approximation

The high level idea of the Laplace approximation is to aproximate the posterior $p(\theta|D)$ around its MAP (maximum a posteriori) using the Laplace approximation.

Recall that $\theta_{MAP}=argmax_\theta \log p(\theta|D) = argmax_\theta \left(\log p(D|\theta) + \log p(\theta)\right)$

This MAP estimation can be achieved by training the deterministic network with proper regularizer, using gradient descent. For example, if $p(\theta)=\mathcal{N}\left(\theta|0,\delta^{-1}\mathbf{I}\right)$, we should use L2 regularizer.

The Laplace approximation of $p(\theta|D)$ around $\theta$ is: \begin{equation} \begin{split} \log p(\theta|D) &\approx \log p(\theta_{MAP}|D) + \frac{1}{2}(\theta-\theta_{MAP})^T\nabla^2_{\theta\theta}\log p(\theta_{MAP}|D)(\theta-\theta_{MAP}) \quad\text{since }\nabla_\theta\log p(\theta_{MAP}|D)=0\\ &=\log p(\theta_{MAP}|D) + \frac{1}{2}(\theta-\theta_{MAP})^T\nabla^2_{\theta\theta}\left[\log p(D|\theta_{MAP})+\log p(\theta_{MAP})\right](\theta-\theta_{MAP})\\ &=\log p(\theta_{MAP}|D) + \frac{1}{2}(\theta-\theta_{MAP})^T\nabla^2_{\theta\theta}\left[\log p(D|\theta_{MAP})-\frac{1}{2}\theta_{MAP}^T\delta\mathbf{I}\theta_{MAP}\right](\theta-\theta_{MAP})\\ &=\log p(\theta_{MAP}|D) - \frac{1}{2}(\theta-\theta_{MAP})^T\left[\nabla^2_{\theta\theta}\left[-\log p(D|\theta_{MAP})\right]+\delta\mathbf{I}\right](\theta-\theta_{MAP}) \end{split} \end{equation}

This is the pdf for $\mathcal{N}\left(\theta_{MAP},\mathbf{\Sigma}\right)$ where $\mathbf{\Sigma}^{-1}=\nabla^2_{\theta\theta}\left[-\log p(D|\theta_{MAP})\right]+\delta\mathbf{I}$

Therefore, $\mathcal{N}\left(\theta_{MAP},\mathbf{\Sigma}\right)$ is an approximation of $p(\theta|D)$

In [ ]: