Source code for redis_openai_agents.middleware.semantic_cache
"""SemanticCacheMiddleware - cache model responses by semantic similarity.
Wraps the project's ``SemanticCache`` (which itself layers an L1 exact-match
hash and an L2 RedisVL vector cache) and exposes it as an
:class:`~redis_openai_agents.middleware.AgentMiddleware` so it can compose
with other middlewares in a ``MiddlewareStack``.
"""
from __future__ import annotations
import base64
import pickle
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from .base import ModelCallHandler, ModelRequest
if TYPE_CHECKING:
from ..cache import SemanticCache
Serializer = Callable[[Any], str]
Deserializer = Callable[[str], Any]
def _default_serialize(response: Any) -> str:
"""Pickle the response and base64-encode the bytes.
Pickle handles arbitrary Python values including dataclasses and
Pydantic models, at the cost of being unsafe to load from untrusted
sources. Callers caching responses across trust boundaries should
supply a structured serializer (e.g. ``json.dumps`` with a custom
``default``).
"""
return base64.b64encode(pickle.dumps(response)).decode("ascii")
def _default_deserialize(payload: str) -> Any:
return pickle.loads(base64.b64decode(payload.encode("ascii")))
[docs]
class SemanticCacheMiddleware:
"""Cache LLM responses keyed by the semantic similarity of the input.
Guards against non-deterministic call contexts - requests are not
cached when any of ``tools``, ``handoffs``, or ``output_schema`` are
present, because the response typically depends on those side
conditions and may not repeat.
Args:
cache: A :class:`SemanticCache` instance managing the underlying
two-level cache.
serializer: Callable that converts a model response to a string
for storage. Defaults to pickle+base64.
deserializer: Inverse of ``serializer``.
cacheable: Optional predicate that returns False to skip caching
for a given request even if the default guards would allow
it. Applied after the default guards.
"""
[docs]
def __init__(
self,
cache: SemanticCache,
*,
serializer: Serializer = _default_serialize,
deserializer: Deserializer = _default_deserialize,
cacheable: Callable[[ModelRequest], bool] | None = None,
) -> None:
self._cache = cache
self._serialize = serializer
self._deserialize = deserializer
self._cacheable = cacheable
async def awrap_model_call(self, request: ModelRequest, handler: ModelCallHandler) -> Any:
if not self._is_cacheable(request):
return await handler(request)
prompt = self._build_prompt(request)
if not prompt:
return await handler(request)
hit = await self._lookup(prompt)
if hit is not None:
return hit
response = await handler(request)
await self._store(prompt, response)
return response
def _is_cacheable(self, request: ModelRequest) -> bool:
"""Default guards plus optional user predicate."""
if request.tools or request.handoffs or request.output_schema:
return False
if self._cacheable is not None:
return self._cacheable(request)
return True
def _build_prompt(self, request: ModelRequest) -> str:
"""Build the cache key from the request's textual content.
Includes system_instructions to distinguish otherwise-identical
user inputs run with different system prompts.
"""
parts: list[str] = []
if request.system_instructions:
parts.append(f"sys:{request.system_instructions}")
if isinstance(request.input, str):
parts.append(f"input:{request.input}")
elif isinstance(request.input, list):
for item in request.input:
parts.append(f"item:{self._stringify(item)}")
else:
parts.append(f"input:{self._stringify(request.input)}")
return "\n".join(parts)
@staticmethod
def _stringify(value: Any) -> str:
if isinstance(value, dict):
# Stable order for keys so cache keys are deterministic.
return "{" + ",".join(f"{k}={value[k]}" for k in sorted(value)) + "}"
return str(value)
async def _lookup(self, prompt: str) -> Any | None:
hit = await self._cache.aget(prompt)
if hit is None:
return None
try:
return self._deserialize(hit.response)
except Exception:
return None
async def _store(self, prompt: str, response: Any) -> None:
try:
payload = self._serialize(response)
except Exception:
return
await self._cache.aset(prompt, payload)