What LayerNorm really does for Attention in Transformers
2 things, not 1…
Normalization via LayerNorm has been part and parcel of the Transformer architecture for some time. If you asked most AI practitioners why we have LayerNorm, the generic answer would be that we use LayerNorm to normalize the activations on the forward pass and gradients on the backward.
But that default response turns out to only be partially correct.
A new paper by Brody, Alon and Yahav, titled the “On the Expressivity Role of LayerNorm in Transformer’s Attention” show that LayerNorm’s role goes much deeper.
LayerNorm actually provides two functions for Transformer’s Attention:
A — Projection: LayerNorm helps the Attention component craft an attention query such that all keys are equally accessible.
It does this by projecting the key vectors onto the same hyperplane, thus enabling the model to align the queries to be orthogonal to the keys.
And in doing so, obviates the need for the Attention component to learn how to do this on it’s own.
The paper contains the finer details, but in one image from the paper you can immediately grasp what is happening.