[link]
Summary by CodyWild 1 month ago
The Transformer architecture - which uses a structure entirely based on key-value attention mechanisms to process sequences such as text - has taken over the worlds of language modeling and NLP in the past three years. However, Transformers at the scale used for large language models have huge computational and memory requirements.
This is largely driven by the fact that information at every step in the sequence (or, in the so-far-generated sequence during generation) is used to inform the representation at every other step. Although the same *parameters* are used for each of these pairwise calculation between keys and queries at each step, this is still a pairwise, and thus N^2, calculation, which can get very costly when processing long sequences on the scale of tens of thousands of tokens. This cost comes from both computation and memory, with memory being the primary focus of this paper, because the max memory requirements of a network step dictate the hardware it can be run on, in a way that the pure amount of computation that needs to be performed doesn't. A L^2 attention calculation, as naively implemented in a set of matrix multiplies, not only has to perform N^2 calculations, but has to be able to hold N^2 values in memory while performing the softmax and weighted sum that is the attention aggregation process. Memory requirements in Transformers are also driven by
- The high parameter counts of dense layers within the network, which have less parameter use per calculation than attention does, and
- The fact that needing to pass forward one representation per element in the list at each layer necessitates cumulatively keeping all the activations from each layer in the forward pass, so that you can use them to calculate derivatives in the backward pass.
This paper, and the "Reformer" architecture they suggest, is less a single idea, and more a suite of solutions targeted to make Transformers more efficient in use of both compute and memory.
1. The substitution of Locality Sensitive Hashing for normal key-query attention is a strategy for reducing the L^2 compute and memory requirements of the raw attention calculation. The essential idea of attention is "calculate how well the query at position i is matched by the key at every other position, and then use those matching softmax weights to calculate a weighted sum of the representations at each other position". If you consider keys and queries to be in the same space, you can think of this as a similarity calculation between positions, where you want to most highly weight the most similar positions to you in the calculation of your own next-layer value. In this spirit, the authors argue that this weighted sum will be mostly influenced by the highest-similarity positions within the softmax, and so, instead of performing attention over the whole sequence, we can first sort positions into buckets based on similarity of their key/query vector for a given head, and perform attention weighting within those buckets.
https://i.imgur.com/tQJkfGe.png
This has the advantage that the first step, of assigning a position's key/query vector to a bucket, can be done for each position individually, rather than with respect to the value at another position. In this case, this bucketing is performed by a Locality Sensitive Hashing algorithm, which works by projecting each position's vector into a lower-dimensional space, and then taking the index of that vector which has the max value. This is then used as a bucket ID for performing full attention within. This shifts the time complexity of attention from O(L^2) to O(LlogL), since for each position in the length, you only need to calculate explicit pairwise similarity for the log(L) other elements in its bucket
2. Reversible layers. This addresses the problem of needing to keep activations from each layer around for computing the backward-pass derivatives. It takes an idea used in RevNets, which proposes a reversible alternative to the commonly used ResNet architecture. In a normal ResNet scenario, Y = X + F(X), where F(X) is the computation of a single layer or block, and Y are the resulting activations passed to the next layer. In this setup, you can't go back from Y to get the value of X if you discard X for memory reasons, because the difference between the two is a function of X, which you don't have. As an alternative, RevNets define a sort of odd crosswise residual structure, that starts by partitioning X into two components, X1 and X2, and the output Y into Y1 and Y2, and performing the calculation shown below.
https://i.imgur.com/EK2vBkK.png
This allows you to work backward, getting X2 from Y1 and Y2 (both of which you have as outputs), and then get X1 from knowing the other three parts.
https://i.imgur.com/uLTrdyf.png
This means that as soon as you have the activations at a given layer, you can discard earlier layer activations, which makes things a lot more memory efficient.
3. There's also a proposal to do (what I *think* is) a pretty basic chunking of feed forward calculations across sequence length, and performing feedforward calculations on parts of the sequence rather than the whole thing. The latter would be faster with vectorized computing, for parallelization reasons, but the former is more memory efficient

more
less