Attention
LLMs
Understanding Attention: Coherency in LLMs

Attention in LLMs: The Core Mechanism Behind Modern AI
Attention mechanisms form the backbone of modern Large Language Models (LLMs), enabling them to process and generate coherent text across long contexts. This blog post breaks down the attention mechanism for engineers who want to understand how these models work under the hood.
What is Attention?
At its core, attention allows a model to focus on relevant parts of the input sequence when producing an output. Unlike traditional RNNs that process sequences linearly, attention gives the model the ability to "look back" at any part of the input, weighing the importance of each token when generating the next one.
The fundamental question attention answers is: "Which parts of the input should I focus on to generate the current output?"
Self-Attention: The Basic Building Block
Self-attention, specifically the "Scaled Dot-Product Attention" introduced in the "Attention Is All You Need" paper, is the fundamental operation in transformer-based LLMs.
Here's how it works:
-
Each input token is transformed into three vectors:
- Query (Q): What the token is "looking for"
- Key (K): What the token "offers" to others
- Value (V): The actual information the token contains
-
The attention score between tokens is calculated using the dot product of queries and keys
-
These scores are scaled, softmaxed, and used to create weighted sums of values
The Math (Simplified)
def self_attention(sequence):
# Create Q, K, V from input sequence
Q = sequence @ W_q # shape: [seq_len, d_k]
K = sequence @ W_k # shape: [seq_len, d_k]
V = sequence @ W_v # shape: [seq_len, d_v]
# Calculate attention scores
scores = Q @ K.transpose(-2, -1) # shape: [seq_len, seq_len]
# Scale the scores
scores = scores / math.sqrt(d_k)
# Apply softmax to get attention weights
weights = F.softmax(scores, dim=-1)
# Get weighted sum of values
output = weights @ V # shape: [seq_len, d_v]
return output
A Concrete Example
Let's see self-attention in action with a tiny example:
Input sequence: ["The", "cat", "sat"]
- Convert tokens to embeddings (simplified)
"The" → [0.2, 0.3, 0.1]
"cat" → [0.5, 0.2, 0.4]
"sat" → [0.1, 0.7, 0.2]
- Project to Q, K, V (simplified with identity projection)
Q = K = V = [[0.2, 0.3, 0.1],
[0.5, 0.2, 0.4],
[0.1, 0.7, 0.2]]
- Calculate attention scores
scores = Q @ K.T =
[[0.14, 0.17, 0.26],
[0.17, 0.45, 0.29],
[0.26, 0.29, 0.54]]
- Scale and softmax to get weights
weights = softmax(scores / sqrt(3)) =
[[0.30, 0.33, 0.37],
[0.31, 0.40, 0.29],
[0.30, 0.32, 0.38]]
- Compute weighted sum of values
output = weights @ V =
[[0.27, 0.40, 0.23],
[0.27, 0.39, 0.24],
[0.26, 0.41, 0.23]]
The output shows how each token now incorporates information from all other tokens, weighted by relevance.
Multi-Head Attention
In practice, LLMs use multi-head attention, which runs multiple attention operations in parallel and concatenates the results:
def multi_head_attention(X, num_heads=8):
head_dim = d_model // num_heads
outputs = []
for _ in range(num_heads):
# Different projection matrices for each head
Q = X @ W_q # shape: [seq_len, head_dim]
K = X @ W_k # shape: [seq_len, head_dim]
V = X @ W_v # shape: [seq_len, head_dim]
# Compute attention as before
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
head_output = weights @ V
outputs.append(head_output)
# Concatenate outputs from all heads
return torch.cat(outputs, dim=-1) @ W_o
Multi-head attention allows the model to jointly attend to information from different representation subspaces, capturing different types of relationships between tokens.
Causal (Masked) Attention
In language generation, we use causal attention to ensure the model only attends to previous tokens:
def causal_self_attention(sequence):
seq_len = sequence.shape[1]
# Create Q, K, V from input sequence
Q = sequence @ W_q
K = sequence @ W_k
V = sequence @ W_v
# Calculate attention scores
scores = Q @ K.transpose(-2, -1)
# Create causal mask (lower triangular)
mask = torch.tril(torch.ones((seq_len, seq_len))).view(1, 1, seq_len, seq_len)
# Apply mask by setting masked positions to -infinity
scores = scores.masked_fill(mask == 0, -1e10)
# Scale, softmax, and weighted sum as before
scores = scores / math.sqrt(d_k)
weights = F.softmax(scores, dim=-1)
output = weights @ V
return output
The causal mask looks like this for a sequence of length 4:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
This ensures that token 2 can only attend to tokens 0, 1, and 2 (itself), but not to token 3, maintaining the autoregressive property during generation.
Attention Visualization
Let's visualize attention to understand how it works in practice:
Input: "The quick brown fox jumps over the lazy dog"
When generating the word "lazy", the attention weights might look like:
Token | Attention Weight |
---|---|
The | 0.05 |
quick | 0.02 |
brown | 0.01 |
fox | 0.03 |
jumps | 0.10 |
over | 0.15 |
the | 0.60 |
lazy | 0.04 |
dog | 0.00 |
The model attends heavily to "the" as it's the determiner for "lazy dog", showing how attention captures grammatical relationships.
Optimized Attention Implementations
Flash Attention
Standard attention requires O(n²) memory for sequence length n, which becomes problematic for long sequences. Flash Attention addresses this with:
- Block-wise computation to leverage GPU memory hierarchy
- Recomputation of attention during the backward pass to save memory
# Pseudocode for block-wise Flash Attention
def flash_attention(Q, K, V, block_size=256):
seq_len = Q.shape[0]
output = torch.zeros_like(V)
for i in range(0, seq_len, block_size):
q_block = Q[i:i+block_size]
# Initialize block outputs
block_output = torch.zeros_like(q_block)
block_weights_sum = torch.zeros(q_block.shape[0], 1)
for j in range(0, seq_len, block_size):
k_block = K[j:j+block_size]
v_block = V[j:j+block_size]
# Compute scores for this block
scores = q_block @ k_block.T / math.sqrt(q_block.shape[1])
# Apply softmax (simplified - in real impl we handle normalization carefully)
block_weights = torch.exp(scores)
# Update block output
block_output += block_weights @ v_block
block_weights_sum += block_weights.sum(dim=1, keepdim=True)
# Normalize block output
output[i:i+block_size] = block_output / block_weights_sum
return output
Flash Attention can reduce memory usage from O(n²) to O(n), enabling much longer context processing.
KV Caching
When generating text autoregressively, we can cache previously computed key and value projections:
def generate_with_kv_cache(model, prompt, max_tokens=100):
# Initial forward pass on prompt
tokens = tokenize(prompt)
states = model.initial_forward(tokens)
# Initialize KV cache
kv_cache = states['kv_cache']
generated = list(tokens)
for _ in range(max_tokens):
# Forward pass with existing KV cache (only compute for the last token)
logits, new_kv = model.forward_with_cache(
tokens=generated[-1:], # Only the last token
kv_cache=kv_cache
)
# Update KV cache
kv_cache = new_kv
# Sample next token
next_token = sample_token(logits)
generated.append(next_token)
if next_token == EOS_TOKEN:
break
return decode(generated)
KV caching dramatically reduces computation during text generation, as we don't need to recompute keys and values for previously processed tokens.
Attention Variants in Modern LLMs
Grouped-Query Attention (GQA)
Grouped-Query Attention reduces computation by sharing key and value heads:
def grouped_query_attention(X, num_q_heads=8, num_kv_heads=2):
# Each KV head is shared across multiple Q heads
q_outputs = []
# Create KV heads (fewer than Q heads)
K_heads = [X @ W_k_h for h in range(num_kv_heads)]
V_heads = [X @ W_v_h for h in range(num_kv_heads)]
for q_head in range(num_q_heads):
# Map each Q head to a KV head
kv_head_idx = q_head % num_kv_heads
Q = X @ W_q[q_head]
K = K_heads[kv_head_idx]
V = V_heads[kv_head_idx]
# Standard attention calculation
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
head_output = weights @ V
q_outputs.append(head_output)
return torch.cat(q_outputs, dim=-1) @ W_o
GQA offers a good trade-off between computation cost and model quality, commonly used in models like PaLM-2 and Claude.
Multi-Query Attention (MQA)
Multi-Query Attention takes GQA to the extreme with only one KV head:
def multi_query_attention(X, num_q_heads=8):
# Single KV pair shared across all query heads
K = X @ W_k # shape: [seq_len, d_k]
V = X @ W_v # shape: [seq_len, d_v]
q_outputs = []
for h in range(num_q_heads):
Q = X @ W_q[h]
# Compute attention using shared K,V
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
head_output = weights @ V
q_outputs.append(head_output)
return torch.cat(q_outputs, dim=-1) @ W_o
MQA further reduces computation and memory but may sacrifice some performance compared to full multi-head attention.
Sliding Window Attention
For very long contexts, sliding window attention restricts each token to attend only to its neighborhood:
def sliding_window_attention(X, window_size=1024):
seq_len = X.shape[1]
# Create Q, K, V
Q = X @ W_q
K = X @ W_k
V = X @ W_v
# Calculate attention scores
scores = Q @ K.transpose(-2, -1)
# Create window mask
mask = torch.ones(seq_len, seq_len)
for i in range(seq_len):
for j in range(seq_len):
if j < i - window_size//2 or j > i + window_size//2:
mask[i, j] = 0
# Apply mask
scores = scores.masked_fill(mask == 0, -1e10)
# Scale, softmax and weighted sum
scores = scores / math.sqrt(d_k)
weights = F.softmax(scores, dim=-1)
output = weights @ V
return output
This approach scales linearly with sequence length, enabling much longer context processing.
Practical Implementation Tips
Memory Optimization
-
Gradient checkpointing: Trade computation for memory by recomputing activations during backpropagation
# Using PyTorch's gradient checkpointing output = torch.utils.checkpoint.checkpoint(attention_fn, query, key, value)
-
Mixed precision: Using FP16 or BF16 drastically reduces memory footprint
# Using PyTorch's automatic mixed precision with torch.cuda.amp.autocast(): output = self_attention(x)
-
Attention chunking: Process attention in chunks when sequence length is large
def chunked_attention(q, k, v, chunk_size=1024): outputs = [] for i in range(0, q.size(1), chunk_size): chunk_q = q[:, i:i+chunk_size] chunk_output = self_attention(chunk_q, k, v) outputs.append(chunk_output) return torch.cat(outputs, dim=1)
Performance Tuning
-
Fused kernels: Use optimized CUDA kernels for attention
# Using xformers' memory-efficient attention from xformers.ops import memory_efficient_attention output = memory_efficient_attention(q, k, v, attn_bias=None)
-
Optimize for inference speed with techniques like KV caching and batch processing
-
Flash Attention 2: Latest optimization makes attention even faster
from flash_attn import flash_attn_qkvpacked_func # Pack QKV and use optimized implementation qkv = torch.cat([q, k, v], dim=2) output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0)
Learn more on how Matter AI helps improve code quality across multiple languages in Pull Requests: https://docs.matterai.dev/product/code-quality
Are you looking for a way to improve your code review process? Learn more on how Matter AI helps team to solve code review challenges with AI: https://matterai.so
Share this Article:
More Articles

LLM Tokenisation fundamentals and working
What is LLM Tokenisation and how it works

LLM Quantization: Making models faster and smaller
What is LLM Quantization and how it enables to make models faster and smaller

Understanding LLM Context Window and Working
What is LLM Context Window and how it works

LLM Prompt Caching
What is LLM Prompt Caching and how it can help reduce LLM cost

How Matter AI brings Velocity, Cost Optimization and Governance to Engineering Teams
Dive into what and how MatterAI offers to engineering teams