Source code for langgraph.store.redis.base

"""Base implementation for Redis-backed store with optional vector search capabilities."""

from __future__ import annotations

import copy
import json
import logging
import threading
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    Iterable,
    Optional,
    Sequence,
    TypedDict,
    TypeVar,
    Union,
)

from langgraph.store.base import (
    GetOp,
    IndexConfig,
    Item,
    ListNamespacesOp,
    Op,
    PutOp,
    SearchItem,
    SearchOp,
    TTLConfig,
    ensure_embeddings,
    get_text_at_path,
    tokenize_path,
)
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from redis.exceptions import ResponseError
from redisvl.index import SearchIndex
from redisvl.query.filter import Tag, Text
from redisvl.utils.token_escaper import TokenEscaper

from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer

from .token_unescaper import TokenUnescaper
from .types import IndexType, RedisClientType

_token_escaper = TokenEscaper()
_token_unescaper = TokenUnescaper()

logger = logging.getLogger(__name__)

REDIS_KEY_SEPARATOR = ":"
STORE_PREFIX = "store"
STORE_VECTOR_PREFIX = "store_vectors"

# Schemas for Redis Search indices
SCHEMAS = [
    {
        "index": {
            "name": "store",
            "prefix": STORE_PREFIX + REDIS_KEY_SEPARATOR,
            "storage_type": "json",
        },
        "fields": [
            {"name": "prefix", "type": "text"},
            {"name": "key", "type": "tag"},
            {"name": "created_at", "type": "numeric"},
            {"name": "updated_at", "type": "numeric"},
            {"name": "ttl_minutes", "type": "numeric"},
            {"name": "expires_at", "type": "numeric"},
        ],
    },
    {
        "index": {
            "name": "store_vectors",
            "prefix": STORE_VECTOR_PREFIX + REDIS_KEY_SEPARATOR,
            "storage_type": "json",
        },
        "fields": [
            {"name": "prefix", "type": "text"},
            {"name": "key", "type": "tag"},
            {"name": "field_name", "type": "tag"},
            {"name": "embedding", "type": "vector"},
            {"name": "created_at", "type": "numeric"},
            {"name": "updated_at", "type": "numeric"},
            {"name": "ttl_minutes", "type": "numeric"},
            {"name": "expires_at", "type": "numeric"},
        ],
    },
]


def _ensure_string_or_literal(value: Any) -> str:
    """Convert value to string safely."""
    if hasattr(value, "lower"):
        return value.lower()
    return str(value)


C = TypeVar("C", bound=Union[Redis, AsyncRedis])


class RedisDocument(TypedDict, total=False):
    prefix: str
    key: str
    value: Optional[str]
    created_at: int
    updated_at: int
    ttl_minutes: Optional[float]
    expires_at: Optional[int]


