Attention Is All You Need: A Walkthrough
Breaking down the core ideas behind the transformer architecture — self-attention, positional encoding, and multi-head attention — with equations and implementation snippets.
## table of contents
Self-Attention Mechanism
The key innovation of transformers is the self-attention mechanism. Given an input sequence, self-attention computes a weighted sum of all positions, where the weights are determined by the compatibility of each pair of positions.
For an input matrix , we compute queries, keys, and values:
The attention output is then:
The scaling factor prevents the dot products from growing too large, which would push the softmax into regions with extremely small gradients.
Multi-Head Attention
Rather than performing a single attention function, transformers use multi-head attention to jointly attend to information from different representation subspaces:
where each head is computed as:
Implementation
Here’s a minimal PyTorch implementation of scaled dot-product attention:
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / d_k ** 0.5
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V)
Architecture Overview
flowchart TB
IN["Input Tokens"] --> EMB["Embedding"]
EMB --> PE["+ Positional Encoding"]
PE --> Q["Q"] & K["K"] & V["V"]
Q & K & V --> H1["Head 1"]
Q & K & V --> H2["Head 2"]
Q & K & V --> Hh["Head h"]
H1 & H2 & Hh --> CAT["Concat + Project"]
CAT --> ADD1(("+"))
PE -.->|residual| ADD1
ADD1 --> LN1["Layer Norm"]
LN1 --> FF["Feed-Forward Network"]
FF --> ADD2(("+"))
LN1 -.->|residual| ADD2
ADD2 --> LN2["Layer Norm"]
LN2 --> OUT["Output"]
Positional Encoding
Since transformers have no inherent notion of position, we add positional encodings to the input embeddings. The original paper uses sinusoidal functions:
The beauty of this encoding is that it allows the model to learn relative positions, since can be represented as a linear function of .
Key Takeaways
- Self-attention has complexity with sequence length
- Multi-head attention allows the model to focus on different aspects simultaneously
- Positional encodings inject sequence order information
- Layer normalization and residual connections are critical for training stability
This post is a simplified walkthrough. For the full details, see Vaswani et al., 2017.