Source code for redis_openai_agents.robust_processor

"""RobustStreamProcessor - Fault-tolerant stream processing with DLQ.

This module provides enhanced stream processing with automatic recovery features:
- Automatic pending message recovery via XCLAIM
- Dead-letter queue (DLQ) for failed messages after max retries
- Processing timeout detection
- Health statistics and monitoring

Key Features:
- Crash recovery: Automatically claims pending messages from crashed consumers
- DLQ: Messages that fail repeatedly are moved to a dead-letter queue
- Replay: Failed messages can be replayed back to the main stream
- Observability: Health stats for monitoring stream processor status
"""

from __future__ import annotations

import logging
import time
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from uuid import uuid4

from redis import asyncio as aioredis
from redis.exceptions import ResponseError

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from redis.asyncio import Redis


[docs] class RobustStreamProcessor: """ Fault-tolerant stream processor with DLQ support. Features: - Automatic pending message recovery (XCLAIM) - Dead-letter queue for failed messages - Processing timeout detection - Health statistics Example: >>> processor = RobustStreamProcessor( ... redis_url="redis://localhost:6379", ... stream_name="agent_events", ... consumer_group="workers", ... ) >>> await processor.initialize() >>> >>> async def handle_message(msg: dict) -> bool: ... # Process message, return True on success ... return True >>> >>> await processor.process_with_recovery(handle_message) """
[docs] def __init__( self, redis_url: str = "redis://localhost:6379", stream_name: str = "agent_events", consumer_group: str = "workers", consumer_name: str | None = None, dlq_stream: str | None = None, max_retries: int = 3, claim_timeout_ms: int = 300000, # 5 minutes ) -> None: """ Initialize RobustStreamProcessor. Args: redis_url: Redis connection URL stream_name: Name of the main Redis Stream consumer_group: Consumer group name consumer_name: Consumer name (auto-generated if not provided) dlq_stream: Dead-letter queue stream name (default: {stream_name}:dlq) max_retries: Maximum delivery attempts before moving to DLQ claim_timeout_ms: Idle time (ms) before claiming pending messages """ self._redis_url = redis_url self._stream_name = stream_name self._consumer_group = consumer_group self._consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}" self._dlq_stream = dlq_stream or f"{stream_name}:dlq" self._max_retries = max_retries self._claim_timeout_ms = claim_timeout_ms self._client: Redis | None = None self._initialized = False
async def initialize(self) -> None: """ Initialize the processor and create consumer group if needed. Must be called before using other methods. """ if self._initialized: return self._client = aioredis.from_url(self._redis_url, decode_responses=True) # Create consumer group if not exists try: await self._client.xgroup_create( self._stream_name, self._consumer_group, id="0", mkstream=True, ) except ResponseError as exc: if "BUSYGROUP" not in str(exc): raise # Group already exists self._initialized = True async def _get_client(self) -> Redis: """Get Redis client, ensuring initialization.""" if not self._initialized or self._client is None: await self.initialize() return self._client # type: ignore[return-value] async def process_batch( self, handler: Callable[[dict[str, Any]], Awaitable[bool]], batch_size: int = 10, block_ms: int = 5000, max_batches: int | None = None, ) -> int: """ Process a batch of messages from the stream. Args: handler: Async function to process each message. Return True on success. batch_size: Number of messages to read per batch block_ms: Blocking timeout in milliseconds max_batches: Maximum batches to process (None = unlimited) Returns: Number of messages successfully processed """ client = await self._get_client() processed = 0 batches = 0 while max_batches is None or batches < max_batches: # First, try to claim any abandoned messages await self.claim_pending_messages() # Read new messages messages = await client.xreadgroup( self._consumer_group, self._consumer_name, {self._stream_name: ">"}, count=batch_size, block=block_ms, ) if not messages: batches += 1 continue for _stream_name, events in messages: for msg_id, data in events: success = await self._process_message(msg_id, data, handler) if success: await client.xack(self._stream_name, self._consumer_group, msg_id) processed += 1 # If not success, message stays pending for retry batches += 1 return processed async def process_with_recovery( self, handler: Callable[[dict[str, Any]], Awaitable[bool]], batch_size: int = 10, block_ms: int = 5000, ) -> None: """ Process stream with automatic failure recovery. This method runs indefinitely: 1. Claims abandoned messages from crashed consumers 2. Processes new messages 3. Moves failed messages to DLQ after max retries Args: handler: Async function to process each message batch_size: Number of messages per batch block_ms: Blocking timeout """ await self.process_batch(handler, batch_size, block_ms, max_batches=None) async def _process_message( self, msg_id: str, data: dict[str, Any], handler: Callable[[dict[str, Any]], Awaitable[bool]], ) -> bool: """Process a single message with error handling.""" try: return await handler(data) except Exception as e: # Log error but don't ACK - will be retried logger.error("Error processing %s: %s", msg_id, e) return False async def claim_pending_messages(self) -> int: """ Claim messages from crashed/slow consumers. Messages older than claim_timeout_ms are claimed for reprocessing. Messages exceeding max_retries are moved to DLQ. Returns: Number of messages claimed for reprocessing """ client = await self._get_client() # Get pending messages pending = await client.xpending_range( self._stream_name, self._consumer_group, min="-", max="+", count=100, ) claimed = 0 for entry in pending: idle_time = entry.get("time_since_delivered", 0) times_delivered = entry.get("times_delivered", 0) msg_id = entry.get("message_id") if idle_time < self._claim_timeout_ms: # Message not old enough to claim continue if times_delivered >= self._max_retries: # Move to DLQ await self._move_to_dlq( msg_id, reason="max_retries_exceeded", attempts=times_delivered, ) else: # Claim for reprocessing await client.xclaim( self._stream_name, self._consumer_group, self._consumer_name, self._claim_timeout_ms, [msg_id], ) claimed += 1 return claimed async def _move_to_dlq( self, msg_id: str, reason: str, attempts: int, ) -> None: """Move failed message to dead-letter queue.""" client = await self._get_client() # Get original message messages = await client.xrange(self._stream_name, msg_id, msg_id) if messages: _, data = messages[0] # Add to DLQ with metadata await client.xadd( self._dlq_stream, { **data, "original_stream": self._stream_name, "original_id": msg_id, "failure_reason": reason, "attempts": str(attempts), "dlq_timestamp": str(time.time()), }, ) # Acknowledge original (remove from pending) await client.xack(self._stream_name, self._consumer_group, msg_id) async def get_dlq_messages( self, count: int = 100, ) -> list[dict[str, Any]]: """ Get messages from dead-letter queue for inspection. Args: count: Maximum number of messages to retrieve Returns: List of DLQ messages with id and data """ client = await self._get_client() messages = await client.xrange(self._dlq_stream, "-", "+", count=count) return [{"id": mid, **data} for mid, data in messages] async def replay_dlq_message( self, dlq_message_id: str, ) -> str: """ Replay a DLQ message back to main stream. Call this after fixing the root cause of the failure. Args: dlq_message_id: Message ID in the DLQ Returns: New message ID in main stream Raises: ValueError: If DLQ message not found """ client = await self._get_client() # Get DLQ message messages = await client.xrange(self._dlq_stream, dlq_message_id, dlq_message_id) if not messages: raise ValueError(f"DLQ message not found: {dlq_message_id}") _, data = messages[0] # Remove DLQ metadata fields dlq_fields = { "original_stream", "original_id", "failure_reason", "attempts", "dlq_timestamp", } replay_data = {k: v for k, v in data.items() if k not in dlq_fields} replay_data["replayed_from_dlq"] = dlq_message_id # Add back to main stream new_id = await client.xadd(self._stream_name, replay_data) # Remove from DLQ await client.xdel(self._dlq_stream, dlq_message_id) return str(new_id) async def get_health_stats(self) -> dict[str, Any]: """ Get processor health statistics. Returns: Dictionary with: - stream_length: Total messages in stream - pending_messages: Messages awaiting processing - consumers: Number of consumers in group - dlq_length: Messages in dead-letter queue - last_delivered_id: Last delivered message ID """ client = await self._get_client() try: # Stream info stream_info = await client.xinfo_stream(self._stream_name) # Group info groups = await client.xinfo_groups(self._stream_name) group_info: dict[str, Any] = next( (g for g in groups if g["name"] == self._consumer_group), {}, ) # DLQ count dlq_len = await client.xlen(self._dlq_stream) return { "stream_length": stream_info.get("length", 0), "pending_messages": group_info.get("pending", 0), "consumers": group_info.get("consumers", 0), "dlq_length": dlq_len, "last_delivered_id": group_info.get("last-delivered-id"), } except Exception as exc: logger.debug("get_health_stats failed (stream may not exist): %s", exc) return { "stream_length": 0, "pending_messages": 0, "consumers": 0, "dlq_length": 0, "last_delivered_id": None, } async def close(self) -> None: """Close the Redis connection.""" if self._client: await self._client.aclose() self._client = None self._initialized = False