Middleware Composition with LangChain Agents#
This notebook demonstrates how to compose multiple Redis middleware together and use them with LangChain agents using the standard create_agent pattern.
Key Features#
Combine multiple middleware: Stack caching, memory, and routing
MiddlewareStack: Compose middleware into a single unit with automatic request sanitization
Connection sharing: Share Redis connections with checkpointers
Factory functions: Quick setup with
create_caching_stackResponses API safety: Automatic stripping of provider-specific IDs from cached content
Request Sanitization#
When using the Responses API (ChatOpenAI(use_responses_api=True)), AIMessage content
is a list of blocks with embedded provider IDs (e.g., rs_...). The MiddlewareStack
automatically sanitizes requests before they reach the LLM handler, stripping these IDs
to prevent duplicate ID errors across turns.
Prerequisites#
Redis 8.0+ or Redis Stack
OpenAI API key
Note on Async Usage#
The Redis middleware uses async methods internally. When using with create_agent, you must use await agent.ainvoke() rather than agent.invoke().
Setup#
Install required packages and set API keys.
%%capture --no-stderr
# When running via docker-compose, the local library is already installed via editable mount.
# Only install from PyPI if not already available.
try:
import langgraph.middleware.redis
print("langgraph-checkpoint-redis already installed")
except ImportError:
%pip install -U langgraph-checkpoint-redis
%pip install -U langchain langchain-openai sentence-transformers
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
REDIS_URL = os.environ.get("REDIS_URL", "redis://redis:6379")
Two-Model Setup and Tools#
We demonstrate composition patterns with both API modes.
import uuid
from langchain_openai import ChatOpenAI
# Default mode: content is a plain string
model_default = ChatOpenAI(model="gpt-4o-mini")
# Responses API mode: content is a list of blocks with embedded IDs
# Used by Azure OpenAI and advanced features (reasoning, annotations)
model_responses_api = ChatOpenAI(model="gpt-4o-mini", use_responses_api=True)
print("Models created:")
print("- model_default: Chat Completions (string content)")
print("- model_responses_api: Responses API (list-of-blocks content)")
Models created:
- model_default: Chat Completions (string content)
- model_responses_api: Responses API (list-of-blocks content)
def format_content(content, max_len=200):
"""Extract readable text from AI message content (handles both API modes)."""
if isinstance(content, str):
text = content
elif isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict):
parts.append(block.get("text", ""))
elif isinstance(block, str):
parts.append(block)
text = " ".join(parts)
else:
text = str(content)
if max_len and len(text) > max_len:
return text[:max_len] + "..."
return text
def inspect_response(result, label=""):
"""Show the structure and content of an AI response."""
ai_msg = result["messages"][-1]
print(f"\n--- {label} ---")
print(f"Content type: {type(ai_msg.content).__name__}")
if isinstance(ai_msg.content, list):
print(f"Number of content blocks: {len(ai_msg.content)}")
for i, block in enumerate(ai_msg.content):
if isinstance(block, dict):
print(f" Block {i}: type={block.get('type')}, has_id={'id' in block}")
print(f"Response: {format_content(ai_msg.content)}")
cached = ai_msg.additional_kwargs.get("cached", False)
print(f"Cached: {cached}")
import ast
import operator as op
import time
from langchain_core.tools import tool
# Safe math evaluator - no arbitrary code execution
SAFE_OPS = {
ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul,
ast.Div: op.truediv, ast.Pow: op.pow, ast.USub: op.neg,
}
def _eval_node(node):
if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.BinOp) and type(node.op) in SAFE_OPS:
return SAFE_OPS[type(node.op)](_eval_node(node.left), _eval_node(node.right))
elif isinstance(node, ast.UnaryOp) and type(node.op) in SAFE_OPS:
return SAFE_OPS[type(node.op)](_eval_node(node.operand))
raise ValueError("Unsupported expression")
def safe_eval(expr: str) -> float:
return _eval_node(ast.parse(expr, mode='eval').body)
# Track executions
tool_calls = {"search": 0, "calculate": 0}
@tool
def search(query: str) -> str:
"""Search the web for information."""
tool_calls["search"] += 1
time.sleep(0.5) # Simulate API call
return f"Search results for '{query}': Found relevant information."
@tool
def calculate(expression: str) -> str:
"""Evaluate a mathematical expression."""
tool_calls["calculate"] += 1
try:
result = safe_eval(expression)
return f"{expression} = {result}"
except Exception as e:
return f"Error: {str(e)}"
tools = [search, calculate]
print("Tools defined: search, calculate")
Tools defined: search, calculate
Using Multiple Middleware with create_agent#
You can pass multiple middleware directly to create_agent. They are applied in order.
from langchain.agents import create_agent
from langgraph.middleware.redis import (
SemanticCacheConfig,
SemanticCacheMiddleware,
ToolCacheConfig,
ToolResultCacheMiddleware,
)
# Define which tools are deterministic (same input = same output)
DETERMINISTIC_TOOLS = ["search", "calculate"]
# Unique cache names
llm_cache_name = f"composition_llm_cache_{uuid.uuid4().hex[:8]}"
tool_cache_name = f"composition_tool_cache_{uuid.uuid4().hex[:8]}"
# Create semantic cache for LLM responses
semantic_cache = SemanticCacheMiddleware(
SemanticCacheConfig(
redis_url=REDIS_URL,
name=llm_cache_name,
ttl_seconds=3600,
deterministic_tools=DETERMINISTIC_TOOLS,
)
)
# Create tool cache for tool results
tool_cache = ToolResultCacheMiddleware(
ToolCacheConfig(
redis_url=REDIS_URL,
name=tool_cache_name,
cacheable_tools=DETERMINISTIC_TOOLS,
ttl_seconds=1800,
)
)
print("Created coordinated middleware:")
print(f"- Deterministic tools: {DETERMINISTIC_TOOLS}")
print("- SemanticCacheMiddleware: caches LLM responses")
print("- ToolResultCacheMiddleware: caches tool results")
print("\nBoth middlewares are aware of which tools are safe to cache!")
Created coordinated middleware:
- Deterministic tools: ['search', 'calculate']
- SemanticCacheMiddleware: caches LLM responses
- ToolResultCacheMiddleware: caches tool results
Both middlewares are aware of which tools are safe to cache!
# Create agent with multiple middleware
agent = create_agent(
model=model_default,
tools=tools,
middleware=[semantic_cache, tool_cache], # Multiple middleware!
)
print("Agent created with both SemanticCache and ToolCache middleware!")
Agent created with both SemanticCache and ToolCache middleware!
from langchain_core.messages import HumanMessage
# Reset counters
tool_calls = {"search": 0, "calculate": 0}
print("Test 1: Search query")
print("=" * 50)
result1 = await agent.ainvoke({"messages": [HumanMessage(content="Search for Python tutorials")]})
print(f"Response: {result1['messages'][-1].content[:100]}...")
print(f"Tool calls: {tool_calls}")
print("\nTest 2: Similar search query (should hit cache)")
print("=" * 50)
result2 = await agent.ainvoke({"messages": [HumanMessage(content="Find Python tutorials online")]})
print(f"Response: {result2['messages'][-1].content[:100]}...")
print(f"Tool calls: {tool_calls}")
print("Note: tool_calls should not increase if cache hit!")
Test 1: Search query
==================================================
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
This vectorizer has no async embed method. Falling back to sync.
This vectorizer has no async embed method. Falling back to sync.
This vectorizer has no async embed method. Falling back to sync.
This vectorizer has no async embed method. Falling back to sync.
Response: I found some relevant information on Python tutorials. Here are a few resources that might help you:...
Tool calls: {'search': 1, 'calculate': 0}
Test 2: Similar search query (should hit cache)
==================================================
Response: I found some relevant information on Python tutorials. Here are a few resources that might help you:...
Tool calls: {'search': 1, 'calculate': 0}
Note: tool_calls should not increase if cache hit!
Using MiddlewareStack#
The MiddlewareStack class lets you compose multiple middleware into a single unit that
can be passed to create_agent.
Important: MiddlewareStack automatically sanitizes requests before they reach the
LLM handler. This includes stripping provider-specific IDs from Responses API content
blocks, preventing duplicate ID errors in multi-turn conversations.
from langgraph.middleware.redis import (
ConversationMemoryConfig,
ConversationMemoryMiddleware,
MiddlewareStack,
)
stack_cache_name = f"stack_llm_cache_{uuid.uuid4().hex[:8]}"
stack_memory_name = f"stack_memory_{uuid.uuid4().hex[:8]}"
# Create a stack with cache + memory
stack = MiddlewareStack(
[
SemanticCacheMiddleware(
SemanticCacheConfig(
redis_url=REDIS_URL,
name=stack_cache_name,
ttl_seconds=3600,
)
),
ConversationMemoryMiddleware(
ConversationMemoryConfig(
redis_url=REDIS_URL,
name=stack_memory_name,
session_tag="stack_demo",
top_k=3,
)
),
]
)
print("MiddlewareStack created with:")
print("- SemanticCacheMiddleware")
print("- ConversationMemoryMiddleware")
print("\nThe stack automatically sanitizes AIMessages before LLM calls.")
MiddlewareStack created with:
- SemanticCacheMiddleware
- ConversationMemoryMiddleware
The stack automatically sanitizes AIMessages before LLM calls.
# Create agent with the stack (stack is also an AgentMiddleware!)
agent_with_stack = create_agent(
model=model_default,
tools=tools,
middleware=[stack], # Pass the stack as a single middleware
)
print("Agent created with MiddlewareStack!")
# Test it
result = await agent_with_stack.ainvoke({"messages": [HumanMessage(content="Hi, I'm testing the middleware stack!")]})
print(f"Response: {result['messages'][-1].content}")
Agent created with MiddlewareStack!
This vectorizer has no async embed method. Falling back to sync.
MPNetModel LOAD REPORT from: sentence-transformers/all-mpnet-base-v2
Key | Status | |
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED | |
Notes:
- UNEXPECTED :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
This vectorizer has no async embed method. Falling back to sync.
Response: Hello! It sounds like you're working on something interesting. How can I assist you with your middleware stack testing?
MiddlewareStack with Responses API Mode#
When using the Responses API, the MiddlewareStack plays a critical safety role.
It sanitizes all AIMessages in the conversation state before they reach the LLM handler,
stripping provider-specific IDs (rs_..., msg_...) that would cause duplicate ID errors.
This is especially important in multi-turn conversations where state accumulates.
# Create a stack specifically for Responses API testing
resp_cache_name = f"resp_stack_cache_{uuid.uuid4().hex[:8]}"
resp_memory_name = f"resp_stack_memory_{uuid.uuid4().hex[:8]}"
responses_stack = MiddlewareStack(
[
SemanticCacheMiddleware(
SemanticCacheConfig(
redis_url=REDIS_URL,
name=resp_cache_name,
ttl_seconds=3600,
)
),
ConversationMemoryMiddleware(
ConversationMemoryConfig(
redis_url=REDIS_URL,
name=resp_memory_name,
session_tag="responses_demo",
top_k=3,
)
),
]
)
agent_responses_stack = create_agent(
model=model_responses_api,
tools=tools,
middleware=[responses_stack],
)
print("Agent created with MiddlewareStack + Responses API!")
# Multi-turn demo showing sanitization
print("\nTurn 1: 'Hello, I like Python programming'")
print("=" * 50)
result_r1 = await agent_responses_stack.ainvoke(
{"messages": [HumanMessage(content="Hello, I like Python programming")]}
)
inspect_response(result_r1, label="Turn 1")
print("\nTurn 2: 'What language did I mention?'")
print("=" * 50)
result_r2 = await agent_responses_stack.ainvoke(
{"messages": [HumanMessage(content="What language did I mention?")]}
)
inspect_response(result_r2, label="Turn 2")
# Verify no duplicate rs_ IDs in any response
all_ids = set()
for label, result in [("Turn 1", result_r1), ("Turn 2", result_r2)]:
ai_msg = result["messages"][-1]
if isinstance(ai_msg.content, list):
for block in ai_msg.content:
if isinstance(block, dict) and "id" in block:
block_id = block["id"]
assert block_id not in all_ids, f"Duplicate ID in {label}: {block_id}"
all_ids.add(block_id)
print("\nNo duplicate content block IDs across turns!")
print("MiddlewareStack sanitized all requests before they reached the LLM.")
Agent created with MiddlewareStack + Responses API!
Turn 1: 'Hello, I like Python programming'
==================================================
This vectorizer has no async embed method. Falling back to sync.
MPNetModel LOAD REPORT from: sentence-transformers/all-mpnet-base-v2
Key | Status | |
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED | |
Notes:
- UNEXPECTED :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
This vectorizer has no async embed method. Falling back to sync.
This vectorizer has no async embed method. Falling back to sync.
--- Turn 1 ---
Content type: list
Number of content blocks: 1
Block 0: type=text, has_id=True
Response: That's great to hear! Python is a versatile and powerful programming language. What specific areas of Python are you interested in? Are you working on any projects or learning particular frameworks or...
Cached: False
Turn 2: 'What language did I mention?'
==================================================
This vectorizer has no async embed method. Falling back to sync.
--- Turn 2 ---
Content type: list
Number of content blocks: 1
Block 0: type=text, has_id=True
Response: You mentioned that you like Python programming.
Cached: False
No duplicate content block IDs across turns!
MiddlewareStack sanitized all requests before they reached the LLM.
Multi-Turn with Checkpointer and Responses API#
This section demonstrates the most complex real-world scenario:
IntegratedRedisMiddlewareshares Redis connections withAsyncRedisSaverThe Responses API produces content blocks with provider IDs
State accumulates across turns via the checkpointer
The middleware stack sanitizes all AIMessages before each LLM call
This mirrors the exact customer scenario that exposed the duplicate ID bug (PR #170).
from langgraph.checkpoint.redis.aio import AsyncRedisSaver
from langgraph.middleware.redis import IntegratedRedisMiddleware
# Create async checkpointer
async_checkpointer = AsyncRedisSaver(redis_url=REDIS_URL)
await async_checkpointer.asetup()
# Create middleware stack that shares connection with checkpointer
integrated_cache_name = f"integrated_cache_{uuid.uuid4().hex[:8]}"
integrated_stack = IntegratedRedisMiddleware.from_saver(
async_checkpointer,
configs=[
SemanticCacheConfig(name=integrated_cache_name, ttl_seconds=3600),
],
)
print("Created IntegratedRedisMiddleware from AsyncRedisSaver!")
print(f"Number of middleware: {len(integrated_stack._middlewares)}")
Created IntegratedRedisMiddleware from AsyncRedisSaver!
Number of middleware: 1
# Create agent with both checkpointer and middleware
integrated_agent = create_agent(
model=model_responses_api,
tools=tools,
checkpointer=async_checkpointer,
middleware=[integrated_stack],
)
# Multi-turn conversation with state accumulation
thread_id = f"integrated_{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": thread_id}}
# Turn 1
print("Turn 1: 'What is the population of Tokyo?'")
print("=" * 50)
result_t1 = await integrated_agent.ainvoke(
{"messages": [HumanMessage(content="What is the population of Tokyo?")]},
config=config,
)
inspect_response(result_t1, label="Turn 1")
# Turn 2 - follow-up in same thread (state accumulates)
print("\nTurn 2: 'And what about New York?'")
print("=" * 50)
result_t2 = await integrated_agent.ainvoke(
{"messages": [HumanMessage(content="And what about New York?")]},
config=config,
)
inspect_response(result_t2, label="Turn 2")
# Turn 3 - another follow-up (state has 2 previous AI responses)
print("\nTurn 3: 'Which one is larger?'")
print("=" * 50)
result_t3 = await integrated_agent.ainvoke(
{"messages": [HumanMessage(content="Which one is larger?")]},
config=config,
)
inspect_response(result_t3, label="Turn 3")
print("\nAll 3 turns completed successfully with Responses API + Checkpointer!")
print("The MiddlewareStack sanitized AIMessages before each LLM call,")
print("preventing duplicate provider IDs from accumulating in state.")
Turn 1: 'What is the population of Tokyo?'
==================================================
This vectorizer has no async embed method. Falling back to sync.
This vectorizer has no async embed method. Falling back to sync.
--- Turn 1 ---
Content type: list
Number of content blocks: 1
Block 0: type=text, has_id=True
Response: As of 2023, the population of Tokyo is approximately 14 million within the city proper, while the greater Tokyo metropolitan area has a population of around 37 million, making it one of the most popul...
Cached: False
Turn 2: 'And what about New York?'
==================================================
This vectorizer has no async embed method. Falling back to sync.
--- Turn 2 ---
Content type: list
Number of content blocks: 1
Block 0: type=text, has_id=True
Response: As of 2023, the population of New York City is approximately 8.5 million people. The larger metropolitan area has a population of around 20 million.
Cached: False
Turn 3: 'Which one is larger?'
==================================================
This vectorizer has no async embed method. Falling back to sync.
--- Turn 3 ---
Content type: list
Number of content blocks: 1
Block 0: type=text, has_id=True
Response: Tokyo is larger, with a population of about 14 million in the city proper and approximately 37 million in the greater metropolitan area. In contrast, New York City has about 8.5 million people, with a...
Cached: False
All 3 turns completed successfully with Responses API + Checkpointer!
The MiddlewareStack sanitized AIMessages before each LLM call,
preventing duplicate provider IDs from accumulating in state.
Factory Functions#
For common patterns, use factory functions like create_caching_stack and from_configs.
from langgraph.middleware.redis import create_caching_stack
# Unique names
factory_llm_name = f"factory_llm_cache_{uuid.uuid4().hex[:8]}"
factory_tool_name = f"factory_tool_cache_{uuid.uuid4().hex[:8]}"
# Quick setup for caching both LLM and tool results
caching_stack = create_caching_stack(
redis_url=REDIS_URL,
semantic_cache_name=factory_llm_name,
semantic_cache_ttl=3600,
tool_cache_name=factory_tool_name,
tool_cache_ttl=1800,
cacheable_tools=["search", "calculate"],
)
print("Created caching stack with create_caching_stack()")
print(f"Number of middleware in stack: {len(caching_stack._middlewares)}")
Created caching stack with create_caching_stack()
Number of middleware in stack: 2
from langgraph.middleware.redis import from_configs
# Unique names
custom_llm_name = f"custom_llm_cache_{uuid.uuid4().hex[:8]}"
custom_tool_name = f"custom_tool_cache_{uuid.uuid4().hex[:8]}"
custom_memory_name = f"custom_memory_{uuid.uuid4().hex[:8]}"
# Create stack from config objects
custom_stack = from_configs(
configs=[
SemanticCacheConfig(
name=custom_llm_name,
distance_threshold=0.15,
ttl_seconds=3600,
),
ToolCacheConfig(
name=custom_tool_name,
cacheable_tools=["search"],
excluded_tools=["calculate"],
ttl_seconds=600,
),
ConversationMemoryConfig(
name=custom_memory_name,
session_tag="custom_session",
top_k=5,
),
],
redis_url=REDIS_URL,
)
print("Created custom stack with from_configs()")
print(f"Number of middleware in stack: {len(custom_stack._middlewares)}")
Created custom stack with from_configs()
Number of middleware in stack: 3
Summary#
Multiple middleware: Pass a list to
create_agent(middleware=[...])MiddlewareStack: Compose middleware into a single unit with automatic request sanitization
create_caching_stack(): Quick setup for LLM + tool caching
from_configs(): Create stack from config objects
IntegratedRedisMiddleware: Share connections with checkpointers
Responses API safety: MiddlewareStack strips provider IDs to prevent duplicates in multi-turn conversations
Cleanup#
# Close all middleware
await semantic_cache.aclose()
await tool_cache.aclose()
await stack.aclose()
await responses_stack.aclose()
await caching_stack.aclose()
await custom_stack.aclose()
# Close the async checkpointer
try:
await async_checkpointer.aclose()
except Exception:
pass
print("All middleware closed.")
print("Demo complete!")
All middleware closed.
Demo complete!