Attention

LLMs

Understanding Attention: Coherency in LLMs

Vatsal Bajpai
Vatsal Bajpai
10 min read·
Cover Image for Understanding Attention: Coherency in LLMs

Attention in LLMs: The Core Mechanism Behind Modern AI

Attention mechanisms form the backbone of modern Large Language Models (LLMs), enabling them to process and generate coherent text across long contexts. This blog post breaks down the attention mechanism for engineers who want to understand how these models work under the hood.

What is Attention?

At its core, attention allows a model to focus on relevant parts of the input sequence when producing an output. Unlike traditional RNNs that process sequences linearly, attention gives the model the ability to "look back" at any part of the input, weighing the importance of each token when generating the next one.

The fundamental question attention answers is: "Which parts of the input should I focus on to generate the current output?"

Self-Attention: The Basic Building Block

Self-attention, specifically the "Scaled Dot-Product Attention" introduced in the "Attention Is All You Need" paper, is the fundamental operation in transformer-based LLMs.

Here's how it works:

  1. Each input token is transformed into three vectors:

    • Query (Q): What the token is "looking for"
    • Key (K): What the token "offers" to others
    • Value (V): The actual information the token contains
  2. The attention score between tokens is calculated using the dot product of queries and keys

  3. These scores are scaled, softmaxed, and used to create weighted sums of values

The Math (Simplified)

def self_attention(sequence):
    # Create Q, K, V from input sequence
    Q = sequence @ W_q  # shape: [seq_len, d_k]
    K = sequence @ W_k  # shape: [seq_len, d_k]
    V = sequence @ W_v  # shape: [seq_len, d_v]
    
    # Calculate attention scores
    scores = Q @ K.transpose(-2, -1)  # shape: [seq_len, seq_len]
    
    # Scale the scores
    scores = scores / math.sqrt(d_k)
    
    # Apply softmax to get attention weights
    weights = F.softmax(scores, dim=-1)
    
    # Get weighted sum of values
    output = weights @ V  # shape: [seq_len, d_v]
    
    return output

A Concrete Example

Let's see self-attention in action with a tiny example:

Input sequence: ["The", "cat", "sat"]

  1. Convert tokens to embeddings (simplified)
"The" → [0.2, 0.3, 0.1]
"cat" → [0.5, 0.2, 0.4]
"sat" → [0.1, 0.7, 0.2]
  1. Project to Q, K, V (simplified with identity projection)
Q = K = V = [[0.2, 0.3, 0.1],
             [0.5, 0.2, 0.4],
             [0.1, 0.7, 0.2]]
  1. Calculate attention scores
scores = Q @ K.T = 
[[0.14, 0.17, 0.26],
 [0.17, 0.45, 0.29],
 [0.26, 0.29, 0.54]]
  1. Scale and softmax to get weights
weights = softmax(scores / sqrt(3)) =
[[0.30, 0.33, 0.37],
 [0.31, 0.40, 0.29],
 [0.30, 0.32, 0.38]]
  1. Compute weighted sum of values
output = weights @ V =
[[0.27, 0.40, 0.23],
 [0.27, 0.39, 0.24],
 [0.26, 0.41, 0.23]]

The output shows how each token now incorporates information from all other tokens, weighted by relevance.

Multi-Head Attention

In practice, LLMs use multi-head attention, which runs multiple attention operations in parallel and concatenates the results:

def multi_head_attention(X, num_heads=8):
    head_dim = d_model // num_heads
    outputs = []
    
    for _ in range(num_heads):
        # Different projection matrices for each head
        Q = X @ W_q  # shape: [seq_len, head_dim]
        K = X @ W_k  # shape: [seq_len, head_dim]
        V = X @ W_v  # shape: [seq_len, head_dim]
        
        # Compute attention as before
        scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
        weights = F.softmax(scores, dim=-1)
        head_output = weights @ V
        
        outputs.append(head_output)
    
    # Concatenate outputs from all heads
    return torch.cat(outputs, dim=-1) @ W_o

Multi-head attention allows the model to jointly attend to information from different representation subspaces, capturing different types of relationships between tokens.

Causal (Masked) Attention

