Source code for langgraph.store.redis

"""Synchronous Redis store implementation."""

from __future__ import annotations

import asyncio
import json
import logging
import math
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Any, Iterable, Iterator, Optional, Sequence, cast

from langgraph.store.base import (
    BaseStore,
    GetOp,
    IndexConfig,
    ListNamespacesOp,
    Op,
    PutOp,
    Result,
    SearchOp,
    TTLConfig,
)
from redis import Redis
from redis.cluster import RedisCluster as SyncRedisCluster
from redis.commands.search.query import Query
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery, VectorQuery
from redisvl.redis.connection import RedisConnectionFactory
from redisvl.utils.token_escaper import TokenEscaper
from ulid import ULID

from langgraph.store.redis.aio import AsyncRedisStore
from langgraph.store.redis.base import (
    REDIS_KEY_SEPARATOR,
    STORE_PREFIX,
    STORE_VECTOR_PREFIX,
    BaseRedisStore,
    RedisDocument,
    _decode_ns,
    _group_ops,
    _namespace_to_text,
    _row_to_item,
    _row_to_search_item,
)

from .token_unescaper import TokenUnescaper

_token_escaper = TokenEscaper()
_token_unescaper = TokenUnescaper()

logger = logging.getLogger(__name__)


def _convert_redis_score_to_similarity(score: float, distance_type: str) -> float:
    """Convert Redis vector distance to similarity score."""
    if distance_type == "cosine":
        # Redis returns cosine distance (1 - cosine_similarity)
        # Convert back to similarity
        return 1.0 - score
    elif distance_type == "l2":
        # For L2, smaller distance means more similar
        # Use a simple exponential decay
        return math.exp(-score)
    elif distance_type == "inner_product":
        # For inner product, Redis already returns what we want
        return score
    return score


