Source code for redis_openai_agents.resumable_streaming

"""ResumableStreamRunner - Resumable LLM streaming backed by Redis Streams.

This module provides resumable streaming for LLM responses, allowing clients
to disconnect and reconnect without losing data. Events are stored durably
in Redis Streams and can be replayed from any point.

Features:
- Durable event storage in Redis Streams
- Resumable subscriptions from any message ID
- Consumer groups for per-client progress tracking
- Automatic stream trimming with configurable max length
- Multi-client support (same stream, different progress)

Example:
    >>> from redis_openai_agents import ResumableStreamRunner
    >>>
    >>> runner = ResumableStreamRunner(redis_url="redis://localhost:6379")
    >>> await runner.initialize()
    >>>
    >>> # Producer publishes events
    >>> session_id = "chat_123"
    >>> await runner.publish_event(session_id, "text_delta", {"delta": "Hello"})
    >>> await runner.publish_event(session_id, "text_delta", {"delta": " world"})
    >>>
    >>> # Consumer subscribes (can reconnect anytime)
    >>> async for event in runner.subscribe(session_id, from_id="0"):
    ...     print(event["data"]["delta"], end="")
    Hello world
"""

import json
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any

import redis.asyncio as redis


@dataclass
class StreamEvent:
    """A streaming event with metadata.

    Attributes:
        id: Redis Stream message ID.
        type: Event type (e.g., "text_delta", "tool_call").
        data: Event payload.
        timestamp: Unix timestamp when event was published.
    """

    id: str
    type: str
    data: dict
    timestamp: float

    def to_dict(self) -> dict:
        """Convert to dictionary representation."""
        return {
            "id": self.id,
            "type": self.type,
            "data": self.data,
            "timestamp": self.timestamp,
        }


