Cost Optimization

Prompt Caching

LLM Prompt Caching

Vatsal Bajpai
Vatsal Bajpai
10 min read·
Cover Image for LLM Prompt Caching

Prompt Caching for LLMs: Implementation and Benefits

Large Language Models power modern AI applications but come with significant computational costs. Prompt caching can dramatically reduce these costs while improving response times. Let's explore implementation approaches, performance metrics, and cost benefits.

Understanding Prompt Caching

Prompt caching stores and retrieves previously computed results instead of re-running the entire inference process when identical or similar prompts are sent to an LLM.

Key Concepts:

  • Cache Keys: Hashes or other deterministic representations of input prompts
  • Cache Values: Stored model outputs associated with specific inputs
  • Invalidation Strategies: Methods to maintain cache freshness
  • Hit Ratio: Percentage of requests served from cache vs. total requests

Implementation Approaches

Basic Hash-Based Caching

import hashlib
import json
from typing import Dict, Any

class PromptCache:
    def __init__(self, capacity: int = 1000):
        self.cache: Dict[str, str] = {}
        self.capacity = capacity
    
    def _generate_key(self, prompt: str, params: Dict[str, Any]) -> str:
        cache_input = json.dumps({"prompt": prompt, "params": params}, sort_keys=True)
        return hashlib.sha256(cache_input.encode()).hexdigest()
    
    def get(self, prompt: str, params: Dict[str, Any]) -> str:
        key = self._generate_key(prompt, params)
        return self.cache.get(key)
    
    def set(self, prompt: str, params: Dict[str, Any], response: str) -> None:
        key = self._generate_key(prompt, params)
        
        # Basic LRU eviction policy
        if len(self.cache) >= self.capacity:
            self.cache.pop(next(iter(self.cache)))
        self.cache[key] = response

Semantic Caching

For retrieving cached results for similar (not just identical) prompts:

import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

class SemanticPromptCache:
    def __init__(self, model_name: str = "all-mpnet-base-v2", 
                 threshold: float = 0.95, capacity: int = 1000):
        self.encoder = SentenceTransformer(model_name)
        self.threshold = threshold
        self.capacity = capacity
        self.embeddings = []
        self.responses = []
    
    def get(self, prompt: str, params: Dict[str, Any]) -> str:
        if not self.embeddings:
            return None
            
        # Generate embedding for the query prompt
        query_embedding = self.encoder.encode([prompt])[0].reshape(1, -1)
        
        # Calculate similarity with all cached prompts
        similarities = cosine_similarity(query_embedding, np.array(self.embeddings))
        
        # Find the most similar prompt
        max_idx = np.argmax(similarities)
        max_similarity = similarities[0][max_idx]
        
        # Return cached response if similarity is above threshold
        if max_similarity >= self.threshold:
            return self.responses[max_idx]
        return None

Distributed Caching with Redis

import redis
import json
import hashlib
import pickle

class RedisPromptCache:
    def __init__(self, host='localhost', port=6379, expiration=86400):
        self.client = redis.Redis(host=host, port=port)
        self.expiration = expiration  # Default TTL: 24 hours
    
    def _generate_key(self, prompt: str, params: Dict[str, Any]) -> str:
        cache_input = json.dumps({"prompt": prompt, "params": params}, sort_keys=True)
        return f"prompt_cache:{hashlib.sha256(cache_input.encode()).hexdigest()}"
    
    def get(self, prompt: str, params: Dict[str, Any]) -> str:
        key = self._generate_key(prompt, params)
        cached_data = self.client.get(key)
        if cached_data:
            return pickle.loads(cached_data)
        return None
    
    def set(self, prompt: str, params: Dict[str, Any], response: str) -> None:
        key = self._generate_key(prompt, params)
        self.client.setex(key, self.expiration, pickle.dumps(response))

Performance Benefits

Response Time Reduction