[docs] class RedisStore(BaseStore, BaseRedisStore[Redis, SearchIndex]): """Redis-backed store with optional vector search. Provides synchronous operations for storing and retrieving data with optional vector similarity search support. Supports standard Redis URLs (redis://), SSL (rediss://), and Sentinel URLs (redis+sentinel://host:26379/service_name/db). """ # Enable TTL support supports_ttl = True ttl_config: Optional[TTLConfig] = None def __init__( self, conn: Redis, *, index: Optional[IndexConfig] = None, ttl: Optional[TTLConfig] = None, cluster_mode: Optional[bool] = None, store_prefix: str = STORE_PREFIX, vector_prefix: str = STORE_VECTOR_PREFIX, ) -> None: BaseStore.__init__(self) BaseRedisStore.__init__( self, conn, index=index, ttl=ttl, cluster_mode=cluster_mode, store_prefix=store_prefix, vector_prefix=vector_prefix, ) # Set client info for monitoring (sync store can call this safely) self.set_client_info() # Detection will happen in setup()
[docs] @classmethod @contextmanager def from_conn_string( cls, conn_string: str, *, index: Optional[IndexConfig] = None, ttl: Optional[TTLConfig] = None, store_prefix: str = STORE_PREFIX, vector_prefix: str = STORE_VECTOR_PREFIX, ) -> Iterator[RedisStore]: """Create store from Redis connection string.""" client = None try: client = RedisConnectionFactory.get_redis_connection(conn_string) store = cls( client, index=index, ttl=ttl, store_prefix=store_prefix, vector_prefix=vector_prefix, ) # Client info is set in __init__, but set it again here to ensure # it's available even if called before setup() store.set_client_info() yield store finally: if client: client.close() client.connection_pool.disconnect()
[docs] def setup(self) -> None: """Initialize store indices.""" # Detect if we're connected to a Redis cluster self._detect_cluster_mode() self.store_index.create(overwrite=False) if self.index_config: self.vector_index.create(overwrite=False)
[docs] def batch(self, ops: Iterable[Op]) -> list[Result]: """Execute batch of operations.""" grouped_ops, num_ops = _group_ops(ops) results: list[Result] = [None] * num_ops if GetOp in grouped_ops: self._batch_get_ops( cast(list[tuple[int, GetOp]], grouped_ops[GetOp]), results ) if PutOp in grouped_ops: self._batch_put_ops(cast(list[tuple[int, PutOp]], grouped_ops[PutOp])) if SearchOp in grouped_ops: self._batch_search_ops( cast(list[tuple[int, SearchOp]], grouped_ops[SearchOp]), results ) if ListNamespacesOp in grouped_ops: self._batch_list_namespaces_ops( cast( Sequence[tuple[int, ListNamespacesOp]], grouped_ops[ListNamespacesOp], ), results, ) return results
def _detect_cluster_mode(self) -> None: """Detect if the Redis client is a cluster client by inspecting its class.""" # If we passed in_cluster_mode explicitly, respect it if self.cluster_mode is not None: logger.info( f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection." ) return if isinstance(self._redis, SyncRedisCluster): self.cluster_mode = True logger.info("Redis cluster client detected for RedisStore.") else: self.cluster_mode = False logger.info("Redis standalone client detected for RedisStore.") def _batch_list_namespaces_ops( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], results: list[Result], ) -> None: """Execute list namespaces operations in batch.""" for idx, op in list_ops: # Construct base query for namespace search base_query = "*" # Start with all documents if op.match_conditions: conditions = [] for condition in op.match_conditions: if condition.match_type == "prefix": prefix = _namespace_to_text( condition.path, handle_wildcards=True ) conditions.append(f"@prefix:{prefix}*") elif condition.match_type == "suffix": suffix = _namespace_to_text( condition.path, handle_wildcards=True ) conditions.append(f"@prefix:*{suffix}") if conditions: base_query = " ".join(conditions) # Execute search with return_fields=["prefix"] to get just namespaces query = FilterQuery(filter_expression=base_query, return_fields=["prefix"]) res = self.store_index.search(query) # Extract unique namespaces namespaces = set() for doc in res.docs: if hasattr(doc, "prefix"): ns = tuple(_token_unescaper.unescape(doc.prefix).split(".")) # Apply max_depth if specified if op.max_depth is not None: ns = ns[: op.max_depth] namespaces.add(ns) # Sort and apply pagination sorted_namespaces = sorted(namespaces) if op.limit or op.offset: offset = op.offset or 0 limit = op.limit or 10 sorted_namespaces = sorted_namespaces[offset : offset + limit] results[idx] = sorted_namespaces def _batch_get_ops( self, get_ops: list[tuple[int, GetOp]], results: list[Result], ) -> None: """Execute GET operations in batch.""" refresh_keys_by_idx: dict[int, list[str]] = ( {} ) # Track keys that need TTL refreshed by op index for query, _, namespace, items in self._get_batch_GET_ops_queries(get_ops): res = self.store_index.search(Query(query)) # Parse JSON from each document key_to_row = { json.loads(doc.json)["key"]: (json.loads(doc.json), doc.id) for doc in res.docs } for idx, key in items: if key in key_to_row: data, doc_id = key_to_row[key] results[idx] = _row_to_item( namespace, data, deserialize_fn=self._deserialize_value ) # Find the corresponding operation by looking it up in the operation list # This is needed because idx is the index in the overall operation list op_idx = None for i, (local_idx, op) in enumerate(get_ops): if local_idx == idx: op_idx = i break if op_idx is not None: op = get_ops[op_idx][1] if hasattr(op, "refresh_ttl") and op.refresh_ttl: if idx not in refresh_keys_by_idx: refresh_keys_by_idx[idx] = [] refresh_keys_by_idx[idx].append(doc_id) # Also add vector keys for the same document doc_uuid = doc_id.split(":")[-1] vector_key = ( f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}" ) refresh_keys_by_idx[idx].append(vector_key) # Now refresh TTLs for any keys that need it if refresh_keys_by_idx and self.ttl_config: # Get default TTL from config ttl_minutes = None if "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) if self.cluster_mode: for keys_to_refresh in refresh_keys_by_idx.values(): for key in keys_to_refresh: ttl = self._redis.ttl(key) if ttl > 0: self._redis.expire(key, ttl_seconds) else: pipeline = self._redis.pipeline(transaction=True) for keys in refresh_keys_by_idx.values(): for key in keys: # Only refresh TTL if the key exists and has a TTL ttl = self._redis.ttl(key) if ttl > 0: # Only refresh if key exists and has TTL pipeline.expire(key, ttl_seconds) if pipeline.command_stack: pipeline.execute() def _batch_put_ops( self, put_ops: list[tuple[int, PutOp]], ) -> None: """Execute PUT operations in batch.""" operations, embedding_request = self._prepare_batch_PUT_queries(put_ops) # First delete any existing documents that are being updated/deleted for _, op in put_ops: namespace = _namespace_to_text(op.namespace) query = f"@prefix:{namespace} @key:{{{_token_escaper.escape(op.key)}}}" results = self.store_index.search(query) if self.cluster_mode: for doc in results.docs: self._redis.delete(doc.id) if self.index_config: vector_results = self.vector_index.search(query) for doc_vec in vector_results.docs: self._redis.delete(doc_vec.id) else: pipeline = self._redis.pipeline(transaction=True) for doc in results.docs: pipeline.delete(doc.id) if self.index_config: vector_results = self.vector_index.search(query) for doc_vec in vector_results.docs: pipeline.delete(doc_vec.id) if pipeline.command_stack: pipeline.execute() # Now handle new document creation doc_ids: dict[tuple[str, str], str] = {} store_docs: list[RedisDocument] = [] store_keys: list[str] = [] ttl_tracking: dict[str, tuple[list[str], Optional[float]]] = ( {} ) # Tracks keys that need TTL + their TTL values # Generate IDs for PUT operations for _, op in put_ops: if op.value is not None: generated_doc_id = str(ULID()) namespace = _namespace_to_text(op.namespace) doc_ids[(namespace, op.key)] = generated_doc_id # Track TTL for this document if specified if hasattr(op, "ttl") and op.ttl is not None: main_key = ( f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{generated_doc_id}" ) ttl_tracking[main_key] = ([], op.ttl) # Load store docs with explicit keys for doc in operations: store_key = (doc["prefix"], doc["key"]) doc_id = doc_ids[store_key] # Remove TTL fields - they're not needed with Redis native TTL if "ttl_minutes" in doc: doc.pop("ttl_minutes", None) if "expires_at" in doc: doc.pop("expires_at", None) store_docs.append(doc) redis_key = f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" store_keys.append(redis_key) if store_docs: if self.cluster_mode: # Load individually if cluster for i, store_doc_item in enumerate(store_docs): self.store_index.load([store_doc_item], keys=[store_keys[i]]) else: self.store_index.load(store_docs, keys=store_keys) # Handle vector embeddings with same IDs if embedding_request and self.embeddings: _, text_params = embedding_request vectors = self.embeddings.embed_documents( [text for _, _, _, text in text_params] ) vector_docs: list[dict[str, Any]] = [] vector_keys: list[str] = [] # Check if we're using hash storage for vectors vector_storage_type = "json" # default if self.index_config: index_dict = dict(self.index_config) vector_storage_type = index_dict.get("vector_storage_type", "json") for (ns, key, path, _), vector in zip(text_params, vectors): vector_key: tuple[str, str] = (ns, key) doc_id = doc_ids[vector_key] # Prepare vector based on storage type if vector_storage_type == "hash": # For hash storage, convert vector to byte string from redisvl.redis.utils import array_to_buffer vector_list = ( vector.tolist() if hasattr(vector, "tolist") else vector ) embedding_value = array_to_buffer(vector_list, "float32") else: # For JSON storage, keep as list embedding_value = ( vector.tolist() if hasattr(vector, "tolist") else vector ) vector_docs.append( { "prefix": ns, "key": key, "field_name": path, "embedding": embedding_value, "created_at": datetime.now(timezone.utc).timestamp(), "updated_at": datetime.now(timezone.utc).timestamp(), } ) redis_vector_key = f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" vector_keys.append(redis_vector_key) # Add this vector key to the related keys list for TTL main_key = f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" if main_key in ttl_tracking: ttl_tracking[main_key][0].append(redis_vector_key) if vector_docs: if self.cluster_mode: # Load individually if cluster for i, vector_doc_item in enumerate(vector_docs): self.vector_index.load([vector_doc_item], keys=[vector_keys[i]]) else: self.vector_index.load(vector_docs, keys=vector_keys) # Now apply TTLs after all documents are loaded for main_key, (related_keys, ttl_minutes) in ttl_tracking.items(): self._apply_ttl_to_keys(main_key, related_keys, ttl_minutes) def _batch_search_ops( self, search_ops: list[tuple[int, SearchOp]], results: list[Result], ) -> None: """Execute search operations in batch.""" queries, embedding_requests = self._get_batch_search_queries(search_ops) # Handle vector search query_vectors = {} if embedding_requests and self.embeddings: vectors = self.embeddings.embed_documents( [query for _, query in embedding_requests] ) query_vectors = dict(zip([idx for idx, _ in embedding_requests], vectors)) # Process each search operation for (idx, op), (query_str, params, limit, offset) in zip(search_ops, queries): if op.query and idx in query_vectors: # Vector similarity search vector = query_vectors[idx] vector_query = VectorQuery( vector=vector.tolist() if hasattr(vector, "tolist") else vector, vector_field_name="embedding", filter_expression=f"@prefix:{_namespace_to_text(op.namespace_prefix)}*", return_fields=["prefix", "key", "vector_distance"], num_results=limit, # Use the user-specified limit ) vector_results = self.vector_index.query(vector_query) # Get matching store docs result_map = {} # Map store key to vector result with distances if self.cluster_mode: store_docs = [] # Direct JSON GET for cluster mode for doc in vector_results: doc_id = ( doc.get("id") if isinstance(doc, dict) else getattr(doc, "id", None) ) if doc_id: doc_uuid = doc_id.split(":")[1] store_key = ( f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}" ) result_map[store_key] = doc # Fetch individually in cluster mode store_doc_item = self._redis.json().get(store_key) store_docs.append(store_doc_item) store_docs_raw = store_docs else: pipe = self._redis.pipeline(transaction=True) for doc in vector_results: doc_id = ( doc.get("id") if isinstance(doc, dict) else getattr(doc, "id", None) ) if not doc_id: continue doc_uuid = doc_id.split(":")[1] store_key = ( f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}" ) result_map[store_key] = doc pipe.json().get(store_key) # Execute all lookups in one batch store_docs_raw = pipe.execute() # Process results maintaining order and applying filters items = [] refresh_keys = [] # Track keys that need TTL refreshed store_docs_iter = iter(store_docs_raw) for store_key in result_map.keys(): store_doc = next(store_docs_iter, None) if store_doc: vector_result = result_map[store_key] # Get vector_distance from original search result dist = ( vector_result.get("vector_distance") if isinstance(vector_result, dict) else getattr(vector_result, "vector_distance", 0) ) # Convert to similarity score score = (1.0 - float(dist)) if dist is not None else 0.0 if not isinstance(store_doc, dict): try: # Cast needed: redis-py types json().get() incorrectly store_doc = json.loads( cast(str, store_doc) ) # Attempt to parse if it's a JSON string except (json.JSONDecodeError, TypeError): logger.error(f"Failed to parse store_doc: {store_doc}") continue # Skip this problematic document if isinstance( store_doc, dict ): # Check again after potential parsing store_doc["vector_distance"] = dist else: # if still not a dict, this means it's a problematic entry logger.error( f"store_doc is not a dict after parsing attempt: {store_doc}" ) continue # Apply value filters if needed if op.filter: matches = True value = store_doc.get("value", {}) for key, expected in op.filter.items(): actual = value.get(key) if isinstance(expected, list): if actual not in expected: matches = False break elif actual != expected: matches = False break if not matches: continue # If refresh_ttl is true, add to list for refreshing if op.refresh_ttl: refresh_keys.append(store_key) # Also find associated vector keys with same ID doc_id = store_key.split(":")[-1] vector_key = ( f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" ) refresh_keys.append(vector_key) items.append( _row_to_search_item( _decode_ns(store_doc["prefix"]), store_doc, score=score, deserialize_fn=self._deserialize_value, ) ) # Refresh TTL if requested if op.refresh_ttl and refresh_keys and self.ttl_config: # Get default TTL from config ttl_minutes = None if "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) if self.cluster_mode: for key in refresh_keys: ttl = self._redis.ttl(key) if ttl > 0: self._redis.expire(key, ttl_seconds) else: pipeline = self._redis.pipeline(transaction=True) for key in refresh_keys: # Only refresh TTL if the key exists and has a TTL ttl = self._redis.ttl(key) if ttl > 0: # Only refresh if key exists and has TTL pipeline.expire(key, ttl_seconds) if pipeline.command_stack: pipeline.execute() results[idx] = items else: # Regular search # Create a query with LIMIT and OFFSET parameters query = Query(query_str).paging(offset, limit) # Execute search with limit and offset applied by Redis res = self.store_index.search(query) items = [] refresh_keys = [] # Track keys that need TTL refreshed for doc in res.docs: data = json.loads(doc.json) # Apply value filters if op.filter: matches = True value = data.get("value", {}) for key, expected in op.filter.items(): actual = value.get(key) if isinstance(expected, list): if actual not in expected: matches = False break elif actual != expected: matches = False break if not matches: continue # If refresh_ttl is true, add the key to refresh list if op.refresh_ttl: refresh_keys.append(doc.id) # Also find associated vector keys with same ID doc_id = doc.id.split(":")[-1] vector_key = ( f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" ) refresh_keys.append(vector_key) items.append( _row_to_search_item( _decode_ns(data["prefix"]), data, deserialize_fn=self._deserialize_value, ) ) # Refresh TTL if requested if op.refresh_ttl and refresh_keys and self.ttl_config: # Get default TTL from config ttl_minutes = None if "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) if self.cluster_mode: for key in refresh_keys: ttl = self._redis.ttl(key) if ttl > 0: self._redis.expire(key, ttl_seconds) else: pipeline = self._redis.pipeline(transaction=True) for key in refresh_keys: # Only refresh TTL if the key exists and has a TTL ttl = self._redis.ttl(key) if ttl > 0: # Only refresh if key exists and has TTL pipeline.expire(key, ttl_seconds) if pipeline.command_stack: pipeline.execute() results[idx] = items
[docs] async def abatch(self, ops: Iterable[Op]) -> list[Result]: """Execute batch of operations asynchronously.""" return await asyncio.get_running_loop().run_in_executor(None, self.batch, ops)
__all__ = ["AsyncRedisStore", "RedisStore"]