[docs] class ResumableStreamRunner: """Resumable LLM streaming backed by Redis Streams. Enables durable, resumable streaming of LLM responses. Events are published to Redis Streams and can be consumed by multiple clients, each tracking their own progress. The key insight is separating generation from consumption: - Generation continues even if clients disconnect - Clients can reconnect and resume from where they left off - Multiple clients can consume the same stream independently Attributes: stream_prefix: Prefix for Redis Stream keys. max_stream_length: Maximum events per stream (for trimming). consumer_group: Default consumer group name. """
[docs] def __init__( self, redis_url: str = "redis://localhost:6379", stream_prefix: str = "llm_stream", max_stream_length: int | None = None, consumer_group: str = "consumers", ) -> None: """Initialize the resumable stream runner. Args: redis_url: Redis connection URL. stream_prefix: Prefix for stream keys. max_stream_length: Max events per stream (None = unlimited). consumer_group: Default consumer group name. """ self._redis_url = redis_url self._stream_prefix = stream_prefix self._max_stream_length = max_stream_length self._consumer_group = consumer_group self._client: redis.Redis | None = None
async def initialize(self) -> None: """Initialize Redis connection.""" self._client = redis.from_url(self._redis_url, decode_responses=True) async def close(self) -> None: """Close Redis connection.""" if self._client: await self._client.aclose() self._client = None def _get_stream_key(self, session_id: str) -> str: """Get Redis Stream key for a session. Args: session_id: Session identifier. Returns: Full Redis key for the stream. """ return f"{self._stream_prefix}:{session_id}" async def publish_event( self, session_id: str, event_type: str, data: dict, metadata: dict | None = None, ) -> str: """Publish a streaming event to Redis. Args: session_id: Session identifier for the stream. event_type: Type of event (e.g., "text_delta", "tool_call"). data: Event payload data. metadata: Optional additional metadata. Returns: Redis Stream message ID. """ if not self._client: raise RuntimeError("Runner not initialized. Call initialize() first.") stream_key = self._get_stream_key(session_id) timestamp = time.time() # Build event fields fields = { "type": event_type, "data": json.dumps(data), "timestamp": str(timestamp), } if metadata: fields["metadata"] = json.dumps(metadata) # Publish to stream with optional trimming if self._max_stream_length: msg_id = await self._client.xadd( stream_key, fields, # type: ignore[arg-type] maxlen=self._max_stream_length, approximate=False, ) else: msg_id = await self._client.xadd(stream_key, fields) # type: ignore[arg-type] return str(msg_id) async def get_all_events(self, session_id: str) -> list[dict]: """Get all events from a stream. Args: session_id: Session identifier. Returns: List of events with id, type, data, and timestamp. """ if not self._client: return [] stream_key = self._get_stream_key(session_id) try: messages = await self._client.xrange(stream_key, "-", "+") except redis.ResponseError: return [] events = [] for msg_id, fields in messages: event = self._parse_message(msg_id, fields) events.append(event) return events def _parse_message(self, msg_id: str, fields: dict) -> dict: """Parse a Redis Stream message into an event dict. Args: msg_id: Redis Stream message ID. fields: Message fields from Redis. Returns: Parsed event dictionary. """ data = json.loads(fields.get("data", "{}")) timestamp = float(fields.get("timestamp", 0)) event = { "id": msg_id, "type": fields.get("type", "unknown"), "data": data, "timestamp": timestamp, } if "metadata" in fields: event["metadata"] = json.loads(fields["metadata"]) return event async def subscribe( self, session_id: str, from_id: str = "$", timeout_ms: int = 5000, count: int = 100, ) -> AsyncIterator[dict]: """Subscribe to streaming events. Yields events from the stream, starting from the specified ID. Use from_id="0" to start from the beginning, or "$" for new events only. Args: session_id: Session identifier. from_id: Message ID to start from ("0" = beginning, "$" = new only). timeout_ms: Block timeout in milliseconds (0 = don't block). count: Max events per read. Yields: Event dictionaries with id, type, data, and timestamp. """ if not self._client: return stream_key = self._get_stream_key(session_id) last_id = from_id # First, read any existing messages if starting from beginning or specific ID if from_id != "$": try: messages = await self._client.xrange( stream_key, "(" + from_id if from_id != "0" else "-", "+", count=count, ) for msg_id, fields in messages: event = self._parse_message(msg_id, fields) yield event last_id = msg_id except redis.ResponseError: pass # Then block for new messages if timeout_ms > 0: try: result = await self._client.xread( {stream_key: last_id}, count=count, block=timeout_ms, ) if result: for _, messages in result: for msg_id, fields in messages: event = self._parse_message(msg_id, fields) yield event except redis.ResponseError: pass async def subscribe_as_consumer( self, session_id: str, consumer_id: str, timeout_ms: int = 5000, count: int = 100, ) -> AsyncIterator[dict]: """Subscribe as a consumer in a consumer group. Uses Redis consumer groups to track which messages each consumer has processed. Consumers must call ack() after processing each message. Args: session_id: Session identifier. consumer_id: Unique consumer identifier. timeout_ms: Block timeout in milliseconds. count: Max events per read. Yields: Event dictionaries with id, type, data, and timestamp. """ if not self._client: return stream_key = self._get_stream_key(session_id) group_name = f"{self._consumer_group}:{session_id}" # Ensure consumer group exists try: await self._client.xgroup_create( stream_key, group_name, id="0", mkstream=True, ) except redis.ResponseError as e: if "BUSYGROUP" not in str(e): raise # First, read any pending messages for this consumer try: pending = await self._client.xreadgroup( group_name, consumer_id, {stream_key: "0"}, count=count, ) if pending: for _, messages in pending: for msg_id, fields in messages: if fields: # Skip deleted messages event = self._parse_message(msg_id, fields) yield event except redis.ResponseError: pass # Then read new messages try: result = await self._client.xreadgroup( group_name, consumer_id, {stream_key: ">"}, count=count, block=timeout_ms, ) if result: for _, messages in result: for msg_id, fields in messages: event = self._parse_message(msg_id, fields) yield event except redis.ResponseError: pass async def ack( self, session_id: str, consumer_id: str, message_id: str, ) -> int: """Acknowledge a message as processed. After acknowledging, the message won't be redelivered to this consumer. Args: session_id: Session identifier. consumer_id: Consumer identifier (for group name lookup). message_id: Message ID to acknowledge. Returns: Number of messages acknowledged (0 or 1). """ if not self._client: return 0 stream_key = self._get_stream_key(session_id) group_name = f"{self._consumer_group}:{session_id}" return int(await self._client.xack(stream_key, group_name, message_id)) async def get_stream_info(self, session_id: str) -> dict: """Get information about a stream. Args: session_id: Session identifier. Returns: Dictionary with stream info (length, first/last entry IDs, etc.). """ if not self._client: return {} stream_key = self._get_stream_key(session_id) try: info = await self._client.xinfo_stream(stream_key) return { "length": info.get("length", 0), "first_entry_id": info.get("first-entry", [None])[0] if info.get("first-entry") else None, "last_entry_id": info.get("last-entry", [None])[0] if info.get("last-entry") else None, "groups": info.get("groups", 0), } except redis.ResponseError: return {"length": 0} async def delete_stream(self, session_id: str) -> bool: """Delete a stream and all its messages. Args: session_id: Session identifier. Returns: True if stream was deleted, False otherwise. """ if not self._client: return False stream_key = self._get_stream_key(session_id) result = await self._client.delete(stream_key) return int(result) > 0 async def get_pending_count( self, session_id: str, consumer_id: str | None = None, ) -> int: """Get count of pending (unacknowledged) messages. Args: session_id: Session identifier. consumer_id: Optional consumer ID to filter by. Returns: Number of pending messages. """ if not self._client: return 0 stream_key = self._get_stream_key(session_id) group_name = f"{self._consumer_group}:{session_id}" try: info = await self._client.xpending(stream_key, group_name) if not info or (isinstance(info, (list, tuple)) and info[0] == 0): return 0 return int(info[0]) if isinstance(info, (list, tuple)) else 0 except redis.ResponseError: return 0 async def claim_pending( self, session_id: str, consumer_id: str, min_idle_time_ms: int = 60000, count: int = 10, ) -> list[dict]: """Claim pending messages from dead consumers. Transfers ownership of messages that have been pending longer than min_idle_time_ms to the specified consumer. Args: session_id: Session identifier. consumer_id: Consumer to claim messages for. min_idle_time_ms: Minimum idle time to claim. count: Max messages to claim. Returns: List of claimed message events. """ if not self._client: return [] stream_key = self._get_stream_key(session_id) group_name = f"{self._consumer_group}:{session_id}" try: # Get pending message IDs pending = await self._client.xpending_range( stream_key, group_name, "-", "+", count, ) if not pending: return [] # Filter by idle time and claim message_ids = [ p["message_id"] for p in pending if p.get("time_since_delivered", 0) >= min_idle_time_ms ] if not message_ids: return [] claimed = await self._client.xclaim( stream_key, group_name, consumer_id, min_idle_time_ms, message_ids, ) return [self._parse_message(msg_id, fields) for msg_id, fields in claimed if fields] except redis.ResponseError: return []
[docs] class StreamingEventPublisher: """High-level helper for publishing SDK streaming events. Provides convenient methods for publishing common event types from the OpenAI Agents SDK streaming interface. Example: >>> publisher = StreamingEventPublisher(runner, session_id="chat_123") >>> >>> # Publish text deltas as they arrive >>> async for event in sdk_stream_events: ... if isinstance(event, RawResponsesStreamEvent): ... await publisher.publish_raw_event(event.data) ... elif isinstance(event, RunItemStreamEvent): ... await publisher.publish_item_event(event.name, event.item) """
[docs] def __init__( self, runner: ResumableStreamRunner, session_id: str, ) -> None: """Initialize the publisher. Args: runner: ResumableStreamRunner instance. session_id: Session identifier for this stream. """ self._runner = runner self._session_id = session_id self._message_count = 0
async def publish_text_delta(self, delta: str) -> str: """Publish a text delta event. Args: delta: Text content to publish. Returns: Message ID. """ return await self._runner.publish_event( session_id=self._session_id, event_type="text_delta", data={"delta": delta}, ) async def publish_tool_call( self, tool_name: str, arguments: dict, call_id: str | None = None, ) -> str: """Publish a tool call event. Args: tool_name: Name of the tool being called. arguments: Tool arguments. call_id: Optional tool call ID. Returns: Message ID. """ data = { "tool": tool_name, "arguments": arguments, } if call_id: data["call_id"] = call_id return await self._runner.publish_event( session_id=self._session_id, event_type="tool_call", data=data, ) async def publish_tool_result( self, tool_name: str, result: Any, call_id: str | None = None, ) -> str: """Publish a tool result event. Args: tool_name: Name of the tool. result: Tool execution result. call_id: Optional tool call ID. Returns: Message ID. """ data = { "tool": tool_name, "result": result, } if call_id: data["call_id"] = call_id return await self._runner.publish_event( session_id=self._session_id, event_type="tool_result", data=data, ) async def publish_stream_start( self, agent_name: str, metadata: dict | None = None, ) -> str: """Publish stream start event. Args: agent_name: Name of the agent starting. metadata: Optional additional metadata. Returns: Message ID. """ return await self._runner.publish_event( session_id=self._session_id, event_type="stream_start", data={"agent": agent_name}, metadata=metadata, ) async def publish_stream_end( self, reason: str = "complete", metadata: dict | None = None, ) -> str: """Publish stream end event. Args: reason: Reason for ending (e.g., "complete", "error", "cancelled"). metadata: Optional additional metadata. Returns: Message ID. """ return await self._runner.publish_event( session_id=self._session_id, event_type="stream_end", data={"reason": reason}, metadata=metadata, ) async def publish_handoff( self, from_agent: str, to_agent: str, ) -> str: """Publish agent handoff event. Args: from_agent: Agent handing off. to_agent: Agent receiving handoff. Returns: Message ID. """ return await self._runner.publish_event( session_id=self._session_id, event_type="handoff", data={ "from_agent": from_agent, "to_agent": to_agent, }, ) async def publish_error( self, error_type: str, message: str, details: dict | None = None, ) -> str: """Publish error event. Args: error_type: Type of error. message: Error message. details: Optional error details. Returns: Message ID. """ data: dict[str, Any] = { "error_type": error_type, "message": message, } if details: data["details"] = details return await self._runner.publish_event( session_id=self._session_id, event_type="error", data=data, )