Google's team developed scheduled sampling as an alternative training procedure to fit RNNs, and they used it in their competition-winning method for image captioning. While I can't argue with the empirical results (so I won't), I was a bit skeptical about the technique at a fundamental level, so I decided to do a bit of math that resulted in this blog post.
Overall, I have a suspicion that scheduled sampling is a flawed objective function for unsupervised/generative modelling, and I want to use this post to explain why I think so. I hope the comments section will work this time so people can comment and argue otherwise. Please also shoot me an email if you have more to say.
#### Summary of this note
- I have a critical look at scheduled sampling as objective function for training RNNs
- I show it can lead to pathologies where the RNN learns marginal instead of conditional distributions
- I explain why I think adversarial training/generative moment matching offers a better alternative
- Lastly, I include a paragraph in which I apologise for being a di*k again.
#### Strictly proper scoring rules
I've mentioned scoring rules on this blog many times, and my PhD thesis was about them, so saying I'm obsessed with this topic would be a valid observation. But this is important stuff, particularly for unsupervised learning, and particularly as a framework to think about hard concepts like overfitting in generative models.
Scoring rules are essentially loss functions for probabilistic models/forecasts. A scoring rule $$S(x,Q)$$ simply measures how bad a probabilistic forecast $Q$ for a variable is in the light of actual observation $x$. In this notation, lower is better. A scoring rule is called strictly proper, if for any $P$, the following holds:
$$\underset{Q}{\operatorname{argmax}} \mathbb{E}_{x\sim P}S(x,Q) = P$$
In other words, if you repeatedly sample observations from some true underlying distribution $P$, then the model $Q$ which minimises expected score is $P$. This means that the scoring rule cannot be fooled and that minimising the expected score yields a consistent estimator for $P$. Because I mention consistency, people may dismiss this as a learning theory argument, but it is not. If you are a Bayesian or a deep learning person with no interest in consistency, a scoring rule being strictly proper simply means that it is safe to use it as a loss function. Anything that's not strictly proper is weird and wrong, it will lead to learning the wrong thing.
This concept is central in unsupervised learning and generative modelling. Unsupervised learning is all about modelling the probability distribution of data, so it's essential that we have loss functions that can measure the discrepancy between our model $Q$, and the true data distribution $P$ in a consistent way.
#### log-likelihood
One of the most frequently used strictly proper scoring rule is the logarithmic score:
$$S(x,Q) = - \log Q(x)$$
This quantity is also known as the negative log-likelihood. Minimising the expected score in an i.i.d scenario yields maximum likelihood estimation, which is known to be a consistent estimator and has nice properties.
Often, the likelihood is impossible to evaluate. Luckily, it is not the only strictly proper scoring rule. In the context of generative models people have used the pseudolikelihood, score matching and moment matching, all of which are examples of strictly proper scoring rules.
To recap, any learning method that corresponds to minimising a strictly proper scoring rule is fine, everything else can go horribly wrong, even if we feed it infinite data, it might just learn the wrong thing.
#### Scheduled Sampling
After successfully establishing myself as a proper-scoring-rule-nazi, let's talk about scheduled sampling (SS). I don't have a lot of space explaining SS in great detail here, only the basic idea. I encourage everyone to read the paper and Hugo's summary above.
SS is a new method to train recurrent neural networks(RNNs) to model sequences. I will use character-by-character models of text as an example. Typically, when you train an RNN, you aim to minimise the log predictive likelihood in predicting the next character in each training sentence, given the prefix string of previous characters. This can be thought of as a special case of maximum likelihood learning, and is all fine, you can actually do this properly without approximations.
After training, you use the RNN to generate sample sentences in a recursive fashion: assuming you've already generated $n$ characters, you feed that prefix into the RNN, and ask it to predict the $n+1$st character. The $n+1$st character is then added to the prefix to predict the $n+2$th character, and so on.
The authors say there is a disconnect between how the model is trained (it's always fed real data) and how it's used (it's always fed synthetic data generated by itself). This, they argue, leads to the RNN being unable to recover from its own mistakes early on in the sentence.
To address this, the authors propose an alternative training strategy, where every once in a while, the network is given its own synthetic data instead of real data at training time. More specifically, for each character in the training sentences, we flip a coin to decide whether we feed the character from the real training sentence, or whether to feed the model's own prediction as to what that character would have been. The authors claim this makes the model more robust to recovering from mistakes, which is probably true.
As far as I'm concerned, I'm happy as long as the new training procedure corresponds to a strictly proper scoring rule. But in this case, I have a strong feeling that it does not.
#### case study: sequence of two variables
For sake of simplicity, let's consider using scheduled sampling to learn the joint distribution of a sequence of just two random variables. This is probably the simplest (shortest) time series I can think of. So SS in this case works as follows: For each datapoint train the network to predict the real $x_1$. Then we flip a coin to decide whether to keep $x_1$ from the datapoint, or to replace it with a sample from the model $Q_{x_1}$. Then we train $Q_{x_2\vert x_1}$ on the $(x_1,x_2)$ pair obtained this way.
The scoring rule for selective sampling looks something like this:
$$ S(Q_{x_1,x_2},(x_1,x_2)) = - (1 - \epsilon) [ \mathbb{E}_{z \sim Q_{x_1}} \log Q_{x_2 \vert x_1}(x_2 \vert z) + \log Q_{x_1}(x_1)] - \epsilon \log Q_{x_2 , x_1}(x_1,x_2),$$
where $\epsilon$ is the probability with wich the true $x_1$ is used.
The authors suggest starting training with $\epsilon=1$ and annealing it so that by the end of the training $\epsilon=0$. So as far as the eventual optimum of SS is concerned, we only have to focus on what the first term of the scoring rule does. The second term is the good old log-likelihood so we know that part works.
After some math, one can show that scheduled sampling with a fixed $\epsilon$ minimises the following divergence between the true $P$ and the model $Q$:
$$D_{SS}[P\|Q] = KL[P_{x_1}\|Q_{x_1}] + (1-\epsilon) \mathbb{E}_{z\sim Q_{x_1}} KL[P_{x_2}\|Q_{x_2\vert x_1=z}] + \epsilon KL[P_{x_2\vert x_1}\|Q_{x_2\vert x_1}]$$
Now, if $\epsilon=1$, we recover the Kullback-Leibler divergence between the joint $P_{x_1,x_2}$ and $Q_{x_1,x_2}$, which is what we expect as it corresponds to maximum likelihood estimation. However, as $\epsilon$ is annealed to $0$, the objective function is somewhat strange, whereby the conditional distribution $Q_{x_2\vert x_1}$ is pushed to model the marginal distribution $P_{x_2}$, instead of $P_{x_2\vert x_1}$ as one would expect. One can therefore see that the factorised $Q^{*} = P_{x_1}P_{x_2}$ minimises this objective function.
#### what this means for text modeling
Extrapolating from the two variable case to longer sequences, one can see that the scheduled sampling objective would fail if minimised properly until convergence. Consider the case when the $\epsilon\approx 0$ stage is reached in the annealing schedule. Now consider what the RNN has to do to predict the $n$th character in a string during training. It is fed a random prefix string that was generated by itself but never seen any real data. Then the RNN has to give a probabilistic forecast of what the $n$th character in the training sentence is, having seen none of the previous characters in the sentence.
The optimal model that minimises this objective would completely ignore all the characters in the sentence so far, but keep a simple linear counter that indexes where it is within the sentence. Then it would emit a character from an index-specific marginal distribution of characters. This is the equivalent of the factorised trivial solution above.
Yes, such a model would be better at "recovering from its own mistakes", because at every character it would start independently from what it has generated so far. But this is at the cost of paying no attention whatsoever as to what the prefix of the sentence was. I believe the reason why this trivial behaviour was not observed in the paper is that the authors did not run the optimisation until convergence, and did not implement the full gradient of the objective function, as they discuss in the paper.
#### Constructive part of criticism
#### What to do instead of SS?
So the observed problem was that RNNs trained via maximum likelihood are unable to recover from their own mistakes early on in a sentence, when they are used to generate.
`The main reason for the observed problem is that the log-likelihood is a local scoring rule`
The local property of scoring rules means that at training time we only ever evaluate the model $Q$ on actually observed datapoints. So if the RNN is faced with a prefix subsequence that was not in the dataset, God knows what it's going to complete that sentence with.
The proper (shall I say strictly proper) way to fix this issue is to use
more
less