[link]
Summary by Chin-Wei 3 months ago
In this note, I'll implement the [Stochastically Unbiased Marginalization Objective (SUMO)](https://openreview.net/forum?id=SylkYeHtwr) to estimate the log-partition function of an energy funtion.
Estimation of log-partition function has many important applications in machine learning. Take latent variable models or Bayeisian inference. The log-partition function of the posterior distribution $$p(z|x)=\frac{1}{Z}p(x|z)p(z)$$ is the log-marginal likelihood of the data $$\log Z = \log \int p(x|z)p(z)dz = \log p(x)$$.
More generally, let $U(x)$ be some energy function which induces some density function $p(x)=\frac{e^{-U(x)}}{\int e^{-U(x)} dx}$. The common practice is to look at a variational form of the log-partition function,
$$
\log Z = \log \int e^{-U(x)}dx = \max_{q(x)}\mathbb{E}[-U(x)-\log q(x)] \nonumber
$$
Plugging in an arbitrary $q$ would normally yield a strict lower bound, which means
$$
\frac{1}{n}\sum_{i=1}^n \left(-U(x_i) - \log q(x_i)\right) \nonumber
$$
for $x_i$ sampled *i.i.d.* from $q$, would be a biased estimate for $\log Z$. In particular, it would be an underestimation.
To see this, lets define the energy function $U$ as follows:
$$
U(x_1,x_2)= - \log \left(\frac{1}{2}\cdot e^{-\frac{(x_1+2)^2 + x_2^2}{2}} + \frac{1}{2}\cdot\frac{1}{4}e^{-\frac{(x_1-2)^2 + x_2^2}{8}}\right) \nonumber
$$
It is not hard to see that $U$ is the energy function of a mixture of Gaussian distribution $\frac{1}{2}\mathcal{N}([-2,0], I) + \frac{1}{2}\mathcal{N}([2,0], 4I)$ with a normalizing constant $Z=2\pi\approx6.28$ and $\log Z\approx1.8379$.
```python
def U(x):
x1 = x[:,0]
x2 = x[:,1]
d2 = x2 ** 2
return - np.log(np.exp(-((x1+2) ** 2 + d2)/2)/2 + np.exp(-((x1-2) ** 2 + d2)/8)/4/2)
```
To visualize the density corresponding to the energy $p(x)\propto e^{-U(x)}$
```python
xx = np.linspace(-5,5,200)
yy = np.linspace(-5,5,200)
X = np.meshgrid(xx,yy)
X = np.concatenate([X[0][:,:,None], X[1][:,:,None]], 2).reshape(-1,2)
unnormalized_density = np.exp(-U(X)).reshape(200,200)
plt.imshow(unnormalized_density)
plt.axis('off')
```
https://i.imgur.com/CZSyIQp.png
As a sanity check, lets also visualize the density of the mixture of Gaussians.
```python
N1, N2 = mvn([-2,0], 1), mvn([2,0], 4)
density = (np.exp(N1.logpdf(X))/2 + np.exp(N2.logpdf(X))/2).reshape(200,200)
plt.imshow(density)
plt.axis('off')
print(np.allclose(unnormalized_density / density - 2*np.pi, 0))
```
`True`
https://i.imgur.com/g4inQxB.png
Now if we estimate the log-partition function by estimating the variational lower bound, we get
```python
q = mvn([0,0],5)
xs = q.rvs(10000*5)
elbo = - U(xs) - q.logpdf(xs)
plt.hist(elbo, range(-5,10))
print("Estimate: %.4f / Ground true: %.4f" % (elbo.mean(), np.log(2*np.pi)))
print("Empirical variance: %.4f" % elbo.var())
```
`Estimate: 1.4595 / Ground true: 1.8379`
`Empirical variance: 0.9921`
https://i.imgur.com/vFzutuY.png
The lower bound can be tightened via [importance sampling):
$$
\log \int e^{-U(x)} dx \geq \mathbb{E}_{q^K}\left[\log\left(\frac{1}{K}\sum_{j=1}^K \frac{e^{-U(x_j)}}{q(x_j)}\right)\right] \nonumber
$$
> This bound is tighter for larger $K$ partly due to the [concentration of the average](https://arxiv.org/pdf/1906.03708.pdf) inside of the $\log$ function: when the random variable is more deterministic, using a local linear approximation near its mean is more accurate as there's less "mass" outside of some neighborhood of the mean.
Now if we use this new estimator with $K=5$
```python
k = 5
xs = q.rvs(10000*k)
elbo = - U(xs) - q.logpdf(xs)
iwlb = elbo.reshape(10000,k)
iwlb = np.log(np.exp(iwlb).mean(1))
plt.hist(iwlb, range(-5,10))
print("Estimate: %.4f / Ground true: %.4f" % (iwlb.mean(), np.log(2*np.pi)))
print("Empirical variance: %.4f" % iwlb.var())
```
`Estimate: 1.7616 / Ground true: 1.8379`
`Empirical variance: 0.1544`
https://i.imgur.com/sCcsQd4.png
We see that both the bias and variance decrease.
Finally, we use the [Stochastically Unbiased Marginalization Objective](https://openreview.net/pdf?id=SylkYeHtwr) (SUMO), which uses the *Russian Roulette* estimator to randomly truncate a telescoping series that converges in expectation to the log partition function. Let $\text{IWAE}_K = \log\left(\frac{1}{K}\sum_{j=1}^K \frac{e^{-U(x_j)}}{q(x_j)}\right)$ be the importance-weighted estimator, and $\Delta_K = \text{IWAE}_{K+1} - \text{IWAE}_K$ be the difference (which can be thought of as some form of correction). The SUMO estimator is defined as
$$
\text{SUMO} = \text{IWAE}_1 + \sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \nonumber
$$
where $K\sim p(K)=\mathbb{P}(\mathcal{K}=K)$. To see why this is an unbiased estimator,
$$
\begin{align*}
\mathbb{E}[\text{SUMO}] &= \mathbb{E}\left[\text{IWAE}_1 + \sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right] \nonumber\\
&= \mathbb{E}_{x's}\left[\text{IWAE}_1 + \mathbb{E}_{K}\left[\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right]\right] \nonumber
\end{align*}
$$
The inner expectation can be further expanded
$$
\begin{align*}
\mathbb{E}_{K}\left[\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right]
&= \sum_{K=1}^\infty P(K)\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \\
&= \sum_{k=1}^\infty \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \sum_{K=k}^\infty P(K) \\
&= \sum_{k=1}^\infty \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \mathbb{P}(\mathcal{K}\geq k) \\
&= \sum_{k=1}^\infty\Delta_K \\
&= \text{IWAE}_{2} - \text{IWAE}_1 + \text{IWAE}_{3} - \text{IWAE}_2 + ... = \lim_{k\rightarrow\infty}\text{IWAE}_{k}-\text{IWAE}_1
\end{align*}
$$
which shows $\mathbb{E}[\text{SUMO}] = \mathbb{E}[\text{IWAE}_\infty] = \log Z$.
> (N.B.) Some care needs to be taken care of for taking the limit. See the paper for more formal derivation.
A choice of $P(K)$ proposed in the paper satisfy $\mathbb{P}(\mathcal{K}\geq K)=\frac{1}{K}$. We can sample such a $K$ easily using the [inverse CDF](https://en.wikipedia.org/wiki/Inverse_transform_sampling), $K=\lfloor\frac{u}{1-u}\rfloor$ where $u$ is sampled uniformly from the interval $[0,1]$.
Now putting things all together, we can estimate the log-partition using SUMO.
```python
count = 0
bs = 10
iwlb = list()
while count <= 1000000:
u = np.random.rand(1)
k = np.ceil(u/(1-u)).astype(int)[0]
xs = q.rvs(bs*(k+1))
elbo = - U(xs) - q.logpdf(xs)
iwlb_ = elbo.reshape(bs, k+1)
iwlb_ = np.log(np.cumsum(np.exp(iwlb_), 1) / np.arange(1,k+2))
iwlb_ = iwlb_[:,0] + ((iwlb_[:,1:k+1] - iwlb_[:,0:k]) * np.arange(1,k+1)).sum(1)
count += bs * (k+1)
iwlb.append(iwlb_)
iwlb = np.concatenate(iwlb)
plt.hist(iwlb, range(-5,10))
print("Estimate: %.4f / Ground true: %.4f" % (iwlb.mean(), np.log(2*np.pi)))
print("Empirical variance: %.4f" % iwlb.var())
```
`Estimate: 1.8359 / Ground true: 1.8379`
`Empirical variance: 4.1794`
https://i.imgur.com/04kPKo5.png
Indeed the empirical average is quite close to the true log-partition of the energy function. However we can also see that the distribution of the estimator is much more spread-out. In fact, it is very heavy-tailed. Note that I did not tune the proposal distribution $q$ based on the ELBO, or IWAE or SUMO. In the paper, the authors propose to tune $q$ to minimize the variance of the $\text{SUMO}$ estimator, which might be an interesting trick to look at next.
(Reposted, see more details and code from https://www.chinweihuang.com/pages/sumo)

more
less