Source code for redis_openai_agents.session

"""AgentSession - Persistent session management for OpenAI Agents SDK.

This module provides session/conversation state management using Redis,
built on top of RedisVL's MessageHistory.
"""

import asyncio
import re
from typing import TYPE_CHECKING, Any, Optional
from uuid import uuid4

from pydantic import BaseModel, Field
from redisvl.extensions.message_history import MessageHistory  # type: ignore[import-untyped]

if TYPE_CHECKING:
    from .pool import RedisConnectionPool


class SessionMetadata(BaseModel):
    """Metadata for an agent session."""

    user_id: str
    conversation_id: str
    current_agent: str | None = None
    agents_used: list[str] = Field(default_factory=list)
    created_at: float | None = None
    updated_at: float | None = None


[docs] class AgentSession: """ Persistent session storage for OpenAI Agents SDK conversations. Built on top of RedisVL's MessageHistory, this class provides: - Message persistence across Python sessions - Multi-conversation management per user - Agent handoff tracking - Conversation metadata Example: >>> # Create a new session >>> session = AgentSession.create(user_id="user_123") >>> >>> # Run agent with session >>> result = Runner.run(agent, input=inputs, session=session) >>> >>> # Later, load and continue >>> session = AgentSession.load(conversation_id="conv_abc") >>> history = session.to_agent_inputs() >>> result = Runner.run(agent, input=history + new_inputs, session=session) """
[docs] def __init__( self, user_id: str, conversation_id: str | None = None, redis_url: str = "redis://localhost:6379", pool: Optional["RedisConnectionPool"] = None, ): """ Initialize an AgentSession. Args: user_id: User identifier conversation_id: Optional conversation ID (generates one if not provided) redis_url: Redis connection URL pool: Optional shared connection pool """ self.user_id = user_id self.conversation_id = conversation_id or str(uuid4().hex[:16]) self._pool = pool # Use pool's URL if provided if pool is not None: self.redis_url = pool.redis_url else: self.redis_url = redis_url # Create RedisVL MessageHistory instance # CRITICAL: Use conversation_id as session_tag to enable loading existing conversations # The name is the index name (shared across all sessions) # The session_tag is what identifies this specific conversation self._history = MessageHistory( name="agent_sessions", session_tag=self.conversation_id, # This makes messages retrievable! redis_url=self.redis_url, ) # Track current agent self._current_agent: str | None = None self._agents_used: set[str] = set()
@classmethod def create( cls, user_id: str, redis_url: str = "redis://localhost:6379", ) -> "AgentSession": """ Create a new session with a generated conversation ID. Args: user_id: User identifier redis_url: Redis connection URL Returns: New AgentSession instance """ return cls(user_id=user_id, redis_url=redis_url) @classmethod def load( cls, conversation_id: str, user_id: str | None = None, redis_url: str = "redis://localhost:6379", ) -> "AgentSession": """ Load an existing session from Redis. Args: conversation_id: Conversation ID to load user_id: Optional user ID (will be extracted from session if not provided) redis_url: Redis connection URL Returns: Loaded AgentSession instance Raises: ValueError: If session not found and user_id not provided """ if not user_id: # TODO: Extract user_id from session metadata in Redis raise ValueError("user_id must be provided when loading session") session = cls(user_id=user_id, conversation_id=conversation_id, redis_url=redis_url) # Load metadata to restore agent tracking metadata = session.get_metadata() session._current_agent = metadata.get("current_agent") session._agents_used = set(metadata.get("agents_used", [])) return session @classmethod def list_conversations( cls, user_id: str, redis_url: str = "redis://localhost:6379", ) -> list["AgentSession"]: """ List all conversations for a user. Args: user_id: User identifier redis_url: Redis connection URL Returns: List of AgentSession instances for this user """ # TODO: Implement by querying Redis for all sessions with this user_id prefix # For now, return empty list - this requires maintaining a user index return [] def add_message(self, role: str, content: str, **metadata: Any) -> None: """ Add a message to the session. Args: role: Message role (user, assistant, system) content: Message content **metadata: Additional metadata to store with message """ self._history.add_message({"role": role, "content": content, **metadata}) def store_exchange( self, user_message: str, assistant_response: str, agent_name: str | None = None ) -> None: """ Store a user-assistant message exchange. This is a convenience method for the common pattern of storing a user message and the assistant's response together. Args: user_message: The user's message assistant_response: The assistant's response agent_name: Optional name of the agent that responded """ self.add_message(role="user", content=user_message) self.add_message(role="assistant", content=assistant_response) if agent_name: self.track_agent(agent_name) def store_agent_result(self, result: Any) -> None: """ Store messages from OpenAI Agents SDK Runner result. This works with both `Runner.run()` and `Runner.run_streamed()` results. After streaming completes, call `result.to_input_list()` to get all messages including the assistant's response. Args: result: The result object from Runner.run() or Runner.run_streamed() (after consuming the stream) """ # Track the current agent if hasattr(result, "current_agent") and hasattr(result.current_agent, "name"): self.track_agent(result.current_agent.name) # Get all messages from the result using to_input_list() # This includes both the input messages and the assistant's response if hasattr(result, "to_input_list"): all_messages = result.to_input_list() # Track what we've stored to avoid duplicates from handoff wrapping stored_user_messages = set() # Store all messages, but filter out internal/function messages for msg in all_messages: # Handle different message formats # OpenAI Agents SDK returns messages as objects with .role and .content if hasattr(msg, "role") and hasattr(msg, "content"): role = msg.role if isinstance(msg.role, str) else str(msg.role) # Skip system and function messages - only store user/assistant if role not in ["user", "assistant"]: continue # Content can be a string or a list of content parts content_raw = msg.content if isinstance(content_raw, str): # Check if this is a wrapped handoff context message # These start with "For context, here is the conversation" if "For context, here is the conversation" in content_raw: # Extract the original user message from the history # Format: "1. user: <original message>\n2. function_call:..." match = re.search(r"1\. user: (.+?)(?:\n2\.|$)", content_raw, re.DOTALL) if match: content = match.group(1).strip() else: # Fallback: skip this wrapped message entirely continue else: content = content_raw elif isinstance(content_raw, list): # Extract text from content parts text_parts = [] for part in content_raw: part_text = None if hasattr(part, "text"): part_text = part.text elif isinstance(part, dict) and "text" in part: part_text = part["text"] if part_text: # Check if this text part is a wrapped handoff message if "For context, here is the conversation" in part_text: # Extract original from wrapped match = re.search( r"1\. user: (.+?)(?:\n2\.|$)", part_text, re.DOTALL ) if match: text_parts.append(match.group(1).strip()) else: text_parts.append(part_text) content = " ".join(text_parts) else: content = str(content_raw) if content: # Only store non-empty messages # Deduplicate: skip if this exact user message was already stored if role == "user" and content in stored_user_messages: continue if role == "user": stored_user_messages.add(content) self.add_message(role=role, content=content) elif isinstance(msg, dict): role = msg.get("role", "unknown") # Skip non-user/assistant messages if role not in ["user", "assistant"]: continue content_raw = msg.get("content", "") # Handle list format in dict - extract text from content parts if isinstance(content_raw, list): text_parts = [] for part in content_raw: if isinstance(part, dict) and "text" in part: text_parts.append(part["text"]) content = " ".join(text_parts) else: content = content_raw # Unwrap handoff context messages # These come from agent handoffs and wrap the original user message # They have role="assistant" but contain "For context, here is the conversation" if ( isinstance(content, str) and "For context, here is the conversation" in content ): match = re.search(r"1\. user: (.+?)(?:\n2\.|$)", content, re.DOTALL) if match: # Extract the original user message and correct the role content = match.group(1).strip() role = "user" # Correct the role to user else: # Skip this wrapped message if we can't extract the original continue if content: # Deduplicate for dict format too if role == "user" and content in stored_user_messages: continue if role == "user": stored_user_messages.add(content) self.add_message(role=role, content=content) def get_messages(self, top_k: int | None = None) -> list[dict[str, Any]]: """ Get recent messages from the session. Args: top_k: Number of recent messages to retrieve (None for all) Returns: List of message dictionaries """ if top_k is None: # Get all messages result: list[dict[str, Any]] = self._history.get_recent(top_k=1000) return result result = self._history.get_recent(top_k=top_k) return list(result) def to_agent_inputs(self) -> list[dict[str, str]]: """ Convert session messages to OpenAI Agents SDK input format. Returns: List of messages in format [{" content": str, "role": str}, ...] """ messages = self.get_messages() # Ensure format matches TResponseInputItem return [{"content": msg["content"], "role": msg["role"]} for msg in messages] def message_count(self) -> int: """ Get the number of messages in the session. Returns: Count of messages """ return len(self.get_messages()) def track_agent(self, agent_name: str) -> None: """ Track agent usage in this session. Args: agent_name: Name of the agent being used """ self._current_agent = agent_name self._agents_used.add(agent_name) @property def current_agent(self) -> str | None: """Get the current agent name.""" return self._current_agent def get_metadata(self) -> dict[str, Any]: """ Get session metadata. Returns: Dictionary containing session metadata """ return { "user_id": self.user_id, "conversation_id": self.conversation_id, "current_agent": self._current_agent, "agents_used": list(self._agents_used), "message_count": self.message_count(), } def clear(self) -> None: """Clear all messages from the session.""" self._history.clear() self._current_agent = None self._agents_used.clear() def delete(self) -> None: """Delete the session and all its data from Redis.""" self._history.clear() self._current_agent = None self._agents_used.clear() # Async methods async def aadd_message(self, role: str, content: str, **metadata: Any) -> None: """Async version of add_message() - add a message to the session. Args: role: Message role (user, assistant, system) content: Message content **metadata: Additional metadata to store with message """ await asyncio.to_thread( self._history.add_message, {"role": role, "content": content, **metadata} ) async def aget_messages(self, top_k: int | None = None) -> list[dict[str, Any]]: """Async version of get_messages() - get recent messages from the session. Args: top_k: Number of recent messages to retrieve (None for all) Returns: List of message dictionaries """ if top_k is None: result: list[dict[str, Any]] = await asyncio.to_thread( self._history.get_recent, top_k=1000 ) return result result = await asyncio.to_thread(self._history.get_recent, top_k=top_k) return list(result) async def astore_agent_result(self, result: Any) -> None: """Async version of store_agent_result() - store messages from Runner result. Args: result: The result object from Runner.run() """ await asyncio.to_thread(self.store_agent_result, result)