Attention and Transformers¶
What This Is¶
Attention lets a model look at all positions in the input at once and decide which ones matter for each output. Transformers are architectures built entirely on attention — no recurrence, no convolutions — and they dominate modern NLP, vision, and multimodal AI.
The core idea: instead of processing sequences left-to-right, let every position attend to every other position in parallel.
Two things make attention a workflow skill, not only a theory topic:
- the attention weights are an inspectable artifact — you can look at what the model is using
- the same block powers encoders (bidirectional) and decoders (causal), and the mask is what separates them
Most real mistakes with transformers come from misreading the mask, misreading the positional encoding, or training from scratch when fine-tuning was the right call.
When You Use It¶
- processing text where long-range dependencies matter
- building or fine-tuning language models (BERT, GPT, T5)
- applying transformers to vision (ViT) or audio
- any task where the relationship between distant elements is important
- multi-modal work where attention bridges modalities (CLIP, Flamingo)
- you need to inspect which input tokens a model used for its output
The Attention Mechanism¶
Attention computes a weighted sum of values, where the weights come from comparing queries to keys:
Attention(Q, K, V) = softmax(QKᵀ / √d_k) × V
In plain terms:
- Each position creates a query ("what am I looking for?")
- Each position creates a key ("what do I contain?")
- Each position creates a value ("what information do I carry?")
- Queries compare against all keys to produce attention weights
- Values are combined using those weights
The scaling by √d_k is not cosmetic. Without it, the dot products grow with dimension, softmax saturates, and gradients vanish through the softmax. Forgetting it is the single most common from-scratch bug.
Self-Attention Step By Step¶
import torch
import torch.nn.functional as F
def self_attention(X, W_q, W_k, W_v):
Q = X @ W_q # queries
K = X @ W_k # keys
V = X @ W_v # values
d_k = K.shape[-1]
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5) # scaled dot-product
weights = F.softmax(scores, dim=-1) # attention weights
output = weights @ V # weighted combination
return output, weights
What to inspect at this step:
- the
weightstensor — every row should sum to 1.0 and be non-negative - how the pattern changes when you swap two rows of
X(should permute rows and columns ofweightsconsistently) - the scale of
scoresbefore softmax — if it is saturating the softmax, your scaling is off
Multi-Head Attention¶
Instead of one attention operation, split Q/K/V into multiple "heads" that each attend to different aspects:
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_out = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
scores = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.W_out(out), attn
Why multiple heads: one head might attend to syntax, another to semantics, another to position. The model learns to specialize.
Rules of thumb:
d_headshould usually be 32 to 128 — too small and heads can't specialize, too large wastes parametersn_heads = d_model // d_head— tuned_headand deriven_heads, not the other way around- number of heads is not a quality knob — more heads with the same
d_modelmeans each head is smaller
Causal Masking For Decoders¶
A decoder must not look at future tokens. The causal mask is a lower-triangular matrix that zeros out attention to future positions:
import torch
def causal_mask(T, device):
return torch.tril(torch.ones(T, T, device=device)).bool()
# in the forward:
mask = causal_mask(T, x.device) # (T, T)
scores = scores.masked_fill(~mask, float("-inf"))
If you forget the causal mask in a decoder, the model will "cheat" by looking at future tokens during training. Training loss will drop beautifully. Generation will then collapse, because at inference time the future tokens are gone.
Inspection habit: print attn[0, 0] and make sure all values above the diagonal are zero.
Transformer Block¶
A transformer block combines:
- Multi-head self-attention — look at all positions
- Add & Norm — residual connection + layer normalization
- Feed-forward network — two linear layers with a non-linearity
- Add & Norm — another residual connection
Input → LayerNorm → Self-Attention → + → LayerNorm → FFN → + → Output
| ^ | ^
|____________________________| |________|
(residual) (residual)
Two conventions exist:
- post-norm (original paper):
x + Attn(LN(x))thenx + FFN(LN(x))— classic, but harder to train deep - pre-norm (modern default):
LN(x) → Attn → +thenLN(x) → FFN → +— more stable gradients at depth
Pre-norm is the standard in most modern implementations. If a paper trained to 32+ layers, it is almost certainly pre-norm.
Full pre-norm block in PyTorch:
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)
self.drop = nn.Dropout(dropout)
def forward(self, x, mask=None):
h, _ = self.attn(self.ln1(x), mask=mask)
x = x + self.drop(h)
x = x + self.drop(self.ff(self.ln2(x)))
return x
Positional Encoding¶
Attention has no notion of position — "the cat sat" and "sat the cat" look the same. Positional encoding injects order.
Sinusoidal (original)¶
import math, torch
def sinusoidal_pe(seq_len, d_model):
pe = torch.zeros(seq_len, d_model)
position = torch.arange(seq_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
Fixed, non-learned, extrapolates to longer sequences than seen in training (approximately).
Learned positional embeddings¶
self.pos_emb = nn.Embedding(max_len, d_model)
x = token_emb + self.pos_emb(torch.arange(T, device=x.device))
Used in BERT, GPT-2. Does not extrapolate past max_len.
Rotary position embedding (RoPE)¶
The modern default in LLaMA, Mistral, and many others. Instead of adding a positional vector, RoPE rotates the query/key vectors in 2D subspaces by an angle that depends on position. The dot product q·k then only depends on the relative position, which is what most downstream tasks actually need.
Key property: RoPE extrapolates better than learned embeddings and is simpler to implement than relative-attention schemes.
ALiBi¶
A bias added directly to the attention scores that penalizes distance. Cheap, extrapolates well, used in BLOOM and some MPT models.
Which to use:
- small from-scratch project: sinusoidal or learned — fastest to implement
- any serious LLM work: RoPE — it is now the default
- if extrapolation to long contexts at inference is a hard requirement: ALiBi or RoPE with position interpolation
Encoder vs Decoder vs Encoder-Decoder¶
| Component | Sees | Mask | Used By |
|---|---|---|---|
| Encoder | all positions | none (bidirectional) | BERT, RoBERTa, ViT, sentence embeddings, classification |
| Decoder | only past positions | causal | GPT family, LLaMA, text generation |
| Encoder-Decoder | encoder: all; decoder: past + cross-attention to encoder | causal on decoder; cross-attn on output | T5, BART, translation, summarization |
Cross-attention is the same operation as self-attention, but queries come from the decoder and keys/values come from the encoder output.
Decision tree for which to build / fine-tune:
- classification, embedding extraction, token-level tagging → encoder-only (BERT-family)
- open-ended generation, chat, code completion → decoder-only (GPT-family)
- input-to-output text tasks with a strong structural contract (translation, summarization with tight constraints) → encoder-decoder (T5-family)
Most modern LLM work uses decoder-only because the architecture is simpler to scale and the same model handles classification via prompting.
KV-Cache For Inference¶
At inference time, decoder attention recomputes keys and values for every past token on every new token. The fix is to cache them:
# pseudocode
kv_cache = [] # list of (k, v) per layer
for step in range(max_new_tokens):
x = embed(last_token)
for layer_idx, block in enumerate(blocks):
k_new, v_new = block.attn.project_kv(x)
if step > 0:
k_prev, v_prev = kv_cache[layer_idx]
k_new = torch.cat([k_prev, k_new], dim=-2)
v_new = torch.cat([v_prev, v_new], dim=-2)
kv_cache[layer_idx] = (k_new, v_new)
x = block.forward_with_kv(x, k_new, v_new)
next_token = sample(lm_head(x))
Without KV-cache, generation is O(T²) per step. With it, each step is O(T). The cache is what makes modern LLM inference affordable.
What to inspect:
- memory usage — KV-cache can dominate GPU memory for long contexts
- whether your attention is masked correctly after concatenation (only the new token attends to everything; past tokens see their own past)
Vision Transformer (ViT)¶
A transformer works on images by treating the image as a sequence of patches:
class ViTPatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_ch=3, d_model=768):
super().__init__()
self.proj = nn.Conv2d(in_ch, d_model, kernel_size=patch_size, stride=patch_size)
n_patches = (img_size // patch_size) ** 2
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_emb = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2) # (B, n_patches, d_model)
cls = self.cls_token.expand(x.size(0), -1, -1)
x = torch.cat([cls, x], dim=1)
return x + self.pos_emb
Key design calls:
- patch size 16 → 196 patches for a 224×224 image — enough for most tasks
- the
[CLS]token is read out for classification after the transformer stack - ViT is data-hungry — fine-tuning a pretrained backbone beats from-scratch for anything except very large datasets
Efficiency Trade-Offs¶
Self-attention is O(n²) in sequence length. Three scaling responses exist:
- FlashAttention — same O(n²) math, but I/O-aware implementation that is 2–4× faster and uses less memory in practice
- Sliding-window attention — each token only attends to the nearest
wtokens (used in Longformer, Mistral) - Linear attention — rewrite the softmax to get O(n) (used in Performer, RWKV) — works but sacrifices some quality
Decision order:
- can you just use FlashAttention? (almost always: yes, via
torch.nn.functional.scaled_dot_product_attentionin recent PyTorch) - do you actually need longer context? if no, stop
- can sliding window cover your task? try that before linear attention
Common Transformer Models To Know¶
| Model | Architecture | Main Use |
|---|---|---|
| BERT | encoder-only | classification, NER, embeddings |
| RoBERTa | encoder-only (better training) | stronger BERT replacement |
| GPT-2 / GPT-3 / GPT-4 | decoder-only | text generation, chat |
| LLaMA / Mistral | decoder-only, RoPE, RMSNorm | modern open-weight LLMs |
| T5 / Flan-T5 | encoder-decoder | text-to-text tasks |
| BART | encoder-decoder | summarization, translation |
| ViT | encoder-only on patches | image classification |
| CLIP | dual encoder | image-text retrieval |
What To Inspect¶
- attention weights on a real sample — pick a meaningful token and look at which inputs it attends to; incoherent patterns usually mean the model has not learned the task
- causal mask shape for decoders — should be strict lower-triangular
- the scale of pre-softmax scores — saturating scores point to missing
√d_kscaling or bad initialization - gradient norms — attention blocks are usually where exploding gradients first show up
- KV-cache correctness — the output with cache should be bitwise equal (within float tolerance) to the output without cache on the same prefix
Failure Pattern¶
Training a transformer from scratch on a small dataset. Transformers are data-hungry — for most practical tasks, fine-tuning a pretrained model is far more effective than training from scratch. Rough rule: if you have less than ~10M tokens of domain data for NLP, fine-tune rather than train.
Another failure: ignoring the quadratic cost of attention. Self-attention scales as O(n²) with sequence length, which becomes a bottleneck for long documents. If you are at sequences longer than 2k and paging becomes slow, reach for FlashAttention or sliding-window attention before trying to shrink the model.
A third: treating attention weights as explanation. They show what the model attended to, but that is correlation, not causation. Attribution methods (integrated gradients, attention rollout) are more trustworthy.
Common Mistakes¶
- forgetting the causal mask in decoder models, which lets the model "cheat" by looking at future tokens
- not scaling the dot products by √d_k, leading to saturating softmax and vanishing gradients
- using too few attention heads for the model dimension (or too many — the right question is
d_head) - ignoring the positional encoding and wondering why word order does not matter
- mixing post-norm and pre-norm in the same stack by accident
- training with
batch_first=Falseon a new model and getting confused by tensor shapes - forgetting to adjust the causal mask length when the sequence length changes batch to batch
- using
torch.nn.MultiheadAttentionwith wrongkey_padding_maskvsattn_mask— one is per-key, the other is per-query-key pair
Practice¶
- Implement single-head self-attention from scratch and inspect the attention weights. Show that rows of the weight matrix sum to 1.
- Show what happens when you remove positional encoding from a sequence-classification task. Compare accuracy.
- Compare a 2-head and 8-head attention with the same
d_model. Explain whether more heads help, in terms ofd_head. - Explain the difference between encoder attention (bidirectional) and decoder attention (causal) by printing both masks.
- Fine-tune a pretrained BERT model and compare against training from scratch on the same dataset.
- Implement a KV-cache for a small decoder and verify that cached generation matches uncached generation on the same prefix.
- Switch a ViT block from post-norm to pre-norm at 12 layers and report which one trains more stably.
- Replace standard attention with
torch.nn.functional.scaled_dot_product_attentionand measure the speedup.
Runnable Example¶
Longer Connection¶
Continue with Transfer and Fine-Tuning for practical fine-tuning of transformer backbones, Text Representations and Order for the simpler text representations that serve as baselines before transformers, and Text Generation and Language Models for the workflow side of using these architectures.