In language generation, we use causal attention to ensure the model only attends to previous tokens:

def causal_self_attention(sequence):
    seq_len = sequence.shape[1]
    
    # Create Q, K, V from input sequence
    Q = sequence @ W_q
    K = sequence @ W_k
    V = sequence @ W_v
    
    # Calculate attention scores
    scores = Q @ K.transpose(-2, -1)
    
    # Create causal mask (lower triangular)
    mask = torch.tril(torch.ones((seq_len, seq_len))).view(1, 1, seq_len, seq_len)
    
    # Apply mask by setting masked positions to -infinity
    scores = scores.masked_fill(mask == 0, -1e10)
    
    # Scale, softmax, and weighted sum as before
    scores = scores / math.sqrt(d_k)
    weights = F.softmax(scores, dim=-1)
    output = weights @ V
    
    return output

The causal mask looks like this for a sequence of length 4:

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

This ensures that token 2 can only attend to tokens 0, 1, and 2 (itself), but not to token 3, maintaining the autoregressive property during generation.

Attention Visualization

Let's visualize attention to understand how it works in practice:

Input: "The quick brown fox jumps over the lazy dog"

When generating the word "lazy", the attention weights might look like:

Token Attention Weight
The 0.05
quick 0.02
brown 0.01
fox 0.03
jumps 0.10
over 0.15
the 0.60
lazy 0.04
dog 0.00

The model attends heavily to "the" as it's the determiner for "lazy dog", showing how attention captures grammatical relationships.

Optimized Attention Implementations

Flash Attention

Standard attention requires O(n²) memory for sequence length n, which becomes problematic for long sequences. Flash Attention addresses this with:

  1. Block-wise computation to leverage GPU memory hierarchy
  2. Recomputation of attention during the backward pass to save memory
# Pseudocode for block-wise Flash Attention
def flash_attention(Q, K, V, block_size=256):
    seq_len = Q.shape[0]
    output = torch.zeros_like(V)
    
    for i in range(0, seq_len, block_size):
        q_block = Q[i:i+block_size]
        
        # Initialize block outputs
        block_output = torch.zeros_like(q_block)
        block_weights_sum = torch.zeros(q_block.shape[0], 1)
        
        for j in range(0, seq_len, block_size):
            k_block = K[j:j+block_size]
            v_block = V[j:j+block_size]
            
            # Compute scores for this block
            scores = q_block @ k_block.T / math.sqrt(q_block.shape[1])
            
            # Apply softmax (simplified - in real impl we handle normalization carefully)
            block_weights = torch.exp(scores)
            
            # Update block output
            block_output += block_weights @ v_block
            block_weights_sum += block_weights.sum(dim=1, keepdim=True)
        
        # Normalize block output
        output[i:i+block_size] = block_output / block_weights_sum
    
    return output

Flash Attention can reduce memory usage from O(n²) to O(n), enabling much longer context processing.

KV Caching

When generating text autoregressively, we can cache previously computed key and value projections:

def generate_with_kv_cache(model, prompt, max_tokens=100):
    # Initial forward pass on prompt
    tokens = tokenize(prompt)
    states = model.initial_forward(tokens)
    
    # Initialize KV cache
    kv_cache = states['kv_cache']
    
    generated = list(tokens)
    
    for _ in range(max_tokens):
        # Forward pass with existing KV cache (only compute for the last token)
        logits, new_kv = model.forward_with_cache(
            tokens=generated[-1:],  # Only the last token
            kv_cache=kv_cache
        )
        
        # Update KV cache
        kv_cache = new_kv
        
        # Sample next token
        next_token = sample_token(logits)
        generated.append(next_token)
        
        if next_token == EOS_TOKEN:
            break
    
    return decode(generated)

KV caching dramatically reduces computation during text generation, as we don't need to recompute keys and values for previously processed tokens.

Attention Variants in Modern LLMs

Grouped-Query Attention (GQA)

Grouped-Query Attention reduces computation by sharing key and value heads:

