Source code for redis_openai_agents.tool_cache
"""cached_tool - memoize deterministic tool results in Redis.
The OpenAI Agents SDK routes tool execution through the Runner, not through
the ``Model`` interface, so tool caching cannot be implemented as a regular
middleware. ``cached_tool`` is a decorator that wraps the underlying
callable before it becomes a ``function_tool``, hashing its arguments to
produce a cache key and serving repeat calls from Redis.
Inspired by the ``ToolResultCacheMiddleware`` in ``langgraph-redis``.
Example::
from agents import function_tool
from redis_openai_agents import cached_tool
@function_tool
@cached_tool(
name="weather",
redis_url="redis://localhost:6379",
ttl=3600,
volatile_arg_names={"timestamp"},
)
async def get_weather(city: str) -> str:
return fetch_forecast(city)
"""
from __future__ import annotations
import asyncio
import hashlib
import inspect
import json
import logging
import pickle
from collections.abc import Callable, Set
from functools import wraps
from typing import Any, TypeVar
from redis import Redis
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Callable[..., Any])
DEFAULT_VOLATILE_ARG_NAMES: frozenset[str] = frozenset(
{
"timestamp",
"current_time",
"now",
"date",
"today",
"current_date",
"current_timestamp",
}
)
DEFAULT_SIDE_EFFECT_PREFIXES: tuple[str, ...] = (
"send_",
"delete_",
"create_",
"update_",
"remove_",
"write_",
"post_",
"put_",
"patch_",
)
[docs]
def cached_tool(
*,
name: str,
redis_url: str = "redis://localhost:6379",
ttl: int | None = None,
key_prefix: str = "tool_cache",
volatile_arg_names: Set[str] | None = None,
ignored_arg_names: Set[str] | None = None,
side_effect_prefixes: tuple[str, ...] | None = None,
) -> Callable[[T], T]:
"""Decorator that memoizes a callable's return value in Redis.
Args:
name: Logical tool name. Also used as part of the cache key and for
checking side-effect prefixes.
redis_url: Redis connection URL. The underlying Redis client is
created once at decoration time and shared across calls.
ttl: Optional TTL in seconds for cache entries. ``None`` keeps them
indefinitely.
key_prefix: Prefix for the Redis key; useful for namespacing across
environments or deployments.
volatile_arg_names: Set of argument names whose presence bypasses
the cache entirely (even if the value is ``None``). Use this for
arguments that always change (timestamps, trace IDs that must
be observed, random seeds). Defaults to
:data:`DEFAULT_VOLATILE_ARG_NAMES`.
ignored_arg_names: Set of argument names to strip from the cache
key before hashing. Use for values that do not change the
result but vary between calls (trace IDs, request IDs).
side_effect_prefixes: Tuple of name prefixes that mark the tool as
side-effecting, in which case caching is disabled. Defaults to
:data:`DEFAULT_SIDE_EFFECT_PREFIXES`.
Returns:
A decorator.
"""
volatile = (
frozenset(volatile_arg_names)
if volatile_arg_names is not None
else DEFAULT_VOLATILE_ARG_NAMES
)
ignored = frozenset(ignored_arg_names) if ignored_arg_names is not None else frozenset()
prefixes = (
tuple(side_effect_prefixes)
if side_effect_prefixes is not None
else DEFAULT_SIDE_EFFECT_PREFIXES
)
is_side_effecting = name.startswith(prefixes)
client = Redis.from_url(redis_url)
def decorator(fn: T) -> T:
signature = inspect.signature(fn)
def make_key(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str | None:
"""Normalize args to a canonical key; return ``None`` to skip the cache."""
try:
bound = signature.bind_partial(*args, **kwargs)
bound.apply_defaults()
except TypeError:
# Let the underlying call raise the real TypeError.
return None
arguments = dict(bound.arguments)
if _contains_volatile(arguments, volatile):
return None
for arg_name in ignored:
arguments.pop(arg_name, None)
canonical = _canonicalize(arguments)
digest = hashlib.sha256(canonical.encode("utf-8")).hexdigest()
return f"{key_prefix}:{name}:{digest}"
def load(key: str) -> tuple[bool, Any]:
try:
raw = client.get(key)
except Exception as exc:
logger.debug("cache read failed for %s: %s", key, exc)
return False, None
if raw is None or not isinstance(raw, (bytes, bytearray)):
return False, None
try:
return True, pickle.loads(raw)
except Exception as exc:
logger.debug("cache payload decode failed for %s: %s", key, exc)
return False, None
def store(key: str, value: Any) -> None:
try:
payload = pickle.dumps(value)
if ttl is not None:
client.setex(key, ttl, payload)
else:
client.set(key, payload)
except Exception as exc:
logger.debug("cache write failed for %s: %s", key, exc)
if asyncio.iscoroutinefunction(fn):
@wraps(fn)
async def async_wrapped(*args: Any, **kwargs: Any) -> Any:
if is_side_effecting:
return await fn(*args, **kwargs)
key = make_key(args, kwargs)
if key is None:
return await fn(*args, **kwargs)
hit, value = await asyncio.to_thread(load, key)
if hit:
return value
result = await fn(*args, **kwargs)
await asyncio.to_thread(store, key, result)
return result
return async_wrapped # type: ignore[return-value]
@wraps(fn)
def sync_wrapped(*args: Any, **kwargs: Any) -> Any:
if is_side_effecting:
return fn(*args, **kwargs)
key = make_key(args, kwargs)
if key is None:
return fn(*args, **kwargs)
hit, value = load(key)
if hit:
return value
result = fn(*args, **kwargs)
store(key, result)
return result
return sync_wrapped # type: ignore[return-value]
return decorator
def _contains_volatile(value: Any, volatile: Set[str]) -> bool:
"""Recursively check whether any volatile arg name appears in ``value``."""
if not volatile:
return False
if isinstance(value, dict):
for k, v in value.items():
if k in volatile:
return True
if _contains_volatile(v, volatile):
return True
return False
if isinstance(value, (list, tuple, set)):
return any(_contains_volatile(v, volatile) for v in value)
return False
def _canonicalize(arguments: dict[str, Any]) -> str:
"""Serialize arguments to a deterministic JSON string.
Falls back to ``repr`` for values that are not JSON-serializable so
the hash still stays stable for a given input shape.
"""
def default(o: Any) -> Any:
if isinstance(o, (set, frozenset)):
return sorted(o, key=repr)
return repr(o)
return json.dumps(arguments, sort_keys=True, default=default, separators=(",", ":"))