Source code for langgraph.checkpoint.redis

from __future__ import annotations

import json
import logging
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast

import orjson
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
    WRITES_IDX_MAP,
    ChannelVersions,
    Checkpoint,
    CheckpointMetadata,
    CheckpointTuple,
    PendingWrite,
    get_checkpoint_id,
)
from langgraph.constants import TASKS
from redis import Redis
from redis.cluster import RedisCluster
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
from redisvl.redis.connection import RedisConnectionFactory
from ulid import ULID

from langgraph.checkpoint.redis.aio import AsyncRedisSaver
from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver
from langgraph.checkpoint.redis.base import (
    CHECKPOINT_PREFIX,
    CHECKPOINT_WRITE_PREFIX,
    REDIS_KEY_SEPARATOR,
    BaseRedisSaver,
)
from langgraph.checkpoint.redis.key_registry import SyncCheckpointKeyRegistry
from langgraph.checkpoint.redis.message_exporter import (
    LangChainRecipe,
    MessageExporter,
    MessageRecipe,
)
from langgraph.checkpoint.redis.shallow import ShallowRedisSaver
from langgraph.checkpoint.redis.util import (
    EMPTY_ID_SENTINEL,
    from_storage_safe_id,
    from_storage_safe_str,
    to_storage_safe_id,
    to_storage_safe_str,
)
from langgraph.checkpoint.redis.version import __lib_name__, __version__

logger = logging.getLogger(__name__)


