Source code for langgraph.middleware.redis.tool_cache

"""Tool result cache middleware.

This module provides a middleware that caches tool call results
using exact-match key-value lookup in Redis. Tool caching is
deterministic: same tool + same args = same result. This uses
direct Redis GET/SET instead of vector similarity.

Compatible with LangChain's AgentMiddleware protocol for use with create_agent.
"""

import json
import logging
from typing import Any, Awaitable, Callable, Dict, Tuple, Union

from langchain.agents.middleware.types import (
    ModelCallResult,
    ModelRequest,
    ModelResponse,
)
from langchain_core.messages import ToolMessage as LangChainToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command

from .aio import AsyncRedisMiddleware
from .types import ToolCacheConfig

logger = logging.getLogger(__name__)

DEFAULT_VOLATILE_ARG_NAMES: frozenset[str] = frozenset(
    {
        "timestamp",
        "current_time",
        "now",
        "date",
        "today",
        "current_date",
        "current_timestamp",
    }
)

DEFAULT_SIDE_EFFECT_PREFIXES: Tuple[str, ...] = (
    "send_",
    "delete_",
    "create_",
    "update_",
    "remove_",
    "write_",
    "post_",
    "put_",
    "patch_",
)


[docs] class ToolResultCacheMiddleware(AsyncRedisMiddleware): """Middleware that caches tool call results using exact-match lookup. Uses direct Redis GET/SET for deterministic tool result caching. When a tool is called with the same arguments as a previous call, the cached result is returned without executing the tool. Tool caching is especially useful for expensive operations like: - API calls to external services - Database queries - Web searches - Complex calculations Example: ```python from langgraph.middleware.redis import ( ToolResultCacheMiddleware, ToolCacheConfig, ) config = ToolCacheConfig( redis_url="redis://localhost:6379", cacheable_tools=["search", "calculate"], excluded_tools=["random_number"], ttl_seconds=3600, ) middleware = ToolResultCacheMiddleware(config) async def execute_tool(request): # Your tool execution here return result # Use middleware result = await middleware.awrap_tool_call(request, execute_tool) ``` """ _config: ToolCacheConfig def __init__(self, config: ToolCacheConfig) -> None: """Initialize the tool cache middleware. Args: config: Configuration for the tool cache. """ super().__init__(config) self._config = config async def _setup_async(self) -> None: """Set up the tool cache. No additional setup needed — the tool cache uses the async Redis client from the base class directly for GET/SET operations. """ pass def _is_tool_cacheable_by_config(self, tool_name: str) -> bool: """Check if a tool's results should be cached based on config. Args: tool_name: The name of the tool. Returns: True if the tool's results should be cached, False otherwise. """ # If cacheable_tools is set, only those tools are cached if self._config.cacheable_tools is not None: return tool_name in self._config.cacheable_tools # Otherwise, cache all tools except excluded ones return tool_name not in self._config.excluded_tools def _has_volatile_args(self, args: Dict[str, Any]) -> bool: """Check if args contain volatile argument names at any nesting depth. Recurses into nested dicts and into dicts inside lists/tuples. Args: args: The tool arguments dict. Returns: True if any key in args (recursively) matches a configured volatile arg name, False otherwise. """ volatile_names = self._config.volatile_arg_names if not volatile_names: return False for key, value in args.items(): if key in volatile_names: return True if isinstance(value, dict): if self._has_volatile_args(value): return True elif isinstance(value, (list, tuple)): for item in value: if isinstance(item, dict) and self._has_volatile_args(item): return True return False def _has_side_effect_prefix(self, tool_name: str) -> bool: """Check if tool name starts with a configured side-effect prefix. Args: tool_name: The name of the tool. Returns: True if the tool name matches a side-effect prefix. """ prefixes = self._config.side_effect_prefixes if not prefixes: return False return tool_name.startswith(prefixes) def _strip_ignored_args(self, args: Dict[str, Any]) -> Dict[str, Any]: """Return a copy of args with ignored names removed (top-level only). Args: args: The tool arguments dict. Returns: A new dict without the ignored keys, or the original if nothing to strip. """ if not isinstance(args, dict): return {} ignored = self._config.ignored_arg_names if not ignored: return args return {k: v for k, v in args.items() if k not in ignored} @staticmethod def _extract_args(request: ToolCallRequest) -> Dict[str, Any]: """Extract args dict from a request (dict or ToolCallRequest). Args: request: The tool call request. Returns: The arguments dict. """ if isinstance(request, dict): return request.get("args", {}) tool_call = getattr(request, "tool_call", None) if isinstance(tool_call, dict): return tool_call.get("args", {}) return {} def _is_tool_cacheable(self, request: ToolCallRequest) -> bool: """Check if a tool's results should be cached. Uses a priority chain inspired by SQL function volatility and MCP ToolAnnotations: 1. ``metadata["cacheable"]`` — explicit override (highest priority) 2. ``metadata["destructive"]`` — never cache destructive tools 3. ``metadata["volatile"]`` — never cache volatile tools 4. ``metadata["read_only"] and metadata["idempotent"]`` — cache 5. Side-effect prefix match — never cache 6. Volatile arg name in call args — never cache 7. Config whitelist / blacklist — existing fallback Args: request: The tool call request containing the tool object. Returns: True if the tool's results should be cached, False otherwise. """ # Extract tool name and tool object from ToolCallRequest if isinstance(request, dict): tool_name = request.get("tool_name", "") tool = None else: tool_call = getattr(request, "tool_call", None) tool_name = tool_call.get("name", "") if isinstance(tool_call, dict) else "" tool = getattr(request, "tool", None) # --- Priority 1: explicit cacheable flag (highest) --- metadata: Dict[str, Any] = {} if tool is not None: metadata = getattr(tool, "metadata", None) or {} if "cacheable" in metadata: return bool(metadata["cacheable"]) # --- Priority 2: destructive metadata → never cache --- if metadata.get("destructive") is True: return False # --- Priority 3: volatile metadata → never cache --- if metadata.get("volatile") is True: return False # --- Priority 4: read_only + idempotent → cache --- if metadata.get("read_only") is True and metadata.get("idempotent") is True: return True # --- Priority 5: side-effect prefix → never cache --- if self._has_side_effect_prefix(tool_name): return False # --- Priority 6: volatile arg names → never cache --- args = self._extract_args(request) if self._has_volatile_args(args): return False # --- Priority 7: config whitelist / blacklist --- return self._is_tool_cacheable_by_config(tool_name) def _build_cache_key(self, request: ToolCallRequest) -> str: """Build a deterministic cache key from the tool request. Creates an exact-match Redis key from the tool name and sorted JSON arguments. This ensures that identical tool calls always produce the same key, and different calls always produce different keys. Args: request: The tool call request (dict or LangChain type). Returns: A deterministic string key for Redis GET/SET. """ # Support both dict-style and LangChain ToolCallRequest types if isinstance(request, dict): tool_name = request.get("tool_name", "") args = request.get("args", {}) else: tool_call = getattr(request, "tool_call", None) if isinstance(tool_call, dict): tool_name = tool_call.get("name", "") args = tool_call.get("args", {}) else: tool_name = "" args = {} # Strip ignored args before building the key effective_args = self._strip_ignored_args(args) # Deterministic key: config name + tool name + sorted JSON args args_str = json.dumps(effective_args, sort_keys=True) return f"{self._config.name}:{tool_name}:{args_str}" def _serialize_tool_result(self, value: Any) -> str: """Serialize a tool result to a JSON string for caching. Supports LangChain ToolMessage/Command objects by converting them to JSON-compatible structures before encoding. Args: value: The tool result to serialize. Returns: A JSON string representation of the result. """ # Handle known LangChain message/command types if isinstance(value, (LangChainToolMessage, Command)): to_json = getattr(value, "to_json", None) if callable(to_json): # Convert to plain data first, then dump return json.dumps(to_json()) # Fallback: try direct JSON serialization try: return json.dumps(value) except TypeError: # Last resort: store string representation return json.dumps(str(value)) def _deserialize_tool_result( self, cached_response: str, tool_name: str, tool_call_id: str ) -> LangChainToolMessage: """Deserialize a cached tool result into a ToolMessage. Converts the cached JSON string back into a proper LangChain ToolMessage so it conforms to the AgentMiddleware protocol. Args: cached_response: The cached JSON string. tool_name: The name of the tool that produced this result. tool_call_id: The ID of the tool call this result is for. Returns: A LangChainToolMessage containing the cached result. """ # Parse the cached content if isinstance(cached_response, str): try: parsed = json.loads(cached_response) except json.JSONDecodeError: parsed = cached_response else: parsed = cached_response # Extract content from the parsed result if isinstance(parsed, dict): content = parsed.get("content", json.dumps(parsed)) elif isinstance(parsed, str): content = parsed else: content = json.dumps(parsed) return LangChainToolMessage( content=content, name=tool_name, tool_call_id=tool_call_id or "", )
[docs] async def awrap_tool_call( self, request: ToolCallRequest, handler: Callable[ [ToolCallRequest], Awaitable[Union[LangChainToolMessage, Command]] ], ) -> Union[LangChainToolMessage, Command]: """Wrap a tool call with exact-match caching. This method is part of the LangChain AgentMiddleware protocol. Checks the cache for an exact tool+args match. If found, returns the cached result. Otherwise, calls the handler and caches the result. Args: request: The tool call request. handler: The async function to execute the tool. Returns: The tool result (from cache or handler). Raises: Exception: If graceful_degradation is False and cache operations fail. """ # Extract tool name from the request if isinstance(request, dict): tool_name = request.get("tool_name", "") tool_call_id = request.get("id", "") else: tool_call = getattr(request, "tool_call", None) if isinstance(tool_call, dict): tool_name = tool_call.get("name", "") tool_call_id = tool_call.get("id", "") else: tool_name = "" tool_call_id = "" # If no tool name or tool is not cacheable, skip caching if not tool_name or not self._is_tool_cacheable(request): return await handler(request) await self._ensure_initialized_async() cache_key = self._build_cache_key(request) # Try to get from cache using exact-match Redis GET try: cached_response = await self._redis.get(cache_key) if cached_response is not None: # Decode bytes to string if needed if isinstance(cached_response, bytes): cached_response = cached_response.decode("utf-8") logger.debug(f"Tool cache hit for key: {cache_key[:80]}") return self._deserialize_tool_result( cached_response, tool_name, tool_call_id ) except Exception as e: if not self._graceful_degradation: raise logger.warning(f"Tool cache check failed, calling handler: {e}") # Cache miss - call handler result = await handler(request) # Store in cache using Redis SET with optional TTL try: result_str = self._serialize_tool_result(result) if self._config.ttl_seconds is not None: await self._redis.set( cache_key, result_str, ex=self._config.ttl_seconds ) else: await self._redis.set(cache_key, result_str) logger.debug(f"Tool cache stored for key: {cache_key[:80]}") except Exception as e: if not self._graceful_degradation: raise logger.warning(f"Tool cache store failed: {e}") return result
[docs] async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: """Pass through model calls without caching. This method is part of the LangChain AgentMiddleware protocol. Tool cache middleware only caches tool calls, not model calls. Args: request: The model request. handler: The async function to call the model. Returns: The model response from the handler. """ return await handler(request)