Member-only story

What LayerNorm really does for Attention in Transformers

Less Wright
3 min readMay 14, 2023

--

LayerNorm is more than scaling…(image credit)

2 things, not 1…

Normalization via LayerNorm has been part and parcel of the Transformer architecture for some time. If you asked most AI practitioners why we have LayerNorm, the generic answer would be that we use LayerNorm to normalize the activations on the forward pass and gradients on the backward.

But that default response turns out to only be partially correct.

A new paper by Brody, Alon and Yahav, titled the “On the Expressivity Role of LayerNorm in Transformer’s Attention” show that LayerNorm’s role goes much deeper.

LayerNorm actually provides two functions for Transformer’s Attention:

A — Projection: LayerNorm helps the Attention component craft an attention query such that all keys are equally accessible.
It does this by projecting the key vectors onto the same hyperplane, thus enabling the model to align the queries to be orthogonal to the keys.
And in doing so, obviates the need for the Attention component to learn how to do this on it’s own.

The paper contains the finer details, but in one image from the paper you can immediately grasp what is happening.

The benefits of LayerNorm projection in organizing key vectors (image from paper)

B — Scaling: This is the more obvious portion, that LayerNorm rescales the input. But what is that re-scaling really accomplishing? According to this paper, the underlying benefit is that scaling ensures two benefits:
1 — Every key has the potential to receive the ‘highest’ attention
2 — That no key can end up in an ‘un-selectable’ zone.

Similarly, a second image from the paper visually expedites building a mental model of this:

The benefits of scaling (image from paper)

The paper of course goes into much greater detail, but the point of this article was to help showcase the two key findings in an intuitive format, as their visuals drive home the point…

--

--

Less Wright
Less Wright

Written by Less Wright

PyTorch, Deep Learning, Object detection, Stock Index investing and long term compounding.

Responses (2)

Write a response