from __future__ import annotations
import json
import logging
import time
from collections import OrderedDict
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, cast
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
PendingWrite,
get_checkpoint_id,
)
from langgraph.constants import TASKS
from redis import Redis
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
from redisvl.redis.connection import RedisConnectionFactory
from ulid import ULID
from langgraph.checkpoint.redis.base import (
CHECKPOINT_PREFIX,
CHECKPOINT_WRITE_PREFIX,
REDIS_KEY_SEPARATOR,
BaseRedisSaver,
)
from langgraph.checkpoint.redis.util import (
from_storage_safe_str,
to_storage_safe_id,
to_storage_safe_str,
)
# Constants
MILLISECONDS_PER_SECOND = 1000
# Logger for this module
logger = logging.getLogger(__name__)
[docs]
class ShallowRedisSaver(BaseRedisSaver[Redis, SearchIndex]):
"""Redis implementation that only stores the most recent checkpoint.
Supports standard Redis URLs (redis://), SSL (rediss://), and
Sentinel URLs (redis+sentinel://host:26379/service_name/db).
"""
# Default cache size limits
DEFAULT_KEY_CACHE_MAX_SIZE = 1000
DEFAULT_CHANNEL_CACHE_MAX_SIZE = 100
def __init__(
self,
redis_url: Optional[str] = None,
*,
redis_client: Optional[Redis] = None,
connection_args: Optional[dict[str, Any]] = None,
ttl: Optional[dict[str, Any]] = None,
key_cache_max_size: Optional[int] = None,
channel_cache_max_size: Optional[int] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> None:
super().__init__(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)
# Instance-level cache for frequently used keys (limited size to prevent memory issues)
# Using OrderedDict for LRU cache eviction
self._key_cache: OrderedDict[str, str] = OrderedDict()
self._key_cache_max_size = key_cache_max_size or self.DEFAULT_KEY_CACHE_MAX_SIZE
self._channel_cache: OrderedDict[str, Any] = OrderedDict()
self._channel_cache_max_size = (
channel_cache_max_size or self.DEFAULT_CHANNEL_CACHE_MAX_SIZE
)
# Prefixes are now set in BaseRedisSaver.__init__
self._separator = REDIS_KEY_SEPARATOR
[docs]
@classmethod
@contextmanager
def from_conn_string(
cls,
redis_url: Optional[str] = None,
*,
redis_client: Optional[Redis] = None,
connection_args: Optional[dict[str, Any]] = None,
ttl: Optional[dict[str, Any]] = None,
key_cache_max_size: Optional[int] = None,
channel_cache_max_size: Optional[int] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> Iterator[ShallowRedisSaver]:
"""Create a new ShallowRedisSaver instance."""
saver: Optional[ShallowRedisSaver] = None
try:
saver = cls(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
key_cache_max_size=key_cache_max_size,
channel_cache_max_size=channel_cache_max_size,
checkpoint_prefix=checkpoint_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)
yield saver
finally:
if saver and saver._owns_its_client:
saver._redis.close()
# RedisCluster doesn't have connection_pool attribute
if getattr(saver._redis, "connection_pool", None):
saver._redis.connection_pool.disconnect()
[docs]
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Store checkpoint with inline channel values."""
configurable = config["configurable"].copy()
run_id = configurable.pop("run_id", metadata.get("run_id"))
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
copy = checkpoint.copy()
next_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
# Extract timestamp from checkpoint_id (ULID) or fallback to checkpoint's ts field
# Note: LangGraph may generate checkpoint IDs in different formats (ULID, UUIDv6, etc.)
# We try ULID first, then fall back gracefully without warnings (Issue #136)
checkpoint_ts = None
if checkpoint["id"]:
try:
ulid_obj = ULID.from_str(checkpoint["id"])
checkpoint_ts = ulid_obj.timestamp # milliseconds since epoch
except Exception:
# Not a valid ULID - this is expected for UUIDv6 and other formats
# Fall back to checkpoint's timestamp field or current time
checkpoint_ts = self._extract_fallback_timestamp(checkpoint)
# Parse metadata from string to dict to avoid double serialization
metadata_str = self._dump_metadata(metadata)
metadata_dict = (
json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
)
# Store channel values inline in the checkpoint
copy["channel_values"] = checkpoint.get("channel_values", {})
checkpoint_data = {
"thread_id": thread_id,
"run_id": to_storage_safe_id(run_id) if run_id else "",
"checkpoint_ns": to_storage_safe_str(checkpoint_ns),
"checkpoint_id": checkpoint["id"],
"checkpoint_ts": checkpoint_ts,
"checkpoint": self._dump_checkpoint(copy),
"metadata": metadata_dict,
# Note: has_writes tracking removed to support put_writes before checkpoint exists
}
# Store at top-level for filters in list()
if all(key in metadata for key in ["source", "step"]):
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"]
checkpoint_key = self._make_shallow_redis_checkpoint_key_cached(
thread_id, checkpoint_ns
)
# Only critical commands (JSON.SET) go in the pipeline.
# EXPIRE is applied separately to avoid pipeline failures on
# Redis Enterprise proxy with mixed JSON + native commands.
with self._redis.pipeline(transaction=False) as pipeline:
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
pipeline.execute()
# Apply TTL separately (best-effort)
if self.ttl_config and "default_ttl" in self.ttl_config:
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
try:
self._redis.expire(checkpoint_key, ttl_seconds)
except Exception:
logger.warning(
"Failed to apply TTL to checkpoint key: %s", checkpoint_key
)
# NOTE: We intentionally do NOT clean up old writes here.
# In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
# put_writes BEFORE the new checkpoint is saved. If we clean up writes
# when the checkpoint changes, we would delete the interrupt writes
# before they can be consumed when resuming.
#
# Writes are cleaned up in the following scenarios:
# 1. When delete_thread is called
# 2. When TTL expires (if configured)
# 3. When put_writes is called again for the same task/idx (overwrites)
#
# See Issue #133 for details on this bug fix.
return next_config
[docs]
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from Redis."""
# Construct the filter expression
filter_expression = []
if config:
filter_expression.append(
Tag("thread_id")
== to_storage_safe_id(config["configurable"]["thread_id"])
)
if run_id := config["configurable"].get("run_id"):
filter_expression.append(Tag("run_id") == to_storage_safe_id(run_id))
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
filter_expression.append(
Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)
)
if filter:
for k, v in filter.items():
if k == "source":
filter_expression.append(Tag("source") == v)
elif k == "step":
filter_expression.append(Num("step") == v)
elif k == "thread_id":
filter_expression.append(Tag("thread_id") == to_storage_safe_id(v))
elif k == "run_id":
filter_expression.append(Tag("run_id") == to_storage_safe_id(v))
else:
raise ValueError(f"Unsupported filter key: {k}")
if before:
before_checkpoint_id = get_checkpoint_id(before)
if before_checkpoint_id:
try:
before_ulid = ULID.from_str(before_checkpoint_id)
before_ts = before_ulid.timestamp
# Use numeric range query: checkpoint_ts < before_ts
filter_expression.append(Num("checkpoint_ts") < before_ts)
except Exception:
# If not a valid ULID, ignore the before filter
pass
# Combine all filter expressions
combined_filter = filter_expression[0] if filter_expression else "*"
for expr in filter_expression[1:]:
combined_filter &= expr
# Get checkpoint data
# Sort by checkpoint_id in descending order to get most recent checkpoints first
query = FilterQuery(
filter_expression=combined_filter,
return_fields=[
"thread_id",
"checkpoint_ns",
"$.checkpoint",
"$.metadata",
],
num_results=limit or 10000,
sort_by=("checkpoint_id", "DESC"),
)
# Execute the query
results = self.checkpoints_index.search(query)
# Process the results
for doc in results.docs:
thread_id = cast(str, getattr(doc, "thread_id", ""))
checkpoint_ns = cast(str, getattr(doc, "checkpoint_ns", ""))
checkpoint = json.loads(doc["$.checkpoint"])
# Extract channel values from the checkpoint (they're stored inline)
channel_values: Dict[str, Any] = checkpoint.get("channel_values", {})
# Deserialize them since they're stored in serialized form
channel_values = self._deserialize_channel_values(channel_values)
# Parse metadata
raw_metadata = getattr(doc, "$.metadata", "{}")
metadata_dict = (
json.loads(raw_metadata)
if isinstance(raw_metadata, str)
else raw_metadata
)
# Sanitize metadata
sanitized_metadata = {
k.replace("\u0000", ""): (
v.replace("\u0000", "") if isinstance(v, str) else v
)
for k, v in metadata_dict.items()
}
metadata = cast(CheckpointMetadata, sanitized_metadata)
# Load checkpoint with inline channel values
checkpoint_param = self._load_checkpoint(
doc["$.checkpoint"],
channel_values, # Pass the extracted channel values
[], # No pending_sends in shallow mode
)
config_param: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_param["id"],
}
}
# Load pending writes (still uses separate keys - already efficient)
pending_writes = self._load_pending_writes(
thread_id, checkpoint_ns, checkpoint_param["id"]
)
yield CheckpointTuple(
config=config_param,
checkpoint=checkpoint_param,
metadata=metadata,
parent_config=None,
pending_writes=pending_writes,
)
[docs]
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get checkpoint with inline channel values."""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
# Single key access gets everything inline
checkpoint_key = self._make_shallow_redis_checkpoint_key_cached(
thread_id, checkpoint_ns
)
checkpoint_data = self._redis.json().get(checkpoint_key)
if not checkpoint_data or not isinstance(checkpoint_data, dict):
return None
# TTL refresh if enabled (best-effort)
if self.ttl_config and self.ttl_config.get("refresh_on_read"):
default_ttl_minutes = self.ttl_config.get("default_ttl", 60)
ttl_seconds = int(default_ttl_minutes * 60)
try:
self._redis.expire(checkpoint_key, ttl_seconds)
except Exception:
logger.warning(
"Failed to refresh TTL on checkpoint key: %s", checkpoint_key
)
# Parse the checkpoint data
checkpoint = checkpoint_data.get("checkpoint", {})
if isinstance(checkpoint, str):
checkpoint = json.loads(checkpoint)
# Extract channel values from the checkpoint (they're stored inline)
channel_values: Dict[str, Any] = checkpoint.get("channel_values", {})
# Deserialize them since they're stored in serialized form
channel_values = self._deserialize_channel_values(channel_values)
# Parse metadata
metadata = checkpoint_data.get("metadata", {})
if isinstance(metadata, str):
metadata = json.loads(metadata)
# Sanitize metadata
sanitized_metadata = {
k.replace("\u0000", ""): (
v.replace("\u0000", "") if isinstance(v, str) else v
)
for k, v in metadata.items()
}
# Load checkpoint with inline channel values
checkpoint_param = self._load_checkpoint(
json.dumps(checkpoint),
channel_values, # Pass the raw channel values - no deserialization needed
[], # No pending_sends in shallow mode
)
# Load pending writes (still uses separate keys - already efficient)
pending_writes = self._load_pending_writes(
thread_id, checkpoint_ns, checkpoint_param["id"]
)
return CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
},
checkpoint=checkpoint_param,
metadata=cast(CheckpointMetadata, sanitized_metadata),
parent_config=None,
pending_writes=pending_writes,
)
[docs]
def create_indexes(self) -> None:
self.checkpoints_index = SearchIndex.from_dict(
self.checkpoints_schema, redis_client=self._redis
)
self.checkpoint_writes_index = SearchIndex.from_dict(
self.writes_schema, redis_client=self._redis
)
[docs]
def setup(self) -> None:
"""Initialize the indices in Redis."""
self.checkpoints_index.create(overwrite=False)
self.checkpoint_writes_index.create(overwrite=False)
[docs]
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint with checkpoint-level registry.
Args:
config: Configuration of the related checkpoint.
writes: List of writes to store, each as (channel, value) pair.
task_id: Identifier for the task creating the writes.
task_path: Optional path info for the task.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
checkpoint_id = config["configurable"]["checkpoint_id"]
# Transform writes into appropriate format
writes_objects = []
for idx, (channel, value) in enumerate(writes):
type_, blob = self.serde.dumps_typed(value)
write_obj = {
"thread_id": thread_id,
"checkpoint_ns": to_storage_safe_str(checkpoint_ns),
"checkpoint_id": checkpoint_id,
"task_id": task_id,
"task_path": task_path,
"idx": WRITES_IDX_MAP.get(channel, idx),
"channel": channel,
"type": type_,
"blob": self._encode_blob(
blob
), # Encode bytes to base64 string for Redis
}
writes_objects.append(write_obj)
# THREAD-LEVEL REGISTRY: Only keep writes for the current checkpoint
# Use to_storage_safe_str for consistent key naming with delete_thread
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
thread_write_registry_key = (
f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow"
)
# Collect all write keys
write_keys = []
for write_obj in writes_objects:
key = self._make_redis_checkpoint_writes_key_cached(
thread_id, checkpoint_ns, checkpoint_id, task_id, write_obj["idx"]
)
write_keys.append(key)
# Only critical commands (JSON.SET, ZADD) go in the pipeline.
# EXPIRE (TTL) is applied separately to avoid pipeline failures on
# Redis Enterprise proxy with mixed JSON + native commands.
with self._redis.pipeline(transaction=False) as pipeline:
# Add all JSON write operations (critical)
for idx, write_obj in enumerate(writes_objects):
key = write_keys[idx]
pipeline.json().set(key, "$", write_obj)
# Registry operation (critical)
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
pipeline.zadd(thread_write_registry_key, zadd_mapping) # type: ignore[arg-type]
# Execute critical commands only
results = pipeline.execute(raise_on_error=False)
# Check results for critical command failures
for result in results:
if isinstance(result, Exception):
raise result
# Apply TTL separately (best-effort — failures don't lose writes)
if self.ttl_config and "default_ttl" in self.ttl_config:
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
try:
self._redis.expire(thread_write_registry_key, ttl_seconds)
except Exception:
logger.warning(
"Failed to apply TTL to write registry key: %s",
thread_write_registry_key,
exc_info=True,
)
for key in write_keys:
try:
self._redis.expire(key, ttl_seconds)
except Exception:
logger.warning(
"Failed to apply TTL to checkpoint write key: %s",
key,
exc_info=True,
)
def _load_pending_writes(
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
) -> List[PendingWrite]:
"""Load pending writes efficiently using thread-level write registry."""
if checkpoint_id is None:
return []
# Use thread-level registry that only contains current checkpoint writes
# All writes belong to the current checkpoint
# Use to_storage_safe_str for consistent key naming with delete_thread
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
thread_write_registry_key = (
f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow"
)
# Get all write keys from the thread's registry (already sorted by index)
write_keys = self._redis.zrange(thread_write_registry_key, 0, -1)
if not write_keys:
return []
# Batch fetch all writes using pipeline
with self._redis.pipeline(transaction=False) as pipeline:
for key in write_keys:
# Decode bytes to string if needed
key_str = key.decode() if isinstance(key, bytes) else key
pipeline.json().get(key_str)
results = pipeline.execute()
# Build the writes dictionary
writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {}
for write_data in results:
if write_data:
task_id = write_data.get("task_id", "")
idx = write_data.get("idx", 0)
writes_dict[(task_id, idx)] = write_data
# Use base class method to deserialize
return BaseRedisSaver._load_writes(self.serde, writes_dict)
[docs]
def get_channel_values(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
channel_versions: Optional[Dict[str, Any]] = None,
) -> dict[str, Any]:
"""Retrieve channel_values dictionary from inline checkpoint data."""
# For shallow checkpoints, channel values are stored inline in the checkpoint
checkpoint_key = self._make_shallow_redis_checkpoint_key_cached(
thread_id, checkpoint_ns
)
# Single JSON.GET operation to retrieve checkpoint with inline channel_values
checkpoint_data = self._redis.json().get(checkpoint_key, "$.checkpoint")
if not checkpoint_data:
return {}
# checkpoint_data[0] is already a deserialized dict
checkpoint = (
checkpoint_data[0] if isinstance(checkpoint_data, list) else checkpoint_data
)
channel_values = checkpoint.get("channel_values", {})
# Deserialize channel values since they're stored in serialized form
# Cast to dict[str, Any] as we know this is the correct type from checkpoint structure
from typing import cast
return self._deserialize_channel_values(
cast(dict[str, Any], channel_values) if channel_values else {}
)
def _load_pending_sends(
self,
thread_id: str,
checkpoint_ns: str,
) -> List[Tuple[str, bytes]]:
"""Load pending sends for a parent checkpoint.
Args:
thread_id: The thread ID
checkpoint_ns: The checkpoint namespace
parent_checkpoint_id: The ID of the parent checkpoint
Returns:
List of (type, blob) tuples representing pending sends
"""
# Query checkpoint_writes for parent checkpoint's TASKS channel
parent_writes_query = FilterQuery(
filter_expression=(Tag("thread_id") == thread_id)
& (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
& (Tag("channel") == TASKS),
return_fields=["type", "$.blob", "task_path", "task_id", "idx"],
num_results=100,
)
parent_writes_results = self.checkpoint_writes_index.search(parent_writes_query)
# Sort results by task_path, task_id, idx (matching Postgres implementation)
sorted_writes = sorted(
parent_writes_results.docs,
key=lambda x: (
getattr(x, "task_path", ""),
getattr(x, "task_id", ""),
getattr(x, "idx", 0),
),
)
# Extract type and blob pairs
# Handle both direct attribute access and JSON path access
# Filter out documents where blob is None (similar to RedisSaver in __init__.py)
return [
(getattr(doc, "type", ""), blob)
for doc in sorted_writes
if (blob := getattr(doc, "$.blob", getattr(doc, "blob", None))) is not None
]
def _make_shallow_redis_checkpoint_key_cached(
self, thread_id: str, checkpoint_ns: str
) -> str:
"""Create a cached key for shallow checkpoints using only thread_id and checkpoint_ns."""
cache_key = f"shallow_checkpoint:{thread_id}:{checkpoint_ns}"
if cache_key in self._key_cache:
# Move to end for LRU (most recently used)
self._key_cache.move_to_end(cache_key)
else:
# Add new entry, evicting oldest if necessary
if len(self._key_cache) >= self._key_cache_max_size:
# Remove least recently used (first item)
self._key_cache.popitem(last=False)
self._key_cache[cache_key] = self._separator.join(
[self._checkpoint_prefix, thread_id, checkpoint_ns]
)
return self._key_cache[cache_key]
@staticmethod
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
"""Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns])
def _make_redis_checkpoint_writes_key_cached(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
task_id: str,
idx: Optional[int],
) -> str:
"""Create a cached key for checkpoint writes."""
cache_key = (
f"writes:{thread_id}:{checkpoint_ns}:{checkpoint_id}:{task_id}:{idx}"
)
if cache_key in self._key_cache:
# Move to end for LRU (most recently used)
self._key_cache.move_to_end(cache_key)
else:
# Add new entry, evicting oldest if necessary
if len(self._key_cache) >= self._key_cache_max_size:
# Remove least recently used (first item)
self._key_cache.popitem(last=False)
self._key_cache[cache_key] = self._make_redis_checkpoint_writes_key(
thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
return self._key_cache[cache_key]
@staticmethod
def _make_shallow_redis_checkpoint_writes_key_pattern(
thread_id: str, checkpoint_ns: str
) -> str:
"""Create a pattern to match all writes keys for a thread and namespace."""
return (
REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_WRITE_PREFIX,
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
]
)
+ ":*"
)
[docs]
def delete_thread(self, thread_id: str) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.
Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
# Only one checkpoint per thread/namespace combination
# Find all namespaces for this thread and delete them
storage_safe_thread_id = to_storage_safe_id(thread_id)
# Find all checkpoints for this thread to get checkpoint IDs
checkpoint_query = FilterQuery(
filter_expression=Tag("thread_id") == storage_safe_thread_id,
return_fields=["checkpoint_ns", "checkpoint_id"],
num_results=10000,
)
checkpoint_results = self.checkpoints_index.search(checkpoint_query)
# Collect namespaces and checkpoint IDs
# The index stores checkpoint_ns in storage-safe form ("" -> "__empty__").
# _make_shallow_redis_checkpoint_key_cached() does its own to_storage_safe_str conversion
# internally, so it needs the raw namespace.
# The wrote_registry / current_checkpoint keys are raw f-strings, so they
# need the storage-safe form that was used when those keys were originally written.
checkpoint_data = []
for doc in checkpoint_results.docs:
safe_ns = getattr(doc, "checkpoint_ns", "") # storage-safe: for f-strings
raw_ns = from_storage_safe_str(safe_ns) # raw: for key builder method
checkpoint_id = getattr(doc, "checkpoint_id", "")
checkpoint_data.append((raw_ns, safe_ns, checkpoint_id))
# Delete all checkpoints and related data
if checkpoint_data:
with self._redis.pipeline(transaction=False) as pipeline:
for raw_ns, safe_ns, checkpoint_id in checkpoint_data:
# Key builder converts internally - pass raw namespace
checkpoint_key = self._make_shallow_redis_checkpoint_key_cached(
thread_id, raw_ns
)
pipeline.delete(checkpoint_key)
# write_registry key was stored with storage-safe ns - use safe_ns here
thread_write_registry_key = (
f"write_registry:{thread_id}:{safe_ns}:shallow"
)
# Get all write keys from the thread registry before deleting
write_keys = self._redis.zrange(thread_write_registry_key, 0, -1)
for write_key in write_keys:
write_key_str = (
write_key.decode()
if isinstance(write_key, bytes)
else write_key
)
pipeline.delete(write_key_str)
# Delete the registry itself
pipeline.delete(thread_write_registry_key)
# Delete the current checkpoint tracker - use safe_ns here
current_checkpoint_key = (
f"current_checkpoint:{thread_id}:{safe_ns}:shallow"
)
pipeline.delete(current_checkpoint_key)
pipeline.execute()
[docs]
def prune(
self,
thread_ids: Sequence[str],
*,
strategy: str = "keep_latest",
keep_last: Optional[int] = None,
) -> None:
"""Prune checkpoints for the given threads.
``ShallowRedisSaver`` stores at most one checkpoint per namespace by
design, so ``strategy="keep_latest"`` (or ``keep_last >= 1``) is
always a no-op. ``strategy="delete"`` (or ``keep_last=0``) removes
all checkpoints for each thread (equivalent to ``delete_thread``).
Args:
thread_ids: Thread IDs to prune.
strategy: Pruning strategy. ``"keep_latest"`` is a no-op for
shallow savers (default). ``"delete"`` removes all.
keep_last: Optional override. Any value >= 1 is a no-op.
Pass ``0`` to delete all.
"""
# Resolve keep_last from strategy if not explicitly provided
if keep_last is None:
if strategy == "delete":
keep_last = 0
else:
keep_last = 1
# Validate input
if not thread_ids:
raise ValueError("``thread_ids`` must be a non-empty sequence")
if keep_last < 0:
raise ValueError(f"``keep_last`` must be >= 0, got {keep_last}")
if keep_last >= 1:
return
for thread_id in thread_ids:
self.delete_thread(thread_id)
def _extract_fallback_timestamp(self, checkpoint: Checkpoint) -> float:
"""Extract timestamp from checkpoint's ts field or use current time.
This is used when the checkpoint_id is not a valid ULID (e.g., UUIDv6 format).
See Issue #136 for details.
Args:
checkpoint: The checkpoint object containing an optional ts field.
Returns:
Timestamp in milliseconds since epoch.
"""
ts_value = checkpoint.get("ts")
if ts_value:
# Handle both ISO string and numeric timestamps
if isinstance(ts_value, str):
try:
dt = datetime.fromisoformat(ts_value.replace("Z", "+00:00"))
return dt.timestamp() * MILLISECONDS_PER_SECOND
except Exception:
return time.time() * MILLISECONDS_PER_SECOND
else:
return ts_value
return time.time() * MILLISECONDS_PER_SECOND