MuProp: Unbiased Backpropagation for Stochastic Neural NetworksMuProp: Unbiased Backpropagation for Stochastic Neural NetworksGu, Shixiang and Levine, Sergey and Sutskever, Ilya and Mnih, Andriy2015
Paper summaryhlarochelleThis paper presents a method for training feed-forward neural networks with stochastic hidden units (e.g. sigmoid belief networks), to optimize the expectation (over the stochastic units) of some arbitrary loss function. While the proposed method is applicable to any type of stochastic units, it is most interesting for the case of discrete stochastic units, since the reparametrization trick of variational autoencoders cannot be applied to backprop through the sampling step.
In short, the method builds on the likelihood ratio method (of which REINFORCE is a special case) and proposes a baseline (also known as control variate) which, according to the authors, is such that an unbiased gradient is obtained. Specifically, the baseline corresponds to the first-order Taylor expansion of the loss function around some deterministic value of the hidden units (x̄) that doesn't depend on the stochastic hidden units (noted x in the paper).
For a likelihood ratio method to be unbiased, it is required that the expectation of the baseline (times the gradient of the model's log distribution) with respect to the model's distribution be tractable. For the proposed baseline, it can be shown that computing this expectation requires the gradient of the mean (μ) of each stochastic unit in the network with respect to each parameter. The key idea behind the proposed method is that 1) an estimate of this expectation can be obtained simply using mean-field and 2) since mean-field is estimated by a feedforward deterministic pass over the network, it is thus possible to compute the gradients of μ by backpropagation through the mean-field pass (hence the name of the method, MuProp).
Experiments show that this method converges much faster than previously proposed unbiased methods and often performs better. Experiments also show that the method obtains competitive performance compared to biased methods (such as the "straight through" method).