[docs] class BaseRedisStore(Generic[RedisClientType, IndexType]): """Base Redis implementation for persistent key-value store with optional vector search.""" _redis: RedisClientType store_index: IndexType vector_index: IndexType _ttl_sweeper_thread: Optional[threading.Thread] = None _ttl_stop_event: threading.Event | None = None # Whether to operate in Redis cluster mode; None triggers auto-detection cluster_mode: Optional[bool] = None SCHEMAS = SCHEMAS supports_ttl: bool = True ttl_config: Optional[TTLConfig] = None # Serializer for handling complex objects like LangChain messages _serde: JsonPlusRedisSerializer def _apply_ttl_to_keys( self, main_key: str, related_keys: Optional[list[str]] = None, ttl_minutes: Optional[float] = None, ) -> Any: """Apply Redis native TTL to keys. Args: main_key: The primary Redis key related_keys: Additional Redis keys that should expire at the same time ttl_minutes: Time-to-live in minutes """ if ttl_minutes is None: # Check if there's a default TTL in config if self.ttl_config and "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) # Use the cluster_mode attribute to determine the approach if self.cluster_mode: # Cluster path: direct expire calls self._redis.expire(main_key, ttl_seconds) if related_keys: for key in related_keys: self._redis.expire(key, ttl_seconds) else: # Non-cluster path: transactional pipeline pipeline = self._redis.pipeline(transaction=True) pipeline.expire(main_key, ttl_seconds) if related_keys: for key in related_keys: pipeline.expire(key, ttl_seconds) pipeline.execute()
[docs] def sweep_ttl(self) -> int: """Clean up any remaining expired items. This is not needed with Redis native TTL, but kept for API compatibility. Redis automatically removes expired keys. Returns: int: Always returns 0 as Redis handles expiration automatically """ return 0
[docs] def start_ttl_sweeper(self, _sweep_interval_minutes: Optional[int] = None) -> None: """Start TTL sweeper. This is a no-op with Redis native TTL, but kept for API compatibility. Redis automatically removes expired keys. Args: _sweep_interval_minutes: Ignored parameter, kept for API compatibility """ # No-op: Redis handles TTL expiration automatically pass
[docs] def stop_ttl_sweeper(self, _timeout: Optional[float] = None) -> bool: """Stop TTL sweeper. This is a no-op with Redis native TTL, but kept for API compatibility. Args: _timeout: Ignored parameter, kept for API compatibility Returns: bool: Always True as there's no sweeper to stop """ # No-op: Redis handles TTL expiration automatically return True
def __init__( self, conn: RedisClientType, *, index: Optional[IndexConfig] = None, ttl: Optional[TTLConfig] = None, # Corrected type hint for ttl cluster_mode: Optional[bool] = None, store_prefix: str = STORE_PREFIX, vector_prefix: str = STORE_VECTOR_PREFIX, ) -> None: """Initialize store with Redis connection and optional index config. Args: conn: Redis client connection index: Optional index configuration for vector search ttl: Optional TTL configuration cluster_mode: Optional cluster mode setting (None = auto-detect) store_prefix: Prefix for store keys (default: "store") vector_prefix: Prefix for vector keys (default: "store_vectors") """ self.index_config = index self.ttl_config = ttl self._redis = conn # Store cluster_mode; None means auto-detect in RedisStore or AsyncRedisStore self.cluster_mode = cluster_mode # Initialize the serializer for handling complex objects like LangChain messages self._serde = JsonPlusRedisSerializer() # Store custom prefixes self.store_prefix = store_prefix self.vector_prefix = vector_prefix if self.index_config: self.index_config = self.index_config.copy() self.embeddings = ensure_embeddings( self.index_config.get("embed"), ) fields = self.index_config.get("fields", ["$"]) or [] if isinstance(fields, str): fields = [fields] self.index_config["__tokenized_fields"] = [ (p, tokenize_path(p)) if p != "$" else (p, p) for p in fields ] # Create custom schemas with instance prefixes store_schema = { "index": { "name": self.store_prefix, "prefix": self.store_prefix + REDIS_KEY_SEPARATOR, "storage_type": "json", }, "fields": [ {"name": "prefix", "type": "text"}, {"name": "key", "type": "tag"}, {"name": "created_at", "type": "numeric"}, {"name": "updated_at", "type": "numeric"}, {"name": "ttl_minutes", "type": "numeric"}, {"name": "expires_at", "type": "numeric"}, ], } # Initialize search indices self.store_index = SearchIndex.from_dict(store_schema, redis_client=self._redis) # Configure vector index if needed if self.index_config: # Get storage type from index config, default to "json" # Cast to dict to safely access potential extra fields index_dict = dict(self.index_config) vector_storage_type = index_dict.get("vector_storage_type", "json") # Create custom vector schema with instance prefix vector_schema: Dict[str, Any] = { "index": { "name": self.vector_prefix, "prefix": self.vector_prefix + REDIS_KEY_SEPARATOR, "storage_type": vector_storage_type, }, "fields": [ {"name": "prefix", "type": "text"}, {"name": "key", "type": "tag"}, {"name": "field_name", "type": "tag"}, {"name": "embedding", "type": "vector"}, {"name": "created_at", "type": "numeric"}, {"name": "updated_at", "type": "numeric"}, {"name": "ttl_minutes", "type": "numeric"}, {"name": "expires_at", "type": "numeric"}, ], } vector_fields = vector_schema.get("fields", []) vector_field = None for f in vector_fields: if isinstance(f, dict) and f.get("name") == "embedding": vector_field = f break if vector_field: # Configure vector field with index config values vector_field["attrs"] = { "algorithm": "flat", # Default to flat "datatype": "float32", "dims": self.index_config["dims"], # Map distance metrics to Redis-accepted literals "distance_metric": { "cosine": "COSINE", "inner_product": "IP", "l2": "L2", }[ _ensure_string_or_literal( index_dict.get("distance_type", "cosine") ) ], } # Apply any additional vector type config if "ann_index_config" in index_dict: vector_field["attrs"].update(index_dict["ann_index_config"]) self.vector_index = SearchIndex.from_dict( vector_schema, redis_client=self._redis ) # Note: set_client_info() should be called by concrete implementations # after initialization to avoid async/sync conflicts
[docs] def set_client_info(self) -> None: """Set client info for Redis monitoring.""" from langgraph.checkpoint.redis.version import __full_lib_name__ try: # Try to use client_setinfo command if available self._redis.client_setinfo("LIB-NAME", __full_lib_name__) except (ResponseError, AttributeError): # Fall back to a simple echo if client_setinfo is not available try: self._redis.echo(__full_lib_name__) except Exception: # Silently fail if even echo doesn't work pass
[docs] async def aset_client_info(self) -> None: """Set client info for Redis monitoring asynchronously.""" from langgraph.checkpoint.redis.version import __full_lib_name__ try: # Try to use client_setinfo command if available await self._redis.client_setinfo("LIB-NAME", __full_lib_name__) except (ResponseError, AttributeError): # Fall back to a simple echo if client_setinfo is not available try: # Call with await to ensure it's an async call echo_result = self._redis.echo(__full_lib_name__) if hasattr(echo_result, "__await__"): await echo_result except Exception: # Silently fail if even echo doesn't work pass
def _serialize_value(self, value: Any) -> Any: """Serialize a value for storage in Redis. This method handles complex objects like LangChain messages by serializing them to a JSON-compatible format. The method is smart about serialization: - If the value is a simple JSON-serializable dict/list, it's stored as-is - If the value contains complex objects (HumanMessage, etc.), it uses the serde wrapper format with __serde_type__ and __serde_data__ keys Note: Values containing LangChain messages will be wrapped in a serde format, which means filters on nested fields won't work for such values. Args: value: The value to serialize (can contain HumanMessage, AIMessage, etc.) Returns: A JSON-serializable representation of the value """ if value is None: return None # First, try standard JSON serialization to check if it's needed try: json.dumps(value) # Value is already JSON-serializable, return as-is for backward # compatibility and to preserve filter functionality return value except TypeError: # Value contains non-JSON-serializable objects, use serde wrapper pass # Use the serializer to handle complex objects type_str, data_bytes = self._serde.dumps_typed(value) # Store the serialized data with type info for proper deserialization # Handle different type formats explicitly for clarity if type_str == "json": data_encoded = data_bytes.decode("utf-8") else: # bytes, bytearray, msgpack, and other types are hex-encoded data_encoded = data_bytes.hex() return { "__serde_type__": type_str, "__serde_data__": data_encoded, } def _deserialize_value(self, value: Any) -> Any: """Deserialize a value from Redis storage. This method handles both new serialized format and legacy plain values for backward compatibility. Args: value: The value from Redis (may be serialized or plain) Returns: The deserialized value with proper Python objects (HumanMessage, etc.) """ if value is None: return None # Check if this is a serialized value (new format) # Use exact key check to prevent collisions with user data if isinstance(value, dict) and set(value.keys()) == { "__serde_type__", "__serde_data__", }: type_str = value["__serde_type__"] data_str = value["__serde_data__"] try: # Convert back to bytes based on type if type_str == "json": data_bytes = data_str.encode("utf-8") else: # bytes, bytearray, msgpack types are hex-encoded data_bytes = bytes.fromhex(data_str) return self._serde.loads_typed((type_str, data_bytes)) except (ValueError, TypeError) as e: # Handle hex decoding errors or deserialization failures logger.error( "Failed to deserialize value from Redis: type=%r, error=%s", type_str, e, ) # Return None to indicate deserialization failure return None except Exception as e: # Handle any other unexpected errors during deserialization logger.error( "Unexpected error deserializing value from Redis: type=%r, error=%s", type_str, e, ) return None # Legacy format: value is stored as-is (plain JSON-serializable data) # Return as-is for backward compatibility return value def _get_batch_GET_ops_queries( self, get_ops: Sequence[tuple[int, GetOp]], ) -> list[tuple[str, Sequence, tuple[str, ...], list]]: """Convert GET operations into Redis queries.""" namespace_groups = defaultdict(list) for idx, op in get_ops: namespace_groups[op.namespace].append((idx, op.key)) results: list[tuple[str, Sequence, tuple[str, ...], list]] = [] for namespace, items in namespace_groups.items(): _, keys = zip(*items) # Use Tag helper to properly escape all special characters prefix_filter = Text("prefix") == _namespace_to_text(namespace) filter_str = f"({prefix_filter} " if keys: key_filter = Tag("key") == list(keys) filter_str += f"{key_filter})" else: filter_str += ")" results.append((filter_str, [], namespace, items)) return results def _prepare_batch_PUT_queries( self, put_ops: Sequence[tuple[int, PutOp]], ) -> tuple[ list[RedisDocument], Optional[tuple[str, list[tuple[str, str, str, str]]]] ]: # Last-write wins dedupped_ops: dict[tuple[tuple[str, ...], str], PutOp] = {} for _, op in put_ops: dedupped_ops[(op.namespace, op.key)] = op inserts: list[PutOp] = [] deletes: list[PutOp] = [] for op in dedupped_ops.values(): if op.value is None: deletes.append(op) else: inserts.append(op) operations: list[RedisDocument] = [] embedding_request = None to_embed: list[tuple[str, str, str, str]] = [] if deletes: # Delete matching documents for op in deletes: prefix = _namespace_to_text(op.namespace) query = f"(@prefix:{prefix} @key:{{{op.key}}})" results = self.store_index.search(query) for doc in results.docs: self._redis.delete(doc.id) # Handle inserts if inserts: for op in inserts: now = int(datetime.now(timezone.utc).timestamp() * 1_000_000) # With native Redis TTL, we don't need to store TTL in document ttl_minutes = None expires_at = None if hasattr(op, "ttl") and op.ttl is not None: ttl_minutes = op.ttl # Calculate expiration but don't rely on it for actual expiration # as we'll use Redis native TTL expires_at = int( ( datetime.now(timezone.utc) + timedelta(minutes=op.ttl) ).timestamp() ) doc = RedisDocument( prefix=_namespace_to_text(op.namespace), key=op.key, value=self._serialize_value(op.value), created_at=now, updated_at=now, ttl_minutes=ttl_minutes, expires_at=expires_at, ) operations.append(doc) if self.index_config and op.index is not False: paths = ( self.index_config["__tokenized_fields"] if op.index is None else [(ix, tokenize_path(ix)) for ix in op.index] ) for path, tokenized_path in paths: texts = get_text_at_path(op.value, tokenized_path) for text in texts: to_embed.append( (_namespace_to_text(op.namespace), op.key, path, text) ) if to_embed: embedding_request = ("", to_embed) return operations, embedding_request def _get_batch_search_queries( self, search_ops: Sequence[tuple[int, SearchOp]], ) -> tuple[list[tuple[str, list, int, int]], list[tuple[int, str]]]: """Convert search operations into Redis queries.""" queries = [] embedding_requests = [] for idx, op in search_ops: filter_conditions = [] if op.namespace_prefix: prefix = _namespace_to_text(op.namespace_prefix) filter_conditions.append(f"@prefix:{prefix}*") if op.query and self.index_config: embedding_requests.append((idx, op.query)) query = " ".join(filter_conditions) if filter_conditions else "*" limit = op.limit if op.limit is not None else 10 offset = op.offset if op.offset is not None else 0 params = [limit, offset] queries.append((query, params, limit, offset)) return queries, embedding_requests def _get_batch_list_namespaces_queries( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], ) -> list[tuple[str, list]]: """Convert list namespaces operations into Redis queries.""" queries = [] for _, op in list_ops: conditions = [] if op.match_conditions: for condition in op.match_conditions: if condition.match_type == "prefix": path = _namespace_to_text(condition.path, handle_wildcards=True) conditions.append(f"@prefix:{path}*") elif condition.match_type == "suffix": path = _namespace_to_text(condition.path, handle_wildcards=True) conditions.append(f"@prefix:*{path}") query = " ".join(conditions) if conditions else "*" params = [op.limit, op.offset] if op.limit or op.offset else [] queries.append((query, params)) return queries def _get_filter_condition(self, key: str, op: str, value: Any) -> str: """Get Redis search filter condition for an operator.""" if op == "$eq": return f'@{key}:"{value}"' elif op == "$gt": return f"@{key}:[({value} inf]" elif op == "$gte": return f"@{key}:[{value} inf]" elif op == "$lt": return f"@{key}:[-inf ({value}]" elif op == "$lte": return f"@{key}:[-inf {value}]" elif op == "$ne": return f'-@{key}:"{value}"' else: raise ValueError(f"Unsupported operator: {op}") def _cosine_similarity( self, vec1: list[float], vecs: list[list[float]] ) -> list[float]: """Compute cosine similarity between vectors.""" # Note: For production use, consider importing numpy for better performance similarities = [] for vec2 in vecs: dot_product = sum(a * b for a, b in zip(vec1, vec2)) norm1 = (sum(x * x for x in vec1)) ** 0.5 norm2 = (sum(x * x for x in vec2)) ** 0.5 if norm1 == 0 or norm2 == 0: similarities.append(0) else: similarities.append(dot_product / (norm1 * norm2)) return similarities
def _namespace_to_text( namespace: tuple[str, ...], handle_wildcards: bool = False ) -> str: """Convert namespace tuple to text string with proper escaping. Args: namespace: Tuple of strings representing namespace components handle_wildcards: Whether to handle wildcard characters specially Returns: Properly escaped string representation of namespace """ if handle_wildcards: namespace = tuple("%" if val == "*" else val for val in namespace) # First join with dots ns_text = _token_escaper.escape(".".join(namespace)) return ns_text def _decode_ns(ns: str) -> tuple[str, ...]: """Convert a dotted namespace string back into a tuple.""" return tuple(_token_unescaper.unescape(ns).split(".")) def _row_to_item( namespace: tuple[str, ...], row: dict[str, Any], deserialize_fn: Optional[Callable[[Any], Any]] = None, ) -> Item: """Convert a row from Redis to an Item. Args: namespace: The namespace tuple for this item row: The raw row data from Redis deserialize_fn: Optional function to deserialize the value (handles LangChain messages, etc.) Returns: An Item with properly deserialized value """ value = row["value"] if deserialize_fn is not None: value = deserialize_fn(value) return Item( value=value, key=row["key"], namespace=namespace, created_at=datetime.fromtimestamp(row["created_at"] / 1_000_000, timezone.utc), updated_at=datetime.fromtimestamp(row["updated_at"] / 1_000_000, timezone.utc), ) def _row_to_search_item( namespace: tuple[str, ...], row: dict[str, Any], score: Optional[float] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, ) -> SearchItem: """Convert a row from Redis to a SearchItem. Args: namespace: The namespace tuple for this item row: The raw row data from Redis score: Optional similarity score from vector search deserialize_fn: Optional function to deserialize the value (handles LangChain messages, etc.) Returns: A SearchItem with properly deserialized value """ value = row["value"] if deserialize_fn is not None: value = deserialize_fn(value) return SearchItem( value=value, key=row["key"], namespace=namespace, created_at=datetime.fromtimestamp(row["created_at"] / 1_000_000, timezone.utc), updated_at=datetime.fromtimestamp(row["updated_at"] / 1_000_000, timezone.utc), score=score, ) def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]: """Group operations by type for batch processing.""" grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list) tot = 0 for idx, op in enumerate(ops): grouped_ops[type(op)].append((idx, op)) tot += 1 return grouped_ops, tot