Source code for langgraph.store.redis.aio

from __future__ import annotations

import asyncio
import json
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from types import TracebackType
from typing import Any, AsyncIterator, Iterable, Optional, Sequence, cast

from langgraph.store.base import (
    GetOp,
    IndexConfig,
    ListNamespacesOp,
    Op,
    PutOp,
    Result,
    SearchOp,
    TTLConfig,
    ensure_embeddings,
    get_text_at_path,
    tokenize_path,
)
from langgraph.store.base.batch import AsyncBatchedBaseStore
from redis.asyncio import Redis as AsyncRedis
from redis.commands.search.query import Query
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery, VectorQuery
from redisvl.redis.connection import RedisConnectionFactory
from redisvl.utils.token_escaper import TokenEscaper
from ulid import ULID

from langgraph.store.redis.base import (
    REDIS_KEY_SEPARATOR,
    STORE_PREFIX,
    STORE_VECTOR_PREFIX,
    BaseRedisStore,
    RedisDocument,
    _decode_ns,
    _ensure_string_or_literal,
    _group_ops,
    _namespace_to_text,
    _row_to_item,
    _row_to_search_item,
    logger,
)

from .token_unescaper import TokenUnescaper

_token_escaper = TokenEscaper()
_token_unescaper = TokenUnescaper()
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster


