[link]
This paper approaches the problem of optimizing parameters of a discrete distribution with respect to some loss function that is an expectation over that distribution. In other words, an experiment will probably be a variational autoencoder with discrete latent variables, but there are many real applications: $$ \mathcal{L} (\eta) : = \mathbb{E}_{z \sim q_{\eta} (z)} \left[ f_{\eta} (z) \right] $$ Using the [product rule of differentiation][product] the derivative of this loss function can be computed by enumerating all $1 \to K$ possible values of $z$: $$ \nabla_\eta \mathbb{E}_{z \sim q_{\eta} (z)} \left[ f_{\eta} (z) \right] = \nabla_\eta \sum_{k=1}^{K} q_\eta (k) f_\eta (k) \\ = \sum_{k=1}^{K} f_\eta (k) \nabla_\eta q_\eta (k) + q_\eta (k) \nabla_\eta f_\eta (k) $$ This expectation can also be expressed as the score function estimator (aka the REINFORCE estimator): $$ \nabla_\eta \mathbb{E}_{z \sim q_{\eta} (z)} \left[ f_{\eta} (z) \right] = \sum_{k=1}^{K} \left(f_\eta (k) \nabla_\eta q_\eta (k) + q_\eta (k) \nabla_\eta f_\eta (k)\right)\frac{q_\eta (k)}{q_\eta (k)} \\ = \sum_{k=1}^{K} q_\eta (k) f_\eta (k) \nabla_\eta \log q_\eta (k) + q_\eta (k) \nabla_\eta f_\eta (k) \\ = \mathbb{E}_{z \sim q_{\eta} (z)} \left[ f_\eta (k) \nabla_\eta \log q_\eta (k) + \nabla_\eta f_\eta (k) \right] \\ = \sum_{k=1}^{K} f_\eta (k) \nabla_\eta q_\eta (k) + q_\eta (k) \nabla_\eta f_\eta (k) = \mathbb{E}_{z \sim q_{\eta} (z)} \left[ g(z) \right] $$ In other words, both can be referred to as estimators $g(z)$. The authors note that this can be calculated over a subset of the $k$ most probable states (overloading their $k$ from possible values of $z$). Call this set $C_k$: $$ \nabla_\eta \mathbb{E}_{z \sim q_{\eta} (z)} \left[ f_{\eta} (z) \right] = \mathbb{E}_{z \sim q_{\eta} (z)} \left[ g(z) \right] \\ = \mathbb{E}_{z \sim q_{\eta} (z)} \left[ g(z) \mathbb{1}\{ z \in C_k\} + g(z) \mathbb{1} \{ z \notin C_k \} \right] \\ = \sum_{z \in C_k} q_\eta(z) g(z) + \mathbb{E}_{z \sim q_{\eta} (z)} \left[ g(z) \mathbb{1} \{ z \notin C_k \} \right] $$ As long as $k$ is small, it's easy to calculate the first term, and if most of the probability mass is contained in that set, then it shouldn't matter how well we approximate the second term. The authors choose an importancesampling for the second term, but this is where I get confused. They denote their importance weighting function $q_\eta (z \notin C_k)$ which *could* mean all of the probability mass *not* under the states in $C_k$? Later, they define a decision variable $b$ that expresses whether we are in this set or not, and it's sampled with probability $q_\eta (z \notin C_k)$, so I think my interpretation is correct. The gradient estimator then becomes: $$ \hat{g} (v) = \sum_{z \in C_k} q_\eta (z) g(z) + q_\eta (z \notin C_k) g(v)\\ v \sim q_\eta  v \notin C_k $$ [product]: https://en.wikipedia.org/wiki/Product_rule Showing this is RaoBlackwellization  Another way to express $z$ would be to sample a Bernoulli r.v. with probability $\sum_{j \notin C_k} q_\eta (j) $, then if it's $1$ sample from $z \in C_k$ and if it's $0$ sample from $z \notin C_k$. As long as those samples are drawn using $q_\eta$ then: $$ T(u,v,b) \stackrel{d}{=} z \\ T := u^{1b} v^b $$ where $u \sim q_\eta  C_k$, $v \sim q_\eta  v \notin C_k$ and $b \sim \text{Bernoulli}(\sum_{j \notin C_k} q_\eta (j))$. Expressing $z$ in this way means the gradient estimator from before can be written as: $$ \hat{g} (v) = \mathbb{E} \left[ g( T(u,v,b) )  v \right] $$ And they left it as an exercise for the reader to expand that out and show it's the same as equation 6: $$ \mathbb{E} \left[ g( T(u,v,b) )  v \right] = \mathbb{E} \left[ g( T(u,v,b)) \mathbb{1} \{ b=0 \} + g( T(u,v,b)) \mathbb{1} \{ b=1 \} \right] \\ = \mathbb{E} \left[ g(z) \mathbb{1} \{ z \in C_k \} + g( z) \mathbb{1} \{ z \notin C_k \} \right] = \mathbb{E} \left[ g(z) \right] $$ Writing the estimator as a conditional expectation of some statistic of the random variables under the distribution is sufficient to show that this is an instance of RaoBlackwellization. To be safe, the authors also apply the [conditional variance decomposition][eve] to reinforce the property that RB estimators always have lower variance: $$ Var(Y) = E\left[ Var (YX) \right] + Var(E \left[ Y  X \right] ) \\ Var(g (z) ) = Var (\mathbb{E} \left[ g( T(u,v,b) )  v \right] ) + E \left[ Var ( g( T(u,v,b) )  v ) \right] \\ Var (\mathbb{E} \left[ g( T(u,v,b) )  v \right] ) = Var (\hat{g} (v) ) = Var(g (z) )  E \left[ Var ( g( T(u,v,b) )  v ) \right] $$ They go on to show that the variance is less than or equal to $Var(g(z)) \sum_{j \notin C_k} q_\eta (j)$. Finally, they note that the variance of a simple estimator can also be reduced by taking multiple samples and averaging. They then provide an equation to calculate the optimal $k$ number of elements of $z$ to evaluate depending on how concentrated the distribution being evaluated is, and a proof showing that this will have a lower variance than the naive estimator. $$ \hat{k} = \underset{k \in {0, ..., N}}{\operatorname{argmin}} \frac{\sum_{j \notin C_k} q_\eta (j)}{Nk} $$ I'm not very interested in the experiments right now, but skimming through them it's interesting to see that this method performs very well on a high dimensional hard attention task on MNIST. Particularly because a Gumbelsoftmax estimator falls apart in the same experiment. It would be nice to see results on RL problems as were shown in the [RELAX][] paper. [eve]: https://en.wikipedia.org/wiki/Law_of_total_variance [relax]: https://arxiv.org/abs/1711.00123
Your comment:
