Neural Machine Translation by Jointly Learning to Align and TranslateNeural Machine Translation by Jointly Learning to Align and TranslateBahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua2014
Paper summaryjoecohenOne core aspect of this attention approach is that it provides the ability to debug the learned representation by visualizing the softmax output (later called $\alpha_{ij}$) over the input words for each output word as shown below.
https://i.imgur.com/Kb7bk3e.png
In this approach each unit in the RNN they attend over the previous states, unitwise so the length can vary, and then apply a softmax and use the resulting probabilities to multiply and sum each state. This forms the memory used by each state to make a prediction. This bypasses the need for the network to encode everything in the state passed between units.
Each hidden unit is computed as:
$$s_i = f(s_{i−1}, y_{i−1}, c_i).$$
Where $s_{i−1}$ is the previous state and $y_{i−1}$ is the previous target word. Their contribution is $c_i$. This is the context vector which contains the memory of the input phrase.
$$c_i = \sum_{j=1} \alpha_{ij} h_j$$
Here $\alpha_{ij}$ is the output of a softmax for the $j$th element of the input sequence. $h_j$ is the hidden state at the point the RNN was processing the input sequence.
TLDR; The authors propose a novel "attention" mechanism that they evaluate on a Machine Translation task, achieving new state of the art (and large improvements in dealing with long sentences). Standard seq2seq models typically try to encode the input sequence into a fixed length vector (the last hidden state) based on which the decoder generates the output sequence. However, it is unreasonable to assume the all necessary information can be encoded in this one vector. Thus, the authors let the decoder depend on a attention vector, which based on the weighted sum (expectation) of the input hidden states. The attention weights are learned jointly, as part of the network architecture.
#### Data Sets and model performance
Bidirectional GRU, 1000 hidden units. Multilayer maxout to compute output probabilities in decoder.
WMT '14 BLEU: 36.15
#### Key Takeaways
- Attention mechanism is a weighted sum of the hidden states computed by the encoder. The weights come from a softmax-normalized attention function (a perceptron in this paper), which are learned during training.
- Attention can be expensive, because it must be evaluated for each encoder-decoder output pair, resulting in a len(x) * len(y) matrix.
- The attention mechanism improves performance across the board, but has a particularly large affect on long sentences, confirming the hyptohesis that the fixed vector encoding is a bottleneck.
- The authors use a bidirectional-GRU, concatenating both hidden states into a final state at each time step.
- It is easy to visualize the attention matrix (for a single input-ouput sequence pair). The authors show that in the case of English to French translations the matrix has large values on the diagonal, showing the these two languages are well aligned in terms of word order.
#### Question/Notes
- The attention mechanism seems limited in that it computes a simple weighted average. What about more complex attention functions that allow input states to interact?
This paper introduces an attention mechanism (soft memory access)
for the task of neural machine translation. Qualitative and quantitative
results show that not only does their model achieve state-of-the-art BLEU
scores, it performs significantly well for long sentences which was a
drawback in earlier NMT works. Their motivation comes from the fact that
encoding all information from an input sentence into a single fixed length
vector and using that in the decoder was probably a bottleneck. Instead,
their decoder uses an attention vector, which is a weighted sum of the
input hidden states, and is learned jointly. Main contributions:
- The encoder is a bidirectional RNN, in which they take the annotation
of each word to be the concatenation of the forward and backward RNN states.
The idea is that the hidden state should encode information from both the
previous and following words.
- The proposed attention mechanism is a weighted sum of the input hidden
states, the weights for which come from an attention function (a single-layer
perceptron, which takes as input the previous hidden state of the decoder and
the current word annotation from the encoder) and are softmax-normalized.
## Strengths
- Incorporating the attention mechanism shows large improvements on
longer sentences. The attention matrix is easily interpretable as well,
and visualizations in the paper show that higher weights are being assigned
to input words that correspond to output words irrespective of their order
in the sequence (unlike an attention model that uses a mixture of Gaussians
which is monotonic).
## Weaknesses / Notes
- Their model formulation to capture long-term dependencies is far more
principled than Sutskever et al's inverting the input idea. They should
have done a comparative study with their approach as well though.
One core aspect of this attention approach is that it provides the ability to debug the learned representation by visualizing the softmax output (later called $\alpha_{ij}$) over the input words for each output word as shown below.
https://i.imgur.com/Kb7bk3e.png
In this approach each unit in the RNN they attend over the previous states, unitwise so the length can vary, and then apply a softmax and use the resulting probabilities to multiply and sum each state. This forms the memory used by each state to make a prediction. This bypasses the need for the network to encode everything in the state passed between units.
Each hidden unit is computed as:
$$s_i = f(s_{i−1}, y_{i−1}, c_i).$$
Where $s_{i−1}$ is the previous state and $y_{i−1}$ is the previous target word. Their contribution is $c_i$. This is the context vector which contains the memory of the input phrase.
$$c_i = \sum_{j=1} \alpha_{ij} h_j$$
Here $\alpha_{ij}$ is the output of a softmax for the $j$th element of the input sequence. $h_j$ is the hidden state at the point the RNN was processing the input sequence.