[docs] class RedisSaver(BaseRedisSaver[Union[Redis, RedisCluster], SearchIndex]): """Standard Redis implementation for checkpoint saving. Supports standard Redis URLs (redis://), SSL (rediss://), and Sentinel URLs (redis+sentinel://host:26379/service_name/db). """ _redis: Union[Redis, RedisCluster] # Support both standalone and cluster clients # Whether to assume the Redis server is a cluster; None triggers auto-detection cluster_mode: Optional[bool] = None def __init__( self, redis_url: Optional[str] = None, *, redis_client: Optional[Union[Redis, RedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ttl: Optional[Dict[str, Any]] = None, checkpoint_prefix: str = CHECKPOINT_PREFIX, checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX, ) -> None: super().__init__( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, ttl=ttl, checkpoint_prefix=checkpoint_prefix, checkpoint_write_prefix=checkpoint_write_prefix, ) # Prefixes are now set in BaseRedisSaver.__init__ self._separator = REDIS_KEY_SEPARATOR # Instance-level cache for frequently used keys (limited size to prevent memory issues) self._key_cache: Dict[str, str] = {} self._key_cache_max_size = 1000 # Configurable limit # Key registry will be initialized in setup() self._key_registry: Optional[SyncCheckpointKeyRegistry] = None
[docs] def configure_client( self, redis_url: Optional[str] = None, redis_client: Optional[Union[Redis, RedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ) -> None: """Configure the Redis client. Supports standard Redis URLs (redis://), SSL (rediss://), and Sentinel URLs (redis+sentinel://host:26379/service_name/db). """ from redis.exceptions import ResponseError from langgraph.checkpoint.redis.version import __full_lib_name__ self._owns_its_client = redis_client is None self._redis = redis_client or RedisConnectionFactory.get_redis_connection( redis_url, **connection_args ) # Set client info for Redis monitoring try: self._redis.client_setinfo("LIB-NAME", __full_lib_name__) except (ResponseError, AttributeError): # Fall back to a simple echo if client_setinfo is not available try: self._redis.echo(__full_lib_name__) except Exception: # Silently fail if even echo doesn't work pass
[docs] def create_indexes(self) -> None: self.checkpoints_index = SearchIndex.from_dict( self.checkpoints_schema, redis_client=self._redis ) self.checkpoint_writes_index = SearchIndex.from_dict( self.writes_schema, redis_client=self._redis )
def _make_redis_checkpoint_key_cached( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str ) -> str: """Optimized key generation with caching.""" # Create cache key cache_key = f"ckpt:{thread_id}:{checkpoint_ns}:{checkpoint_id}" # Check cache first if cache_key in self._key_cache: return self._key_cache[cache_key] # Generate key using optimized string operations safe_thread_id = str(to_storage_safe_id(thread_id)) safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) safe_checkpoint_id = str(to_storage_safe_id(checkpoint_id)) # Use pre-computed prefix and join key = self._separator.join( [ self._checkpoint_prefix, safe_thread_id, safe_checkpoint_ns, safe_checkpoint_id, ] ) # Cache for future use (limit cache size to prevent memory issues) if len(self._key_cache) < self._key_cache_max_size: self._key_cache[cache_key] = key return key def _make_redis_checkpoint_writes_key_cached( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, task_id: str, idx: Optional[int], ) -> str: """Optimized writes key generation with caching.""" # Create cache key cache_key = f"write:{thread_id}:{checkpoint_ns}:{checkpoint_id}:{task_id}:{idx}" # Check cache first if cache_key in self._key_cache: return self._key_cache[cache_key] # Generate key using optimized string operations safe_thread_id = str(to_storage_safe_id(thread_id)) safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) safe_checkpoint_id = str(to_storage_safe_id(checkpoint_id)) # Build key components key_parts = [ self._checkpoint_write_prefix, safe_thread_id, safe_checkpoint_ns, safe_checkpoint_id, task_id, ] if idx is not None: key_parts.append(str(idx)) key = self._separator.join(key_parts) # Cache for future use (limit cache size) if len(self._key_cache) < self._key_cache_max_size: self._key_cache[cache_key] = key return key
[docs] def setup(self) -> None: """Initialize the indices in Redis and detect cluster mode.""" self._detect_cluster_mode() super().setup() # Initialize key registry for this instance if self._redis and not self._key_registry: self._key_registry = SyncCheckpointKeyRegistry(self._redis)
def _detect_cluster_mode(self) -> None: """Detect if the Redis client is a cluster client by inspecting its class.""" if self.cluster_mode is not None: logger.info( f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection." ) return # Determine cluster mode based on client class if isinstance(self._redis, RedisCluster): logger.info("Redis client is a cluster client") self.cluster_mode = True else: logger.info("Redis client is a standalone client") self.cluster_mode = False
[docs] def list( self, config: Optional[RunnableConfig], *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, # noqa: ARG002 limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from Redis.""" # Construct the filter expression filter_expression = [] if config: filter_expression.append( Tag("thread_id") == to_storage_safe_id(config["configurable"]["thread_id"]) ) if run_id := config["configurable"].get("run_id"): filter_expression.append(Tag("run_id") == to_storage_safe_id(run_id)) # Search for checkpoints with any namespace, including an empty # string, while `checkpoint_id` has to have a value. if checkpoint_ns := config["configurable"].get("checkpoint_ns"): filter_expression.append( Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns) ) if checkpoint_id := get_checkpoint_id(config): filter_expression.append( Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id) ) if filter: for k, v in filter.items(): if k == "source": filter_expression.append(Tag("source") == v) elif k == "step": filter_expression.append(Num("step") == v) elif k == "thread_id": filter_expression.append(Tag("thread_id") == to_storage_safe_id(v)) elif k == "run_id": filter_expression.append(Tag("run_id") == to_storage_safe_id(v)) else: raise ValueError(f"Unsupported filter key: {k}") if before: before_checkpoint_id = get_checkpoint_id(before) if before_checkpoint_id: try: before_ulid = ULID.from_str(before_checkpoint_id) before_ts = before_ulid.timestamp # Use numeric range query: checkpoint_ts < before_ts filter_expression.append(Num("checkpoint_ts") < before_ts) except Exception: # If not a valid ULID, ignore the before filter pass # Combine all filter expressions combined_filter = filter_expression[0] if filter_expression else "*" for expr in filter_expression[1:]: combined_filter &= expr # Construct the Redis query # Sort by checkpoint_id in descending order to get most recent checkpoints first query = FilterQuery( filter_expression=combined_filter, return_fields=[ "thread_id", "checkpoint_ns", "checkpoint_id", "parent_checkpoint_id", "$.checkpoint", "$.metadata", "has_writes", # Include has_writes to optimize pending_writes loading ], num_results=limit or 10000, sort_by=("checkpoint_id", "DESC"), ) # Execute the query results = self.checkpoints_index.search(query) # Pre-process all docs to collect batch query requirements all_docs_data = [] pending_sends_batch_keys = [] pending_writes_batch_keys = [] for doc in results.docs: # Extract all attributes once doc_dict = doc.__dict__ if hasattr(doc, "__dict__") else {} thread_id = from_storage_safe_id(doc["thread_id"]) checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"]) checkpoint_id = from_storage_safe_id(doc["checkpoint_id"]) parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"]) # Get channel values from inline checkpoint data (already returned by FT.SEARCH) checkpoint_data = doc_dict.get("$.checkpoint") or getattr( doc, "$.checkpoint", None ) if checkpoint_data: # Parse checkpoint to extract inline channel_values if isinstance(checkpoint_data, list) and checkpoint_data: checkpoint_data = checkpoint_data[0] # Use orjson for faster parsing checkpoint_dict = ( checkpoint_data if isinstance(checkpoint_data, dict) else orjson.loads(checkpoint_data) ) channel_values = checkpoint_dict.get("channel_values", {}) else: # If checkpoint data is missing, the document is corrupted # Set empty channel values rather than attempting a fallback channel_values = {} # Collect batch keys for pending_sends if parent_checkpoint_id and parent_checkpoint_id != "None": batch_key = (thread_id, checkpoint_ns, parent_checkpoint_id) pending_sends_batch_keys.append(batch_key) # Collect batch keys for pending_writes checkpoint_has_writes = doc_dict.get("has_writes") or getattr( doc, "has_writes", False ) # Convert string "False" to boolean false if needed (optimize for common case) if checkpoint_has_writes == "true": checkpoint_has_writes = True elif checkpoint_has_writes == "false" or checkpoint_has_writes == "False": checkpoint_has_writes = False if checkpoint_has_writes: batch_key = (thread_id, checkpoint_ns, checkpoint_id) pending_writes_batch_keys.append(batch_key) # Store processed doc data for final iteration all_docs_data.append( { "doc": doc, "doc_dict": doc_dict, "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, "parent_checkpoint_id": parent_checkpoint_id, "checkpoint_data": checkpoint_data, "checkpoint_dict": checkpoint_dict if checkpoint_data else None, "channel_values": channel_values, "has_writes": checkpoint_has_writes, } ) # Load pending_sends for all parent checkpoints at once pending_sends_map = {} if pending_sends_batch_keys: pending_sends_map = self._batch_load_pending_sends(pending_sends_batch_keys) # Load pending_writes for all checkpoints with writes at once pending_writes_map = {} if pending_writes_batch_keys: pending_writes_map = self._batch_load_pending_writes( pending_writes_batch_keys ) # Process the results using pre-loaded batch data for doc_data in all_docs_data: thread_id = doc_data["thread_id"] checkpoint_ns = doc_data["checkpoint_ns"] checkpoint_id = doc_data["checkpoint_id"] parent_checkpoint_id = doc_data["parent_checkpoint_id"] # Get pending_sends from batch results pending_sends: List[Tuple[str, bytes]] = [] if parent_checkpoint_id: batch_key = (thread_id, checkpoint_ns, parent_checkpoint_id) pending_sends = pending_sends_map.get(batch_key, []) # Fetch and parse metadata doc_dict = doc_data["doc_dict"] raw_metadata = doc_dict.get("$.metadata") or getattr( doc_data["doc"], "$.metadata", "{}" ) # Use orjson for faster parsing metadata_dict = ( orjson.loads(raw_metadata) if isinstance(raw_metadata, str) else raw_metadata ) # Only sanitize if null bytes detected (rare case) if any( "\u0000" in str(v) for v in metadata_dict.values() if isinstance(v, str) ): sanitized_metadata = { k.replace("\u0000", ""): ( v.replace("\u0000", "") if isinstance(v, str) else v ) for k, v in metadata_dict.items() } metadata = cast(CheckpointMetadata, sanitized_metadata) else: metadata = cast(CheckpointMetadata, metadata_dict) # Pre-create the config structure more efficiently config_param: RunnableConfig = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } } # Pass already parsed checkpoint_dict to avoid re-parsing checkpoint_param = self._load_checkpoint( ( doc_data["checkpoint_dict"] if doc_data["checkpoint_data"] else doc_data["doc"]["$.checkpoint"] ), doc_data["channel_values"], pending_sends, ) # Get pending_writes from batch results pending_writes: List[PendingWrite] = [] if doc_data["has_writes"]: batch_key = (thread_id, checkpoint_ns, checkpoint_id) pending_writes = pending_writes_map.get(batch_key, []) # Build parent config if parent_checkpoint_id exists parent_config: RunnableConfig | None = None if parent_checkpoint_id: parent_config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } yield CheckpointTuple( config=config_param, checkpoint=checkpoint_param, metadata=metadata, parent_config=parent_config, pending_writes=pending_writes, )
[docs] def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Store a checkpoint to Redis with inline channel value storage.""" configurable = config["configurable"].copy() run_id = configurable.pop("run_id", metadata.get("run_id")) thread_id = configurable.pop("thread_id") checkpoint_ns = configurable.pop("checkpoint_ns") # Get checkpoint_id from config - this will be parent if saving a child config_checkpoint_id = configurable.pop("checkpoint_id", None) # For backward compatibility with thread_ts thread_ts = configurable.pop("thread_ts", "") # Determine the checkpoint ID checkpoint_id = config_checkpoint_id or thread_ts or checkpoint.get("id", "") # If checkpoint has its own ID that's different from what we'd use, # and we have a config checkpoint_id, then config checkpoint_id is the parent parent_checkpoint_id = None if ( checkpoint.get("id") and config_checkpoint_id and checkpoint.get("id") != config_checkpoint_id ): parent_checkpoint_id = config_checkpoint_id checkpoint_id = checkpoint["id"] # Convert empty strings to the sentinel value. storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id) copy = checkpoint.copy() # When we return the config, we need to preserve empty strings that # were passed in, instead of the sentinel value. next_config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } } # Extract timestamp from checkpoint_id (ULID) checkpoint_ts = None if checkpoint_id: try: from ulid import ULID ulid_obj = ULID.from_str(checkpoint_id) checkpoint_ts = ulid_obj.timestamp # milliseconds since epoch except Exception: # If not a valid ULID, use current time import time checkpoint_ts = time.time() * 1000 checkpoint_data = { "thread_id": storage_safe_thread_id, "run_id": to_storage_safe_id(run_id) if run_id else "", "checkpoint_ns": storage_safe_checkpoint_ns, "checkpoint_id": storage_safe_checkpoint_id, "parent_checkpoint_id": ( to_storage_safe_id(parent_checkpoint_id) if parent_checkpoint_id else "" ), "checkpoint_ts": checkpoint_ts, "checkpoint": self._dump_checkpoint(copy), # Includes channel_values inline "metadata": self._dump_metadata(metadata), "has_writes": False, # Track if this checkpoint has pending writes } # Store at top-level for filters in list() if all(key in metadata for key in ["source", "step"]): checkpoint_data["source"] = metadata["source"] checkpoint_data["step"] = metadata["step"] # Create the checkpoint key checkpoint_key = self._make_redis_checkpoint_key_cached( thread_id, checkpoint_ns, checkpoint_id, ) # Calculate TTL in seconds if configured ttl_seconds = None if self.ttl_config and "default_ttl" in self.ttl_config: ttl_seconds = int(self.ttl_config["default_ttl"] * 60) # Store checkpoint with TTL in a single pipeline operation self.checkpoints_index.load( [checkpoint_data], keys=[checkpoint_key], ttl=ttl_seconds, # RedisVL applies TTL in its internal pipeline ) # Update latest checkpoint pointer latest_pointer_key = ( f"checkpoint_latest:{storage_safe_thread_id}:{storage_safe_checkpoint_ns}" ) self._redis.set(latest_pointer_key, checkpoint_key) # Apply TTL to latest pointer key as well (best-effort) if ttl_seconds is not None: try: self._redis.expire(latest_pointer_key, ttl_seconds) except Exception: logger.warning( "Failed to apply TTL to latest pointer key: %s", latest_pointer_key, ) return next_config
[docs] def put_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, task_path: str = "", ) -> None: """Store intermediate writes linked to a checkpoint with integrated key registry.""" thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"].get("checkpoint_ns", "") checkpoint_id = config["configurable"]["checkpoint_id"] # Transform writes into appropriate format writes_objects = [] for idx, (channel, value) in enumerate(writes): type_, blob = self.serde.dumps_typed(value) write_obj = { "thread_id": to_storage_safe_id(thread_id), "checkpoint_ns": to_storage_safe_str(checkpoint_ns), "checkpoint_id": to_storage_safe_id(checkpoint_id), "task_id": task_id, "task_path": task_path, "idx": WRITES_IDX_MAP.get(channel, idx), "channel": channel, "type": type_, "blob": self._encode_blob( blob ), # Encode bytes to base64 string for Redis } writes_objects.append(write_obj) # IMPORTANT: Only critical commands (JSON.SET, JSON.MERGE) go in the pipeline. # EXPIRE (TTL) commands are applied separately afterward to avoid pipeline # failures on Redis Enterprise proxy, where mixed JSON module + native commands # in a single pipeline can cause EXPIRE to fail, aborting the entire pipeline # and losing interrupt writes. write_keys: list[str] = [] checkpoint_key = "" merge_failed = False with self._redis.pipeline(transaction=False) as pipeline: for write_obj in writes_objects: idx_value = write_obj["idx"] assert isinstance(idx_value, int) key = self._make_redis_checkpoint_writes_key_cached( thread_id, checkpoint_ns, checkpoint_id, task_id, idx_value, ) write_keys.append(key) pipeline.json().set(key, "$", cast(Any, write_obj)) # Update checkpoint to indicate it has writes (critical) if writes_objects: checkpoint_key = self._make_redis_checkpoint_key_cached( thread_id, checkpoint_ns, checkpoint_id ) pipeline.json().merge(checkpoint_key, "$", {"has_writes": True}) # Execute critical commands with raise_on_error=False results = pipeline.execute(raise_on_error=False) # Check results for critical command failures for result in results: if isinstance(result, Exception): err_str = str(result) if "JSON.MERGE" in err_str or "merge" in err_str.lower(): merge_failed = True else: raise result # Handle JSON.MERGE fallback for older Redis versions if merge_failed and checkpoint_key: try: checkpoint_data = self._redis.json().get(checkpoint_key) if isinstance(checkpoint_data, dict) and not checkpoint_data.get( "has_writes" ): checkpoint_data["has_writes"] = True self._redis.json().set(checkpoint_key, "$", checkpoint_data) except Exception: pass # Apply TTL separately (best-effort — failures here don't lose writes). # Individual calls ensure partial success: if one key's EXPIRE fails # on RE proxy, the others still get TTL applied. if write_keys and self.ttl_config and "default_ttl" in self.ttl_config: ttl_seconds = int(self.ttl_config["default_ttl"] * 60) for key in write_keys: try: self._redis.expire(key, ttl_seconds) except Exception: logger.warning( "Failed to apply TTL to checkpoint write key: %s", key ) # Update key registry with the write keys if self._key_registry and write_keys: self._key_registry.register_write_keys_batch( thread_id, checkpoint_ns, checkpoint_id, write_keys ) # Apply TTL to registry key (already best-effort inside apply_ttl) if self.ttl_config and "default_ttl" in self.ttl_config: ttl_seconds = int(self.ttl_config["default_ttl"] * 60) self._key_registry.apply_ttl( thread_id, checkpoint_ns, checkpoint_id, ttl_seconds )
def _get_checkpoint_document_by_id( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str ) -> Optional[dict]: """Get checkpoint document by specific ID using direct key access.""" checkpoint_key = self._make_redis_checkpoint_key_cached( thread_id, checkpoint_ns, checkpoint_id ) checkpoint_data = self._redis.json().get(checkpoint_key) if not checkpoint_data or not isinstance(checkpoint_data, dict): return None # Extract the actual checkpoint data checkpoint_inner = checkpoint_data.get("checkpoint", {}) return { "thread_id": checkpoint_data.get( "thread_id", to_storage_safe_id(thread_id) ), "checkpoint_ns": checkpoint_data.get( "checkpoint_ns", to_storage_safe_str(checkpoint_ns) ), "checkpoint_id": checkpoint_data.get( "checkpoint_id", to_storage_safe_id(checkpoint_id) ), "parent_checkpoint_id": checkpoint_data.get( "parent_checkpoint_id", to_storage_safe_id(checkpoint_id) ), "$.checkpoint": ( json.dumps(checkpoint_inner) if isinstance(checkpoint_inner, dict) else checkpoint_inner ), "$.metadata": checkpoint_data.get("metadata", "{}"), "_channel_versions": ( checkpoint_inner.get("channel_versions") if isinstance(checkpoint_inner, dict) else None ), "has_writes": checkpoint_data.get("has_writes", False), } def _get_latest_checkpoint_document( self, thread_id: str, checkpoint_ns: str ) -> Optional[dict]: """Get latest checkpoint document using pointer.""" storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) # Get latest checkpoint using pointer latest_pointer_key = ( f"checkpoint_latest:{storage_safe_thread_id}:{storage_safe_checkpoint_ns}" ) checkpoint_key_bytes = self._redis.get(latest_pointer_key) if not checkpoint_key_bytes: # No pointer means no checkpoints exist return None # Decode bytes to string checkpoint_key = ( checkpoint_key_bytes.decode() if isinstance(checkpoint_key_bytes, bytes) else checkpoint_key_bytes ) checkpoint_data = self._redis.json().get(str(checkpoint_key)) if not checkpoint_data or not isinstance(checkpoint_data, dict): # Pointer exists but checkpoint is missing - data inconsistency return None checkpoint_inner = checkpoint_data.get("checkpoint", {}) return { "thread_id": checkpoint_data.get("thread_id", storage_safe_thread_id), "checkpoint_ns": checkpoint_data.get( "checkpoint_ns", storage_safe_checkpoint_ns ), "checkpoint_id": checkpoint_data.get("checkpoint_id"), "parent_checkpoint_id": checkpoint_data.get("parent_checkpoint_id"), "$.checkpoint": ( json.dumps(checkpoint_inner) if isinstance(checkpoint_inner, dict) else checkpoint_inner ), "$.metadata": checkpoint_data.get("metadata", "{}"), "_channel_versions": ( checkpoint_inner.get("channel_versions") if isinstance(checkpoint_inner, dict) else None ), "has_writes": checkpoint_data.get("has_writes", False), # Store the full checkpoint data to avoid re-fetching "_checkpoint_data": checkpoint_data, } def _refresh_checkpoint_ttl( self, doc_thread_id: str, doc_checkpoint_ns: str, doc_checkpoint_id: str ) -> None: """Refresh TTL for checkpoint and all related keys.""" if not self.ttl_config or not self.ttl_config.get("refresh_on_read"): return checkpoint_key = self._make_redis_checkpoint_key_cached( doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id, ) # Get write keys write_keys = [] if self._key_registry: write_keys = self._key_registry.get_write_keys( doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id ) else: # Use search indices as fallback write_keys = self._get_write_keys_from_search( doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id ) # Apply TTL to all keys self._apply_ttl_to_keys(checkpoint_key, write_keys) # Refresh registry key TTL if self._key_registry and self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) # Registry TTL is handled per checkpoint self._key_registry.apply_ttl( doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id, ttl_seconds ) def _get_write_keys_from_search( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str ) -> List[str]: """Get write keys using search index.""" write_query = FilterQuery( filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id)) & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)) & (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)), return_fields=["task_id", "idx"], num_results=1000, ) write_results = self.checkpoint_writes_index.search(write_query) return [ self._make_redis_checkpoint_writes_key( to_storage_safe_id(thread_id), to_storage_safe_str(checkpoint_ns), to_storage_safe_id(checkpoint_id), getattr(doc, "task_id", ""), getattr(doc, "idx", 0), ) for doc in write_results.docs ]
[docs] def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from Redis. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ thread_id = config["configurable"]["thread_id"] checkpoint_id = get_checkpoint_id(config) checkpoint_ns = config["configurable"].get("checkpoint_ns", "") # For values we store in Redis, we need to convert empty strings to the # sentinel value. storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) if checkpoint_id and checkpoint_id != EMPTY_ID_SENTINEL: # Direct key access when checkpoint_id is known - no fallback needed storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id) # Construct direct key for checkpoint data checkpoint_key = self._make_redis_checkpoint_key_cached( thread_id, checkpoint_ns, checkpoint_id ) # Direct key access only checkpoint_data = self._redis.json().get(checkpoint_key) if not checkpoint_data or not isinstance(checkpoint_data, dict): # Checkpoint doesn't exist return None # Process checkpoint data from direct access # Create doc-like object from direct access # Extract the actual checkpoint data checkpoint_inner = checkpoint_data.get("checkpoint", {}) doc = { "thread_id": checkpoint_data.get("thread_id", storage_safe_thread_id), "checkpoint_ns": checkpoint_data.get( "checkpoint_ns", storage_safe_checkpoint_ns ), "checkpoint_id": checkpoint_data.get( "checkpoint_id", storage_safe_checkpoint_id ), "parent_checkpoint_id": checkpoint_data.get( "parent_checkpoint_id", storage_safe_checkpoint_id ), "$.checkpoint": ( json.dumps(checkpoint_inner) if isinstance(checkpoint_inner, dict) else checkpoint_inner ), "$.metadata": checkpoint_data.get( "metadata", "{}" ), # metadata is already a JSON string # Store channel_versions for easy access "_channel_versions": ( checkpoint_inner.get("channel_versions") if isinstance(checkpoint_inner, dict) else None ), # Store has_writes flag "has_writes": checkpoint_data.get( "has_writes", False ), # Default to False to avoid expensive searches # Store the full checkpoint data to avoid re-fetching "_checkpoint_data": checkpoint_data, } else: # Get latest checkpoint using the helper method doc = self._get_latest_checkpoint_document(thread_id, checkpoint_ns) if not doc: return None # Handle both dict (from direct access) and Document objects (from FT.SEARCH) if isinstance(doc, dict): doc_thread_id = from_storage_safe_id(doc["thread_id"]) doc_checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"]) doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"]) doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"]) else: doc_thread_id = from_storage_safe_id(doc.thread_id) doc_checkpoint_ns = from_storage_safe_str(doc.checkpoint_ns) doc_checkpoint_id = from_storage_safe_id(doc.checkpoint_id) doc_parent_checkpoint_id = from_storage_safe_id(doc.parent_checkpoint_id) # Lazy TTL refresh - only refresh if TTL is below threshold if self.ttl_config and self.ttl_config.get("refresh_on_read"): # Get the checkpoint key checkpoint_key = self._make_redis_checkpoint_key_cached( doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id, ) # Always refresh TTL when refresh_on_read is enabled # This ensures all related keys maintain synchronized TTLs current_ttl = self._redis.ttl(checkpoint_key) # Only refresh if key exists and has TTL (skip keys with no expiry) # TTL states: -2 = key doesn't exist, -1 = key exists but no TTL, 0 = expired, >0 = seconds remaining if current_ttl > 0: # Note: We don't refresh TTL for keys with no expiry (TTL = -1) # Get write keys - use key registry if available, otherwise fall back to search write_keys = [] if self._key_registry: # Use key registry for faster lookup write_keys = self._key_registry.get_write_keys( doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id ) else: # Fallback to search index write_keys = self._get_write_keys_from_search( doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id ) # Apply TTL to checkpoint and write keys self._apply_ttl_to_keys(checkpoint_key, write_keys) # Fetch channel_values - pass channel_versions if we have them from direct access # First check if we stored channel_versions during direct access channel_versions_from_checkpoint = doc.get("_channel_versions") if channel_versions_from_checkpoint is None: # Fall back to extracting from checkpoint data checkpoint_raw = ( doc.get("$.checkpoint") if isinstance(doc, dict) else getattr(doc, "$.checkpoint", None) ) if isinstance(checkpoint_raw, str): checkpoint_data_dict = json.loads(checkpoint_raw) else: checkpoint_data_dict = checkpoint_raw channel_versions_from_checkpoint = ( checkpoint_data_dict.get("channel_versions") if checkpoint_data_dict else None ) # Get channel values from the checkpoint we already fetched # Extract the checkpoint data based on doc type if isinstance(doc, dict): # From direct access - we have the full data checkpoint_inner = doc.get("_checkpoint_data", {}).get("checkpoint", {}) if isinstance(checkpoint_inner, str): checkpoint_inner = json.loads(checkpoint_inner) else: # From search - parse the checkpoint checkpoint_str = getattr(doc, "$.checkpoint", "{}") checkpoint_inner = ( json.loads(checkpoint_str) if isinstance(checkpoint_str, str) else checkpoint_str ) # Channel values are already inline in the checkpoint channel_values = checkpoint_inner.get("channel_values", {}) # Deserialize them since they're stored in serialized form channel_values = self._deserialize_channel_values(channel_values) # Fetch pending_sends from parent checkpoint pending_sends = [] if doc_parent_checkpoint_id: pending_sends = self._load_pending_sends_with_registry_check( thread_id=doc_thread_id, checkpoint_ns=doc_checkpoint_ns, parent_checkpoint_id=doc_parent_checkpoint_id, ) # Fetch and parse metadata raw_metadata = ( doc.get("$.metadata", "{}") if isinstance(doc, dict) else getattr(doc, "$.metadata", "{}") ) metadata_dict = ( json.loads(raw_metadata) if isinstance(raw_metadata, str) else raw_metadata ) # Ensure metadata matches CheckpointMetadata type sanitized_metadata = { k.replace("\u0000", ""): ( v.replace("\u0000", "") if isinstance(v, str) else v ) for k, v in metadata_dict.items() } metadata = cast(CheckpointMetadata, sanitized_metadata) config_param: RunnableConfig = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": doc_checkpoint_id, } } # Handle both direct dict access and FT.SEARCH results efficiently checkpoint_data = ( doc.get("$.checkpoint") if isinstance(doc, dict) else getattr(doc, "$.checkpoint") ) checkpoint_param = self._load_checkpoint( checkpoint_data or {}, channel_values, pending_sends, ) # Skip pending_writes if we can determine there are none checkpoint_has_writes = ( doc.get("has_writes") if isinstance(doc, dict) else getattr(doc, "has_writes", False) ) pending_writes = self._load_pending_writes_with_registry_check( doc_thread_id, doc_checkpoint_ns, doc_checkpoint_id, checkpoint_has_writes=bool(checkpoint_has_writes), registry_has_writes=False, # We don't have registry info here ) # Build parent config if parent_checkpoint_id exists parent_config: RunnableConfig | None = None if doc_parent_checkpoint_id: parent_config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": doc_parent_checkpoint_id, } } return CheckpointTuple( config=config_param, checkpoint=checkpoint_param, metadata=metadata, parent_config=parent_config, pending_writes=pending_writes, )
[docs] @classmethod @contextmanager def from_conn_string( cls, redis_url: Optional[str] = None, *, redis_client: Optional[Union[Redis, RedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ttl: Optional[Dict[str, Any]] = None, checkpoint_prefix: str = CHECKPOINT_PREFIX, checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX, ) -> Iterator[RedisSaver]: """Create a new RedisSaver instance.""" saver: Optional[RedisSaver] = None try: saver = cls( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, ttl=ttl, checkpoint_prefix=checkpoint_prefix, checkpoint_write_prefix=checkpoint_write_prefix, ) yield saver finally: if saver and saver._owns_its_client: # Ensure saver is not None saver._redis.close() # RedisCluster doesn't have connection_pool attribute if getattr(saver._redis, "connection_pool", None): saver._redis.connection_pool.disconnect()
[docs] def get_channel_values( self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = "", channel_versions: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: """Retrieve channel_values using efficient FT.SEARCH with checkpoint_id.""" # Get checkpoint with inline channel_values using single JSON.GET operation checkpoint_key = self._make_redis_checkpoint_key_cached( thread_id, checkpoint_ns, checkpoint_id, ) # Single JSON.GET operation to retrieve checkpoint with inline channel_values checkpoint_data = self._redis.json().get(checkpoint_key, "$.checkpoint") if not checkpoint_data: return {} # checkpoint_data[0] is already a deserialized dict, not a typed tuple checkpoint = checkpoint_data[0] return checkpoint.get("channel_values", {})
def _load_pending_writes( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str ) -> List[PendingWrite]: """Load pending writes using sorted set registry.""" return self._load_pending_writes_with_registry_check( thread_id, checkpoint_ns, checkpoint_id, checkpoint_has_writes=True, # Assume writes exist if we're calling this registry_has_writes=False, ) def _load_pending_sends( self, thread_id: str, checkpoint_ns: str, parent_checkpoint_id: str, ) -> List[Tuple[str, Union[str, bytes]]]: """Load pending sends for a parent checkpoint. Args: thread_id: The thread ID checkpoint_ns: The checkpoint namespace parent_checkpoint_id: The ID of the parent checkpoint Returns: List of (type, blob) tuples representing pending sends """ storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) storage_safe_parent_checkpoint_id = to_storage_safe_id(parent_checkpoint_id) parent_writes_query = FilterQuery( filter_expression=(Tag("thread_id") == storage_safe_thread_id) & (Tag("checkpoint_ns") == storage_safe_checkpoint_ns) & (Tag("checkpoint_id") == storage_safe_parent_checkpoint_id) & (Tag("channel") == TASKS), return_fields=["type", "$.blob", "task_path", "task_id", "idx"], num_results=100, # Adjust as needed ) parent_writes_results = self.checkpoint_writes_index.search(parent_writes_query) # Sort results by task_path, task_id, idx sorted_writes = sorted( parent_writes_results.docs, key=lambda x: ( getattr(x, "task_path", ""), getattr(x, "task_id", ""), getattr(x, "idx", 0), ), ) # Extract type and blob pairs # Handle both direct attribute access and JSON path access return [ ( getattr(doc, "type", ""), getattr(doc, "$.blob", getattr(doc, "blob", b"")), ) for doc in sorted_writes ] def _batch_load_pending_sends( self, batch_keys: List[Tuple[str, str, str]] ) -> Dict[Tuple[str, str, str], List[Tuple[str, bytes]]]: """Batch load pending sends for multiple parent checkpoints. Args: batch_keys: List of (thread_id, checkpoint_ns, parent_checkpoint_id) tuples Returns: Dict mapping batch_key -> list of (type, blob) tuples """ if not batch_keys: return {} results_map = {} # Group by thread_id and checkpoint_ns for efficient querying grouped_keys: Dict[Tuple[str, str], List[str]] = {} for thread_id, checkpoint_ns, parent_checkpoint_id in batch_keys: group_key = (thread_id, checkpoint_ns) if group_key not in grouped_keys: grouped_keys[group_key] = [] grouped_keys[group_key].append(parent_checkpoint_id) # Batch query for each group for (thread_id, checkpoint_ns), parent_checkpoint_ids in grouped_keys.items(): storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) storage_safe_parent_checkpoint_ids = [ to_storage_safe_id(pid) for pid in parent_checkpoint_ids ] # Build filter for multiple parent checkpoint IDs thread_filter = Tag("thread_id") == storage_safe_thread_id ns_filter = Tag("checkpoint_ns") == storage_safe_checkpoint_ns channel_filter = Tag("channel") == TASKS # Create filter for multiple parent checkpoint IDs (Tag supports lists) checkpoint_filter = ( Tag("checkpoint_id") == storage_safe_parent_checkpoint_ids ) batch_query = FilterQuery( filter_expression=thread_filter & ns_filter & checkpoint_filter & channel_filter, return_fields=[ "checkpoint_id", "type", "$.blob", "task_path", "task_id", "idx", ], num_results=1000, # Increased limit for batch loading ) batch_results = self.checkpoint_writes_index.search(batch_query) # Group results by parent checkpoint ID writes_by_checkpoint: Dict[str, List[Any]] = {} for doc in batch_results.docs: parent_checkpoint_id = from_storage_safe_id(doc.checkpoint_id) if parent_checkpoint_id not in writes_by_checkpoint: writes_by_checkpoint[parent_checkpoint_id] = [] writes_by_checkpoint[parent_checkpoint_id].append(doc) # Sort and format results for each parent checkpoint for parent_checkpoint_id in parent_checkpoint_ids: batch_key = (thread_id, checkpoint_ns, parent_checkpoint_id) writes = writes_by_checkpoint.get(parent_checkpoint_id, []) # Sort results by task_path, task_id, idx sorted_writes = sorted( writes, key=lambda x: ( getattr(x, "task_path", ""), getattr(x, "task_id", ""), getattr(x, "idx", 0), ), ) # Extract type and blob pairs # Handle both direct attribute access and JSON path access results_map[batch_key] = [ ( getattr(doc, "type", ""), getattr(doc, "$.blob", getattr(doc, "blob", b"")), ) for doc in sorted_writes ] return results_map def _batch_load_pending_writes( self, batch_keys: List[Tuple[str, str, str]] ) -> Dict[Tuple[str, str, str], List[PendingWrite]]: """Batch load pending writes for multiple checkpoints. Args: batch_keys: List of (thread_id, checkpoint_ns, checkpoint_id) tuples Returns: Dict mapping batch_key -> list of PendingWrite objects """ if not batch_keys: return {} results_map = {} # Group by thread_id and checkpoint_ns for efficient querying grouped_keys: Dict[Tuple[str, str], List[str]] = {} for thread_id, checkpoint_ns, checkpoint_id in batch_keys: group_key = (thread_id, checkpoint_ns) if group_key not in grouped_keys: grouped_keys[group_key] = [] grouped_keys[group_key].append(checkpoint_id) # Batch query for each group for (thread_id, checkpoint_ns), checkpoint_ids in grouped_keys.items(): storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) storage_safe_checkpoint_ids = [ to_storage_safe_id(cid) for cid in checkpoint_ids ] # Build filter for multiple checkpoint IDs thread_filter = Tag("thread_id") == storage_safe_thread_id ns_filter = Tag("checkpoint_ns") == storage_safe_checkpoint_ns # Create filter for multiple checkpoint IDs (Tag supports lists) checkpoint_filter = Tag("checkpoint_id") == storage_safe_checkpoint_ids batch_query = FilterQuery( filter_expression=thread_filter & ns_filter & checkpoint_filter, return_fields=[ "checkpoint_id", "task_id", "idx", "channel", "type", "$.blob", ], num_results=10000, # Large limit for batch loading ) batch_results = self.checkpoint_writes_index.search(batch_query) # Group results by checkpoint ID writes_by_checkpoint: Dict[str, Dict[Tuple[str, str], Dict[str, Any]]] = {} for doc in batch_results.docs: checkpoint_id = from_storage_safe_id(doc.checkpoint_id) if checkpoint_id not in writes_by_checkpoint: writes_by_checkpoint[checkpoint_id] = {} task_id = str(doc.task_id) idx = str(doc.idx) writes_by_checkpoint[checkpoint_id][(task_id, idx)] = { "task_id": task_id, "idx": idx, "channel": getattr(doc, "channel", ""), "type": getattr(doc, "type", ""), "blob": getattr(doc, "$.blob", b""), } # Format results for each checkpoint for checkpoint_id in checkpoint_ids: batch_key = (thread_id, checkpoint_ns, checkpoint_id) writes_dict = writes_by_checkpoint.get(checkpoint_id, {}) # Use base class method to deserialize results_map[batch_key] = BaseRedisSaver._load_writes( self.serde, writes_dict ) return results_map def _load_pending_writes_with_registry_check( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, checkpoint_has_writes: bool, registry_has_writes: bool, ) -> List[PendingWrite]: """Load pending writes with registry optimization and fallback.""" if not checkpoint_has_writes: return [] # FAST PATH: Try sorted set registry first if self._key_registry: try: # Check write count from registry write_count = self._key_registry.get_write_count( thread_id, checkpoint_ns, checkpoint_id ) if write_count == 0: return [] # Get write keys from registry write_keys = self._key_registry.get_write_keys( thread_id, checkpoint_ns, checkpoint_id ) if write_keys: # Batch fetch all writes using pipeline with self._redis.pipeline(transaction=False) as pipeline: for key in write_keys: pipeline.json().get(key) results = pipeline.execute() # Build writes dictionary writes_dict = {} for write_data in results: if write_data: task_id = write_data.get("task_id", "") idx = write_data.get("idx", 0) writes_dict[(task_id, idx)] = write_data # Use base class method to deserialize return BaseRedisSaver._load_writes(self.serde, writes_dict) except Exception: # Fall through to FT.SEARCH fallback pass # FALLBACK: Use FT.SEARCH if registry not available or failed # Call the base class implementation to avoid recursion return super()._load_pending_writes(thread_id, checkpoint_ns, checkpoint_id) def _load_pending_sends_with_registry_check( self, thread_id: str, checkpoint_ns: str, parent_checkpoint_id: str, ) -> List[Tuple[str, Union[str, bytes]]]: """Load pending sends for a parent checkpoint with pre-computed registry check.""" if not parent_checkpoint_id: return [] # FAST PATH: Try sorted set registry first if self._key_registry: try: # Check if parent checkpoint has any writes in the sorted set write_count = self._key_registry.get_write_count( thread_id, checkpoint_ns, parent_checkpoint_id ) if write_count == 0: # No writes for parent checkpoint - return immediately return [] # Get exact write keys from the per-checkpoint registry write_keys = self._key_registry.get_write_keys( thread_id, checkpoint_ns, parent_checkpoint_id ) # Filter for TASKS channel writes task_write_keys = [] for key in write_keys: # Keys contain channel info: checkpoint_write:thread:ns:checkpoint:task:idx # We need to check if it's a TASKS channel write # This is a simple heuristic - we might need to fetch to be sure if TASKS in key or "__pregel_tasks" in key: task_write_keys.append(key) if not task_write_keys: return [] # Fetch task writes using pipeline (safe for cluster mode) with self._redis.pipeline(transaction=False) as pipeline: for key in task_write_keys: pipeline.json().get(key) results = pipeline.execute() # Extract pending sends and sort them pending_sends_with_sort_keys = [] for write_data in results: if write_data and write_data.get("channel") == TASKS: pending_sends_with_sort_keys.append( ( write_data.get("task_path", ""), write_data.get("task_id", ""), write_data.get("idx", 0), write_data.get("type", ""), write_data.get("blob", b""), ) ) # Sort by task_path, task_id, idx pending_sends_with_sort_keys.sort(key=lambda x: (x[0], x[1], x[2])) # Return just the (type, blob) tuples return [(item[3], item[4]) for item in pending_sends_with_sort_keys] except Exception: # If sorted set approach fails, fall back to FT.SEARCH pass storage_safe_thread_id = to_storage_safe_id(thread_id) storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) storage_safe_parent_checkpoint_id = to_storage_safe_id(parent_checkpoint_id) parent_writes_query = FilterQuery( filter_expression=(Tag("thread_id") == storage_safe_thread_id) & (Tag("checkpoint_ns") == storage_safe_checkpoint_ns) & (Tag("checkpoint_id") == storage_safe_parent_checkpoint_id) & (Tag("channel") == TASKS), return_fields=["type", "$.blob", "task_path", "task_id", "idx"], num_results=100, # Adjust as needed ) parent_writes_results = self.checkpoint_writes_index.search(parent_writes_query) # Sort results by task_path, task_id, idx (matching Postgres implementation) sorted_writes = sorted( parent_writes_results.docs, key=lambda x: ( getattr(x, "task_path", ""), getattr(x, "task_id", ""), getattr(x, "idx", 0), ), ) # Extract type and blob pairs # Handle both direct attribute access and JSON path access return [ ( getattr(doc, "type", ""), getattr(doc, "$.blob", getattr(doc, "blob", b"")), ) for doc in sorted_writes ]
[docs] def delete_thread(self, thread_id: str) -> None: """Delete all checkpoints and writes associated with a specific thread ID. Args: thread_id: The thread ID whose checkpoints should be deleted. """ storage_safe_thread_id = to_storage_safe_id(thread_id) # Delete all checkpoints for this thread checkpoint_query = FilterQuery( filter_expression=Tag("thread_id") == storage_safe_thread_id, return_fields=["checkpoint_ns", "checkpoint_id"], num_results=10000, # Get all checkpoints for this thread ) checkpoint_results = self.checkpoints_index.search(checkpoint_query) # Collect all keys to delete keys_to_delete = [] checkpoint_namespaces = set() for doc in checkpoint_results.docs: checkpoint_ns = getattr(doc, "checkpoint_ns", "") checkpoint_id = getattr(doc, "checkpoint_id", "") # Track unique namespaces for latest pointer cleanup checkpoint_namespaces.add(checkpoint_ns) # Delete checkpoint key checkpoint_key = self._make_redis_checkpoint_key_cached( thread_id, checkpoint_ns, checkpoint_id ) keys_to_delete.append(checkpoint_key) # Add latest checkpoint pointers to deletion list for checkpoint_ns in checkpoint_namespaces: latest_pointer_key = f"checkpoint_latest:{storage_safe_thread_id}:{to_storage_safe_str(checkpoint_ns)}" keys_to_delete.append(latest_pointer_key) # Channel values are stored inline — no separate blob keys to clean up. # Delete all writes for this thread writes_query = FilterQuery( filter_expression=Tag("thread_id") == storage_safe_thread_id, return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"], num_results=10000, ) writes_results = self.checkpoint_writes_index.search(writes_query) for doc in writes_results.docs: checkpoint_ns = getattr(doc, "checkpoint_ns", "") checkpoint_id = getattr(doc, "checkpoint_id", "") task_id = getattr(doc, "task_id", "") idx = getattr(doc, "idx", 0) write_key = self._make_redis_checkpoint_writes_key( storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx ) keys_to_delete.append(write_key) # Delete the registry sorted sets for each checkpoint if self._key_registry: # Get unique checkpoints from the results we already have processed_checkpoints = set() for doc in checkpoint_results.docs: checkpoint_ns = getattr(doc, "checkpoint_ns", "") checkpoint_id = getattr(doc, "checkpoint_id", "") checkpoint_key = (thread_id, checkpoint_ns, checkpoint_id) if checkpoint_key not in processed_checkpoints: processed_checkpoints.add(checkpoint_key) # Add the write registry key for this checkpoint zset_key = self._key_registry.make_write_keys_zset_key( thread_id, checkpoint_ns, checkpoint_id ) keys_to_delete.append(zset_key) # Execute all deletions based on cluster mode if self.cluster_mode: # For cluster mode, delete keys individually for key in keys_to_delete: self._redis.delete(key) else: # For non-cluster mode, use pipeline for efficiency pipeline = self._redis.pipeline() for key in keys_to_delete: pipeline.delete(key) pipeline.execute()
[docs] def prune( self, thread_ids: Sequence[str], *, strategy: str = "keep_latest", keep_last: Optional[int] = None, max_results: int = 10000, ) -> None: """Prune old checkpoints for the given threads per namespace. Retains the most-recent checkpoints **per checkpoint namespace** and removes the rest, along with their associated write keys and key-registry sorted sets. Each namespace (root ``""`` and any subgraph namespaces) is treated as an independent checkpoint chain. Channel values are stored inline within each checkpoint document, so they are automatically removed when the checkpoint document is deleted. Args: thread_ids: Thread IDs whose old checkpoints should be pruned. strategy: Pruning strategy. ``"keep_latest"`` retains only the most recent checkpoint per namespace (default). ``"delete"`` removes all checkpoints for the thread. keep_last: Optional override — number of recent checkpoints to retain per namespace. When provided, takes precedence over ``strategy``. Use ``keep_last=0`` to remove all checkpoints. max_results: Maximum number of checkpoints fetched from the index per thread in a single query. Defaults to 10 000. """ # Resolve keep_last from strategy if not explicitly provided if keep_last is None: if strategy == "delete": keep_last = 0 else: keep_last = 1 # Validate input if not thread_ids: raise ValueError("``thread_ids`` must be a non-empty sequence") if keep_last < 0: raise ValueError(f"``keep_last`` must be >= 0, got {keep_last}") if max_results < 1: raise ValueError(f"``max_results`` must be >= 1, got {max_results}") for thread_id in thread_ids: storage_safe_thread_id = to_storage_safe_id(thread_id) # Fetch all checkpoints for this thread across all namespaces checkpoint_query = FilterQuery( filter_expression=Tag("thread_id") == storage_safe_thread_id, return_fields=["checkpoint_ns", "checkpoint_id"], num_results=max_results, ) checkpoint_results = self.checkpoints_index.search(checkpoint_query) if not checkpoint_results.docs: continue # Group by namespace — each namespace is an independent checkpoint chain # (root graph vs. subgraph checkpoints must be evicted independently). by_ns: Dict[str, list] = defaultdict(list) for doc in checkpoint_results.docs: ns = getattr(doc, "checkpoint_ns", "") by_ns[ns].append(doc) # Within each namespace sort newest-first (ULIDs are lex time-ordered) # and collect checkpoints that fall outside the keep_last window. to_evict = [] # Track namespaces where every checkpoint is evicted so we can clean # up the checkpoint_latest:{thread}:{ns} pointer key too. fully_evicted_ns: set = set() for ns, ns_docs in by_ns.items(): ns_sorted = sorted( ns_docs, key=lambda d: getattr(d, "checkpoint_id", ""), reverse=True, ) ns_evicted = ns_sorted[keep_last:] to_evict.extend(ns_evicted) if len(ns_evicted) == len(ns_docs): # nothing left in this namespace fully_evicted_ns.add(ns) if not to_evict: continue keys_to_delete = [] for doc in to_evict: checkpoint_ns = getattr(doc, "checkpoint_ns", "") checkpoint_id = getattr(doc, "checkpoint_id", "") # Evict checkpoint document keys_to_delete.append( self._make_redis_checkpoint_key_cached( thread_id, checkpoint_ns, checkpoint_id ) ) # Evict all write documents for this checkpoint writes_query = FilterQuery( filter_expression=( (Tag("thread_id") == storage_safe_thread_id) & (Tag("checkpoint_id") == checkpoint_id) ), return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"], num_results=max_results, ) writes_results = self.checkpoint_writes_index.search(writes_query) for wdoc in writes_results.docs: keys_to_delete.append( self._make_redis_checkpoint_writes_key( storage_safe_thread_id, getattr(wdoc, "checkpoint_ns", ""), getattr(wdoc, "checkpoint_id", ""), getattr(wdoc, "task_id", ""), int(getattr(wdoc, "idx", 0)), ) ) # Evict key-registry sorted set for this checkpoint if self._key_registry: keys_to_delete.append( self._key_registry.make_write_keys_zset_key( thread_id, checkpoint_ns, checkpoint_id ) ) # Delete checkpoint_latest pointers for fully_evicted namespaces. # ns values here come from the index and are already storage-safe, # matching the format written by put(): checkpoint-latest:{tid}:{safe_ns} for ns in fully_evicted_ns: keys_to_delete.append( f"checkpoint_latest:{storage_safe_thread_id}:{ns}" ) if self.cluster_mode: for key in keys_to_delete: self._redis.delete(key) else: pipeline = self._redis.pipeline() for key in keys_to_delete: pipeline.delete(key) pipeline.execute()
__all__ = [ "__version__", "__lib_name__", "RedisSaver", "AsyncRedisSaver", "BaseRedisSaver", "ShallowRedisSaver", "AsyncShallowRedisSaver", "MessageExporter", "LangChainRecipe", "MessageRecipe", ]