Source code for redis_openai_agents.streams

"""RedisStreamTransport - Reliable token streaming via Redis Streams.

This module provides reliable event streaming using Redis Streams,
enabling real-time token delivery with replay capability.

Features:
- Reliable event delivery via Redis Streams
- Consumer groups for multiple clients
- Event replay from any point
- Automatic stream trimming
"""

import asyncio
import json
import logging
import time
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, Optional

from redis import Redis, ResponseError
from redis import asyncio as aioredis
from redis.exceptions import RedisError

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from .pool import RedisConnectionPool


[docs] class RedisStreamTransport: """Redis Streams-based event transport for token streaming. Provides reliable, replayable event streaming using Redis Streams. Supports consumer groups for multiple concurrent clients. Example: >>> stream = RedisStreamTransport(stream_name="agent_output") >>> stream.publish(event_type="token", data={"token": "Hello"}) >>> events = stream.read_all(count=10) Args: stream_name: Name of the Redis Stream redis_url: Redis connection URL consumer_group: Consumer group name for reading max_len: Maximum stream length (older entries trimmed) """
[docs] def __init__( self, stream_name: str, redis_url: str = "redis://localhost:6379", consumer_group: str = "agents", max_len: int = 10000, pool: Optional["RedisConnectionPool"] = None, ) -> None: """Initialize the stream transport. Args: stream_name: Name of the Redis Stream redis_url: Redis connection URL consumer_group: Consumer group name for reading max_len: Maximum stream length (older entries trimmed) pool: Optional shared connection pool """ self._stream_name = stream_name self._consumer_group = consumer_group self._max_len = max_len self._pool = pool # Use pool's client if provided if pool is not None: self._redis_url = pool.redis_url self._client = pool.get_sync_client() else: self._redis_url = redis_url self._client = Redis.from_url(redis_url, decode_responses=True) # Create consumer group if it doesn't exist try: self._client.xgroup_create( stream_name, consumer_group, id="0", mkstream=True, ) except ResponseError as e: if "BUSYGROUP" not in str(e): raise
@staticmethod def _parse_event(msg_id: str, fields: dict[str, str]) -> dict[str, Any]: """Parse raw stream fields into a structured event dict.""" event: dict[str, Any] = { "id": msg_id, "type": fields.get("type", "unknown"), "timestamp": float(fields.get("timestamp", 0)), } data_str = fields.get("data", "{}") try: event["data"] = json.loads(data_str) except (json.JSONDecodeError, TypeError): event["data"] = {} metadata_str = fields.get("metadata") if metadata_str: try: event["metadata"] = json.loads(metadata_str) except (json.JSONDecodeError, TypeError): event["metadata"] = {} return event @property def stream_name(self) -> str: """Name of the Redis Stream.""" return self._stream_name @property def consumer_group(self) -> str: """Consumer group name.""" return self._consumer_group def publish( self, event_type: str, data: dict[str, Any], metadata: dict[str, Any] | None = None, ) -> str: """Publish an event to the stream. Args: event_type: Type of event (token, tool_call, tool_result, complete) data: Event data metadata: Optional metadata Returns: Message ID from Redis """ event = { "type": event_type, "timestamp": str(time.time()), "data": json.dumps(data), } if metadata: event["metadata"] = json.dumps(metadata) # Add to stream with automatic trimming msg_id: str = self._client.xadd( # type: ignore[assignment] self._stream_name, event, # type: ignore[arg-type] maxlen=self._max_len, approximate=True, ) return msg_id def read_all( self, count: int = 100, start_id: str = "0", ) -> list[dict[str, Any]]: """Read all events from the stream. Args: count: Maximum number of events to read start_id: Start reading from this ID (default: beginning) Returns: List of events """ try: # Use XRANGE to read all messages messages = self._client.xrange( self._stream_name, min=start_id, max="+", count=count, ) except ResponseError: return [] if not messages: return [] events: list[dict[str, Any]] = [] for msg_id, fields in messages: # type: ignore[union-attr] events.append(self._parse_event(msg_id, fields)) return events def info(self) -> dict[str, Any]: """Get stream information. Returns: Dictionary with stream info (length, groups, etc.) """ try: # Get stream info info: dict[str, Any] = self._client.xinfo_stream( # type: ignore[assignment] self._stream_name ) groups_info: list[Any] = self._client.xinfo_groups( # type: ignore[assignment] self._stream_name ) return { "length": info.get("length", 0), "first_entry": info.get("first-entry"), "last_entry": info.get("last-entry"), "groups": len(groups_info), } except ResponseError: return { "length": 0, "first_entry": None, "last_entry": None, "groups": 0, } def read_group( self, consumer: str, count: int = 100, block_ms: int | None = None, ) -> list[dict[str, Any]]: """Read events using consumer group (XREADGROUP). Events read but not ACKed are tracked as pending. Use ack() to acknowledge processed events. Args: consumer: Consumer name within the group count: Maximum number of events to read block_ms: Block for this many milliseconds waiting for events (None = non-blocking) Returns: List of events """ try: # XREADGROUP reads only new messages for this consumer streams: dict[str, str] = {self._stream_name: ">"} result = self._client.xreadgroup( groupname=self._consumer_group, consumername=consumer, streams=streams, # type: ignore[arg-type] count=count, block=block_ms, ) if not result: return [] # Result format: [(stream_name, [(msg_id, {fields}), ...])] events: list[dict[str, Any]] = [] for _stream_name, messages in result: # type: ignore[union-attr] for msg_id, fields in messages: events.append(self._parse_event(msg_id, fields)) return events except ResponseError: return [] def ack(self, ids: list[str]) -> int: """Acknowledge events as processed (XACK). ACKed events are removed from the pending list. Args: ids: List of message IDs to acknowledge Returns: Number of messages acknowledged """ if not ids: return 0 try: acked: int = self._client.xack( # type: ignore[assignment] self._stream_name, self._consumer_group, *ids, ) return acked except ResponseError: return 0 def pending(self) -> dict[str, Any]: """Get information about pending messages (XPENDING). Returns: Dictionary with pending info: - count: Number of pending messages - min_id: Smallest pending message ID - max_id: Largest pending message ID - consumers: Dict mapping consumer names to pending counts """ try: result = self._client.xpending( self._stream_name, self._consumer_group, ) if not result: return {"count": 0, "min_id": None, "max_id": None, "consumers": {}} # redis-py returns a dict with keys: pending, min, max, consumers pending_count = result.get("pending", 0) # type: ignore[union-attr] if pending_count == 0: return {"count": 0, "min_id": None, "max_id": None, "consumers": {}} consumers_list = result.get("consumers", []) # type: ignore[union-attr] consumers: dict[str, int] = {} if consumers_list: for consumer_info in consumers_list: consumer_name = consumer_info.get("name", "") consumer_count = consumer_info.get("pending", 0) if consumer_name: consumers[consumer_name] = consumer_count return { "count": pending_count, "min_id": result.get("min"), # type: ignore[union-attr] "max_id": result.get("max"), # type: ignore[union-attr] "consumers": consumers, } except ResponseError: return {"count": 0, "min_id": None, "max_id": None, "consumers": {}} def claim( self, consumer: str, min_idle_ms: int, ids: list[str], ) -> list[dict[str, Any]]: """Claim pending messages from another consumer (XCLAIM). Used to recover messages from dead/stuck consumers. Args: consumer: Consumer to assign messages to min_idle_ms: Only claim messages idle for at least this long ids: Message IDs to claim Returns: List of claimed events """ if not ids: return [] try: result = self._client.xclaim( self._stream_name, self._consumer_group, consumer, min_idle_ms, ids, # type: ignore[arg-type] ) if not result: return [] events: list[dict[str, Any]] = [] for msg_id, fields in result: # type: ignore[union-attr] if fields is None: # Message doesn't exist or already ACKed continue events.append(self._parse_event(msg_id, fields)) return events except ResponseError: return [] def delete(self) -> None: """Delete the stream and all its data.""" try: self._client.delete(self._stream_name) except ResponseError: pass def close(self) -> None: """Close the Redis connection.""" self._client.close() # Async methods def _get_async_redis(self) -> aioredis.Redis: """Get or create async Redis client.""" if not hasattr(self, "_async_client"): self._async_client: aioredis.Redis = aioredis.from_url( self._redis_url, decode_responses=True ) return self._async_client async def apublish( self, event_type: str, data: dict[str, Any], metadata: dict[str, Any] | None = None, ) -> str: """Async version of publish() - publish an event to the stream. Args: event_type: Type of event (token, tool_call, tool_result, complete) data: Event data metadata: Optional metadata Returns: Message ID from Redis """ client = self._get_async_redis() event = { "type": event_type, "timestamp": str(time.time()), "data": json.dumps(data), } if metadata: event["metadata"] = json.dumps(metadata) msg_id = await client.xadd( self._stream_name, event, # type: ignore[arg-type] maxlen=self._max_len, approximate=True, ) return str(msg_id) async def asubscribe( self, last_id: str = "$", block_ms: int = 1000, ) -> AsyncIterator[dict[str, Any]]: """Async generator that yields events from the stream. Args: last_id: Start from this ID ("$" = new events only, "0" = all) block_ms: Block timeout in milliseconds Yields: Event dictionaries """ client = self._get_async_redis() current_id = last_id while True: try: result = await client.xread( streams={self._stream_name: current_id}, block=block_ms, count=10, ) if not result: continue for _stream_name, messages in result: for msg_id, fields in messages: current_id = msg_id yield self._parse_event(msg_id, fields) except RedisError as exc: logger.warning("Transient error in asubscribe, retrying: %s", exc) await asyncio.sleep(0.5) continue except Exception as exc: logger.error("Fatal error in asubscribe, stopping: %s", exc) break