Source code for langgraph.checkpoint.redis.base

import base64
import binascii
import logging
import random
from abc import abstractmethod
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, Union, cast

import orjson
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
    WRITES_IDX_MAP,
    BaseCheckpointSaver,
    Checkpoint,
    CheckpointMetadata,
    PendingWrite,
)
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.types import ChannelProtocol
from redisvl.query import FilterQuery
from redisvl.query.filter import Tag

from langgraph.checkpoint.redis.util import (
    safely_decode,
    to_storage_safe_id,
    to_storage_safe_str,
)

from .jsonplus_redis import JsonPlusRedisSerializer
from .types import IndexType, RedisClientType

logger = logging.getLogger(__name__)

REDIS_KEY_SEPARATOR = ":"
CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_WRITE_PREFIX = "checkpoint_write"


[docs] class BaseRedisSaver(BaseCheckpointSaver[str], Generic[RedisClientType, IndexType]): """Base Redis implementation for checkpoint saving. Uses Redis JSON for storing checkpoints and related data, with RediSearch for querying. """ _redis: RedisClientType _owns_its_client: bool = False _key_registry: Optional[Any] = None checkpoints_index: IndexType checkpoint_writes_index: IndexType def __init__( self, redis_url: Optional[str] = None, *, redis_client: Optional[RedisClientType] = None, connection_args: Optional[Dict[str, Any]] = None, ttl: Optional[Dict[str, Any]] = None, checkpoint_prefix: str = CHECKPOINT_PREFIX, checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX, ) -> None: """Initialize Redis-backed checkpoint saver. Args: redis_url: Redis connection URL redis_client: Redis client instance to use (alternative to redis_url) connection_args: Additional arguments for Redis connection ttl: Optional TTL configuration dict with optional keys: - default_ttl: TTL in minutes for all checkpoint keys - refresh_on_read: Whether to refresh TTL on reads checkpoint_prefix: Prefix for checkpoint keys (default: "checkpoint") checkpoint_write_prefix: Prefix for checkpoint write keys (default: "checkpoint_write") """ super().__init__(serde=JsonPlusRedisSerializer()) if redis_url is None and redis_client is None: raise ValueError("Either redis_url or redis_client must be provided") # Store TTL configuration self.ttl_config = ttl # Store custom prefixes self._checkpoint_prefix = checkpoint_prefix self._checkpoint_write_prefix = checkpoint_write_prefix self.configure_client( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args or {}, ) # Initialize indexes self.checkpoints_index: IndexType self.checkpoint_writes_index: IndexType self.create_indexes() @property def checkpoints_schema(self) -> Dict[str, Any]: """Schema for the checkpoints index.""" return { "index": { "name": self._checkpoint_prefix, "prefix": self._checkpoint_prefix + REDIS_KEY_SEPARATOR, "storage_type": "json", }, "fields": [ {"name": "thread_id", "type": "tag"}, {"name": "run_id", "type": "tag"}, {"name": "checkpoint_ns", "type": "tag"}, {"name": "checkpoint_id", "type": "tag"}, {"name": "parent_checkpoint_id", "type": "tag"}, {"name": "checkpoint_ts", "type": "numeric"}, {"name": "source", "type": "tag"}, {"name": "step", "type": "numeric"}, {"name": "has_writes", "type": "tag"}, ], } @property def writes_schema(self) -> Dict[str, Any]: """Schema for the checkpoint writes index.""" return { "index": { "name": self._checkpoint_write_prefix, "prefix": self._checkpoint_write_prefix + REDIS_KEY_SEPARATOR, "storage_type": "json", }, "fields": [ {"name": "thread_id", "type": "tag"}, {"name": "checkpoint_ns", "type": "tag"}, {"name": "checkpoint_id", "type": "tag"}, {"name": "task_id", "type": "tag"}, {"name": "idx", "type": "numeric"}, {"name": "channel", "type": "tag"}, {"name": "type", "type": "tag"}, ], }
[docs] @abstractmethod def create_indexes(self) -> None: """Create appropriate SearchIndex instances.""" pass
[docs] @abstractmethod def configure_client( self, redis_url: Optional[str] = None, redis_client: Optional[RedisClientType] = None, connection_args: Optional[Dict[str, Any]] = None, ) -> None: """Configure the Redis client.""" pass
[docs] def set_client_info(self) -> None: """Set client info for Redis monitoring.""" from redis.exceptions import ResponseError from langgraph.checkpoint.redis.version import __full_lib_name__ try: # Try to use client_setinfo command if available self._redis.client_setinfo("LIB-NAME", __full_lib_name__) except (ResponseError, AttributeError): # Fall back to a simple echo if client_setinfo is not available try: self._redis.echo(__full_lib_name__) except Exception: # Silently fail if even echo doesn't work pass
[docs] async def aset_client_info(self) -> None: """Set client info for Redis monitoring asynchronously.""" from redis.exceptions import ResponseError from langgraph.checkpoint.redis.version import __full_lib_name__ try: # Try to use client_setinfo command if available await self._redis.client_setinfo("LIB-NAME", __full_lib_name__) except (ResponseError, AttributeError): # Fall back to a simple echo if client_setinfo is not available try: # Call with await to ensure it's an async call echo_result = self._redis.echo(__full_lib_name__) if hasattr(echo_result, "__await__"): await echo_result except Exception: # Silently fail if even echo doesn't work pass
[docs] def setup(self) -> None: """Initialize the indices in Redis.""" # Create indexes in Redis self.checkpoints_index.create(overwrite=False) self.checkpoint_writes_index.create(overwrite=False)
def _load_checkpoint( self, checkpoint: Union[Dict[str, Any], str], channel_values: Dict[str, Any], pending_sends: List[Any], ) -> Checkpoint: if not checkpoint: return {} # OPTIMIZED: Handle both dict and string inputs efficiently loaded = ( checkpoint if isinstance(checkpoint, dict) else cast(dict, orjson.loads(checkpoint)) ) return { **loaded, "pending_sends": [ self.serde.loads_typed((safely_decode(c), b)) for c, b in pending_sends or [] ], "channel_values": channel_values, } def _apply_ttl_to_keys( self, main_key: str, related_keys: Optional[list[str]] = None, ttl_minutes: Optional[float] = None, ) -> Any: """Apply Redis native TTL to keys. Args: main_key: The primary Redis key related_keys: Additional Redis keys that should expire at the same time ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided Use -1 to remove TTL (make keys persistent) Returns: Result of the Redis operation """ if ttl_minutes is None: # Check if there's a default TTL in config if self.ttl_config and "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: # Special case: -1 means remove TTL (make persistent) if ttl_minutes == -1: # Apply PERSIST individually per key so that a single failure # does not prevent TTL removal on the remaining keys. all_keys = [main_key] + (related_keys or []) for key in all_keys: try: self._redis.persist(key) except Exception: logger.warning("Failed to remove TTL from key: %s", key) return True # Regular TTL setting ttl_seconds = int(ttl_minutes * 60) # Apply TTL individually per key so that a single EXPIRE failure # (e.g. MOVED on Redis Enterprise proxy) does not prevent TTL # from being set on the remaining keys. all_keys = [main_key] + (related_keys or []) for key in all_keys: try: self._redis.expire(key, ttl_seconds) except Exception: logger.warning("Failed to apply TTL to key: %s", key) return True def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]: """Convert checkpoint to Redis format.""" type_, data = self.serde.dumps_typed(checkpoint) # Decode the serialized data - handle both JSON and msgpack if type_ == "json": checkpoint_data = cast(dict, orjson.loads(data)) else: # For msgpack or other types, deserialize with loads_typed checkpoint_data = cast(dict, self.serde.loads_typed((type_, data))) # When using msgpack, bytes are preserved - but Redis JSON.SET can't handle them # Encode bytes in channel_values with type marker for JSON storage if "channel_values" in checkpoint_data: for key, value in checkpoint_data["channel_values"].items(): if isinstance(value, bytes): checkpoint_data["channel_values"][key] = { "__bytes__": self._encode_blob(value) } # Ensure channel_versions are always strings to fix issue #40 if "channel_versions" in checkpoint_data: checkpoint_data["channel_versions"] = { k: str(v) for k, v in checkpoint_data["channel_versions"].items() } return {"type": type_, **checkpoint_data, "pending_sends": []} def _deserialize_channel_values( self, channel_values: dict[str, Any] ) -> dict[str, Any]: """Deserialize channel values that were stored inline. When channel values are stored inline in the checkpoint, they're in their serialized form. This method deserializes them back to their original types. This specifically handles LangChain message objects that may be stored in their serialized format: {'lc': 1, 'type': 'constructor', 'id': [...], 'kwargs': {...}} and ensures they are properly reconstructed as message objects. """ if not channel_values: return {} try: # Apply recursive deserialization to handle nested structures and LangChain objects return self._recursive_deserialize(channel_values) except Exception as e: logger.warning( f"Error deserializing channel values, attempting recovery: {e}" ) # Attempt to recover by processing each channel individually recovered = {} for key, value in channel_values.items(): try: recovered[key] = self._recursive_deserialize(value) except Exception as inner_e: logger.error( f"Failed to deserialize channel '{key}': {inner_e}. " f"Value will be returned as-is." ) recovered[key] = value return recovered def _recursive_deserialize(self, obj: Any) -> Any: """Recursively deserialize LangChain objects and nested structures. This method specifically handles the deserialization of LangChain message objects that may be stored in their serialized format to prevent MESSAGE_COERCION_FAILURE. Args: obj: The object to deserialize, which may be a dict, list, or primitive. Returns: The deserialized object, with LangChain objects properly reconstructed. """ if isinstance(obj, dict): # Check if this is a bytes marker from msgpack storage if "__bytes__" in obj and len(obj) == 1: # Decode base64-encoded bytes return self._decode_blob(obj["__bytes__"]) # Check if this is a Send object marker (issue #94) if ( obj.get("__send__") is True and "node" in obj and "arg" in obj and len(obj) == 3 ): try: from langgraph.types import Send return Send( node=obj["node"], arg=self._recursive_deserialize(obj["arg"]), ) except (ImportError, TypeError, ValueError) as e: logger.debug( "Failed to deserialize Send object: %s", e, exc_info=True ) # Check if this is a LangChain serialized object if obj.get("lc") in (1, 2) and obj.get("type") == "constructor": try: # Use the serde's reviver to reconstruct the object if hasattr(self.serde, "_revive_if_needed"): return self.serde._revive_if_needed(obj) elif hasattr(self.serde, "_reviver"): return self.serde._reviver(obj) else: # Log warning if serde doesn't have reviver logger.warning( "Serializer does not have a reviver method. " "LangChain object may not be properly deserialized. " f"Object ID: {obj.get('id')}" ) return obj except Exception as e: # Provide detailed error message for debugging obj_id = obj.get("id", "unknown") obj_type = ( obj.get("id", ["unknown"])[-1] if isinstance(obj.get("id"), list) else "unknown" ) logger.error( f"Failed to deserialize LangChain object of type '{obj_type}'. " f"This may cause MESSAGE_COERCION_FAILURE. Error: {e}. " f"Object structure: lc={obj.get('lc')}, type={obj.get('type')}, " f"id={obj_id}" ) # Return the object as-is to prevent complete failure return obj # Recursively process nested dicts return {k: self._recursive_deserialize(v) for k, v in obj.items()} elif isinstance(obj, list): # Recursively process lists return [self._recursive_deserialize(item) for item in obj] else: # Return primitives as-is return obj def _dump_writes( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, task_id: str, writes: Sequence[tuple[str, Any]], ) -> list[dict[str, Any]]: """Convert write operations for Redis storage.""" return [ { "thread_id": to_storage_safe_id(thread_id), "checkpoint_ns": to_storage_safe_str(checkpoint_ns), "checkpoint_id": to_storage_safe_id(checkpoint_id), "task_id": task_id, "idx": WRITES_IDX_MAP.get(channel, idx), "channel": channel, "type": t, "blob": self._encode_blob(b), # Encode bytes to base64 string for Redis } for idx, (channel, value) in enumerate(writes) for t, b in [self.serde.dumps_typed(value)] ] def _load_metadata(self, metadata: dict[str, Any]) -> CheckpointMetadata: """Load metadata from Redis-compatible dictionary. Args: metadata: Dictionary representation from Redis. Returns: Original metadata dictionary. """ # Roundtrip through serializer to ensure proper type handling type_str, data_bytes = self.serde.dumps_typed(metadata) return self.serde.loads_typed((type_str, data_bytes)) def _dump_metadata(self, metadata: CheckpointMetadata) -> str: """Convert metadata to a Redis-compatible dictionary. Args: metadata: Metadata to convert. Returns: Dictionary representation of metadata for Redis storage. """ type_str, serialized_bytes = self.serde.dumps_typed(metadata) # NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing return serialized_bytes.decode().replace("\\u0000", "")
[docs] def get_next_version( # type: ignore[override] self, current: Optional[str], channel: ChannelProtocol[Any, Any, Any] ) -> str: """Generate next version number.""" if current is None: current_v = 0 elif isinstance(current, int): current_v = current else: current_v = int(current.split(".")[0]) next_v = current_v + 1 next_h = random.random() return f"{next_v:032}.{next_h:016}"
def _encode_blob(self, blob: Any) -> str: """Encode blob data for Redis storage.""" if isinstance(blob, bytes): return base64.b64encode(blob).decode() return blob def _decode_blob(self, blob: str) -> bytes: """Decode blob data from Redis storage.""" try: return base64.b64decode(blob) except (binascii.Error, TypeError): # Handle both malformed base64 data and incorrect input types return blob.encode() if isinstance(blob, str) else blob def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]: """Load writes from Redis JSON storage by key.""" if not write_key: return [] # Get the full JSON document # Cast needed: redis-py types json().get() as List[JsonType] but returns dict result = cast(Optional[Dict[str, Any]], self._redis.json().get(write_key)) if not result: return [] writes = [] for write in result["writes"]: writes.append( ( write["task_id"], write["channel"], self.serde.loads_typed( (write["type"], self._decode_blob(write["blob"])) ), ) ) return writes
[docs] def put_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, task_path: str = "", ) -> None: """Store intermediate writes linked to a checkpoint. Args: config: Configuration of the related checkpoint. writes: List of writes to store, each as (channel, value) pair. task_id: Identifier for the task creating the writes. task_path: Optional path info for the task. """ thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"].get("checkpoint_ns", "") checkpoint_id = config["configurable"]["checkpoint_id"] # Transform writes into appropriate format writes_objects = [] for idx, (channel, value) in enumerate(writes): type_, blob = self.serde.dumps_typed(value) write_obj = { "thread_id": to_storage_safe_id(thread_id), "checkpoint_ns": to_storage_safe_str(checkpoint_ns), "checkpoint_id": to_storage_safe_id(checkpoint_id), "task_id": task_id, "task_path": task_path, "idx": WRITES_IDX_MAP.get(channel, idx), "channel": channel, "type": type_, "blob": self._encode_blob( blob ), # Encode bytes to base64 string for Redis } writes_objects.append(write_obj) # For each write, check existence and then perform appropriate operation with self._redis.json().pipeline(transaction=False) as pipeline: # Keep track of keys we're creating created_keys = [] for write_obj in writes_objects: idx_value = write_obj["idx"] assert isinstance(idx_value, int) key = self._make_redis_checkpoint_writes_key( thread_id, checkpoint_ns, checkpoint_id, task_id, idx_value, ) # First check if key exists key_exists = self._redis.exists(key) == 1 if all(w[0] in WRITES_IDX_MAP for w in writes): # UPSERT case - only update specific fields if key_exists: # Update only channel, type, and blob fields pipeline.json().set(key, "$.channel", write_obj["channel"]) pipeline.json().set(key, "$.type", write_obj["type"]) pipeline.json().set(key, "$.blob", write_obj["blob"]) else: # For new records, set the complete object pipeline.json().set(key, "$", write_obj) created_keys.append(key) else: # INSERT case - only insert if doesn't exist if not key_exists: pipeline.json().set(key, "$", write_obj) created_keys.append(key) pipeline.execute() # Apply TTL to newly created keys if created_keys and self.ttl_config and "default_ttl" in self.ttl_config: self._apply_ttl_to_keys( created_keys[0], created_keys[1:] if len(created_keys) > 1 else None ) # Update checkpoint to indicate it has writes if writes_objects: checkpoint_key = self._make_redis_checkpoint_key( to_storage_safe_id(thread_id), to_storage_safe_str(checkpoint_ns), to_storage_safe_id(checkpoint_id), ) # Check if the checkpoint exists before updating if self._redis.exists(checkpoint_key): # JSON.SET can add new fields at non-root paths for existing documents # Use JSONPath $ to update at root level self._redis.json().set(checkpoint_key, "$.has_writes", True)
def _load_pending_writes( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str ) -> List[PendingWrite]: if checkpoint_id is None: return [] # Early return if no checkpoint_id # Most checkpoints don't have writes, return empty list quickly # Quick check: see if write registry exists and has any keys write_registry_key = self._key_registry.make_write_keys_zset_key( thread_id, checkpoint_ns, checkpoint_id ) registry_exists = self._redis.exists(write_registry_key) if not registry_exists: # No writes registry means no writes return [] # Use search index instead of keys() to avoid CrossSlot errors # Note: All tag fields use sentinel values for consistency writes_query = FilterQuery( filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id)) & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)) & (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)), return_fields=["task_id", "idx", "channel", "type", "$.blob"], num_results=1000, # Adjust as needed ) writes_results = self.checkpoint_writes_index.search(writes_query) # Sort results by idx to maintain order sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0)) # Build the writes dictionary writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {} for doc in sorted_writes: task_id = str(getattr(doc, "task_id", "")) idx = str(getattr(doc, "idx", 0)) blob_data = getattr(doc, "$.blob", "") # Ensure blob is bytes for deserialization if isinstance(blob_data, str): blob_data = blob_data.encode("utf-8") writes_dict[(task_id, idx)] = { "task_id": task_id, "idx": idx, "channel": str(getattr(doc, "channel", "")), "type": str(getattr(doc, "type", "")), "blob": blob_data, } pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict) return pending_writes def _load_pending_writes_with_registry_check( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, checkpoint_has_writes: bool, registry_has_writes: bool, ) -> List[PendingWrite]: """Load pending writes with pre-computed registry check to avoid duplicate Redis calls.""" if checkpoint_id is None: return [] # Early return if no checkpoint_id # Pre-computed registry check instead of making another Redis call if not registry_has_writes: # No writes in registry means no writes to load return [] # Also check checkpoint-level has_writes flag for additional optimization if not checkpoint_has_writes: return [] # Fallback to original FT.SEARCH logic since registry indicates writes exist # Use search index instead of keys() to avoid CrossSlot errors # Note: All tag fields use sentinel values for consistency writes_query = FilterQuery( filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id)) & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)) & (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)), return_fields=["task_id", "idx", "channel", "type", "$.blob"], num_results=1000, # Adjust as needed ) writes_results = self.checkpoint_writes_index.search(writes_query) # Sort results by idx to maintain order sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0)) # Build the writes dictionary writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {} for doc in sorted_writes: task_id = str(getattr(doc, "task_id", "")) idx = str(getattr(doc, "idx", 0)) blob_data = getattr(doc, "$.blob", "") # Ensure blob is bytes for deserialization if isinstance(blob_data, str): blob_data = blob_data.encode("utf-8") writes_dict[(task_id, idx)] = { "task_id": task_id, "idx": idx, "channel": str(getattr(doc, "channel", "")), "type": str(getattr(doc, "type", "")), "blob": blob_data, } pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict) return pending_writes @staticmethod def _load_writes( serde: SerializerProtocol, task_id_to_data: dict[tuple[str, str], dict] ) -> list[PendingWrite]: """Deserialize pending writes.""" writes = [ ( task_id, data["channel"], serde.loads_typed( (data["type"], BaseRedisSaver._decode_blob_static(data["blob"])) ), ) for (task_id, _), data in task_id_to_data.items() ] return writes @staticmethod def _decode_blob_static(blob: bytes | str) -> bytes: """Decode blob data from Redis storage (static method).""" try: # If it's already bytes, try to decode as base64 if isinstance(blob, bytes): return base64.b64decode(blob) # If it's a string, encode to bytes first then decode return base64.b64decode(blob.encode("utf-8")) except (binascii.Error, TypeError, ValueError): # Handle both malformed base64 data and incorrect input types return blob.encode("utf-8") if isinstance(blob, str) else blob @staticmethod def _parse_redis_checkpoint_writes_key(redis_key: str) -> dict: # Ensure redis_key is a string redis_key = safely_decode(redis_key) parts = redis_key.split(REDIS_KEY_SEPARATOR) # Ensure we have at least 6 parts if len(parts) < 6: raise ValueError( f"Expected at least 6 parts in Redis key, got {len(parts)}" ) # Extract the first 6 parts regardless of total length namespace, thread_id, checkpoint_ns, checkpoint_id, task_id, idx = parts[:6] if namespace != CHECKPOINT_WRITE_PREFIX: raise ValueError("Expected checkpoint key to start with 'checkpoint'") return { "thread_id": to_storage_safe_str(thread_id), "checkpoint_ns": to_storage_safe_str(checkpoint_ns), "checkpoint_id": to_storage_safe_str(checkpoint_id), "task_id": task_id, "idx": idx, } def _make_redis_checkpoint_key( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str ) -> str: return REDIS_KEY_SEPARATOR.join( [ self._checkpoint_prefix, str(to_storage_safe_id(thread_id)), to_storage_safe_str(checkpoint_ns), str(to_storage_safe_id(checkpoint_id)), ] ) def _make_redis_checkpoint_writes_key( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, task_id: str, idx: Optional[int], ) -> str: storage_safe_thread_id = str(to_storage_safe_id(thread_id)) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) storage_safe_checkpoint_id = str(to_storage_safe_id(checkpoint_id)) if idx is None: return REDIS_KEY_SEPARATOR.join( [ self._checkpoint_write_prefix, storage_safe_thread_id, storage_safe_checkpoint_ns, storage_safe_checkpoint_id, task_id, ] ) return REDIS_KEY_SEPARATOR.join( [ self._checkpoint_write_prefix, storage_safe_thread_id, storage_safe_checkpoint_ns, storage_safe_checkpoint_id, task_id, str(idx), ] )