[link]
This paper considers the problem of structured output prediction, in the specific case where the output is a sequence and we represent the sequence as a (conditional) directed graphical model that generates from the first token to the last. The paper starts from the observation that training such models by maximum likelihood (ML) does not reflect well how the model is actually used at test time. Indeed, ML training implies that the model is effectively trained to predict each token conditioned on the previous tokens *from the ground truth* sequence (this is known as "teacher forcing"). Yet, when making a prediction for a new input, the model will actually generate a sequence by generating tokens one after another and conditioning on *its own predicted tokens* instead. So the authors propose a different training procedure, where at training time each *conditioning* ground truth token is sometimes replaced by the model's previous prediction. The choice of replacing the ground truth by the model's prediction is made by "flipping a coin" with some probability, independently for each token. Importantly, the authors propose to start with a high probability of using the ground truth (i.e. start close to ML) and anneal that probability closer to 0, according to some schedule (thus the name Schedule Sampling). Experiments on 3 tasks (image caption generation, constituency parsing and speech recognition) based on neural networks with LSTM units, demonstrate that this approach indeed improves over ML training in terms of the various performance metrics appropriate for each problem, and yields better sequence prediction models. #### My two cents Big fan of this paper. It both identifies an important flaw in how sequential prediction models are currently trained and, most importantly, suggests a solution that is simple yet effective. I also believe that this approach played a nonnegligible role in Google's winner system for image caption generation, in the Microsoft COCO competition. My alternative interpretation of why Scheduled Sampling helps is that ML training does not inform the model about the relative quality of the errors it can make. In terms of ML, it is as bad to put high probability on an output sequence that has just 1 token that's wrong, than it is to put the same amount of probability on a sequence that has all tokens wrong. Yet, say for image caption generation, outputting a sentence that is one word away from the ground truth is clearly preferable from making a mistake on a words (something that is also reflected in the performance metrics, such as BLEU). By training the model to be robust to its own mistakes, Scheduled Sampling ensures that errors won't accumulate and makes predictions that are entirely off much less likely. An alternative to Scheduled Sampling is DAgger (Dataset Aggregation: \cite{journals/jmlr/RossGB11}), which briefly put alternates between training the model and adding to the training set examples that mix model predictions and the ground truth. However, Scheduled Sampling has the advantage that there is no need to explicitly create and store that increasingly large dataset of sampled examples, something that isn't appealing for online learning or learning on large datasets. I'm also very curious and interested by one of the direction of future work mentioned in the conclusion: figuring out a way to backprop through the stochastic predictions made by the model. Indeed, as the authors point out, the current algorithm ignores the fact that, by sometimes taking as input its previous prediction, this induces an additional relationship between the model's parameters and its ultimate prediction, a relationship that isn't taken into account during training. To take it into account, you'd need to somehow backpropagate through the stochastic process that generated the previous token prediction. While the work on variational autoencoders has shown that we can backprop through gaussian samples, backpropagating through the sampling of a discrete multinomial distribution is essentially an open problem. I do believe that there is work that tried to tackle propagating through stochastic binary units however, so perhaps that's a start. Anyways, if the authors could make progress on that specific issue, it could be quite useful not just in the context of Schedule Sampling, but possibly in the context of training networks with discrete stochastic units in general!
Your comment:
