(See also a more thorough summary in [a LaTeX PDF][1].)
This paper has some nice clear theory which bridges maximum likelihood (supervised) learning and standard reinforcement learning. It focuses on *structured prediction* tasks, where we want to learn to predict $p_\theta(y \mid x)$ where $y$ is some object with complex internal structure.
We can agree on some deficiencies of maximum likelihood learning:
- ML training fails to assign **partial credit**. Models are trained to maximize the likelihood of the ground-truth outputs in the dataset, and all other outputs are equally wrong. This is an increasingly important problem as the space of possible solutions grows.
- ML training is potentially disconnected from **downstream task reward**. In machine translation, we usually want to optimize relatively complex metrics like BLEU or TER. Since these metrics are non-differentiable, we have to settle for optimizing proxy losses that we hope are related to the metric of interest.
Reinforcement learning offers an attractive alternative in theory. RL algorithms are designed to optimize non-differentiable (even stochastic) reward functions, which sounds like just what we want. But RL algorithms have their own problems with this sort of structured output space:
- Standard RL algorithms rely on samples from the model we are learning, $p_\theta(y \mid x)$. This becomes intractable when our output space is very complex (e.g. 80-token sequences where each word is drawn from a vocabulary of 80,000 words).
- The reward spaces for problems of interest are extremely sparse. Our metrics will assign 0 reward to most of the 80^80K possible outputs in the translation problem in the paper.
- Vanilla RL doesn't take into account the ground-truth outputs available to us in structured prediction.
This paper designs a solution which combines supervised learning with a reinforcement learning-inspired smoothing method. Concretely, the authors design an **exponentiated payoff distribution** $q(y \mid y^*; \tau)$ which assigns high mass to high-reward outputs $y$ and low mass elsewhere. This distribution is used to effectively smooth the loss function established by the ground-truth outputs in the supervised data. We end up optimizing the following objective:
$$\mathcal L_\text{RML} = - \mathbb E_{x, y^* \sim \mathcal D}\left[ \sum_y q(y \mid y^*; \tau) \log p_\theta(y \mid x) \right]$$
This optimization depends on samples from our dataset $\mathcal D$ and, more importantly, the stationary payoff distribution $q$. This contrasts strongly with standard RL training, where the objective depends on samples from the non-stationary model distribution $p_\theta$. To make that clear, we can rewrite the above with another expectation:
$$\mathcal L_\text{RML} = - \mathbb E_{x, y^* \sim \mathcal D, y \sim q(y \mid y^*; \tau)}\left[ \log p_\theta(y \mid x) \right]$$
### Model details
If you're interested in the low-level details, I wrote up the gist of the math in [this PDF][1].
### Analysis
#### Relationship to label smoothing
This training approach is mathematically equivalent to label smoothing, applied here to structured output problems. In next-word prediction language modeling, a popular trick involves smoothing the target distributions by combining the ground-truth output with some simple base model, e.g. a unigram word frequency distribution. (This just means we take a weighted sum of the one-hot vector from our supervised data and a normalized frequency vector calculated on some corpus.) Mathematically, the cross entropy with label smoothing is
$$\mathcal L_\text{ML-smooth} = - \mathbb E_{x, y^* \sim \mathcal D} \left[ \sum_y p_\text{smooth}(y; y^*) \log p_\theta(y \mid x) \right]$$
(The equation above leaves out a constant entropy term.)
The gradient of this objective looks exactly the same as the reward-augmented ML gradient from the paper:
$$\nabla_\theta \mathcal L_\text{ML-smooth} = \mathbb E_{x, y^* \sim \mathcal D, y \sim p_\text{smooth}} \left[ \log p_\theta(y \mid x) \right]$$
So reward-augmented likelihood is equivalent to label smoothing, where our smoothing distribution is log-proportional to our downstream reward function.
#### Relationship to distillation
Optimizing the reward-augmented maximum likelihood is equivalent to minimizing the KL divergence $$D_\text{KL}(q(y \mid y^*; \tau) \mid\mid p_\theta(y \mid x))$$
This divergence reaches zero iff $q = p$. We can say, then, that the effect of optimizing on $\mathcal L_\text{RML}$ is to **distill** the reward function (which parameterizes $q$) into the model parameters $\theta$ (which parameterize $p_\theta$).
It's exciting to think about other sorts of more complex models that we might be able to distill in this framework. The unfortunate (?) restriction is that the "source" model of the distillation ($q$ in this paper) must admit to efficient sampling.
#### Relationship to adversarial training
We can also view reward-augmented maximum likelihood training as a data augmentation technique: it synthesizes new "partially correct" examples using the reward function as a guide. We then train on all of the original and synthesized data, again weighting the gradients based on the reward function.
Adversarial training is a similar data augmentation technique which generates examples that force the model to be robust to changes in its input space (robust to changes of $x$). Both adversarial training and the RML objective encourage the model to be robust "near" the ground-truth supervised data. A high-level comparison:
- Adversarial training can be seen as data augmentation in the input space; RML training performs data augmentation in the output space.
- Adversarial training is a **model-based data augmentation**: the samples are generated from a process that depends on the current parameters during training. RML training performs **data-based augmentation**, which could in theory be done independent of the actual training process.
---
Thanks to Andrej Karpathy, Alec Radford, and Tim Salimans for interesting discussion which contributed to this summary.
[1]: https://drive.google.com/file/d/0B3Rdm_P3VbRDVUQ4SVBRYW82dU0/view
more
less