Source code for redis_openai_agents.tracing

"""RedisTracingProcessor - Store OpenAI Agents SDK traces in Redis.

This module provides a tracing processor that stores trace and span data
in Redis for debugging, analysis, and replay capabilities.

Features:
- Buffered writes for performance
- Redis Streams for replay capability
- Redis Hash for quick trace lookup
- Parent-child span relationship tracking

Example:
    >>> from redis_openai_agents import RedisTracingProcessor
    >>> from agents.tracing import setup_tracing
    >>>
    >>> # Create processor
    >>> processor = RedisTracingProcessor(redis_url="redis://localhost:6379")
    >>> await processor.initialize()
    >>>
    >>> # Register with OpenAI Agents SDK
    >>> setup_tracing(processors=[processor])
    >>>
    >>> # Query traces later
    >>> traces = await processor.list_traces(limit=10)
    >>> spans = await processor.get_spans(trace_id="trace_123")
"""

import asyncio
import json
import logging
import time
from typing import Any

import redis.asyncio as redis

logger = logging.getLogger(__name__)


[docs] class RedisTracingProcessor: """Stores OpenAI Agents SDK traces in Redis for observability. This processor implements the TracingProcessor interface from the OpenAI Agents SDK to capture trace and span lifecycle events. Traces are stored in: - Redis Streams for replay capability and time-series access - Redis Hashes for quick trace/span lookup by ID Attributes: redis_url: Redis connection URL. stream_name: Name of the Redis Stream for storing events. buffer_size: Number of events to buffer before flushing. """
[docs] def __init__( self, redis_url: str = "redis://localhost:6379", stream_name: str = "agent_traces", buffer_size: int = 100, trace_ttl: int = 86400 * 7, # 7 days default TTL ) -> None: """Initialize the tracing processor. Args: redis_url: Redis connection URL. stream_name: Name of the Redis Stream for events. buffer_size: Number of events to buffer before auto-flush. trace_ttl: TTL in seconds for trace data (default 7 days). """ self._redis_url = redis_url self._stream_name = stream_name self._buffer_size = buffer_size self._trace_ttl = trace_ttl self._client: redis.Redis | None = None self._buffer: list[dict[str, Any]] = [] self._initialized = False
async def initialize(self) -> None: """Initialize Redis connection. Must be called before using the processor. """ self._client = redis.from_url(self._redis_url, decode_responses=True) self._initialized = True async def close(self) -> None: """Close Redis connection.""" if self._client: await self._client.aclose() self._client = None self._initialized = False def on_trace_start(self, trace: Any) -> None: """Called when a trace begins. Args: trace: The Trace object from OpenAI Agents SDK. """ event = { "event_type": "trace_start", "trace_id": getattr(trace, "trace_id", str(id(trace))), "name": getattr(trace, "name", "unknown"), "started_at": getattr(trace, "started_at", time.time()), "timestamp": time.time(), } self._buffer.append(event) self._maybe_flush() def on_trace_end(self, trace: Any) -> None: """Called when a trace completes. Args: trace: The Trace object from OpenAI Agents SDK. """ event = { "event_type": "trace_end", "trace_id": getattr(trace, "trace_id", str(id(trace))), "name": getattr(trace, "name", "unknown"), "completed_at": getattr(trace, "completed_at", time.time()), "error": getattr(trace, "error", None), "timestamp": time.time(), } self._buffer.append(event) self._maybe_flush() def on_span_start(self, span: Any) -> None: """Called when a span begins. Args: span: The Span object from OpenAI Agents SDK. """ span_data = getattr(span, "span_data", None) span_type = getattr(span_data, "type", "unknown") if span_data else "unknown" event = { "event_type": "span_start", "trace_id": getattr(span, "trace_id", "unknown"), "span_id": getattr(span, "span_id", str(id(span))), "parent_id": getattr(span, "parent_id", None), "name": getattr(span, "name", "unknown"), "span_type": span_type, "started_at": getattr(span, "started_at", time.time()), "timestamp": time.time(), } self._buffer.append(event) self._maybe_flush() def on_span_end(self, span: Any) -> None: """Called when a span completes. Args: span: The Span object from OpenAI Agents SDK. """ span_data = getattr(span, "span_data", None) # Export span data if available exported_data = {} if span_data and hasattr(span_data, "export"): try: exported_data = span_data.export() except Exception as exc: logger.debug("span_data.export() failed: %s", exc) event = { "event_type": "span_end", "trace_id": getattr(span, "trace_id", "unknown"), "span_id": getattr(span, "span_id", str(id(span))), "parent_id": getattr(span, "parent_id", None), "name": getattr(span, "name", "unknown"), "span_type": exported_data.get("type", "unknown"), "finished_at": getattr(span, "finished_at", time.time()), "error": getattr(span, "error", None), "span_data": exported_data, "timestamp": time.time(), } self._buffer.append(event) self._maybe_flush() def shutdown(self) -> None: """Called on application shutdown. Flushes any buffered events before shutdown. """ self.force_flush() def force_flush(self) -> None: """Force immediate flushing of buffered events. This method blocks until all buffered events are written. For use in async contexts, use aforce_flush() instead. """ if not self._buffer or not self._client: return # Use asyncio to run the async flush try: loop = asyncio.get_event_loop() if loop.is_running(): # We're in an async context - this won't work well # The sync flush in async context is best-effort # Use aforce_flush() for reliable async flushing import threading # Create a new event loop in a thread for sync behavior result_event = threading.Event() exception_holder: list = [] def run_in_thread() -> None: try: new_loop = asyncio.new_event_loop() asyncio.set_event_loop(new_loop) try: new_loop.run_until_complete(self._flush_async_direct()) finally: new_loop.close() except Exception as e: exception_holder.append(e) finally: result_event.set() thread = threading.Thread(target=run_in_thread) thread.start() result_event.wait(timeout=5.0) thread.join(timeout=1.0) if exception_holder: raise exception_holder[0] else: loop.run_until_complete(self._flush_async()) except RuntimeError: # No event loop, create one asyncio.run(self._flush_async()) async def aforce_flush(self) -> None: """Async version of force_flush(). Force immediate flushing of buffered events. Use this in async contexts for reliable flushing. """ await self._flush_async() def _maybe_flush(self) -> None: """Flush buffer if it's full.""" if len(self._buffer) >= self._buffer_size: # Schedule async flush - don't block try: loop = asyncio.get_event_loop() if loop.is_running(): asyncio.create_task(self._flush_async()) else: self.force_flush() except RuntimeError: pass def _build_flush_pipeline(self, pipe: Any, events: list[dict[str, Any]]) -> None: """Populate a Redis pipeline with flush commands for the given events.""" for event in events: event_type = event.get("event_type", "unknown") trace_id = event.get("trace_id", "unknown") span_id = event.get("span_id") # Store in stream for replay stream_data = { k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in event.items() } pipe.xadd(self._stream_name, stream_data) # Store trace metadata in hash for quick lookup if event_type in ("trace_start", "trace_end"): hash_key = f"trace:{trace_id}" pipe.hset( hash_key, mapping={ "trace_id": trace_id, "name": event.get("name", ""), "started_at": str(event.get("started_at", "")), "completed_at": str(event.get("completed_at", "")), "error": str(event.get("error", "")), "status": "completed" if event_type == "trace_end" else "running", }, ) pipe.expire(hash_key, self._trace_ttl) # Store span in trace's span list if span_id and event_type in ("span_start", "span_end"): span_key = f"trace:{trace_id}:spans" span_data = json.dumps( { "span_id": span_id, "parent_id": event.get("parent_id"), "name": event.get("name", ""), "span_type": event.get("span_type", ""), "started_at": event.get("started_at"), "finished_at": event.get("finished_at"), "error": event.get("error"), "span_data": event.get("span_data", {}), "status": "completed" if event_type == "span_end" else "running", } ) pipe.hset(span_key, span_id, span_data) pipe.expire(span_key, self._trace_ttl) async def _flush_async_direct(self) -> None: """Flush using a fresh Redis connection (for threaded contexts).""" if not self._buffer: return events = self._buffer.copy() self._buffer.clear() client = redis.from_url(self._redis_url, decode_responses=True) try: pipe = client.pipeline() self._build_flush_pipeline(pipe, events) await pipe.execute() finally: await client.aclose() async def _flush_async(self) -> None: """Async implementation of buffer flush.""" if not self._buffer or not self._client: return events = self._buffer.copy() self._buffer.clear() pipe = self._client.pipeline() self._build_flush_pipeline(pipe, events) await pipe.execute() async def get_trace(self, trace_id: str) -> dict[str, Any] | None: """Get trace data by ID. Args: trace_id: The trace ID to retrieve. Returns: Trace data dictionary or None if not found. """ if not self._client: return None hash_key = f"trace:{trace_id}" data_result = await self._client.hgetall(hash_key) # type: ignore[misc] data: dict[str, str] = data_result if isinstance(data_result, dict) else {} if not data: return None return { "trace_id": data.get("trace_id", trace_id), "name": data.get("name", ""), "started_at": float(data.get("started_at", 0)) if data.get("started_at") else None, "completed_at": float(data.get("completed_at", 0)) if data.get("completed_at") and data.get("completed_at") != "None" else None, "error": data.get("error") if data.get("error") != "None" else None, "status": data.get("status", "unknown"), } async def get_spans(self, trace_id: str) -> list[dict[str, Any]]: """Get all spans for a trace. Args: trace_id: The trace ID to get spans for. Returns: List of span data dictionaries. """ if not self._client: return [] span_key = f"trace:{trace_id}:spans" data_result = await self._client.hgetall(span_key) # type: ignore[misc] data: dict[str, str] = data_result if isinstance(data_result, dict) else {} if not data: return [] spans = [] for _span_id, span_json in data.items(): try: span_data = json.loads(span_json) spans.append(span_data) except json.JSONDecodeError: continue return spans async def list_traces( self, limit: int = 100, name_filter: str | None = None, ) -> list[dict[str, Any]]: """List recent traces. Args: limit: Maximum number of traces to return. name_filter: Optional filter by trace name substring. Returns: List of trace data dictionaries. """ if not self._client: return [] # Scan for trace keys traces = [] cursor = 0 while True: cursor, keys = await self._client.scan( cursor=cursor, match="trace:*", count=100, ) for key in keys: # Skip span keys if ":spans" in key: continue trace_id = key.replace("trace:", "") trace_data = await self.get_trace(trace_id) if trace_data: # Apply name filter if name_filter: if name_filter.lower() not in trace_data.get("name", "").lower(): continue traces.append(trace_data) if len(traces) >= limit: break if cursor == 0 or len(traces) >= limit: break # Sort by started_at descending (most recent first) traces.sort(key=lambda x: x.get("started_at", 0) or 0, reverse=True) return traces[:limit] async def get_stream_length(self) -> int: """Get the length of the trace stream. Returns: Number of events in the stream. """ if not self._client: return 0 return int(await self._client.xlen(self._stream_name)) async def trim_stream(self, max_length: int = 10000) -> int: """Trim the trace stream to a maximum length. Args: max_length: Maximum number of events to keep. Returns: Number of events trimmed. """ if not self._client: return 0 current_length = await self.get_stream_length() if current_length <= max_length: return 0 # Trim using XTRIM MAXLEN await self._client.xtrim(self._stream_name, maxlen=max_length, approximate=True) return current_length - max_length