Intuitive understanding of the transformer model’s secret sauce — Self-attention
- Thomas Benham
- Dec 27, 2024
- 6 min read
See original Medium post here.
What’s driving the constant and accelerating AI innovation we are seeing almost weekly now, particularly in large language models (LLMs)? Or more simply, how on earth did openai create their new ChatGPT bot? The primary answer is the innovation of transformers.
Whether you’re familiar with deep learning to some degree but new to transformers, or familiar with plain old machine learning and “transformer curious”, this post is for you.
The goal of this post is to provide an intuitive, less “mathy”, understanding of the major innovation and secret sauce of transformers. The innovation of self-attention.
Specifically, I am going to look at this from the perspective of language transformers.
I’ll provide a range of amazing references for deeper dives into the entirety of the transformer architecture as well but this post will focus on understanding the intuitions of self-attention.
Why write this? Much has been written about how transformer models work. But a lot of this is just restating the original paper. This doesn’t really provide an alternative perspective for understanding.
The way I, and I assume most people, learn complex quantitative things like models is not sequentially but iteratively. That is, gain some intuitive insight, play with the math, run some code, go back to the intuition, back to math etc. Note, the good thing about ML and DL is that it’s a contact sport, as one of my professors once said. So there’s also the key element of building the model i.e building some code, that enables hands on experience. So read, calc, build, calc, read, calc, build…
As said, my explicit goal here is to provide the intuitive (read) piece, to go with the quantitative (calc) and create (build) pieces.
Big picture. Lay of the land
First the calc and build parts:
Here is the original transformer paper if you want to read it.
Here is the best restatement of that paper in a more user friendly format. This is a really good somewhat less formal deep dive.
And here is a walk thru of building a transformer model from scratch by the hugging face team, along with the accompanying book.
I still think however there is a lot of room for and value in creating some conceptual or intuitive approaches to understanding transformers that can really help on the journey to mastering these models
So to back up, there are 3 flavors of Transformer models — Encoder only models, decoder only models, encoder-decoder models.
But if you can wrap your head around the encoder piece, you’re almost all the way there to understanding them all. So what do you need to understand encoders? You need to understand:
Encodings & embeddings (prior innovations)
Self-attention — the novel piece of transformer models
FF NN (a prior innovation)
Classification & prediction heads (prior innovations)
Plus of course a whole bunch of prerequisite concepts and math e.g. softmax, normalization etc (prior innovations)
At a high or simplified level, deep learning models can generally be broken down into the body and the head. The body transforms all the input data and the head makes the prediction based on that data.
So the body is made up of everything but the classification & prediction piece.
For encoders the body is generally a sequence of Encoding & embeddings > self-attention > FF NN > some normalization. Then the data is fed into the head.
The primary innovation of transformers is this self-attention piece as I said. But…it took me some time to understand this process.
So the deep dive here is going to be on self-attention. I am just going to assume you understand embeddings and FF NN. But:
If you need a refresher on or even introduction to embeddings, this is the best reference I have found (link). This is important to understand when trying to understand self-attention.
If you need a refresher on or even introduction to FF NN, this is the best reference I have found (link). This is not important to understand when trying to understand self-attention but is if you want to understand deep learning in general.
Diving into self-attention
Now that you hopefully understand embeddings, let’s dive into self-attention.
What does self-attention mean? For language models, fundamentally it means the model is going to learn the context of the words in the input sequence e.g. a sentence.
Learning the context here means understanding which words inform the meaning of, and relationship between, other words, in a given sentence, or body of text. The sort of thing we humans do without even thinking about it.
As this context is learnt, it is going to pay more “attention” (or quantitatively speaking more weight) to words that have some meaningful relationship. How does it do this? It simply adjusts each word’s embedding values.
At its most basic level, the whole self-attention process is just taking an input matrix X with certain dimensions and outputting X’, which has the same dimensions as X but we’ve just adjusted the values in the matrix.
Let’s look at an example. The one used by the hugging face team are two sentences using similar words with different meanings.
“Time flies like an arrow”
“Fruit flies like a banana”
Clearly the word “flies” means different things in each sentence.
But given what we (hopefully) know about embeddings, we know that when we feed these token embeddings into the model, the embedding for every word, including “flies”, effectively represents each word’s relative position in some high dimensional “space”, in relation to ALL other words in an entire vocabulary of say 30,000 words. So the relative position of “flies’’ is defined by its potential for usage in multiple contexts. Now remember, embeddings are just vectors of numbers. And these vectors have meaning only by way of their relationship to each other.
Back to our example, and the goal of self-attention. Now we have the word “flies” in a very specific context i.e. a 5 letter sentence. As such, when looking at the first sentence, we want to in effect drag its relative position of “flies” in our vocabulary space, closer to the words “time” and “arrow”, because these words help contextualize the usage of “flies”.
In fact, we’re going to drag the “flies” position in space i.e. adjust it’s embedding vector values, in relation to all the words in this specific sentence, relative to their importance. The relative importance of “like” and “an” are less than “time” and “arrow”, so “time” and “arrow” will be more highly weighted i.e. get more attention, when we adjust the embedding for “flies”.
In effect, that’s all that self-attention means. It means the process of adjusting the embedding values or a word, given a specific input. That’s it. The actual output from the self-attention process is the exact same dimension as the input, as I said. We have just altered the embedding contents i.e. we have changed the vector of numbers, through the process described above.
Obviously self-attention is an equal opportunity process, so it adjusts all embeddings, not just that of “flies”. Every embedding is adjusted in some way, to reflect each word’s relationship to the other words in the given input.
Now, how it does this, requires a deeper dive into the calc and build parts. Here, Peter Bloem’s post (above) starts to become very useful again. But for just a taste, if we put our “calc” hat on, one easy way to determine the relationship between 2 vectors (in our case word embeddings) is to look at the sum of their dot product i.e. just multiplying them together (pairwise) and add the result up. And this is exactly what the self-attention process does. Larger dot product results infer greater importance.
It also does some rescaling of the results i.e. reduces the magnitude of all results, and creates multiple versions of each embedding, so it can learn multiple relationships, but these are just either standard processes not unique to self-attention, or a replication of self-attention.
The fundamental point to intuit, is the self-attention process is simply adjusting the embedding values based on the specific sentence input to the model.
Then it (the adjusted embedding) continues on its merry way into the FF NN and classification task or whatever the goal is of the transformer.
The transformers architecture is actually much simpler than some prior deep learning models, such as LSTM, RNN etc. Wrap your head and some sample code around self-attention, and you’re well on your way to building or tuning your own transformers.
Now go and have some fun with ChatGPT https://chat.openai.com
Comments