Source code for redis_openai_agents.middleware.semantic_router
"""SemanticRouterMiddleware - short-circuit model calls by matched intent.
Routes the user input through a :class:`SemanticRouter`. If the input
matches a known route with a configured canned response, that response is
returned immediately and the LLM call is skipped. Otherwise, the request
is forwarded to the inner model.
"""
from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any
from ._response import is_model_response, text_response
from ._utils import extract_user_text
from .base import ModelCallHandler, ModelRequest
if TYPE_CHECKING:
from ..semantic_router import RouteMatch, SemanticRouter
ResponseFactory = Callable[["RouteMatch"], Any]
[docs]
class SemanticRouterMiddleware:
"""Short-circuit the LLM call when the input matches a known intent.
Args:
router: A :class:`SemanticRouter` configured with the intents to
recognize.
responses: Optional mapping from route name to canned response.
Used when ``response_factory`` is not provided.
response_factory: Optional callable that receives the
:class:`RouteMatch` and returns the response. Takes
precedence over ``responses`` when supplied.
Either ``responses`` or ``response_factory`` must yield a value for
the short-circuit to trigger. If neither produces a response (for
instance, an unmapped route name), the request delegates to the
inner model.
"""
[docs]
def __init__(
self,
router: SemanticRouter,
*,
responses: Mapping[str, Any] | None = None,
response_factory: ResponseFactory | None = None,
auto_wrap: bool = False,
) -> None:
self._router = router
self._responses = dict(responses) if responses else {}
self._response_factory = response_factory
self._auto_wrap = auto_wrap
async def awrap_model_call(self, request: ModelRequest, handler: ModelCallHandler) -> Any:
statement = self._extract_statement(request)
if not statement:
return await handler(request)
match = await self._router(statement)
if match.name is None:
return await handler(request)
response = self._resolve_response(match)
if response is _SENTINEL:
return await handler(request)
# Plain strings are auto-wrapped so the Runner receives a real
# ModelResponse. Pre-built responses are passed through.
if isinstance(response, str):
return text_response(response)
if not is_model_response(response) and self._auto_wrap:
return text_response(str(response))
return response
@staticmethod
def _extract_statement(request: ModelRequest) -> str:
"""Pick the most recent user text out of the request input."""
return extract_user_text(request.input, fallback_to_last=True)
def _resolve_response(self, match: RouteMatch) -> Any:
if self._response_factory is not None:
try:
return self._response_factory(match)
except Exception:
return _SENTINEL
if match.name in self._responses:
return self._responses[match.name]
return _SENTINEL
# Sentinel distinguishing "no canned response" from a user-supplied None.
_SENTINEL: Any = object()