Source code for langgraph.checkpoint.redis.shallow

from __future__ import annotations

import json
import logging
import time
from collections import OrderedDict
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, cast

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
    WRITES_IDX_MAP,
    ChannelVersions,
    Checkpoint,
    CheckpointMetadata,
    CheckpointTuple,
    PendingWrite,
    get_checkpoint_id,
)
from langgraph.constants import TASKS
from redis import Redis
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
from redisvl.redis.connection import RedisConnectionFactory
from ulid import ULID

from langgraph.checkpoint.redis.base import (
    CHECKPOINT_PREFIX,
    CHECKPOINT_WRITE_PREFIX,
    REDIS_KEY_SEPARATOR,
    BaseRedisSaver,
)
from langgraph.checkpoint.redis.util import (
    from_storage_safe_str,
    to_storage_safe_id,
    to_storage_safe_str,
)

# Constants
MILLISECONDS_PER_SECOND = 1000

# Logger for this module
logger = logging.getLogger(__name__)


[docs] class ShallowRedisSaver(BaseRedisSaver[Redis, SearchIndex]): """Redis implementation that only stores the most recent checkpoint. Supports standard Redis URLs (redis://), SSL (rediss://), and Sentinel URLs (redis+sentinel://host:26379/service_name/db). """ # Default cache size limits DEFAULT_KEY_CACHE_MAX_SIZE = 1000 DEFAULT_CHANNEL_CACHE_MAX_SIZE = 100 def __init__( self, redis_url: Optional[str] = None, *, redis_client: Optional[Redis] = None, connection_args: Optional[dict[str, Any]] = None, ttl: Optional[dict[str, Any]] = None, key_cache_max_size: Optional[int] = None, channel_cache_max_size: Optional[int] = None, checkpoint_prefix: str = CHECKPOINT_PREFIX, checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX, ) -> None: super().__init__( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, ttl=ttl, checkpoint_prefix=checkpoint_prefix, checkpoint_write_prefix=checkpoint_write_prefix, ) # Instance-level cache for frequently used keys (limited size to prevent memory issues) # Using OrderedDict for LRU cache eviction self._key_cache: OrderedDict[str, str] = OrderedDict() self._key_cache_max_size = key_cache_max_size or self.DEFAULT_KEY_CACHE_MAX_SIZE self._channel_cache: OrderedDict[str, Any] = OrderedDict() self._channel_cache_max_size = ( channel_cache_max_size or self.DEFAULT_CHANNEL_CACHE_MAX_SIZE ) # Prefixes are now set in BaseRedisSaver.__init__ self._separator = REDIS_KEY_SEPARATOR
[docs] @classmethod @contextmanager def from_conn_string( cls, redis_url: Optional[str] = None, *, redis_client: Optional[Redis] = None, connection_args: Optional[dict[str, Any]] = None, ttl: Optional[dict[str, Any]] = None, key_cache_max_size: Optional[int] = None, channel_cache_max_size: Optional[int] = None, checkpoint_prefix: str = CHECKPOINT_PREFIX, checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX, ) -> Iterator[ShallowRedisSaver]: """Create a new ShallowRedisSaver instance.""" saver: Optional[ShallowRedisSaver] = None try: saver = cls( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args, ttl=ttl, key_cache_max_size=key_cache_max_size, channel_cache_max_size=channel_cache_max_size, checkpoint_prefix=checkpoint_prefix, checkpoint_write_prefix=checkpoint_write_prefix, ) yield saver finally: if saver and saver._owns_its_client: saver._redis.close() # RedisCluster doesn't have connection_pool attribute if getattr(saver._redis, "connection_pool", None): saver._redis.connection_pool.disconnect()
[docs] def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Store checkpoint with inline channel values.""" configurable = config["configurable"].copy() run_id = configurable.pop("run_id", metadata.get("run_id")) thread_id = configurable.pop("thread_id") checkpoint_ns = configurable.pop("checkpoint_ns") copy = checkpoint.copy() next_config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } } # Extract timestamp from checkpoint_id (ULID) or fallback to checkpoint's ts field # Note: LangGraph may generate checkpoint IDs in different formats (ULID, UUIDv6, etc.) # We try ULID first, then fall back gracefully without warnings (Issue #136) checkpoint_ts = None if checkpoint["id"]: try: ulid_obj = ULID.from_str(checkpoint["id"]) checkpoint_ts = ulid_obj.timestamp # milliseconds since epoch except Exception: # Not a valid ULID - this is expected for UUIDv6 and other formats # Fall back to checkpoint's timestamp field or current time checkpoint_ts = self._extract_fallback_timestamp(checkpoint) # Parse metadata from string to dict to avoid double serialization metadata_str = self._dump_metadata(metadata) metadata_dict = ( json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str ) # Store channel values inline in the checkpoint copy["channel_values"] = checkpoint.get("channel_values", {}) checkpoint_data = { "thread_id": thread_id, "run_id": to_storage_safe_id(run_id) if run_id else "", "checkpoint_ns": to_storage_safe_str(checkpoint_ns), "checkpoint_id": checkpoint["id"], "checkpoint_ts": checkpoint_ts, "checkpoint": self._dump_checkpoint(copy), "metadata": metadata_dict, # Note: has_writes tracking removed to support put_writes before checkpoint exists } # Store at top-level for filters in list() if all(key in metadata for key in ["source", "step"]): checkpoint_data["source"] = metadata["source"] checkpoint_data["step"] = metadata["step"] checkpoint_key = self._make_shallow_redis_checkpoint_key_cached( thread_id, checkpoint_ns ) # Only critical commands (JSON.SET) go in the pipeline. # EXPIRE is applied separately to avoid pipeline failures on # Redis Enterprise proxy with mixed JSON + native commands. with self._redis.pipeline(transaction=False) as pipeline: pipeline.json().set(checkpoint_key, "$", checkpoint_data) pipeline.execute() # Apply TTL separately (best-effort) if self.ttl_config and "default_ttl" in self.ttl_config: ttl_seconds = int(self.ttl_config.get("default_ttl") * 60) try: self._redis.expire(checkpoint_key, ttl_seconds) except Exception: logger.warning( "Failed to apply TTL to checkpoint key: %s", checkpoint_key ) # NOTE: We intentionally do NOT clean up old writes here. # In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via # put_writes BEFORE the new checkpoint is saved. If we clean up writes # when the checkpoint changes, we would delete the interrupt writes # before they can be consumed when resuming. # # Writes are cleaned up in the following scenarios: # 1. When delete_thread is called # 2. When TTL expires (if configured) # 3. When put_writes is called again for the same task/idx (overwrites) # # See Issue #133 for details on this bug fix. return next_config
[docs] def list( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from Redis.""" # Construct the filter expression filter_expression = [] if config: filter_expression.append( Tag("thread_id") == to_storage_safe_id(config["configurable"]["thread_id"]) ) if run_id := config["configurable"].get("run_id"): filter_expression.append(Tag("run_id") == to_storage_safe_id(run_id)) if checkpoint_ns := config["configurable"].get("checkpoint_ns"): filter_expression.append( Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns) ) if filter: for k, v in filter.items(): if k == "source": filter_expression.append(Tag("source") == v) elif k == "step": filter_expression.append(Num("step") == v) elif k == "thread_id": filter_expression.append(Tag("thread_id") == to_storage_safe_id(v)) elif k == "run_id": filter_expression.append(Tag("run_id") == to_storage_safe_id(v)) else: raise ValueError(f"Unsupported filter key: {k}") if before: before_checkpoint_id = get_checkpoint_id(before) if before_checkpoint_id: try: before_ulid = ULID.from_str(before_checkpoint_id) before_ts = before_ulid.timestamp # Use numeric range query: checkpoint_ts < before_ts filter_expression.append(Num("checkpoint_ts") < before_ts) except Exception: # If not a valid ULID, ignore the before filter pass # Combine all filter expressions combined_filter = filter_expression[0] if filter_expression else "*" for expr in filter_expression[1:]: combined_filter &= expr # Get checkpoint data # Sort by checkpoint_id in descending order to get most recent checkpoints first query = FilterQuery( filter_expression=combined_filter, return_fields=[ "thread_id", "checkpoint_ns", "$.checkpoint", "$.metadata", ], num_results=limit or 10000, sort_by=("checkpoint_id", "DESC"), ) # Execute the query results = self.checkpoints_index.search(query) # Process the results for doc in results.docs: thread_id = cast(str, getattr(doc, "thread_id", "")) checkpoint_ns = cast(str, getattr(doc, "checkpoint_ns", "")) checkpoint = json.loads(doc["$.checkpoint"]) # Extract channel values from the checkpoint (they're stored inline) channel_values: Dict[str, Any] = checkpoint.get("channel_values", {}) # Deserialize them since they're stored in serialized form channel_values = self._deserialize_channel_values(channel_values) # Parse metadata raw_metadata = getattr(doc, "$.metadata", "{}") metadata_dict = ( json.loads(raw_metadata) if isinstance(raw_metadata, str) else raw_metadata ) # Sanitize metadata sanitized_metadata = { k.replace("\u0000", ""): ( v.replace("\u0000", "") if isinstance(v, str) else v ) for k, v in metadata_dict.items() } metadata = cast(CheckpointMetadata, sanitized_metadata) # Load checkpoint with inline channel values checkpoint_param = self._load_checkpoint( doc["$.checkpoint"], channel_values, # Pass the extracted channel values [], # No pending_sends in shallow mode ) config_param: RunnableConfig = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_param["id"], } } # Load pending writes (still uses separate keys - already efficient) pending_writes = self._load_pending_writes( thread_id, checkpoint_ns, checkpoint_param["id"] ) yield CheckpointTuple( config=config_param, checkpoint=checkpoint_param, metadata=metadata, parent_config=None, pending_writes=pending_writes, )
[docs] def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get checkpoint with inline channel values.""" thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"].get("checkpoint_ns", "") # Single key access gets everything inline checkpoint_key = self._make_shallow_redis_checkpoint_key_cached( thread_id, checkpoint_ns ) checkpoint_data = self._redis.json().get(checkpoint_key) if not checkpoint_data or not isinstance(checkpoint_data, dict): return None # TTL refresh if enabled (best-effort) if self.ttl_config and self.ttl_config.get("refresh_on_read"): default_ttl_minutes = self.ttl_config.get("default_ttl", 60) ttl_seconds = int(default_ttl_minutes * 60) try: self._redis.expire(checkpoint_key, ttl_seconds) except Exception: logger.warning( "Failed to refresh TTL on checkpoint key: %s", checkpoint_key ) # Parse the checkpoint data checkpoint = checkpoint_data.get("checkpoint", {}) if isinstance(checkpoint, str): checkpoint = json.loads(checkpoint) # Extract channel values from the checkpoint (they're stored inline) channel_values: Dict[str, Any] = checkpoint.get("channel_values", {}) # Deserialize them since they're stored in serialized form channel_values = self._deserialize_channel_values(channel_values) # Parse metadata metadata = checkpoint_data.get("metadata", {}) if isinstance(metadata, str): metadata = json.loads(metadata) # Sanitize metadata sanitized_metadata = { k.replace("\u0000", ""): ( v.replace("\u0000", "") if isinstance(v, str) else v ) for k, v in metadata.items() } # Load checkpoint with inline channel values checkpoint_param = self._load_checkpoint( json.dumps(checkpoint), channel_values, # Pass the raw channel values - no deserialization needed [], # No pending_sends in shallow mode ) # Load pending writes (still uses separate keys - already efficient) pending_writes = self._load_pending_writes( thread_id, checkpoint_ns, checkpoint_param["id"] ) return CheckpointTuple( config={ "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } }, checkpoint=checkpoint_param, metadata=cast(CheckpointMetadata, sanitized_metadata), parent_config=None, pending_writes=pending_writes, )
[docs] def configure_client( self, redis_url: Optional[str] = None, redis_client: Optional[Redis] = None, connection_args: Optional[dict[str, Any]] = None, ) -> None: """Configure the Redis client.""" self._owns_its_client = redis_client is None self._redis = redis_client or RedisConnectionFactory.get_redis_connection( redis_url, **connection_args ) # Set client info for Redis monitoring self.set_client_info()
[docs] def create_indexes(self) -> None: self.checkpoints_index = SearchIndex.from_dict( self.checkpoints_schema, redis_client=self._redis ) self.checkpoint_writes_index = SearchIndex.from_dict( self.writes_schema, redis_client=self._redis )
[docs] def setup(self) -> None: """Initialize the indices in Redis.""" self.checkpoints_index.create(overwrite=False) self.checkpoint_writes_index.create(overwrite=False)
[docs] def put_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, task_path: str = "", ) -> None: """Store intermediate writes linked to a checkpoint with checkpoint-level registry. Args: config: Configuration of the related checkpoint. writes: List of writes to store, each as (channel, value) pair. task_id: Identifier for the task creating the writes. task_path: Optional path info for the task. """ thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"].get("checkpoint_ns", "") checkpoint_id = config["configurable"]["checkpoint_id"] # Transform writes into appropriate format writes_objects = [] for idx, (channel, value) in enumerate(writes): type_, blob = self.serde.dumps_typed(value) write_obj = { "thread_id": thread_id, "checkpoint_ns": to_storage_safe_str(checkpoint_ns), "checkpoint_id": checkpoint_id, "task_id": task_id, "task_path": task_path, "idx": WRITES_IDX_MAP.get(channel, idx), "channel": channel, "type": type_, "blob": self._encode_blob( blob ), # Encode bytes to base64 string for Redis } writes_objects.append(write_obj) # THREAD-LEVEL REGISTRY: Only keep writes for the current checkpoint # Use to_storage_safe_str for consistent key naming with delete_thread safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) thread_write_registry_key = ( f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow" ) # Collect all write keys write_keys = [] for write_obj in writes_objects: key = self._make_redis_checkpoint_writes_key_cached( thread_id, checkpoint_ns, checkpoint_id, task_id, write_obj["idx"] ) write_keys.append(key) # Only critical commands (JSON.SET, ZADD) go in the pipeline. # EXPIRE (TTL) is applied separately to avoid pipeline failures on # Redis Enterprise proxy with mixed JSON + native commands. with self._redis.pipeline(transaction=False) as pipeline: # Add all JSON write operations (critical) for idx, write_obj in enumerate(writes_objects): key = write_keys[idx] pipeline.json().set(key, "$", write_obj) # Registry operation (critical) zadd_mapping = {key: idx for idx, key in enumerate(write_keys)} pipeline.zadd(thread_write_registry_key, zadd_mapping) # type: ignore[arg-type] # Execute critical commands only results = pipeline.execute(raise_on_error=False) # Check results for critical command failures for result in results: if isinstance(result, Exception): raise result # Apply TTL separately (best-effort — failures don't lose writes) if self.ttl_config and "default_ttl" in self.ttl_config: ttl_seconds = int(self.ttl_config.get("default_ttl") * 60) try: self._redis.expire(thread_write_registry_key, ttl_seconds) except Exception: logger.warning( "Failed to apply TTL to write registry key: %s", thread_write_registry_key, exc_info=True, ) for key in write_keys: try: self._redis.expire(key, ttl_seconds) except Exception: logger.warning( "Failed to apply TTL to checkpoint write key: %s", key, exc_info=True, )
def _load_pending_writes( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str ) -> List[PendingWrite]: """Load pending writes efficiently using thread-level write registry.""" if checkpoint_id is None: return [] # Use thread-level registry that only contains current checkpoint writes # All writes belong to the current checkpoint # Use to_storage_safe_str for consistent key naming with delete_thread safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) thread_write_registry_key = ( f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow" ) # Get all write keys from the thread's registry (already sorted by index) write_keys = self._redis.zrange(thread_write_registry_key, 0, -1) if not write_keys: return [] # Batch fetch all writes using pipeline with self._redis.pipeline(transaction=False) as pipeline: for key in write_keys: # Decode bytes to string if needed key_str = key.decode() if isinstance(key, bytes) else key pipeline.json().get(key_str) results = pipeline.execute() # Build the writes dictionary writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {} for write_data in results: if write_data: task_id = write_data.get("task_id", "") idx = write_data.get("idx", 0) writes_dict[(task_id, idx)] = write_data # Use base class method to deserialize return BaseRedisSaver._load_writes(self.serde, writes_dict)
[docs] def get_channel_values( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, channel_versions: Optional[Dict[str, Any]] = None, ) -> dict[str, Any]: """Retrieve channel_values dictionary from inline checkpoint data.""" # For shallow checkpoints, channel values are stored inline in the checkpoint checkpoint_key = self._make_shallow_redis_checkpoint_key_cached( thread_id, checkpoint_ns ) # Single JSON.GET operation to retrieve checkpoint with inline channel_values checkpoint_data = self._redis.json().get(checkpoint_key, "$.checkpoint") if not checkpoint_data: return {} # checkpoint_data[0] is already a deserialized dict checkpoint = ( checkpoint_data[0] if isinstance(checkpoint_data, list) else checkpoint_data ) channel_values = checkpoint.get("channel_values", {}) # Deserialize channel values since they're stored in serialized form # Cast to dict[str, Any] as we know this is the correct type from checkpoint structure from typing import cast return self._deserialize_channel_values( cast(dict[str, Any], channel_values) if channel_values else {} )
def _load_pending_sends( self, thread_id: str, checkpoint_ns: str, ) -> List[Tuple[str, bytes]]: """Load pending sends for a parent checkpoint. Args: thread_id: The thread ID checkpoint_ns: The checkpoint namespace parent_checkpoint_id: The ID of the parent checkpoint Returns: List of (type, blob) tuples representing pending sends """ # Query checkpoint_writes for parent checkpoint's TASKS channel parent_writes_query = FilterQuery( filter_expression=(Tag("thread_id") == thread_id) & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)) & (Tag("channel") == TASKS), return_fields=["type", "$.blob", "task_path", "task_id", "idx"], num_results=100, ) parent_writes_results = self.checkpoint_writes_index.search(parent_writes_query) # Sort results by task_path, task_id, idx (matching Postgres implementation) sorted_writes = sorted( parent_writes_results.docs, key=lambda x: ( getattr(x, "task_path", ""), getattr(x, "task_id", ""), getattr(x, "idx", 0), ), ) # Extract type and blob pairs # Handle both direct attribute access and JSON path access # Filter out documents where blob is None (similar to RedisSaver in __init__.py) return [ (getattr(doc, "type", ""), blob) for doc in sorted_writes if (blob := getattr(doc, "$.blob", getattr(doc, "blob", None))) is not None ] def _make_shallow_redis_checkpoint_key_cached( self, thread_id: str, checkpoint_ns: str ) -> str: """Create a cached key for shallow checkpoints using only thread_id and checkpoint_ns.""" cache_key = f"shallow_checkpoint:{thread_id}:{checkpoint_ns}" if cache_key in self._key_cache: # Move to end for LRU (most recently used) self._key_cache.move_to_end(cache_key) else: # Add new entry, evicting oldest if necessary if len(self._key_cache) >= self._key_cache_max_size: # Remove least recently used (first item) self._key_cache.popitem(last=False) self._key_cache[cache_key] = self._separator.join( [self._checkpoint_prefix, thread_id, checkpoint_ns] ) return self._key_cache[cache_key] @staticmethod def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str: """Create a key for shallow checkpoints using only thread_id and checkpoint_ns.""" return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns]) def _make_redis_checkpoint_writes_key_cached( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, task_id: str, idx: Optional[int], ) -> str: """Create a cached key for checkpoint writes.""" cache_key = ( f"writes:{thread_id}:{checkpoint_ns}:{checkpoint_id}:{task_id}:{idx}" ) if cache_key in self._key_cache: # Move to end for LRU (most recently used) self._key_cache.move_to_end(cache_key) else: # Add new entry, evicting oldest if necessary if len(self._key_cache) >= self._key_cache_max_size: # Remove least recently used (first item) self._key_cache.popitem(last=False) self._key_cache[cache_key] = self._make_redis_checkpoint_writes_key( thread_id, checkpoint_ns, checkpoint_id, task_id, idx ) return self._key_cache[cache_key] @staticmethod def _make_shallow_redis_checkpoint_writes_key_pattern( thread_id: str, checkpoint_ns: str ) -> str: """Create a pattern to match all writes keys for a thread and namespace.""" return ( REDIS_KEY_SEPARATOR.join( [ CHECKPOINT_WRITE_PREFIX, str(to_storage_safe_id(thread_id)), to_storage_safe_str(checkpoint_ns), ] ) + ":*" )
[docs] def delete_thread(self, thread_id: str) -> None: """Delete all checkpoints and writes associated with a specific thread ID. Args: thread_id: The thread ID whose checkpoints should be deleted. """ # Only one checkpoint per thread/namespace combination # Find all namespaces for this thread and delete them storage_safe_thread_id = to_storage_safe_id(thread_id) # Find all checkpoints for this thread to get checkpoint IDs checkpoint_query = FilterQuery( filter_expression=Tag("thread_id") == storage_safe_thread_id, return_fields=["checkpoint_ns", "checkpoint_id"], num_results=10000, ) checkpoint_results = self.checkpoints_index.search(checkpoint_query) # Collect namespaces and checkpoint IDs # The index stores checkpoint_ns in storage-safe form ("" -> "__empty__"). # _make_shallow_redis_checkpoint_key_cached() does its own to_storage_safe_str conversion # internally, so it needs the raw namespace. # The wrote_registry / current_checkpoint keys are raw f-strings, so they # need the storage-safe form that was used when those keys were originally written. checkpoint_data = [] for doc in checkpoint_results.docs: safe_ns = getattr(doc, "checkpoint_ns", "") # storage-safe: for f-strings raw_ns = from_storage_safe_str(safe_ns) # raw: for key builder method checkpoint_id = getattr(doc, "checkpoint_id", "") checkpoint_data.append((raw_ns, safe_ns, checkpoint_id)) # Delete all checkpoints and related data if checkpoint_data: with self._redis.pipeline(transaction=False) as pipeline: for raw_ns, safe_ns, checkpoint_id in checkpoint_data: # Key builder converts internally - pass raw namespace checkpoint_key = self._make_shallow_redis_checkpoint_key_cached( thread_id, raw_ns ) pipeline.delete(checkpoint_key) # write_registry key was stored with storage-safe ns - use safe_ns here thread_write_registry_key = ( f"write_registry:{thread_id}:{safe_ns}:shallow" ) # Get all write keys from the thread registry before deleting write_keys = self._redis.zrange(thread_write_registry_key, 0, -1) for write_key in write_keys: write_key_str = ( write_key.decode() if isinstance(write_key, bytes) else write_key ) pipeline.delete(write_key_str) # Delete the registry itself pipeline.delete(thread_write_registry_key) # Delete the current checkpoint tracker - use safe_ns here current_checkpoint_key = ( f"current_checkpoint:{thread_id}:{safe_ns}:shallow" ) pipeline.delete(current_checkpoint_key) pipeline.execute()
[docs] def prune( self, thread_ids: Sequence[str], *, strategy: str = "keep_latest", keep_last: Optional[int] = None, ) -> None: """Prune checkpoints for the given threads. ``ShallowRedisSaver`` stores at most one checkpoint per namespace by design, so ``strategy="keep_latest"`` (or ``keep_last >= 1``) is always a no-op. ``strategy="delete"`` (or ``keep_last=0``) removes all checkpoints for each thread (equivalent to ``delete_thread``). Args: thread_ids: Thread IDs to prune. strategy: Pruning strategy. ``"keep_latest"`` is a no-op for shallow savers (default). ``"delete"`` removes all. keep_last: Optional override. Any value >= 1 is a no-op. Pass ``0`` to delete all. """ # Resolve keep_last from strategy if not explicitly provided if keep_last is None: if strategy == "delete": keep_last = 0 else: keep_last = 1 # Validate input if not thread_ids: raise ValueError("``thread_ids`` must be a non-empty sequence") if keep_last < 0: raise ValueError(f"``keep_last`` must be >= 0, got {keep_last}") if keep_last >= 1: return for thread_id in thread_ids: self.delete_thread(thread_id)
def _extract_fallback_timestamp(self, checkpoint: Checkpoint) -> float: """Extract timestamp from checkpoint's ts field or use current time. This is used when the checkpoint_id is not a valid ULID (e.g., UUIDv6 format). See Issue #136 for details. Args: checkpoint: The checkpoint object containing an optional ts field. Returns: Timestamp in milliseconds since epoch. """ ts_value = checkpoint.get("ts") if ts_value: # Handle both ISO string and numeric timestamps if isinstance(ts_value, str): try: dt = datetime.fromisoformat(ts_value.replace("Z", "+00:00")) return dt.timestamp() * MILLISECONDS_PER_SECOND except Exception: return time.time() * MILLISECONDS_PER_SECOND else: return ts_value return time.time() * MILLISECONDS_PER_SECOND