[link]
Summary by CodyWild 3 months ago
This paper came on my radar after winning Best Paper recently at ICLR, and all in all I found it a clever way of engineering a somewhat complicated inductive bias into a differentiable structure. The empirical results weren’t compelling enough to suggest that this structural shift made a regime-change difference in performing, but it does seem to have some consistently stronger ability to do syntactic evaluation across large gaps in sentences.
The core premise of this paper is that, while language is to some extent sequence-like, it is in a more fundamental sense tree-like: a recursive structure of modified words, phrases, and clauses, aggregating up to a fully complete sentence. In practical terms, this cashes out to parse trees, labels akin to the sentence diagrams that you or I perhaps did once upon a time in grade school.
https://i.imgur.com/GAJP7ji.png
Given this, if you want to effectively model language, it might be useful to have a neural network structure explicitly designed to track where you are in the tree. To do this, the authors of this paper use a clever activation function scheme based on the intuition that you can think of jumping between levels of the tree as adding information to the stack of local context, and then removing that information from the stack when you’ve reached the end of some local phrase. In the framework of a LSTM, which has explicit gating mechanisms for both “forgetting” (removing information from cell memory) and input (adding information to the representation within cell memory) this can be understood as forcing a certain structure of input and forgetting, where you have to sequentially “close out” or add nodes as you move up or down the tree.
To represent this mathematically, the authors use a new activation function they developed, termed cumulative max or cumax. In the same way that the softmax is a differentiable (i.e. “soft”) version of an argmax, the cumulative max is a softened version of a vector that has zeros up to some switch point k, and ones thereafter. If you had such a vector as your forget mask, then “closing out” a layer in your tree would be equivalent to shifting the index where you switch from 0 to 1 up by one, so that a layer that previously had a “remember” value of 1.0 now is removing its content from the stack. However, since we need to differentiate, this notional 0/1 vector is instead represented as a cumulative sum of a softmax, which can be thought of as the continuous-valued probability that you’ve reached that switch-point by any given point in the vector. Outside of the abstractions of what we’re imagining this cumax function to represent, in a practical sense, it does strictly enforce that you monotonically remember or input more as you move along the vector. This has the practical fact that the network will be biased towards remembering information at one end of the representation vector for longer, meaning it could be a useful inductive bias around storing information that has a more long-term usefulness to it. One advantage that this system has over a previous system that, for example, had each layer of the LSTM operate on a different forgetting-decay timescale, is that this is a soft approximation, so, up to the number of neurons in the representation, the model can dynamically approximate whatever number of tree nodes it likes, rather than being explicitly correspondent with the number of layers.
Beyond being a mathematically clever idea, the question of whether it improves performance is a little mixed. It does consistently worse at tasks that require keeping track of short term dependency information, but seems to do better at more long-term tasks, although not in a perfectly consistent or overly dramatic way. My overall read is that this is a neat idea, and I’m interested to see if it gets built on, as well as interested to see later papers that do some introspective work to validate whether the model is actually using this inductive bias in the tree-like way that we’re hoping and imagining it will.

more
less