[docs] class AsyncRedisStore( BaseRedisStore[AsyncRedis, AsyncSearchIndex], AsyncBatchedBaseStore ): """Async Redis store with optional vector search. Supports standard Redis URLs (redis://), SSL (rediss://), and Sentinel URLs (redis+sentinel://host:26379/service_name/db). """ store_index: AsyncSearchIndex vector_index: AsyncSearchIndex _owns_its_client: bool supports_ttl: bool = True # Use a different name to avoid conflicting with the base class attribute _async_ttl_stop_event: asyncio.Event | None = None _ttl_sweeper_task: asyncio.Task | None = None ttl_config: Optional[TTLConfig] = None # Whether to assume the Redis server is a cluster; None triggers auto-detection cluster_mode: Optional[bool] = None def __init__( self, redis_url: Optional[str] = None, *, redis_client: Optional[AsyncRedis] = None, index: Optional[IndexConfig] = None, connection_args: Optional[dict[str, Any]] = None, ttl: Optional[TTLConfig] = None, cluster_mode: Optional[bool] = None, store_prefix: str = STORE_PREFIX, vector_prefix: str = STORE_VECTOR_PREFIX, ) -> None: """Initialize store with Redis connection and optional index config.""" if redis_url is None and redis_client is None: raise ValueError("Either redis_url or redis_client must be provided") # Initialize base classes AsyncBatchedBaseStore.__init__(self) # Configure client first self.configure_client( redis_url=redis_url, redis_client=redis_client, connection_args=connection_args or {}, ) # Validate and store cluster_mode; None means auto-detect later if cluster_mode is not None and not isinstance(cluster_mode, bool): raise TypeError("cluster_mode must be a boolean or None") # Initialize BaseRedisStore with prefix parameters BaseRedisStore.__init__( self, conn=self._redis, index=index, ttl=ttl, cluster_mode=cluster_mode, store_prefix=store_prefix, vector_prefix=vector_prefix, ) # Update store_index to async version from copy import deepcopy store_schema = { "index": { "name": self.store_prefix, "prefix": self.store_prefix + REDIS_KEY_SEPARATOR, "storage_type": "json", }, "fields": [ {"name": "prefix", "type": "text"}, {"name": "key", "type": "tag"}, {"name": "created_at", "type": "numeric"}, {"name": "updated_at", "type": "numeric"}, {"name": "ttl_minutes", "type": "numeric"}, {"name": "expires_at", "type": "numeric"}, ], } self.store_index = AsyncSearchIndex.from_dict( store_schema, redis_client=self._redis ) # Configure vector index if needed if self.index_config: # Create custom vector schema with instance prefix vector_schema = { "index": { "name": self.vector_prefix, "prefix": self.vector_prefix + REDIS_KEY_SEPARATOR, "storage_type": "json", }, "fields": [ {"name": "prefix", "type": "text"}, {"name": "key", "type": "tag"}, {"name": "field_name", "type": "tag"}, {"name": "embedding", "type": "vector"}, {"name": "created_at", "type": "numeric"}, {"name": "updated_at", "type": "numeric"}, {"name": "ttl_minutes", "type": "numeric"}, {"name": "expires_at", "type": "numeric"}, ], } vector_fields = vector_schema.get("fields", []) vector_field = None for f in vector_fields: if isinstance(f, dict) and f.get("name") == "embedding": vector_field = f break if vector_field: # Configure vector field with index config values vector_field["attrs"] = { "algorithm": "flat", # Default to flat "datatype": "float32", "dims": self.index_config["dims"], "distance_metric": { "cosine": "COSINE", "inner_product": "IP", "l2": "L2", }[ _ensure_string_or_literal( self.index_config.get("distance_type", "cosine") ) ], } # Apply any additional vector type config if "ann_index_config" in self.index_config: vector_field["attrs"].update(self.index_config["ann_index_config"]) try: self.vector_index = AsyncSearchIndex.from_dict( vector_schema, redis_client=self._redis ) except Exception as e: raise ValueError( f"Failed to create vector index with schema: {vector_schema}. Error: {str(e)}" ) from e
[docs] def configure_client( self, redis_url: Optional[str] = None, redis_client: Optional[AsyncRedis] = None, connection_args: Optional[dict[str, Any]] = None, ) -> None: """Configure the Redis client. Supports standard Redis URLs (redis://), SSL (rediss://), and Sentinel URLs (redis+sentinel://host:26379/service_name/db). """ self._owns_its_client = redis_client is None if redis_client is None: self._redis = RedisConnectionFactory.get_async_redis_connection( redis_url, **(connection_args or {}) ) else: self._redis = redis_client
[docs] async def setup(self) -> None: """Initialize store indices.""" # Handle embeddings in same way as sync store if self.index_config: self.embeddings = ensure_embeddings( self.index_config.get("embed"), ) # Auto-detect cluster mode if not explicitly set if self.cluster_mode is None: await self._detect_cluster_mode() else: logger.info( f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection." ) # Create indices in Redis await self.store_index.create(overwrite=False) if self.index_config: await self.vector_index.create(overwrite=False)
async def _detect_cluster_mode(self) -> None: """Detect if the Redis client is a cluster client by inspecting its class.""" # Determine cluster mode based on client class if isinstance(self._redis, AsyncRedisCluster): self.cluster_mode = True logger.info("Redis cluster client detected for AsyncRedisStore.") else: self.cluster_mode = False logger.info("Redis standalone client detected for AsyncRedisStore.") # This can't be properly typed due to covariance issues with async methods async def _apply_ttl_to_keys( self, main_key: str, related_keys: Optional[list[str]] = None, ttl_minutes: Optional[float] = None, ) -> Any: """Apply Redis native TTL to keys asynchronously. Args: main_key: The primary Redis key related_keys: Additional Redis keys that should expire at the same time ttl_minutes: Time-to-live in minutes """ if ttl_minutes is None: # Check if there's a default TTL in config if self.ttl_config and "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) if self.cluster_mode: await self._redis.expire(main_key, ttl_seconds) if related_keys: for key in related_keys: await self._redis.expire(key, ttl_seconds) else: pipeline = self._redis.pipeline(transaction=True) # Set TTL for main key pipeline.expire(main_key, ttl_seconds) # Set TTL for related keys if related_keys: # Check if related_keys is not None for key in related_keys: pipeline.expire(key, ttl_seconds) await pipeline.execute() # This can't be properly typed due to covariance issues with async methods
[docs] async def sweep_ttl(self) -> int: # type: ignore[override] """Clean up any remaining expired items. This is not needed with Redis native TTL, but kept for API compatibility. Redis automatically removes expired keys. Returns: int: Always returns 0 as Redis handles expiration automatically """ return 0
# This can't be properly typed due to covariance issues with async methods
[docs] async def start_ttl_sweeper( # type: ignore[override] self, _sweep_interval_minutes: Optional[int] = None ) -> None: """Start TTL sweeper. This is a no-op with Redis native TTL, but kept for API compatibility. Redis automatically removes expired keys. Args: _sweep_interval_minutes: Ignored parameter, kept for API compatibility """ # No-op: Redis handles TTL expiration automatically pass
# This can't be properly typed due to covariance issues with async methods
[docs] async def stop_ttl_sweeper(self, _timeout: Optional[float] = None) -> bool: # type: ignore[override] """Stop TTL sweeper. This is a no-op with Redis native TTL, but kept for API compatibility. Args: _timeout: Ignored parameter, kept for API compatibility Returns: bool: Always True as there's no sweeper to stop """ # No-op: Redis handles TTL expiration automatically return True
[docs] @classmethod @asynccontextmanager async def from_conn_string( cls, conn_string: str, *, index: Optional[IndexConfig] = None, ttl: Optional[TTLConfig] = None, store_prefix: str = STORE_PREFIX, vector_prefix: str = STORE_VECTOR_PREFIX, ) -> AsyncIterator[AsyncRedisStore]: """Create store from Redis connection string.""" async with cls( redis_url=conn_string, index=index, ttl=ttl, store_prefix=store_prefix, vector_prefix=vector_prefix, ) as store: await store.setup() # Set client information after setup await store.aset_client_info() yield store
[docs] def create_indexes(self) -> None: """Create async indices.""" self.store_index = AsyncSearchIndex.from_dict( self.SCHEMAS[0], redis_client=self._redis ) if self.index_config: self.vector_index = AsyncSearchIndex.from_dict( self.SCHEMAS[1], redis_client=self._redis )
async def __aenter__(self) -> AsyncRedisStore: """Async context manager enter.""" # Client info was already set in __init__, # but we'll set it again here to be consistent with checkpoint code await self.aset_client_info() return self async def __aexit__( self, _exc_type: Optional[type[BaseException]] = None, _exc_value: Optional[BaseException] = None, _traceback: Optional[TracebackType] = None, ) -> None: """Async context manager exit.""" # Cancel the background task created by AsyncBatchedBaseStore if hasattr(self, "_task") and self._task is not None and not self._task.done(): self._task.cancel() try: await self._task except asyncio.CancelledError: pass # Close Redis connections if we own them if self._owns_its_client: await self._redis.aclose() # type: ignore[attr-defined] await self._redis.connection_pool.disconnect()
[docs] async def abatch(self, ops: Iterable[Op]) -> list[Result]: """Execute batch of operations asynchronously.""" grouped_ops, num_ops = _group_ops(ops) results: list[Result] = [None] * num_ops tasks = [] if GetOp in grouped_ops: tasks.append( self._batch_get_ops( list(cast(list[tuple[int, GetOp]], grouped_ops[GetOp])), results ) ) if PutOp in grouped_ops: tasks.append( self._batch_put_ops( list(cast(list[tuple[int, PutOp]], grouped_ops[PutOp])) ) ) if SearchOp in grouped_ops: tasks.append( self._batch_search_ops( list(cast(list[tuple[int, SearchOp]], grouped_ops[SearchOp])), results, ) ) if ListNamespacesOp in grouped_ops: tasks.append( self._batch_list_namespaces_ops( list( cast( list[tuple[int, ListNamespacesOp]], grouped_ops[ListNamespacesOp], ) ), results, ) ) await asyncio.gather(*tasks) return results
[docs] def batch(self: AsyncRedisStore, ops: Iterable[Op]) -> list[Result]: """Execute batch of operations synchronously. Args: ops: Operations to execute in batch Returns: Results from batch execution Raises: asyncio.InvalidStateError: If called from the main event loop """ try: if asyncio.get_running_loop(): raise asyncio.InvalidStateError( "Synchronous calls to AsyncRedisStore are only allowed from a " "different thread. From the main thread, use the async interface." "For example, use `await store.abatch(...)` or `await " "store.aget(...)`" ) except RuntimeError: pass return asyncio.run_coroutine_threadsafe( self.abatch(ops), asyncio.get_event_loop() ).result()
async def _batch_get_ops( self, get_ops: Sequence[tuple[int, GetOp]], results: list[Result], ) -> None: """Execute GET operations in batch asynchronously.""" refresh_keys_by_idx: dict[int, list[str]] = ( {} ) # Track keys that need TTL refreshed by op index for query, _, namespace, items in self._get_batch_GET_ops_queries(get_ops): res = await self.store_index.search(Query(query)) # Parse JSON from each document key_to_row = { json.loads(doc.json)["key"]: (json.loads(doc.json), doc.id) for doc in res.docs } for idx, key in items: if key in key_to_row: data, doc_id = key_to_row[key] results[idx] = _row_to_item( namespace, data, deserialize_fn=self._deserialize_value ) # Find the corresponding operation by looking it up in the operation list # This is needed because idx is the index in the overall operation list op_idx = None for i, (local_idx, op) in enumerate(get_ops): if local_idx == idx: op_idx = i break if op_idx is not None: op = get_ops[op_idx][1] if hasattr(op, "refresh_ttl") and op.refresh_ttl: if idx not in refresh_keys_by_idx: refresh_keys_by_idx[idx] = [] refresh_keys_by_idx[idx].append(doc_id) # Also add vector keys for the same document doc_uuid = doc_id.split(":")[-1] vector_key = ( f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}" ) refresh_keys_by_idx[idx].append(vector_key) # Now refresh TTLs for any keys that need it if refresh_keys_by_idx and self.ttl_config: # Get default TTL from config ttl_minutes = None if "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) if self.cluster_mode: for keys_to_refresh in refresh_keys_by_idx.values(): for key in keys_to_refresh: ttl = await self._redis.ttl(key) if ttl > 0: await self._redis.expire(key, ttl_seconds) else: # In cluster mode, we must use transaction=False # Comment no longer relevant pipeline = self._redis.pipeline( transaction=True ) # Assuming non-cluster or single node for now for keys in refresh_keys_by_idx.values(): for key in keys: # Only refresh TTL if the key exists and has a TTL ttl = await self._redis.ttl(key) if ttl > 0: # Only refresh if key exists and has TTL pipeline.expire(key, ttl_seconds) await pipeline.execute() async def _aprepare_batch_PUT_queries( self, put_ops: Sequence[tuple[int, PutOp]], ) -> tuple[ list[RedisDocument], Optional[tuple[str, list[tuple[str, str, str, str]]]] ]: """Prepare queries - no Redis operations in async version.""" # Last-write wins dedupped_ops: dict[tuple[tuple[str, ...], str], PutOp] = {} for _, op in put_ops: dedupped_ops[(op.namespace, op.key)] = op inserts: list[PutOp] = [] deletes: list[PutOp] = [] for op in dedupped_ops.values(): if op.value is None: deletes.append(op) else: inserts.append(op) operations: list[RedisDocument] = [] embedding_request = None to_embed: list[tuple[str, str, str, str]] = [] if deletes: # Delete matching documents for op in deletes: prefix = _namespace_to_text(op.namespace) query = f"(@prefix:{prefix} @key:{{{op.key}}})" results = await self.store_index.search(query) for doc in results.docs: await self._redis.delete(doc.id) # Handle inserts if inserts: for op in inserts: now = int(datetime.now(timezone.utc).timestamp() * 1_000_000) # Handle TTL ttl_minutes = None expires_at = None if op.ttl is not None: ttl_minutes = op.ttl expires_at = int( ( datetime.now(timezone.utc) + timedelta(minutes=op.ttl) ).timestamp() ) doc = RedisDocument( prefix=_namespace_to_text(op.namespace), key=op.key, value=self._serialize_value(op.value), created_at=now, updated_at=now, ttl_minutes=ttl_minutes, expires_at=expires_at, ) operations.append(doc) if self.index_config and op.index is not False: paths = ( self.index_config["__tokenized_fields"] if op.index is None else [(ix, tokenize_path(ix)) for ix in op.index] ) for path, tokenized_path in paths: texts = get_text_at_path(op.value, tokenized_path) for text in texts: to_embed.append( (_namespace_to_text(op.namespace), op.key, path, text) ) if to_embed: embedding_request = ("", to_embed) return operations, embedding_request async def _batch_put_ops( self, put_ops: Sequence[tuple[int, PutOp]], ) -> None: """Execute PUT operations in batch asynchronously.""" operations, embedding_request = await self._aprepare_batch_PUT_queries(put_ops) # First delete any existing documents that are being updated/deleted for _, op in put_ops: namespace = _namespace_to_text(op.namespace) query = f"@prefix:{namespace} @key:{{{_token_escaper.escape(op.key)}}}" results = await self.store_index.search(query) if self.cluster_mode: for doc in results.docs: await self._redis.delete(doc.id) if self.index_config: vector_results = await self.vector_index.search(query) for doc_vec in vector_results.docs: await self._redis.delete(doc_vec.id) else: pipeline = self._redis.pipeline(transaction=True) for doc in results.docs: pipeline.delete(doc.id) if self.index_config: vector_results = await self.vector_index.search(query) for doc_vec in vector_results.docs: pipeline.delete(doc_vec.id) if ( pipeline.command_stack ): # Check if pipeline has commands before executing await pipeline.execute() # Now handle new document creation doc_ids: dict[tuple[str, str], str] = {} store_docs: list[RedisDocument] = [] store_keys: list[str] = [] ttl_tracking: dict[str, tuple[list[str], Optional[float]]] = ( {} ) # Tracks keys that need TTL + their TTL values # Generate IDs for PUT operations for _, op in put_ops: if op.value is not None: generated_doc_id = str(ULID()) namespace = _namespace_to_text(op.namespace) doc_ids[(namespace, op.key)] = generated_doc_id # Track TTL for this document if specified if hasattr(op, "ttl") and op.ttl is not None: main_key = ( f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{generated_doc_id}" ) ttl_tracking[main_key] = ([], op.ttl) # Load store docs with explicit keys for doc in operations: store_key = (doc["prefix"], doc["key"]) doc_id = doc_ids[store_key] # Remove TTL fields - they're not needed with Redis native TTL if "ttl_minutes" in doc: doc.pop("ttl_minutes", None) if "expires_at" in doc: doc.pop("expires_at", None) store_docs.append(doc) redis_key = f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" store_keys.append(redis_key) if store_docs: if self.cluster_mode: # For cluster mode, load documents individually if SearchIndex.load isn't cluster-safe for batching. # This is a conservative approach. If redisvl's load is cluster-safe, this can be optimized. for i, store_doc_item in enumerate(store_docs): await self.store_index.load([store_doc_item], keys=[store_keys[i]]) else: await self.store_index.load(store_docs, keys=store_keys) # Handle vector embeddings with same IDs if embedding_request and self.embeddings: _, text_params = embedding_request vectors = await self.embeddings.aembed_documents( [text for _, _, _, text in text_params] ) vector_docs: list[dict[str, Any]] = [] vector_keys: list[str] = [] for (ns, key, path, _), vector in zip(text_params, vectors): vector_key: tuple[str, str] = (ns, key) doc_id = doc_ids[vector_key] vector_docs.append( { "prefix": ns, "key": key, "field_name": path, "embedding": ( vector.tolist() if hasattr(vector, "tolist") else vector ), "created_at": datetime.now(timezone.utc).timestamp(), "updated_at": datetime.now(timezone.utc).timestamp(), } ) redis_vector_key = f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" vector_keys.append(redis_vector_key) # Add this vector key to the related keys list for TTL main_key = f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" if main_key in ttl_tracking: ttl_tracking[main_key][0].append(redis_vector_key) if vector_docs: if self.cluster_mode: # Similar to store_docs, load vector docs individually in cluster mode as a precaution. for i, vector_doc_item in enumerate(vector_docs): await self.vector_index.load( [vector_doc_item], keys=[vector_keys[i]] ) else: await self.vector_index.load(vector_docs, keys=vector_keys) # Now apply TTLs after all documents are loaded for main_key, (related_keys, ttl_minutes) in ttl_tracking.items(): await self._apply_ttl_to_keys(main_key, related_keys, ttl_minutes) async def _batch_search_ops( self, search_ops: Sequence[tuple[int, SearchOp]], results: list[Result], ) -> None: """Execute search operations in batch asynchronously.""" queries, embedding_requests = self._get_batch_search_queries(search_ops) # Handle vector search query_vectors = {} if embedding_requests and self.embeddings: vectors = await self.embeddings.aembed_documents( [query for _, query in embedding_requests] ) query_vectors = dict(zip([idx for idx, _ in embedding_requests], vectors)) # Process each search operation for (idx, op), (query_str, params, limit, offset) in zip(search_ops, queries): if op.query and idx in query_vectors: # Vector similarity search vector = query_vectors[idx] vector_query = VectorQuery( vector=vector.tolist() if hasattr(vector, "tolist") else vector, vector_field_name="embedding", filter_expression=f"@prefix:{_namespace_to_text(op.namespace_prefix)}*", return_fields=["prefix", "key", "vector_distance"], num_results=limit, # Use the user-specified limit ) vector_query.paging(offset, limit) vector_results_docs = await self.vector_index.query(vector_query) # Get matching store docs result_map = {} if self.cluster_mode: store_docs = [] for doc in vector_results_docs: doc_id = ( doc.get("id") if isinstance(doc, dict) else getattr(doc, "id", None) ) if doc_id: doc_uuid = doc_id.split(":")[1] store_key = ( f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}" ) result_map[store_key] = doc # Fetch individually in cluster mode store_doc_item = await self._redis.json().get(store_key) # type: ignore[misc] store_docs.append(store_doc_item) store_docs_raw = store_docs else: pipeline = self._redis.pipeline(transaction=False) for ( doc ) in ( vector_results_docs ): # doc_vr is now an individual doc from the list doc_id = ( doc.get("id") if isinstance(doc, dict) else getattr(doc, "id", None) ) if doc_id: doc_uuid = doc_id.split(":")[1] store_key = ( f"{self.store_prefix}{REDIS_KEY_SEPARATOR}{doc_uuid}" ) result_map[store_key] = doc pipeline.json().get(store_key) store_docs_raw = await pipeline.execute() # Process results maintaining order and applying filters items = [] refresh_keys = [] # Track keys that need TTL refreshed store_docs_iter = iter(store_docs_raw) for store_key in result_map.keys(): store_doc = next(store_docs_iter, None) if store_doc: vector_result = result_map[store_key] # Get vector_distance from original search result dist = ( vector_result.get("vector_distance") if isinstance(vector_result, dict) else getattr(vector_result, "vector_distance", 0) ) # Convert to similarity score score = (1.0 - float(dist)) if dist is not None else 0.0 # Ensure store_doc is a dictionary before trying to assign to it if not isinstance(store_doc, dict): try: store_doc = json.loads( store_doc ) # Attempt to parse if it's a JSON string except (json.JSONDecodeError, TypeError): logger.error(f"Failed to parse store_doc: {store_doc}") continue # Skip this problematic document if isinstance( store_doc, dict ): # Check again after potential parsing store_doc["vector_distance"] = dist else: # if still not a dict, this means it's a problematic entry logger.error( f"store_doc is not a dict after parsing attempt: {store_doc}" ) continue # Apply value filters if needed if op.filter: matches = True value = store_doc.get("value", {}) for key, expected in op.filter.items(): actual = value.get(key) if isinstance(expected, list): if actual not in expected: matches = False break elif actual != expected: matches = False break if not matches: continue # If refresh_ttl is true, add to list for refreshing if op.refresh_ttl: refresh_keys.append(store_key) # Also find associated vector keys with same ID doc_id = store_key.split(":")[-1] vector_key = ( f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" ) refresh_keys.append(vector_key) items.append( _row_to_search_item( _decode_ns(store_doc["prefix"]), store_doc, score=score, deserialize_fn=self._deserialize_value, ) ) # Refresh TTL if requested if op.refresh_ttl and refresh_keys and self.ttl_config: # Get default TTL from config ttl_minutes = None if "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) if self.cluster_mode: for key in refresh_keys: ttl = await self._redis.ttl(key) if ttl > 0: await self._redis.expire(key, ttl_seconds) else: pipeline = self._redis.pipeline(transaction=True) for key in refresh_keys: # Only refresh TTL if the key exists and has a TTL ttl = await self._redis.ttl(key) if ttl > 0: # Only refresh if key exists and has TTL pipeline.expire(key, ttl_seconds) if pipeline.command_stack: await pipeline.execute() results[idx] = items else: # Regular search # Create a query with LIMIT and OFFSET parameters query = Query(query_str).paging(offset, limit) # Execute search with limit and offset applied by Redis res = await self.store_index.search(query) items = [] refresh_keys = [] # Track keys that need TTL refreshed for doc in res.docs: data = json.loads(doc.json) # Apply value filters if op.filter: matches = True value = data.get("value", {}) for key, expected in op.filter.items(): actual = value.get(key) if isinstance(expected, list): if actual not in expected: matches = False break elif actual != expected: matches = False break if not matches: continue # If refresh_ttl is true, add the key to refresh list if op.refresh_ttl: refresh_keys.append(doc.id) # Also find associated vector keys with same ID doc_id = doc.id.split(":")[-1] vector_key = ( f"{self.vector_prefix}{REDIS_KEY_SEPARATOR}{doc_id}" ) refresh_keys.append(vector_key) items.append( _row_to_search_item( _decode_ns(data["prefix"]), data, deserialize_fn=self._deserialize_value, ) ) # Refresh TTL if requested if op.refresh_ttl and refresh_keys and self.ttl_config: # Get default TTL from config ttl_minutes = None if "default_ttl" in self.ttl_config: ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) if self.cluster_mode: for key in refresh_keys: ttl = await self._redis.ttl(key) if ttl > 0: await self._redis.expire(key, ttl_seconds) else: pipeline = self._redis.pipeline(transaction=True) for key in refresh_keys: # Only refresh TTL if the key exists and has a TTL ttl = await self._redis.ttl(key) if ttl > 0: # Only refresh if key exists and has TTL pipeline.expire(key, ttl_seconds) if pipeline.command_stack: await pipeline.execute() results[idx] = items async def _batch_list_namespaces_ops( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], results: list[Result], ) -> None: """Execute list namespaces operations in batch.""" for idx, op in list_ops: # Construct base query for namespace search base_query = "*" # Start with all documents if op.match_conditions: conditions = [] for condition in op.match_conditions: if condition.match_type == "prefix": prefix = _namespace_to_text( condition.path, handle_wildcards=True ) conditions.append(f"@prefix:{prefix}*") elif condition.match_type == "suffix": suffix = _namespace_to_text( condition.path, handle_wildcards=True ) conditions.append(f"@prefix:*{suffix}") if conditions: base_query = " ".join(conditions) # Execute search with return_fields=["prefix"] to get just namespaces query = FilterQuery(filter_expression=base_query, return_fields=["prefix"]) res = await self.store_index.search(query) # Extract unique namespaces namespaces = set() for doc in res.docs: if hasattr(doc, "prefix"): ns = tuple(_token_unescaper.unescape(doc.prefix).split(".")) # Apply max_depth if specified if op.max_depth is not None: ns = ns[: op.max_depth] namespaces.add(ns) # Sort and apply pagination sorted_namespaces = sorted(namespaces) if op.limit or op.offset: offset = op.offset or 0 limit = op.limit or 10 sorted_namespaces = sorted_namespaces[offset : offset + limit] results[idx] = sorted_namespaces
# We don't need _run_background_tasks anymore as AsyncBatchedBaseStore provides this