"""Synchronous Redis store implementation."""
from __future__ import annotations
import asyncio
import json
import logging
import math
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Any, Iterable, Iterator, Optional, Sequence, cast
from langgraph.store.base import (
BaseStore,
GetOp,
IndexConfig,
ListNamespacesOp,
Op,
PutOp,
Result,
SearchOp,
TTLConfig,
)
from redis import Redis
from redis.cluster import RedisCluster as SyncRedisCluster
from redis.commands.search.query import Query
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery, VectorQuery
from redisvl.redis.connection import RedisConnectionFactory
from redisvl.utils.token_escaper import TokenEscaper
from ulid import ULID
from langgraph.store.redis.aio import AsyncRedisStore
from langgraph.store.redis.base import (
REDIS_KEY_SEPARATOR,
STORE_PREFIX,
STORE_VECTOR_PREFIX,
BaseRedisStore,
RedisDocument,
_decode_ns,
_group_ops,
_namespace_to_text,
_row_to_item,
_row_to_search_item,
)
from .token_unescaper import TokenUnescaper
_token_escaper = TokenEscaper()
_token_unescaper = TokenUnescaper()
logger = logging.getLogger(__name__)
def _convert_redis_score_to_similarity(score: float, distance_type: str) -> float:
"""Convert Redis vector distance to similarity score."""
if distance_type == "cosine":
# Redis returns cosine distance (1 - cosine_similarity)
# Convert back to similarity
return 1.0 - score
elif distance_type == "l2":
# For L2, smaller distance means more similar
# Use a simple exponential decay
return math.exp(-score)
elif distance_type == "inner_product":
# For inner product, Redis already returns what we want
return score
return score
[docs]
class RedisStore(BaseStore, BaseRedisStore[Redis, SearchIndex]):
"""Redis-backed store with optional vector search.
Provides synchronous operations for storing and retrieving data with optional
vector similarity search support.
Supports standard Redis URLs (redis://), SSL (rediss://), and
Sentinel URLs (redis+sentinel://host:26379/service_name/db).
"""
# Enable TTL support
supports_ttl = True
ttl_config: Optional[TTLConfig] = None
def __init__(
self,
conn: Redis,
*,
index: Optional[IndexConfig] = None,
ttl: Optional[TTLConfig] = None,
cluster_mode: Optional[bool] = None,
store_prefix: str = STORE_PREFIX,
vector_prefix: str = STORE_VECTOR_PREFIX,
) -> None:
BaseStore.__init__(self)
BaseRedisStore.__init__(
self,
conn,
index=index,
ttl=ttl,
cluster_mode=cluster_mode,
store_prefix=store_prefix,
vector_prefix=vector_prefix,
)
# Set client info for monitoring (sync store can call this safely)
self.set_client_info()
# Detection will happen in setup()
[docs]
@classmethod
@contextmanager
def from_conn_string(
cls,
conn_string: str,
*,
index: Optional[IndexConfig] = None,
ttl: Optional[TTLConfig] = None,
store_prefix: str = STORE_PREFIX,
vector_prefix: str = STORE_VECTOR_PREFIX,
) -> Iterator[RedisStore]:
"""Create store from Redis connection string."""
client = None
try:
client = RedisConnectionFactory.get_redis_connection(conn_string)
store = cls(
client,
index=index,
ttl=ttl,
store_prefix=store_prefix,
vector_prefix=vector_prefix,
)
# Client info is set in __init__, but set it again here to ensure
# it's available even if called before setup()
store.set_client_info()
yield store
finally:
if client:
client.close()
client.connection_pool.disconnect()
[docs]
def setup(self) -> None:
"""Initialize store indices."""
# Detect if we're connected to a Redis cluster
self._detect_cluster_mode()
self.store_index.create(overwrite=False)
if self.index_config:
self.vector_index.create(overwrite=False)
[docs]
def batch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute batch of operations."""
grouped_ops, num_ops = _group_ops(ops)
results: list[Result] = [None] * num_ops
if GetOp in grouped_ops:
self._batch_get_ops(
cast(list[tuple[int, GetOp]], grouped_ops[GetOp]), results
)
if PutOp in grouped_ops:
self._batch_put_ops(cast(list[tuple[int, PutOp]], grouped_ops[PutOp]))
if SearchOp in grouped_ops:
self._batch_search_ops(
cast(list[tuple[int, SearchOp]], grouped_ops[SearchOp]), results
)
if ListNamespacesOp in grouped_ops:
self._batch_list_namespaces_ops(
cast(
Sequence[tuple[int, ListNamespacesOp]],
grouped_ops[ListNamespacesOp],
),
results,
)
return results
def _detect_cluster_mode(self) -> None:
"""Detect if the Redis client is a cluster client by inspecting its class."""
# If we passed in_cluster_mode explicitly, respect it
if self.cluster_mode is not None:
logger.info(
f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection."
)
return
if isinstance(self._redis, SyncRedisCluster):
self.cluster_mode = True
logger.info("Redis cluster client detected for RedisStore.")
else:
self.cluster_mode = False
logger.info("Redis standalone client detected for RedisStore.")
def _batch_list_namespaces_ops(
self,
list_ops: Sequence[tuple[int, ListNamespacesOp]],
results: list[Result],
) -> None:
"""Execute list namespaces operations in batch."""
for idx, op in list_ops:
# Construct base query for namespace search
base_query = "*" # Start with all documents
if op.match_conditions:
conditions = []
for condition in op.match_conditions:
if condition.match_type == "prefix":
prefix = _namespace_to_text(
condition.path, handle_wildcards=True
)
conditions.append(f"@prefix:{prefix}*")
elif condition.match_type == "suffix":
suffix = _namespace_to_text(
condition.path, handle_wildcards=True
)
conditions.append(f"@prefix:*{suffix}")
if conditions:
base_query = " ".join(conditions)
# Execute search with return_fields=["prefix"] to get just namespaces
query = FilterQuery(filter_expression=base_query, return_fields=["prefix"])
res = self.store_index.search(query)
# Extract unique namespaces
namespaces = set()
for doc in res.docs:
if hasattr(doc, "prefix"):
ns = tuple(_token_unescaper.unescape(doc.prefix).split("."))
# Apply max_depth if specified
if op.max_depth is not None:
ns = ns[: op.max_depth]
namespaces.add(ns)
# Sort and apply pagination
sorted_namespaces = sorted(namespaces)
if op.limit or op.offset:
offset = op.offset or 0
limit = op.limit or 10
sorted_namespaces = sorted_namespaces[offset : offset + limit]
results[idx] = sorted_namespaces
def _batch_get_ops(
self,
get_ops: list[tuple[int, GetOp]],
results: list[Result],
) -> None:
"""Execute GET operations in batch."""
refresh_keys_by_idx: dict[int, list[str]] = (
{}
) # Track keys that need TTL refreshed by op index
for query, _, namespace, items in self._get_batch_GET_ops_queries(get_ops):
res = self.store_index.search(Query(query))
# Parse JSON from each document
key_to_row = {
json.loads(doc.json)["key"]: (json.loads(doc.json), doc.id)
for doc in res.docs
}
for idx, key in items:
if key in key_to_row:
data, doc_id = key_to_row[key]
results[idx] = _row_to_item(
namespace, data, deserialize_fn=self._deserialize_value
)
# Find the corresponding operation by looking it up in the operation list
# This is needed because idx is the index in the overall operation list
op_idx = None
for i, (local_idx, op) in enumerate(get_ops):
if local_idx == idx:
op_idx = i
break
if op_idx is not None:
op = get_ops[op_idx][1]
if hasattr(op, "refresh_ttl") and op.refresh_ttl:
if idx not in refresh_keys_by_idx:
refresh_keys_by_idx[idx] = []
refresh_keys_by_idx[idx].append(doc_id)
# Also add vector keys for the same document
doc_uuid = doc_id.split(":")[-1]
vector_key = (
f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}"
)
refresh_keys_by_idx[idx].append(vector_key)
# Now refresh TTLs for any keys that need it
if refresh_keys_by_idx and self.ttl_config:
# Get default TTL from config
ttl_minutes = None
if "default_ttl" in self.ttl_config:
ttl_minutes = self.ttl_config.get("default_ttl")
if ttl_minutes is not None:
ttl_seconds = int(ttl_minutes * 60)
if self.cluster_mode:
for keys_to_refresh in refresh_keys_by_idx.values():
for key in keys_to_refresh:
ttl = self._redis.ttl(key)
if ttl > 0:
self._redis.expire(key, ttl_seconds)
else:
pipeline = self._redis.pipeline(transaction=True)
for keys in refresh_keys_by_idx.values():
for key in keys:
# Only refresh TTL if the key exists and has a TTL
ttl = self._redis.ttl(key)
if ttl > 0: # Only refresh if key exists and has TTL
pipeline.expire(key, ttl_seconds)
if pipeline.command_stack:
pipeline.execute()
def _batch_put_ops(
self,
put_ops: list[tuple[int, PutOp]],
) -> None:
"""Execute PUT operations in batch."""
operations, embedding_request = self._prepare_batch_PUT_queries(put_ops)
# First delete any existing documents that are being updated/deleted
for _, op in put_ops:
namespace = _namespace_to_text(op.namespace)
query = f"@prefix:{namespace} @key:{{{_token_escaper.escape(op.key)}}}"
results = self.store_index.search(query)
if self.cluster_mode:
for doc in results.docs:
self._redis.delete(doc.id)
if self.index_config:
vector_results = self.vector_index.search(query)
for doc_vec in vector_results.docs:
self._redis.delete(doc_vec.id)
else:
pipeline = self._redis.pipeline(transaction=True)
for doc in results.docs:
pipeline.delete(doc.id)
if self.index_config:
vector_results = self.vector_index.search(query)
for doc_vec in vector_results.docs:
pipeline.delete(doc_vec.id)
if pipeline.command_stack:
pipeline.execute()
# Now handle new document creation
doc_ids: dict[tuple[str, str], str] = {}
store_docs: list[RedisDocument] = []
store_keys: list[str] = []
ttl_tracking: dict[str, tuple[list[str], Optional[float]]] = (
{}
) # Tracks keys that need TTL + their TTL values
# Generate IDs for PUT operations
for _, op in put_ops:
if op.value is not None:
generated_doc_id = str(ULID())
namespace = _namespace_to_text(op.namespace)
doc_ids[(namespace, op.key)] = generated_doc_id
# Track TTL for this document if specified
if hasattr(op, "ttl") and op.ttl is not None:
main_key = (
f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{generated_doc_id}"
)
ttl_tracking[main_key] = ([], op.ttl)
# Load store docs with explicit keys
for doc in operations:
store_key = (doc["prefix"], doc["key"])
doc_id = doc_ids[store_key]
# Remove TTL fields - they're not needed with Redis native TTL
if "ttl_minutes" in doc:
doc.pop("ttl_minutes", None)
if "expires_at" in doc:
doc.pop("expires_at", None)
store_docs.append(doc)
redis_key = f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_id}"
store_keys.append(redis_key)
if store_docs:
if self.cluster_mode:
# Load individually if cluster
for i, store_doc_item in enumerate(store_docs):
self.store_index.load([store_doc_item], keys=[store_keys[i]])
else:
self.store_index.load(store_docs, keys=store_keys)
# Handle vector embeddings with same IDs
if embedding_request and self.embeddings:
_, text_params = embedding_request
vectors = self.embeddings.embed_documents(
[text for _, _, _, text in text_params]
)
vector_docs: list[dict[str, Any]] = []
vector_keys: list[str] = []
# Check if we're using hash storage for vectors
vector_storage_type = "json" # default
if self.index_config:
index_dict = dict(self.index_config)
vector_storage_type = index_dict.get("vector_storage_type", "json")
for (ns, key, path, _), vector in zip(text_params, vectors):
vector_key: tuple[str, str] = (ns, key)
doc_id = doc_ids[vector_key]
# Prepare vector based on storage type
if vector_storage_type == "hash":
# For hash storage, convert vector to byte string
from redisvl.redis.utils import array_to_buffer
vector_list = (
vector.tolist() if hasattr(vector, "tolist") else vector
)
embedding_value = array_to_buffer(vector_list, "float32")
else:
# For JSON storage, keep as list
embedding_value = (
vector.tolist() if hasattr(vector, "tolist") else vector
)
vector_docs.append(
{
"prefix": ns,
"key": key,
"field_name": path,
"embedding": embedding_value,
"created_at": datetime.now(timezone.utc).timestamp(),
"updated_at": datetime.now(timezone.utc).timestamp(),
}
)
redis_vector_key = f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}"
vector_keys.append(redis_vector_key)
# Add this vector key to the related keys list for TTL
main_key = f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_id}"
if main_key in ttl_tracking:
ttl_tracking[main_key][0].append(redis_vector_key)
if vector_docs:
if self.cluster_mode:
# Load individually if cluster
for i, vector_doc_item in enumerate(vector_docs):
self.vector_index.load([vector_doc_item], keys=[vector_keys[i]])
else:
self.vector_index.load(vector_docs, keys=vector_keys)
# Now apply TTLs after all documents are loaded
for main_key, (related_keys, ttl_minutes) in ttl_tracking.items():
self._apply_ttl_to_keys(main_key, related_keys, ttl_minutes)
def _batch_search_ops(
self,
search_ops: list[tuple[int, SearchOp]],
results: list[Result],
) -> None:
"""Execute search operations in batch."""
queries, embedding_requests = self._get_batch_search_queries(search_ops)
# Handle vector search
query_vectors = {}
if embedding_requests and self.embeddings:
vectors = self.embeddings.embed_documents(
[query for _, query in embedding_requests]
)
query_vectors = dict(zip([idx for idx, _ in embedding_requests], vectors))
# Process each search operation
for (idx, op), (query_str, params, limit, offset) in zip(search_ops, queries):
if op.query and idx in query_vectors:
# Vector similarity search
vector = query_vectors[idx]
vector_query = VectorQuery(
vector=vector.tolist() if hasattr(vector, "tolist") else vector,
vector_field_name="embedding",
filter_expression=f"@prefix:{_namespace_to_text(op.namespace_prefix)}*",
return_fields=["prefix", "key", "vector_distance"],
num_results=limit, # Use the user-specified limit
)
vector_results = self.vector_index.query(vector_query)
# Get matching store docs
result_map = {} # Map store key to vector result with distances
if self.cluster_mode:
store_docs = []
# Direct JSON GET for cluster mode
for doc in vector_results:
doc_id = (
doc.get("id")
if isinstance(doc, dict)
else getattr(doc, "id", None)
)
if doc_id:
doc_uuid = doc_id.split(":")[1]
store_key = (
f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}"
)
result_map[store_key] = doc
# Fetch individually in cluster mode
store_doc_item = self._redis.json().get(store_key)
store_docs.append(store_doc_item)
store_docs_raw = store_docs
else:
pipe = self._redis.pipeline(transaction=True)
for doc in vector_results:
doc_id = (
doc.get("id")
if isinstance(doc, dict)
else getattr(doc, "id", None)
)
if not doc_id:
continue
doc_uuid = doc_id.split(":")[1]
store_key = (
f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}"
)
result_map[store_key] = doc
pipe.json().get(store_key)
# Execute all lookups in one batch
store_docs_raw = pipe.execute()
# Process results maintaining order and applying filters
items = []
refresh_keys = [] # Track keys that need TTL refreshed
store_docs_iter = iter(store_docs_raw)
for store_key in result_map.keys():
store_doc = next(store_docs_iter, None)
if store_doc:
vector_result = result_map[store_key]
# Get vector_distance from original search result
dist = (
vector_result.get("vector_distance")
if isinstance(vector_result, dict)
else getattr(vector_result, "vector_distance", 0)
)
# Convert to similarity score
score = (1.0 - float(dist)) if dist is not None else 0.0
if not isinstance(store_doc, dict):
try:
# Cast needed: redis-py types json().get() incorrectly
store_doc = json.loads(
cast(str, store_doc)
) # Attempt to parse if it's a JSON string
except (json.JSONDecodeError, TypeError):
logger.error(f"Failed to parse store_doc: {store_doc}")
continue # Skip this problematic document
if isinstance(
store_doc, dict
): # Check again after potential parsing
store_doc["vector_distance"] = dist
else:
# if still not a dict, this means it's a problematic entry
logger.error(
f"store_doc is not a dict after parsing attempt: {store_doc}"
)
continue
# Apply value filters if needed
if op.filter:
matches = True
value = store_doc.get("value", {})
for key, expected in op.filter.items():
actual = value.get(key)
if isinstance(expected, list):
if actual not in expected:
matches = False
break
elif actual != expected:
matches = False
break
if not matches:
continue
# If refresh_ttl is true, add to list for refreshing
if op.refresh_ttl:
refresh_keys.append(store_key)
# Also find associated vector keys with same ID
doc_id = store_key.split(":")[-1]
vector_key = (
f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}"
)
refresh_keys.append(vector_key)
items.append(
_row_to_search_item(
_decode_ns(store_doc["prefix"]),
store_doc,
score=score,
deserialize_fn=self._deserialize_value,
)
)
# Refresh TTL if requested
if op.refresh_ttl and refresh_keys and self.ttl_config:
# Get default TTL from config
ttl_minutes = None
if "default_ttl" in self.ttl_config:
ttl_minutes = self.ttl_config.get("default_ttl")
if ttl_minutes is not None:
ttl_seconds = int(ttl_minutes * 60)
if self.cluster_mode:
for key in refresh_keys:
ttl = self._redis.ttl(key)
if ttl > 0:
self._redis.expire(key, ttl_seconds)
else:
pipeline = self._redis.pipeline(transaction=True)
for key in refresh_keys:
# Only refresh TTL if the key exists and has a TTL
ttl = self._redis.ttl(key)
if ttl > 0: # Only refresh if key exists and has TTL
pipeline.expire(key, ttl_seconds)
if pipeline.command_stack:
pipeline.execute()
results[idx] = items
else:
# Regular search
# Create a query with LIMIT and OFFSET parameters
query = Query(query_str).paging(offset, limit)
# Execute search with limit and offset applied by Redis
res = self.store_index.search(query)
items = []
refresh_keys = [] # Track keys that need TTL refreshed
for doc in res.docs:
data = json.loads(doc.json)
# Apply value filters
if op.filter:
matches = True
value = data.get("value", {})
for key, expected in op.filter.items():
actual = value.get(key)
if isinstance(expected, list):
if actual not in expected:
matches = False
break
elif actual != expected:
matches = False
break
if not matches:
continue
# If refresh_ttl is true, add the key to refresh list
if op.refresh_ttl:
refresh_keys.append(doc.id)
# Also find associated vector keys with same ID
doc_id = doc.id.split(":")[-1]
vector_key = (
f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}"
)
refresh_keys.append(vector_key)
items.append(
_row_to_search_item(
_decode_ns(data["prefix"]),
data,
deserialize_fn=self._deserialize_value,
)
)
# Refresh TTL if requested
if op.refresh_ttl and refresh_keys and self.ttl_config:
# Get default TTL from config
ttl_minutes = None
if "default_ttl" in self.ttl_config:
ttl_minutes = self.ttl_config.get("default_ttl")
if ttl_minutes is not None:
ttl_seconds = int(ttl_minutes * 60)
if self.cluster_mode:
for key in refresh_keys:
ttl = self._redis.ttl(key)
if ttl > 0:
self._redis.expire(key, ttl_seconds)
else:
pipeline = self._redis.pipeline(transaction=True)
for key in refresh_keys:
# Only refresh TTL if the key exists and has a TTL
ttl = self._redis.ttl(key)
if ttl > 0: # Only refresh if key exists and has TTL
pipeline.expire(key, ttl_seconds)
if pipeline.command_stack:
pipeline.execute()
results[idx] = items
[docs]
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute batch of operations asynchronously."""
return await asyncio.get_running_loop().run_in_executor(None, self.batch, ops)
__all__ = ["AsyncRedisStore", "RedisStore"]