Source code for redis_openai_agents.middleware.conversation_memory

"""ConversationMemoryMiddleware - inject semantically relevant past messages.

On each model call, looks up the top-K past messages most relevant to the
current user input and prepends them to the request. After the handler
returns, stores both the user turn and the assistant reply in the history
for future retrieval.

Backed by :class:`redisvl.extensions.message_history.SemanticMessageHistory`.
"""

from __future__ import annotations

import asyncio
import logging
from typing import TYPE_CHECKING, Any

from ._utils import extract_user_text
from .base import ModelCallHandler, ModelRequest

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from redisvl.extensions.message_history import (  # type: ignore[import-untyped]
        SemanticMessageHistory,
    )


[docs] class ConversationMemoryMiddleware: """Prepend semantically relevant past messages into the request. Args: history: A :class:`SemanticMessageHistory` instance backing the retrieval and storage. Callers are responsible for its lifecycle. session_tag: Tag passed to history queries and inserts. Allows tenant / user / conversation isolation. top_k: Maximum number of past messages to prepend. distance_threshold: Optional override for relevance matching. persist_reply: When True (default), the assistant reply is also stored back into history so follow-up turns can retrieve it. response_text_extractor: Optional callable that turns a model response into the text to store. Defaults to best-effort extraction from OpenAI Responses-shaped responses. """
[docs] def __init__( self, history: SemanticMessageHistory, *, session_tag: str | None = None, top_k: int = 5, distance_threshold: float | None = None, persist_reply: bool = True, response_text_extractor: Any = None, ) -> None: self._history = history self._session_tag = session_tag self._top_k = top_k self._distance_threshold = distance_threshold self._persist_reply = persist_reply self._extract_response_text = response_text_extractor or _default_extract
async def awrap_model_call(self, request: ModelRequest, handler: ModelCallHandler) -> Any: prompt = self._extract_prompt(request) if prompt: relevant = await asyncio.to_thread(self._fetch_relevant, prompt) if relevant: request.input = self._merge_input(request.input, relevant) response = await handler(request) if prompt and self._persist_reply: reply_text = self._extract_response_text(response) await asyncio.to_thread(self._persist, prompt, reply_text) return response def _fetch_relevant(self, prompt: str) -> list[dict[str, Any]]: kwargs: dict[str, Any] = { "top_k": self._top_k, "raw": False, "as_text": False, } if self._session_tag is not None: kwargs["session_tag"] = self._session_tag if self._distance_threshold is not None: kwargs["distance_threshold"] = self._distance_threshold try: result = self._history.get_relevant(prompt, **kwargs) except Exception as exc: logger.debug("Failed to fetch relevant history: %s", exc) return [] # get_relevant returns List[Dict[str,str]] when as_text=False. return [dict(item) for item in result if isinstance(item, dict)] @staticmethod def _extract_prompt(request: ModelRequest) -> str: return extract_user_text(request.input) @staticmethod def _merge_input(current: Any, prepend: list[dict[str, Any]]) -> list[dict[str, Any]]: """Return a list input with past turns in front of the current input.""" normalized = [_normalize_turn(turn) for turn in prepend] if isinstance(current, list): return [*normalized, *current] return [*normalized, {"role": "user", "content": str(current)}] def _persist(self, user_text: str, assistant_text: str) -> None: try: messages = [{"role": "user", "content": user_text}] if assistant_text: messages.append({"role": "assistant", "content": assistant_text}) if self._session_tag is not None: self._history.add_messages(messages, session_tag=self._session_tag) else: self._history.add_messages(messages) except Exception as exc: # History persistence errors must not fail the model call. logger.debug("Failed to persist conversation history: %s", exc)
def _normalize_turn(turn: dict[str, Any]) -> dict[str, Any]: """Normalize a stored history turn for the Agents SDK input format. Older entries may carry the deprecated "llm" role; the SDK expects "assistant". """ role = turn.get("role") or turn.get("type") or "user" if role == "llm": role = "assistant" content = turn.get("content", "") return {"role": role, "content": content} def _default_extract(response: Any) -> str: """Best-effort extraction of the assistant text from a ModelResponse.""" output = getattr(response, "output", None) if not output: return "" for item in output: content = getattr(item, "content", None) if isinstance(content, list): for block in content: text = getattr(block, "text", None) if isinstance(text, str) and text: return text return ""