"""AgentCoordinator - Streams-based coordination for distributed agents.
This module provides real-time coordination between distributed agent instances
using Redis Streams with consumer groups.
Key Features:
- Real-time handoff notifications
- Tool result broadcasting to multiple consumers
- State synchronization across replicas
- Crash recovery via XCLAIM for pending messages
- Consumer group support for work distribution
"""
from __future__ import annotations
import json
import logging
import time
from collections.abc import AsyncIterator
from enum import Enum
from typing import TYPE_CHECKING, Any
from uuid import uuid4
from redis import asyncio as aioredis
from redis.exceptions import ResponseError
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.asyncio import Redis
[docs]
class EventType(str, Enum):
"""Standard event types for agent coordination."""
HANDOFF_READY = "handoff_ready"
TOOL_RESULT = "tool_result"
STATE_CHANGED = "state_changed"
AGENT_STARTED = "agent_started"
AGENT_COMPLETED = "agent_completed"
ERROR = "error"
[docs]
class AgentCoordinator:
"""
Streams-based coordination for distributed agents.
Enables:
- Real-time handoff notifications (low latency vs polling)
- Tool result broadcasting to multiple consumers
- State synchronization across replicas
- Crash recovery via pending message claiming
Example:
>>> coordinator = AgentCoordinator(
... redis_url="redis://localhost:6379",
... stream_name="agent_events",
... consumer_group="workers",
... )
>>> await coordinator.initialize()
>>> await coordinator.publish_handoff_ready(
... from_agent="research",
... to_agent="analysis",
... session_id="sess_123",
... context={"data": "value"},
... )
"""
[docs]
def __init__(
self,
redis_url: str = "redis://localhost:6379",
stream_name: str = "agent_events",
consumer_group: str | None = None,
consumer_name: str | None = None,
) -> None:
"""
Initialize AgentCoordinator.
Args:
redis_url: Redis connection URL
stream_name: Name of the Redis Stream
consumer_group: Consumer group name (required for subscribing)
consumer_name: Consumer name within group (auto-generated if not provided)
"""
self._redis_url = redis_url
self._stream_name = stream_name
self._consumer_group = consumer_group
self._consumer_name = consumer_name or (
f"consumer_{uuid4().hex[:8]}" if consumer_group else None
)
self._client: Redis | None = None
self._initialized = False
async def initialize(self) -> None:
"""
Initialize the coordinator and create consumer group if needed.
Must be called before using other methods.
"""
if self._initialized:
return
self._client = aioredis.from_url(self._redis_url, decode_responses=True)
# Create consumer group if specified
if self._consumer_group:
try:
await self._client.xgroup_create(
self._stream_name,
self._consumer_group,
id="0",
mkstream=True,
)
except ResponseError as exc:
if "BUSYGROUP" not in str(exc):
raise
# Group already exists
self._initialized = True
async def _get_client(self) -> Redis:
"""Get Redis client, ensuring initialization."""
if not self._initialized or self._client is None:
await self.initialize()
return self._client # type: ignore[return-value]
async def _publish(self, event_data: dict[str, Any]) -> str:
"""
Publish event to stream.
Args:
event_data: Event data dictionary
Returns:
Message ID
"""
client = await self._get_client()
# Serialize complex values to JSON strings. Drop keys whose value is
# None, since XADD rejects NoneType fields.
serialized = {}
for key, value in event_data.items():
if value is None:
continue
if isinstance(value, (dict, list)):
serialized[key] = json.dumps(value)
elif isinstance(value, (int, float)):
serialized[key] = str(value)
else:
serialized[key] = value
result = await client.xadd(self._stream_name, serialized) # type: ignore[arg-type]
return str(result)
async def publish_handoff_ready(
self,
from_agent: str,
to_agent: str,
session_id: str,
context: dict[str, Any],
) -> str:
"""
Notify target agent that handoff is ready.
Args:
from_agent: Name of agent initiating handoff
to_agent: Name of target agent
session_id: Session identifier
context: Handoff context data
Returns:
Message ID for tracking
"""
return await self._publish(
{
"type": EventType.HANDOFF_READY.value,
"from_agent": from_agent,
"to_agent": to_agent,
"session_id": session_id,
"context": context,
"timestamp": time.time(),
}
)
async def publish_tool_result(
self,
tool_name: str,
session_id: str,
result: Any,
execution_time_ms: float,
) -> str:
"""
Broadcast tool completion to all interested consumers.
Args:
tool_name: Name of the tool
session_id: Session identifier
result: Tool execution result
execution_time_ms: Execution time in milliseconds
Returns:
Message ID
"""
return await self._publish(
{
"type": EventType.TOOL_RESULT.value,
"tool_name": tool_name,
"session_id": session_id,
"result": result,
"execution_time_ms": execution_time_ms,
"timestamp": time.time(),
}
)
async def publish_state_changed(
self,
session_id: str,
changes: dict[str, Any],
) -> str:
"""
Notify about session state changes.
Args:
session_id: Session identifier
changes: Dictionary of changed fields
Returns:
Message ID
"""
return await self._publish(
{
"type": EventType.STATE_CHANGED.value,
"session_id": session_id,
"changes": changes,
"timestamp": time.time(),
}
)
async def publish_agent_started(
self,
agent_name: str,
session_id: str,
input_summary: str,
) -> str:
"""
Notify that an agent started processing.
Args:
agent_name: Name of the agent
session_id: Session identifier
input_summary: Summary of input
Returns:
Message ID
"""
return await self._publish(
{
"type": EventType.AGENT_STARTED.value,
"agent_name": agent_name,
"session_id": session_id,
"input_summary": input_summary,
"timestamp": time.time(),
}
)
async def publish_agent_completed(
self,
agent_name: str,
session_id: str,
output_summary: str,
duration_ms: float,
tokens_used: int,
) -> str:
"""
Notify that an agent completed processing.
Args:
agent_name: Name of the agent
session_id: Session identifier
output_summary: Summary of output
duration_ms: Processing duration
tokens_used: Number of tokens used
Returns:
Message ID
"""
return await self._publish(
{
"type": EventType.AGENT_COMPLETED.value,
"agent_name": agent_name,
"session_id": session_id,
"output_summary": output_summary,
"duration_ms": duration_ms,
"tokens_used": tokens_used,
"timestamp": time.time(),
}
)
async def publish_error(
self,
session_id: str,
error_type: str,
error_message: str,
agent_name: str | None = None,
) -> str:
"""
Publish error event.
Args:
session_id: Session identifier
error_type: Type of error
error_message: Error message
agent_name: Optional agent name if error is agent-specific
Returns:
Message ID
"""
return await self._publish(
{
"type": EventType.ERROR.value,
"session_id": session_id,
"error_type": error_type,
"error_message": error_message,
"agent_name": agent_name,
"timestamp": time.time(),
}
)
def _parse_event(self, data: dict[str, str]) -> dict[str, Any]:
"""Parse event data, deserializing JSON fields."""
result = {}
for key, value in data.items():
# Try to parse JSON strings
if isinstance(value, str):
try:
result[key] = json.loads(value)
except json.JSONDecodeError:
result[key] = value
else:
result[key] = value
return result
async def subscribe(
self,
event_types: list[str] | None = None,
timeout_ms: int = 5000,
max_events: int | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""
Subscribe to coordination events.
Args:
event_types: Filter to specific types (None = all)
timeout_ms: Block timeout in milliseconds
max_events: Maximum events to yield (None = unlimited)
Yields:
Event dictionaries with automatic acknowledgment
"""
if not self._consumer_group:
raise ValueError("Consumer group required for subscription")
client = await self._get_client()
events_yielded = 0
while True:
if max_events and events_yielded >= max_events:
break
messages = await client.xreadgroup(
self._consumer_group,
self._consumer_name or "",
{self._stream_name: ">"},
count=10,
block=timeout_ms,
)
if not messages:
if max_events:
break
continue
for _stream_name, events in messages:
for msg_id, data in events:
event = self._parse_event(data)
event["_msg_id"] = msg_id
# Filter by type if specified
if event_types and event.get("type") not in event_types:
await client.xack(self._stream_name, self._consumer_group, msg_id)
continue
yield event
# Acknowledge after yielding
await client.xack(self._stream_name, self._consumer_group, msg_id)
events_yielded += 1
if max_events and events_yielded >= max_events:
return
async def claim_abandoned_messages(
self,
min_idle_ms: int = 300000, # 5 minutes
) -> list[dict[str, Any]]:
"""
Claim messages from crashed consumers.
Call periodically to recover from worker failures.
Args:
min_idle_ms: Minimum idle time to consider message abandoned
Returns:
List of claimed event dictionaries
"""
if not self._consumer_group:
raise ValueError("Consumer group required for claiming")
client = await self._get_client()
# Get pending messages
pending = await client.xpending_range(
self._stream_name,
self._consumer_group,
min="-",
max="+",
count=100,
)
# Filter to old messages only
old_messages = [p for p in pending if p.get("time_since_delivered", 0) > min_idle_ms]
if not old_messages:
return []
# Claim them for this consumer
msg_ids = [p["message_id"] for p in old_messages]
claimed = await client.xclaim(
self._stream_name,
self._consumer_group,
self._consumer_name or "",
min_idle_time=min_idle_ms,
message_ids=msg_ids,
)
return [self._parse_event(data) for _, data in claimed]
async def get_stream_info(self) -> dict[str, Any]:
"""
Get stream statistics.
Returns:
Dictionary with stream length, groups, etc.
"""
client = await self._get_client()
try:
stream_info = await client.xinfo_stream(self._stream_name)
groups = await client.xinfo_groups(self._stream_name)
except Exception as exc:
logger.debug("get_stream_info failed (stream may not exist): %s", exc)
return {"length": 0, "groups": []}
return {
"length": stream_info.get("length", 0),
"first_entry": stream_info.get("first-entry"),
"last_entry": stream_info.get("last-entry"),
"groups": groups,
}
async def trim_stream(
self,
max_length: int,
approximate: bool = True,
) -> int:
"""
Trim stream to maximum length.
Args:
max_length: Maximum entries to keep
approximate: Use approximate trimming (more efficient)
Returns:
Number of entries trimmed
"""
client = await self._get_client()
result = await client.xtrim(
self._stream_name,
maxlen=max_length,
approximate=approximate,
)
return int(result)
async def close(self) -> None:
"""Close the Redis connection."""
if self._client:
await self._client.aclose()
self._client = None
self._initialized = False