def grouped_query_attention(X, num_q_heads=8, num_kv_heads=2):
    # Each KV head is shared across multiple Q heads
    q_outputs = []
    
    # Create KV heads (fewer than Q heads)
    K_heads = [X @ W_k_h for h in range(num_kv_heads)]
    V_heads = [X @ W_v_h for h in range(num_kv_heads)]
    
    for q_head in range(num_q_heads):
        # Map each Q head to a KV head
        kv_head_idx = q_head % num_kv_heads
        
        Q = X @ W_q[q_head]
        K = K_heads[kv_head_idx]
        V = V_heads[kv_head_idx]
        
        # Standard attention calculation
        scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
        weights = F.softmax(scores, dim=-1)
        head_output = weights @ V
        
        q_outputs.append(head_output)
    
    return torch.cat(q_outputs, dim=-1) @ W_o

GQA offers a good trade-off between computation cost and model quality, commonly used in models like PaLM-2 and Claude.

Multi-Query Attention (MQA)

Multi-Query Attention takes GQA to the extreme with only one KV head:

def multi_query_attention(X, num_q_heads=8):
    # Single KV pair shared across all query heads
    K = X @ W_k  # shape: [seq_len, d_k]
    V = X @ W_v  # shape: [seq_len, d_v]
    
    q_outputs = []
    for h in range(num_q_heads):
        Q = X @ W_q[h] 
        
        # Compute attention using shared K,V
        scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
        weights = F.softmax(scores, dim=-1)
        head_output = weights @ V
        
        q_outputs.append(head_output)
    
    return torch.cat(q_outputs, dim=-1) @ W_o

MQA further reduces computation and memory but may sacrifice some performance compared to full multi-head attention.

Sliding Window Attention

For very long contexts, sliding window attention restricts each token to attend only to its neighborhood:

def sliding_window_attention(X, window_size=1024):
    seq_len = X.shape[1]
    
    # Create Q, K, V
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    
    # Calculate attention scores
    scores = Q @ K.transpose(-2, -1)
    
    # Create window mask
    mask = torch.ones(seq_len, seq_len)
    for i in range(seq_len):
        for j in range(seq_len):
            if j < i - window_size//2 or j > i + window_size//2:
                mask[i, j] = 0
    
    # Apply mask
    scores = scores.masked_fill(mask == 0, -1e10)
    
    # Scale, softmax and weighted sum
    scores = scores / math.sqrt(d_k)
    weights = F.softmax(scores, dim=-1)
    output = weights @ V
    
    return output

This approach scales linearly with sequence length, enabling much longer context processing.

Practical Implementation Tips

Memory Optimization

  1. Gradient checkpointing: Trade computation for memory by recomputing activations during backpropagation

    # Using PyTorch's gradient checkpointing
    output = torch.utils.checkpoint.checkpoint(attention_fn, query, key, value)
    
  2. Mixed precision: Using FP16 or BF16 drastically reduces memory footprint

    # Using PyTorch's automatic mixed precision
    with torch.cuda.amp.autocast():
        output = self_attention(x)
    
  3. Attention chunking: Process attention in chunks when sequence length is large

    def chunked_attention(q, k, v, chunk_size=1024):
        outputs = []
        for i in range(0, q.size(1), chunk_size):
            chunk_q = q[:, i:i+chunk_size]
            chunk_output = self_attention(chunk_q, k, v)
            outputs.append(chunk_output)
        return torch.cat(outputs, dim=1)
    

Performance Tuning

  1. Fused kernels: Use optimized CUDA kernels for attention

    # Using xformers' memory-efficient attention
    from xformers.ops import memory_efficient_attention
    output = memory_efficient_attention(q, k, v, attn_bias=None)
    
  2. Optimize for inference speed with techniques like KV caching and batch processing

  3. Flash Attention 2: Latest optimization makes attention even faster

    from flash_attn import flash_attn_qkvpacked_func
    # Pack QKV and use optimized implementation
    qkv = torch.cat([q, k, v], dim=2)
    output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0)
    

Learn more on how Matter AI helps improve code quality across multiple languages in Pull Requests: https://docs.matterai.dev/product/code-quality

Are you looking for a way to improve your code review process? Learn more on how Matter AI helps team to solve code review challenges with AI: https://matterai.so

Share this Article: