"""Async shallow Redis implementation for LangGraph checkpoint saving."""
from __future__ import annotations
import asyncio
import json
import logging
import time
from contextlib import asynccontextmanager
from datetime import datetime
from types import TracebackType
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
PendingWrite,
get_checkpoint_id,
)
from langgraph.constants import TASKS
from redis.asyncio import Redis as AsyncRedis
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
from redisvl.redis.connection import RedisConnectionFactory
from ulid import ULID
from langgraph.checkpoint.redis.base import (
CHECKPOINT_PREFIX,
CHECKPOINT_WRITE_PREFIX,
REDIS_KEY_SEPARATOR,
BaseRedisSaver,
)
from langgraph.checkpoint.redis.util import (
to_storage_safe_id,
to_storage_safe_str,
)
logger = logging.getLogger(__name__)
# Constants
MILLISECONDS_PER_SECOND = 1000
[docs]
class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
"""Async Redis implementation that only stores the most recent checkpoint.
Supports standard Redis URLs (redis://), SSL (rediss://), and
Sentinel URLs (redis+sentinel://host:26379/service_name/db).
"""
_redis_url: str
checkpoints_index: AsyncSearchIndex
checkpoint_writes_index: AsyncSearchIndex
_redis: AsyncRedis # Override the type from the base class
def __init__(
self,
redis_url: Optional[str] = None,
*,
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
ttl: Optional[dict[str, Any]] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> None:
super().__init__(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)
self.loop = asyncio.get_running_loop()
# Instance-level cache for frequently used keys (limited size to prevent memory issues)
self._key_cache: Dict[str, str] = {}
self._key_cache_max_size = 1000 # Configurable limit
self._channel_cache: Dict[str, Any] = {}
self._separator = REDIS_KEY_SEPARATOR
async def __aenter__(self) -> AsyncShallowRedisSaver:
"""Async context manager enter."""
await self.asetup()
# Set client info once Redis is set up
await self.aset_client_info()
return self
async def __aexit__(
self,
_exc_type: Optional[Type[BaseException]],
_exc: Optional[BaseException],
_tb: Optional[TracebackType],
) -> None:
if self._owns_its_client:
await self._redis.aclose() # type: ignore[attr-defined]
# RedisCluster doesn't have connection_pool attribute
if getattr(self._redis, "connection_pool", None):
coro = self._redis.connection_pool.disconnect()
if coro:
await coro
# Prevent RedisVL from attempting to close the client
# on an event loop in a separate thread.
self.checkpoints_index._redis_client = None
self.checkpoint_writes_index._redis_client = None
[docs]
@classmethod
@asynccontextmanager
async def from_conn_string(
cls,
redis_url: Optional[str] = None,
*,
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
ttl: Optional[dict[str, Any]] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> AsyncIterator[AsyncShallowRedisSaver]:
"""Create a new AsyncShallowRedisSaver instance."""
async with cls(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
) as saver:
yield saver
[docs]
async def asetup(self) -> None:
"""Initialize Redis indexes asynchronously."""
await self.checkpoints_index.create(overwrite=False)
await self.checkpoint_writes_index.create(overwrite=False)
[docs]
async def setup(self) -> None: # type: ignore[override]
"""Set up the checkpoint saver asynchronously.
This method creates the necessary indices in Redis.
It MUST be called before using the checkpointer.
This async method follows the canonical pattern used by other
async checkpointers in the LangGraph ecosystem. The type ignore is necessary because
the base class defines a sync setup() method, but async checkpointers require
an async setup() method to properly handle coroutines.
"""
await self.asetup()
[docs]
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Store checkpoint with INLINE channel values
Stores all channel values directly in main checkpoint JSON
Args:
config: The config to associate with the checkpoint
checkpoint: The checkpoint data to store
metadata: Additional metadata to save with the checkpoint
new_versions: New channel versions as of this write
Returns:
Updated configuration after storing the checkpoint
Raises:
asyncio.CancelledError: If the operation is cancelled/interrupted
"""
configurable = config["configurable"].copy()
run_id = configurable.pop("run_id", metadata.get("run_id"))
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
copy = checkpoint.copy()
next_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
try:
# Extract timestamp from checkpoint_id (ULID) or fallback to checkpoint's ts field
# Note: LangGraph may generate checkpoint IDs in different formats (ULID, UUIDv6, etc.)
# We try ULID first, then fall back gracefully without warnings (Issue #136)
checkpoint_ts = None
if checkpoint["id"]:
try:
ulid_obj = ULID.from_str(checkpoint["id"])
checkpoint_ts = ulid_obj.timestamp # milliseconds since epoch
except Exception:
# Not a valid ULID - this is expected for UUIDv6 and other formats
# Fall back to checkpoint's timestamp field or current time
checkpoint_ts = self._extract_fallback_timestamp(checkpoint)
# Store channel values inline in the checkpoint
copy["channel_values"] = checkpoint.get("channel_values", {})
checkpoint_data = {
"thread_id": thread_id,
"run_id": to_storage_safe_id(run_id) if run_id else "",
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
"checkpoint_ts": checkpoint_ts,
"checkpoint": self._dump_checkpoint(copy),
"metadata": self._dump_metadata(metadata),
# Note: has_writes tracking removed to support put_writes before checkpoint exists
}
# Store at top-level for filters in list()
if all(key in metadata for key in ["source", "step"]):
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"]
# SHALLOW MODE: Only one key needed - overwrite everything atomically
checkpoint_key = self._make_shallow_redis_checkpoint_key_cached(
thread_id, checkpoint_ns
)
# Only critical commands (JSON.SET) go in the pipeline.
# EXPIRE is applied separately to avoid pipeline failures on
# Redis Enterprise proxy with mixed JSON + native commands.
pipeline = self._redis.pipeline(transaction=False)
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
await pipeline.execute()
# Apply TTL separately (best-effort)
if self.ttl_config and "default_ttl" in self.ttl_config:
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
try:
await self._redis.expire(checkpoint_key, ttl_seconds)
except Exception:
logger.warning(
"Failed to apply TTL to checkpoint key: %s", checkpoint_key
)
# NOTE: We intentionally do NOT clean up old writes here.
# In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
# put_writes BEFORE the new checkpoint is saved. If we clean up writes
# when the checkpoint changes, we would delete the interrupt writes
# before they can be consumed when resuming.
#
# Writes are cleaned up in the following scenarios:
# 1. When delete_thread is called
# 2. When TTL expires (if configured)
# 3. When put_writes is called again for the same task/idx (overwrites)
#
# See Issue #133 for details on this bug fix.
return next_config
except asyncio.CancelledError:
# Handle cancellation/interruption
# Pipeline will be automatically discarded
# Either all operations succeed or none do
raise
except Exception as e:
# Re-raise other exceptions
raise e
[docs]
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None, # noqa: ARG002
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""List checkpoints from Redis asynchronously."""
query_filter = []
if config:
query_filter.append(
Tag("thread_id")
== to_storage_safe_id(config["configurable"]["thread_id"])
)
if run_id := config["configurable"].get("run_id"):
query_filter.append(Tag("run_id") == to_storage_safe_id(run_id))
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
query_filter.append(
Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)
)
if filter:
for key, value in filter.items():
if key == "source":
query_filter.append(Tag("source") == value)
elif key == "step":
query_filter.append(Num("step") == value)
elif key == "thread_id":
query_filter.append(Tag("thread_id") == to_storage_safe_id(value))
elif key == "run_id":
query_filter.append(Tag("run_id") == to_storage_safe_id(value))
if before:
before_checkpoint_id = get_checkpoint_id(before)
if before_checkpoint_id:
try:
before_ulid = ULID.from_str(before_checkpoint_id)
before_ts = before_ulid.timestamp
# Use numeric range query: checkpoint_ts < before_ts
query_filter.append(Num("checkpoint_ts") < before_ts)
except Exception:
# If not a valid ULID, ignore the before filter
pass
combined_filter = query_filter[0] if query_filter else "*"
for expr in query_filter[1:]:
combined_filter &= expr
# Sort by checkpoint_id in descending order to get most recent checkpoints first
query = FilterQuery(
filter_expression=combined_filter,
return_fields=[
"thread_id",
"checkpoint_ns",
"checkpoint_id",
"source",
"step",
"score",
"ts",
],
num_results=limit or 100, # Set higher limit to retrieve more results
sort_by=("checkpoint_id", "DESC"),
)
results = await self.checkpoints_index.search(query)
for doc in results.docs:
yield CheckpointTuple(
config={
"configurable": {
"thread_id": getattr(doc, "thread_id", ""),
"checkpoint_ns": getattr(doc, "checkpoint_ns", ""),
"checkpoint_id": getattr(doc, "checkpoint_id", ""),
}
},
checkpoint={
"v": 1,
"ts": getattr(doc, "ts", ""),
"id": getattr(doc, "checkpoint_id", ""),
"channel_values": {},
"channel_versions": {},
"versions_seen": {},
"pending_sends": [],
},
metadata={
"source": getattr(doc, "source", "input"),
"step": int(getattr(doc, "step", 0)),
"writes": {},
"score": float(getattr(doc, "score", 0)),
},
pending_writes=[],
)
[docs]
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Retrieve a checkpoint tuple from Redis asynchronously."""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
# Use direct key access for shallow checkpoints
# Shallow checkpoints only store the latest checkpoint per thread/namespace
checkpoint_key = self._make_shallow_redis_checkpoint_key_cached(
thread_id, checkpoint_ns
)
# Single fetch gets everything inline - matching sync implementation
full_checkpoint_data = await self._redis.json().get(checkpoint_key) # type: ignore[misc]
if not full_checkpoint_data or not isinstance(full_checkpoint_data, dict):
return None
# If refresh_on_read is enabled, refresh TTL for checkpoint key (best-effort)
if self.ttl_config and self.ttl_config.get("refresh_on_read"):
default_ttl_minutes = self.ttl_config.get("default_ttl", 60)
ttl_seconds = int(default_ttl_minutes * 60)
try:
await self._redis.expire(checkpoint_key, ttl_seconds)
except Exception:
logger.warning(
"Failed to refresh TTL on checkpoint key: %s", checkpoint_key
)
# Parse the checkpoint data
checkpoint = full_checkpoint_data.get("checkpoint", {})
if isinstance(checkpoint, str):
checkpoint = json.loads(checkpoint)
# Extract channel values from the checkpoint (they're stored inline)
# NO NEED TO CALL aget_channel_values - we already have the data!
channel_values: Dict[str, Any] = checkpoint.get("channel_values", {})
# Deserialize them since they're stored in serialized form
channel_values = self._deserialize_channel_values(channel_values)
# Parse metadata
metadata = full_checkpoint_data.get("metadata", {})
if isinstance(metadata, str):
metadata = json.loads(metadata)
# Ensure metadata matches CheckpointMetadata type
sanitized_metadata = {
k.replace("\u0000", ""): (
v.replace("\u0000", "") if isinstance(v, str) else v
)
for k, v in metadata.items()
}
# For shallow mode, pending_sends is always empty
pending_sends: list[tuple[str, bytes]] = []
config_param: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
checkpoint_param = self._load_checkpoint(
json.dumps(checkpoint),
channel_values,
pending_sends, # No pending_sends in shallow mode
)
pending_writes = await self._aload_pending_writes(
thread_id, checkpoint_ns, checkpoint_param["id"]
)
return CheckpointTuple(
config=config_param,
checkpoint=checkpoint_param,
metadata=cast(CheckpointMetadata, sanitized_metadata),
parent_config=None,
pending_writes=pending_writes,
)
[docs]
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes for the latest checkpoint and clean up old writes with transaction handling.
This method uses Redis pipeline with transaction=True to ensure atomicity of all
write operations. In case of interruption, all operations will be aborted.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
task_path (str): Path of the task creating the writes.
Raises:
asyncio.CancelledError: If the operation is cancelled/interrupted
"""
if not writes:
return
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
checkpoint_id = config["configurable"]["checkpoint_id"]
try:
# Create a pipeline without transaction to avoid lock contention
pipeline = self._redis.pipeline(transaction=False)
# Transform writes into appropriate format
writes_objects = []
for idx, (channel, value) in enumerate(writes):
type_, blob = self.serde.dumps_typed(value)
write_obj = {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
"task_id": task_id,
"task_path": task_path,
"idx": WRITES_IDX_MAP.get(channel, idx),
"channel": channel,
"type": type_,
"blob": self._encode_blob(
blob
), # Encode bytes to base64 string for Redis
}
writes_objects.append(write_obj)
# Thread-level sorted set for write keys
# Use to_storage_safe_str for consistent key naming
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
thread_zset_key = (
f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow"
)
# Collect all write keys
write_keys = []
for write_obj in writes_objects:
key = self._make_redis_checkpoint_writes_key_cached(
thread_id,
checkpoint_ns,
checkpoint_id,
task_id,
write_obj["idx"],
)
write_keys.append(key)
# No cleanup in put_writes - we do it in aput() when checkpoint changes
# Add new writes to the pipeline (critical)
for idx, write_obj in enumerate(writes_objects):
key = write_keys[idx]
pipeline.json().set(key, "$", write_obj)
# Registry operation (critical)
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
pipeline.zadd(thread_zset_key, zadd_mapping) # type: ignore[arg-type]
# Execute critical commands only — no EXPIRE in the pipeline to avoid
# failures on Redis Enterprise proxy with mixed JSON + native commands
results = await pipeline.execute(raise_on_error=False)
# Check results for critical command failures
for result in results:
if isinstance(result, Exception):
raise result
# Apply TTL separately (best-effort — failures don't lose writes)
if self.ttl_config and "default_ttl" in self.ttl_config:
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
for key in write_keys:
try:
await self._redis.expire(key, ttl_seconds)
except Exception:
logger.warning(
"Failed to apply TTL to checkpoint write key: %s", key
)
try:
await self._redis.expire(thread_zset_key, ttl_seconds)
except Exception:
logger.warning(
"Failed to apply TTL to write registry key: %s",
thread_zset_key,
exc_info=True,
)
except asyncio.CancelledError:
# Handle cancellation/interruption
# Pipeline will be automatically discarded
# Either all operations succeed or none do
raise
except Exception as e:
# Re-raise other exceptions
raise e
[docs]
async def aget_channel_values(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
channel_versions: Optional[Dict[str, Any]] = None,
) -> dict[str, Any]:
"""Retrieve channel_values dictionary from inline checkpoint data."""
# For shallow checkpoints, channel values are stored inline in the checkpoint
checkpoint_key = self._make_shallow_redis_checkpoint_key_cached(
thread_id, checkpoint_ns
)
# Single JSON.GET operation to retrieve checkpoint with inline channel_values
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") # type: ignore[misc]
if not checkpoint_data:
return {}
# checkpoint_data[0] is already a deserialized dict
checkpoint = (
checkpoint_data[0] if isinstance(checkpoint_data, list) else checkpoint_data
)
channel_values = checkpoint.get("channel_values", {})
# Deserialize channel values since they're stored in serialized form
return self._deserialize_channel_values(channel_values)
async def _aload_pending_sends(
self,
thread_id: str,
checkpoint_ns: str,
) -> list[tuple[str, bytes]]:
"""Load pending sends for a parent checkpoint.
Args:
thread_id: The thread ID
checkpoint_ns: The checkpoint namespace
parent_checkpoint_id: The ID of the parent checkpoint
Returns:
List of (type, blob) tuples representing pending sends
"""
# Query checkpoint_writes for parent checkpoint's TASKS channel
parent_writes_query = FilterQuery(
filter_expression=(Tag("thread_id") == thread_id)
& (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
& (Tag("channel") == TASKS),
return_fields=["type", "$.blob", "task_path", "task_id", "idx"],
num_results=100,
)
parent_writes_results = await self.checkpoint_writes_index.search(
parent_writes_query
)
# Sort results by task_path, task_id, idx
sorted_writes = sorted(
parent_writes_results.docs,
key=lambda x: (
getattr(x, "task_path", ""),
getattr(x, "task_id", ""),
getattr(x, "idx", 0),
),
)
# Extract type and blob pairs
# Handle both direct attribute access and JSON path access
# Filter out documents where blob is None (similar to AsyncRedisSaver in aio.py)
return [
(getattr(doc, "type", ""), blob)
for doc in sorted_writes
if (blob := getattr(doc, "$.blob", getattr(doc, "blob", None))) is not None
]
async def _aload_pending_writes(
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
) -> List[PendingWrite]:
"""Load pending writes using thread-level sorted set registry."""
if checkpoint_id is None:
return []
# Use thread-level sorted set
# Use to_storage_safe_str for consistent key naming
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
thread_zset_key = f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow"
try:
# Check if we have any writes in the thread sorted set
write_count = await self._redis.zcard(thread_zset_key)
if write_count == 0:
# No writes for this thread
return []
# Get all write keys from the thread sorted set
write_keys = await self._redis.zrange(thread_zset_key, 0, -1)
if write_keys:
# All keys in the set belong to current checkpoint
decoded_keys = [
key.decode() if isinstance(key, bytes) else key
for key in write_keys
]
# Fetch all writes using pipeline
pipeline = self._redis.pipeline(transaction=False)
for key in decoded_keys:
pipeline.json().get(key)
results = await pipeline.execute()
# Build the writes dictionary
writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {}
for write_data in results:
if write_data:
task_id = write_data.get("task_id", "")
idx = write_data.get("idx", 0)
writes_dict[(task_id, idx)] = write_data
# Use base class method to deserialize
return BaseRedisSaver._load_writes(self.serde, writes_dict)
except Exception:
pass
# Return empty list if registry not available
return []
[docs]
def create_indexes(self) -> None:
"""Create indexes without connecting to Redis."""
self.checkpoints_index = AsyncSearchIndex.from_dict(
self.checkpoints_schema, redis_client=self._redis
)
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(
self.writes_schema, redis_client=self._redis
)
[docs]
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Retrieve a checkpoint tuple from Redis synchronously."""
try:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncShallowRedisSaver are only allowed from a "
"different thread. From the main thread, use the async interface."
"For example, use `await checkpointer.aget_tuple(...)` or `await "
"graph.ainvoke(...)`."
)
except RuntimeError:
pass
return asyncio.run_coroutine_threadsafe(
self.aget_tuple(config), self.loop
).result()
[docs]
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Store only the latest checkpoint synchronously."""
return asyncio.run_coroutine_threadsafe(
self.aput(config, checkpoint, metadata, new_versions), self.loop
).result()
[docs]
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes synchronously."""
return asyncio.run_coroutine_threadsafe(
self.aput_writes(config, writes, task_id), self.loop
).result()
[docs]
def get_channel_values(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
channel_versions: Optional[Dict[str, Any]] = None,
) -> dict[str, Any]:
"""Retrieve channel_values dictionary with properly constructed message objects (sync wrapper)."""
try:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncShallowRedisSaver are only allowed from a "
"different thread. From the main thread, use the async interface."
"For example, use `await checkpointer.aget_channel_values(...)`."
)
except RuntimeError:
pass
return asyncio.run_coroutine_threadsafe(
self.aget_channel_values(
thread_id, checkpoint_ns, checkpoint_id, channel_versions
),
self.loop,
).result()
def _make_shallow_redis_checkpoint_key_cached(
self, thread_id: str, checkpoint_ns: str
) -> str:
"""Create a cached key for shallow checkpoints using only thread_id and checkpoint_ns."""
cache_key = f"shallow_checkpoint:{thread_id}:{checkpoint_ns}"
if cache_key not in self._key_cache:
self._key_cache[cache_key] = self._separator.join(
[self._checkpoint_prefix, thread_id, checkpoint_ns]
)
return self._key_cache[cache_key]
@staticmethod
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
"""Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns])
def _make_redis_checkpoint_writes_key_cached(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
task_id: str,
idx: Optional[int],
) -> str:
"""Create a cached key for checkpoint writes."""
cache_key = (
f"writes:{thread_id}:{checkpoint_ns}:{checkpoint_id}:{task_id}:{idx}"
)
if cache_key not in self._key_cache:
self._key_cache[cache_key] = self._make_redis_checkpoint_writes_key(
thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
return self._key_cache[cache_key]
@staticmethod
def _make_shallow_redis_checkpoint_writes_key_pattern(
thread_id: str, checkpoint_ns: str
) -> str:
"""Create a pattern to match all writes keys for a thread and namespace."""
return (
REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_WRITE_PREFIX,
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
]
)
+ ":*"
)
def _extract_fallback_timestamp(self, checkpoint: Checkpoint) -> float:
"""Extract timestamp from checkpoint's ts field or use current time.
This is used when the checkpoint_id is not a valid ULID (e.g., UUIDv6 format).
See Issue #136 for details.
Args:
checkpoint: The checkpoint object containing an optional ts field.
Returns:
Timestamp in milliseconds since epoch.
"""
ts_value = checkpoint.get("ts")
if ts_value:
# Handle both ISO string and numeric timestamps
if isinstance(ts_value, str):
try:
dt = datetime.fromisoformat(ts_value.replace("Z", "+00:00"))
return dt.timestamp() * MILLISECONDS_PER_SECOND
except Exception:
return time.time() * MILLISECONDS_PER_SECOND
else:
return ts_value
return time.time() * MILLISECONDS_PER_SECOND
[docs]
async def adelete_thread(self, thread_id: str) -> None:
"""Delete checkpoint and writes associated with a specific thread ID.
Args:
thread_id: The thread ID which checkpoint should be deleted.
"""
storage_safe_thread_id = to_storage_safe_id(thread_id)
# Despite the fact that shallow saver stores only the current version
# of checkpoint, there may be several while using subgraphs.
checkpoint_query = FilterQuery(
filter_expression=Tag("thread_id") == thread_id,
return_fields=["checkpoint_ns", "checkpoint_id"],
num_results=10000,
)
checkpoint_results = await self.checkpoints_index.search(checkpoint_query)
# Collect all keys to delete
keys_to_delete = []
checkpoint_namespaces = set()
for doc in checkpoint_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
# Collect namespaces to clean write_keys_zset later
checkpoint_namespaces.add(checkpoint_ns)
# Delete checkpoint key
checkpoint_key = self._make_shallow_redis_checkpoint_key(
thread_id, checkpoint_ns
)
keys_to_delete.append(checkpoint_key)
checkpoint_writes_query = FilterQuery(
filter_expression=Tag("thread_id") == thread_id,
return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
num_results=10000,
)
checkpoint_writes_results = await self.checkpoint_writes_index.search(
checkpoint_writes_query
)
for doc in checkpoint_writes_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
checkpoint_id = getattr(doc, "checkpoint_id", "")
task_id = getattr(doc, "task_id", "")
idx = getattr(doc, "idx", 0)
write_key = self._make_redis_checkpoint_writes_key(
thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
keys_to_delete.append(write_key)
checkpoint_namespaces.add(checkpoint_ns)
for checkpoint_ns in checkpoint_namespaces:
keys_to_delete.append(
f"write_keys_zset:{storage_safe_thread_id}:{to_storage_safe_str(checkpoint_ns)}:shallow"
)
# use pipeline for efficiency
pipeline = self._redis.pipeline()
for key in keys_to_delete:
pipeline.delete(key)
await pipeline.execute()
[docs]
async def aprune(
self,
thread_ids: Sequence[str],
*,
strategy: str = "keep_latest",
keep_last: Optional[int] = None,
) -> None:
"""Prune checkpoints for the given threads.
``AsyncShallowRedisSaver`` stores at most one checkpoint per namespace
by design, so ``strategy="keep_latest"`` (or ``keep_last >= 1``) is
always a no-op. ``strategy="delete"`` (or ``keep_last=0``) removes
all checkpoints for each thread (equivalent to ``adelete_thread``).
Args:
thread_ids: Thread IDs to prune.
strategy: Pruning strategy. ``"keep_latest"`` is a no-op for
shallow savers (default). ``"delete"`` removes all.
keep_last: Optional override. Any value >= 1 is a no-op.
Pass ``0`` to delete all.
"""
# Resolve keep_last from strategy if not explicitly provided
if keep_last is None:
if strategy == "delete":
keep_last = 0
else:
keep_last = 1
# Validate input
if not thread_ids:
raise ValueError("``thread_ids`` must be a non-empty sequence")
if keep_last < 0:
raise ValueError(f"``keep_last`` must be >= 0, got {keep_last}")
if keep_last >= 1:
return
for thread_id in thread_ids:
await self.adelete_thread(thread_id)