Cache Type Cold Start (ms) Cache Hit (ms) Speedup Factor
No Cache 1250 N/A 1.0x
Local Cache 1250 15 83.3x
Redis Cache 1250 35 35.7x
Semantic 1250 150 8.3x

Cache Hit Ratios in Production

Application Type Exact Match (%) Semantic Match (%) Combined (%)
Customer Support 22.5 35.3 57.8
Document Q&A 43.7 28.1 71.8
Code Generation 18.2 9.7 27.9

Cost Analysis

For a service using Claude 3.5 Sonnet with:

  • 1M API calls per day
  • Average of 1000 tokens per request
  • $0.01/1000 tokens (blended rate)
  • 50% cache hit ratio

Without Caching:

  • 1M requests × 1000 tokens × $0.01/1000 tokens = $10,000/day

With Caching:

  • 500K requests × 1000 tokens × $0.01/1000 tokens = $5,000/day
  • Infrastructure cost: ~$500/day
  • Net savings: $4,500/day or $1.64M annually

Production Implementation with FastAPI

from fastapi import FastAPI
import redis
import hashlib
import json
import time
import anthropic
from pydantic import BaseModel
from typing import Dict, Optional

# Models
class PromptRequest(BaseModel):
    prompt: str
    temperature: float = 0.7
    max_tokens: int = 1000

# Initialize Redis and LLM client
redis_client = redis.Redis(host="localhost", port=6379)
client = anthropic.Anthropic(api_key="your-api-key")

app = FastAPI()

def get_cache_key(request: PromptRequest) -> str:
    request_dict = {
        "prompt": request.prompt,
        "temperature": request.temperature,
        "max_tokens": request.max_tokens
    }
    serialized = json.dumps(request_dict, sort_keys=True)
    return f"prompt_cache:{hashlib.sha256(serialized.encode()).hexdigest()}"

@app.post("/generate")
async def generate_text(request: PromptRequest):
    # Check cache first
    cache_key = get_cache_key(request)
    cached_data = redis_client.get(cache_key)
    
    if cached_data:
        redis_client.incr("stats:cache_hits")
        return {
            "text": json.loads(cached_data),
            "cached": True,
            "latency_ms": 0
        }
    
    # Cache miss - call the LLM API
    start_time = time.time()
    response = client.messages.create(
        model="claude-3-sonnet-20240229",
        max_tokens=request.max_tokens,
        temperature=request.temperature,
        messages=[
            {"role": "user", "content": request.prompt}
        ]
    )
    response_text = response.content[0].text
    
    # Calculate latency and cache the response
    latency_ms = int((time.time() - start_time) * 1000)
    redis_client.setex(cache_key, 86400, json.dumps(response_text))
    redis_client.incr("stats:cache_misses")
    
    return {
        "text": response_text,
        "cached": False,
        "latency_ms": latency_ms
    }

Advanced Techniques

Prompt Templating for Higher Hit Ratios

from string import Template

class PromptTemplate:
    def __init__(self, template_string: str):
        self.template = Template(template_string)
    
    def format(self, **kwargs) -> str:
        return self.template.safe_substitute(**kwargs)

# Example template
customer_support_template = PromptTemplate(
    "Please help with this customer query about $product_category: $query"
)

Monitoring Cache Performance

@app.get("/cache/stats")
async def get_cache_stats():
    hits = int(redis_client.get("stats:cache_hits") or 0)
    misses = int(redis_client.get("stats:cache_misses") or 0)
    total = hits + misses
    hit_ratio = hits / total if total > 0 else 0
    
    return {
        "hits": hits,
        "misses": misses,
        "total": total,
        "hit_ratio": hit_ratio,
        "estimated_savings_usd": hits * 0.01  # $0.01 per request saved
    }

Learn more on how Matter AI helps improve code quality across multiple languages in Pull Requests: https://docs.matterai.dev/product/code-quality

Are you looking for a way to improve your code review process? Learn more on how Matter AI helps team to solve code review challenges with AI: https://matterai.so

Share this Article: