Source code for langgraph.middleware.redis.semantic_cache

"""Semantic cache middleware for LLM responses.

This module provides a middleware that caches LLM responses based on
semantic similarity using Redis and vector embeddings. Compatible with
LangChain's AgentMiddleware protocol for use with create_agent.
"""

import json
import logging
import uuid
from typing import Any, Awaitable, Callable, List, Union

from langchain.agents.middleware.types import (
    ModelCallResult,
    ModelRequest,
    ModelResponse,
)
from langchain_core.messages import AIMessage
from redisvl.extensions.cache.llm import SemanticCache

from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer

from .aio import AsyncRedisMiddleware
from .types import SemanticCacheConfig

logger = logging.getLogger(__name__)

# Use the project's serializer for proper LangChain object handling
_serializer = JsonPlusRedisSerializer()


def _strip_content_ids(content: Any) -> Any:
    """Strip provider-specific IDs from content blocks.

    When using the OpenAI Responses API, content is a list of blocks with
    embedded item IDs (rs_, msg_ prefixes). These must be removed from cached
    messages to prevent duplicate ID errors.
    """
    if not isinstance(content, list):
        return content
    stripped = []
    for block in content:
        if isinstance(block, dict) and "id" in block:
            stripped.append({k: v for k, v in block.items() if k != "id"})
        else:
            stripped.append(block)
    return stripped


def _serialize_response(response: Any) -> str:
    """Serialize a model response for cache storage.

    Uses LangChain's to_json() for proper serialization of LangChain objects.

    Args:
        response: The response to serialize.

    Returns:
        JSON string representation of the response.
    """
    # Handle ModelResponse (has .result which is list of messages)
    if hasattr(response, "result") and isinstance(response.result, list):
        # Serialize the last AI message from the result list
        for msg in reversed(response.result):
            if hasattr(msg, "to_json"):
                # Use LangChain's built-in serialization
                return json.dumps(msg.to_json())
            elif hasattr(msg, "content"):
                # Fallback: extract content
                return json.dumps({"content": getattr(msg, "content", "")})
        return json.dumps({"content": ""})

    # Handle objects with to_json() (LangChain objects like AIMessage)
    if hasattr(response, "to_json"):
        return json.dumps(response.to_json())

    # Handle dict-style responses
    if isinstance(response, dict):
        return json.dumps(response)

    # Fallback: try to get content attribute
    content = getattr(response, "content", str(response))
    return json.dumps({"content": content})


def _deserialize_response(cached_str: str) -> ModelResponse:
    """Deserialize a cached response into a ModelResponse.

    Uses the project's JsonPlusRedisSerializer for proper LangChain object revival.
    Always returns a ModelResponse to maintain compatibility with agent routing.

    IMPORTANT: Each cache hit generates a NEW message ID (UUID). This is critical
    for frontend streaming compatibility - without unique IDs, the frontend
    deduplicates messages and cached responses don't appear.

    The cached response is also marked with additional_kwargs={"cached": True}
    to allow consumers to identify cached responses.

    Args:
        cached_str: The cached JSON string.

    Returns:
        A ModelResponse containing the cached message with a unique ID.
    """
    # Generate a new UUID for this cache hit
    # This ensures each cached response appears as a new message in the frontend
    new_message_id = str(uuid.uuid4())

    try:
        data = json.loads(cached_str)
        if isinstance(data, dict):
            # Check if this is in LangChain constructor format
            if data.get("lc") in (1, 2) and data.get("type") == "constructor":
                # Use the project's serializer to properly revive
                revived = _serializer._revive_if_needed(data)
                if isinstance(revived, AIMessage):
                    # Create a new AIMessage with fresh ID and cached marker,
                    # preserving all fields from the revived message
                    cached_message = revived.model_copy(
                        update={
                            "id": new_message_id,
                            "content": _strip_content_ids(revived.content),
                            "additional_kwargs": {"cached": True},
                            "response_metadata": {},
                        }
                    )
                    return ModelResponse(
                        result=[cached_message], structured_response=None
                    )
                # If revived is not an AIMessage, wrap content in one
                content = getattr(revived, "content", str(revived))
                return ModelResponse(
                    result=[
                        AIMessage(
                            content=content,
                            id=new_message_id,
                            additional_kwargs={"cached": True},
                        )
                    ],
                    structured_response=None,
                )
            # Simple dict with content - wrap in ModelResponse
            content = data.get("content", "")
            return ModelResponse(
                result=[
                    AIMessage(
                        content=content,
                        id=new_message_id,
                        additional_kwargs={"cached": True},
                    )
                ],
                structured_response=None,
            )
        # Non-dict data - convert to string
        return ModelResponse(
            result=[
                AIMessage(
                    content=str(data),
                    id=new_message_id,
                    additional_kwargs={"cached": True},
                )
            ],
            structured_response=None,
        )
    except json.JSONDecodeError:
        # If not valid JSON, treat as plain content
        return ModelResponse(
            result=[
                AIMessage(
                    content=cached_str,
                    id=new_message_id,
                    additional_kwargs={"cached": True},
                )
            ],
            structured_response=None,
        )


[docs] class SemanticCacheMiddleware(AsyncRedisMiddleware): """Middleware that caches LLM responses based on semantic similarity. Uses redisvl.extensions.llmcache.SemanticCache to store and retrieve cached responses. When a request is semantically similar to a previous request (within the distance threshold), the cached response is returned without calling the LLM. By default, only "final" responses (those without tool_calls) are cached. This prevents caching intermediate responses that require tool execution. Example: ```python from langgraph.middleware.redis import ( SemanticCacheMiddleware, SemanticCacheConfig, ) config = SemanticCacheConfig( redis_url="redis://localhost:6379", distance_threshold=0.1, ttl_seconds=3600, ) middleware = SemanticCacheMiddleware(config) async def call_llm(request): # Your LLM call here return response # Use middleware result = await middleware.awrap_model_call(request, call_llm) ``` """ _cache: SemanticCache _config: SemanticCacheConfig def __init__(self, config: SemanticCacheConfig) -> None: """Initialize the semantic cache middleware. Args: config: Configuration for the semantic cache. """ super().__init__(config) self._config = config async def _setup_async(self) -> None: """Set up the SemanticCache instance. Note: SemanticCache from redisvl uses synchronous Redis operations internally, so we must provide redis_url and let it manage its own sync connection rather than passing our async client. """ cache_kwargs: dict[str, Any] = { "name": self._config.name, "distance_threshold": self._config.distance_threshold, } # SemanticCache requires a sync Redis connection # Use redis_url to let it create its own connection if self._config.redis_url: cache_kwargs["redis_url"] = self._config.redis_url elif self._config.connection_args: cache_kwargs["connection_kwargs"] = self._config.connection_args if self._config.vectorizer is not None: cache_kwargs["vectorizer"] = self._config.vectorizer if self._config.ttl_seconds is not None: cache_kwargs["ttl"] = self._config.ttl_seconds self._cache = SemanticCache(**cache_kwargs) def _extract_prompt(self, messages: List[Union[dict[str, Any], Any]]) -> str: """Extract the prompt to use for cache lookup. Extracts the last user message content from the messages list. Handles both dict-style messages and LangChain message objects. Args: messages: List of messages from the request. Returns: The extracted prompt string. """ if not messages: return "" # Find the last user message for message in reversed(messages): # Handle dict-style messages if isinstance(message, dict): role = message.get("role", "") if role == "user": return message.get("content", "") else: # Handle LangChain-style message objects msg_type = getattr(message, "type", None) or getattr( message, "role", None ) if msg_type in ("user", "human"): return getattr(message, "content", "") return "" def _is_final_response(self, response: Any) -> bool: """Check if the response is a final response (no tool calls). Args: response: The model response to check (dict or LangChain type). Returns: True if the response is final (should be cached), False otherwise. """ # Support both dict-style and LangChain response types if isinstance(response, dict): tool_calls = response.get("tool_calls") else: # For ModelResponse, check result[0].tool_calls # ModelResponse itself doesn't have tool_calls attribute tool_calls = getattr(response, "tool_calls", None) if tool_calls is None and hasattr(response, "result"): result = response.result if result and len(result) > 0: tool_calls = getattr(result[0], "tool_calls", None) # Response is final if there are no tool_calls or tool_calls is empty return not tool_calls def _get_tool_names_from_results( self, messages: List[Union[dict[str, Any], Any]] ) -> List[str]: """Extract tool names from tool result messages. Args: messages: List of messages from the request. Returns: List of tool names that have results in the messages. """ tool_names = [] for message in messages: if isinstance(message, dict): role = message.get("role", "") or message.get("type", "") if role == "tool": tool_names.append(message.get("name", "")) else: msg_type = getattr(message, "type", None) or getattr( message, "role", None ) if msg_type == "tool" or message.__class__.__name__ == "ToolMessage": tool_names.append(getattr(message, "name", "")) return tool_names def _should_skip_cache_for_tool_results( self, messages: List[Union[dict[str, Any], Any]] ) -> bool: """Check if cache should be skipped due to tool results. When tool results are present, we check if ALL tools are in the deterministic_tools list. If so, caching is safe. Otherwise, we skip the cache to avoid returning stale responses. Args: messages: List of messages from the request. Returns: True if cache should be skipped, False if caching is OK. """ tool_names = self._get_tool_names_from_results(messages) if not tool_names: # No tool results - caching is OK return False # If deterministic_tools is not configured, always skip cache # when tool results are present (safest default) if self._config.deterministic_tools is None: return True # Check if ALL tool results are from deterministic tools for tool_name in tool_names: if tool_name and tool_name not in self._config.deterministic_tools: # Found a non-deterministic tool result - skip cache return True # All tools are deterministic - caching is OK return False
[docs] async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: """Wrap a model call with semantic caching. Checks the cache for a semantically similar request. If found, returns the cached response. Otherwise, calls the handler and optionally caches the result. Args: request: The model request containing messages. handler: The async function to call the model. Returns: The model response (from cache or handler). Raises: Exception: If graceful_degradation is False and cache operations fail. """ await self._ensure_initialized_async() # Support both dict-style and LangChain ModelRequest types if isinstance(request, dict): messages = request.get("messages", []) else: messages = getattr(request, "messages", []) prompt = self._extract_prompt(messages) if not prompt: # No prompt to cache, just call handler return await handler(request) # Skip cache lookup if request contains non-deterministic tool results # The model needs to process tool output to generate the final response if self._should_skip_cache_for_tool_results(messages): logger.debug( "Skipping cache - request contains non-deterministic tool results" ) response = await handler(request) # Cache the final response after tool processing if tools are deterministic if not self._config.cache_final_only or self._is_final_response(response): try: response_str = _serialize_response(response) await self._cache.astore(prompt=prompt, response=response_str) except Exception as e: if not self._graceful_degradation: raise logger.warning(f"Cache store failed: {e}") return response # Try to get from cache using async method try: cached = await self._cache.acheck(prompt=prompt) if cached: cached_response = cached[0].get("response") if cached_response: logger.debug(f"Cache hit for prompt: {prompt[:50]}...") return _deserialize_response(cached_response) except Exception as e: if not self._graceful_degradation: raise logger.warning(f"Cache check failed, calling handler: {e}") # Cache miss - call handler response = await handler(request) # Store in cache if appropriate should_cache = not self._config.cache_final_only or self._is_final_response( response ) if should_cache: try: # Serialize response for storage using async method response_str = _serialize_response(response) await self._cache.astore(prompt=prompt, response=response_str) logger.debug(f"Cached response for prompt: {prompt[:50]}...") except Exception as e: if not self._graceful_degradation: raise logger.warning(f"Cache store failed: {e}") return response