Skip to content

Commit 78222c9

Browse files
release: Bump to 0.0.8 and loosen redisvl dep (#67)
* redisvl widen deps and simplify env * improve mypy warnings and update package ranges * fix pipeline async awaiting * ignore the rest of the mypy issues * update imports * run tests against latest 8.0.2 * remove dead code * use individual redis calls instead of pipeline for cluster ops * comment fixes * Fix faulty test and address langgraph checkpoint 2.1.0 issues
1 parent c425317 commit 78222c9

File tree

11 files changed

+814
-839
lines changed

11 files changed

+814
-839
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
fail-fast: false
2626
matrix:
2727
python-version: [3.9, '3.10', 3.11, 3.12, 3.13]
28-
redis-version: ['6.2.6-v9', 'latest', '8.0-M03']
28+
redis-version: ['6.2.6-v9', 'latest', '8.0.2']
2929

3030
steps:
3131
- name: Check out repository
@@ -49,7 +49,7 @@ jobs:
4949
5050
- name: Set Redis image name
5151
run: |
52-
if [[ "${{ matrix.redis-version }}" == "8.0-M03" ]]; then
52+
if [[ "${{ matrix.redis-version }}" == "8.0.2" ]]; then
5353
echo "REDIS_IMAGE=redis:${{ matrix.redis-version }}" >> $GITHUB_ENV
5454
else
5555
echo "REDIS_IMAGE=redis/redis-stack-server:${{ matrix.redis-version }}" >> $GITHUB_ENV

langgraph/checkpoint/redis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def put(
280280
# store at top-level for filters in list()
281281
if all(key in metadata for key in ["source", "step"]):
282282
checkpoint_data["source"] = metadata["source"]
283-
checkpoint_data["step"] = metadata["step"] # type: ignore
283+
checkpoint_data["step"] = metadata["step"]
284284

285285
# Create the checkpoint key
286286
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(

langgraph/checkpoint/redis/aio.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import os
99
from contextlib import asynccontextmanager
10-
from functools import partial
1110
from types import TracebackType
1211
from typing import (
1312
Any,
@@ -34,12 +33,10 @@
3433
)
3534
from langgraph.constants import TASKS
3635
from redis.asyncio import Redis as AsyncRedis
37-
from redis.asyncio.client import Pipeline
3836
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
3937
from redisvl.index import AsyncSearchIndex
4038
from redisvl.query import FilterQuery
4139
from redisvl.query.filter import Num, Tag
42-
from redisvl.redis.connection import RedisConnectionFactory
4340

4441
from langgraph.checkpoint.redis.base import BaseRedisSaver
4542
from langgraph.checkpoint.redis.util import (
@@ -54,25 +51,6 @@
5451
logger = logging.getLogger(__name__)
5552

5653

57-
async def _write_obj_tx(
58-
pipe: Pipeline,
59-
key: str,
60-
write_obj: Dict[str, Any],
61-
upsert_case: bool,
62-
) -> None:
63-
exists: int = await pipe.exists(key)
64-
if upsert_case:
65-
if exists:
66-
await pipe.json().set(key, "$.channel", write_obj["channel"])
67-
await pipe.json().set(key, "$.type", write_obj["type"])
68-
await pipe.json().set(key, "$.blob", write_obj["blob"])
69-
else:
70-
await pipe.json().set(key, "$", write_obj)
71-
else:
72-
if not exists:
73-
await pipe.json().set(key, "$", write_obj)
74-
75-
7654
class AsyncRedisSaver(
7755
BaseRedisSaver[Union[AsyncRedis, AsyncRedisCluster], AsyncSearchIndex]
7856
):
@@ -568,7 +546,7 @@ async def aput(
568546
# store at top-level for filters in list()
569547
if all(key in metadata for key in ["source", "step"]):
570548
checkpoint_data["source"] = metadata["source"]
571-
checkpoint_data["step"] = metadata["step"] # type: ignore
549+
checkpoint_data["step"] = metadata["step"]
572550

573551
# Prepare checkpoint key
574552
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
@@ -587,11 +565,11 @@ async def aput(
587565

588566
if self.cluster_mode:
589567
# For cluster mode, execute operations individually
590-
await self._redis.json().set(checkpoint_key, "$", checkpoint_data)
568+
await self._redis.json().set(checkpoint_key, "$", checkpoint_data) # type: ignore[misc]
591569

592570
if blobs:
593571
for key, data in blobs:
594-
await self._redis.json().set(key, "$", data)
572+
await self._redis.json().set(key, "$", data) # type: ignore[misc]
595573

596574
# Apply TTL if configured
597575
if self.ttl_config and "default_ttl" in self.ttl_config:
@@ -604,12 +582,12 @@ async def aput(
604582
pipeline = self._redis.pipeline(transaction=True)
605583

606584
# Add checkpoint data to pipeline
607-
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
585+
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
608586

609587
if blobs:
610588
# Add all blob operations to the pipeline
611589
for key, data in blobs:
612-
await pipeline.json().set(key, "$", data)
590+
pipeline.json().set(key, "$", data)
613591

614592
# Execute all operations atomically
615593
await pipeline.execute()
@@ -654,13 +632,13 @@ async def aput(
654632

655633
if self.cluster_mode:
656634
# For cluster mode, execute operation directly
657-
await self._redis.json().set(
635+
await self._redis.json().set( # type: ignore[misc]
658636
checkpoint_key, "$", checkpoint_data
659637
)
660638
else:
661639
# For non-cluster mode, use pipeline
662640
pipeline = self._redis.pipeline(transaction=True)
663-
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
641+
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
664642
await pipeline.execute()
665643
except Exception:
666644
# If this also fails, we just propagate the original cancellation
@@ -739,24 +717,18 @@ async def aput_writes(
739717
exists = await self._redis.exists(key)
740718
if exists:
741719
# Update existing key
742-
await self._redis.json().set(
743-
key, "$.channel", write_obj["channel"]
744-
)
745-
await self._redis.json().set(
746-
key, "$.type", write_obj["type"]
747-
)
748-
await self._redis.json().set(
749-
key, "$.blob", write_obj["blob"]
750-
)
720+
await self._redis.json().set(key, "$.channel", write_obj["channel"]) # type: ignore[misc, arg-type]
721+
await self._redis.json().set(key, "$.type", write_obj["type"]) # type: ignore[misc, arg-type]
722+
await self._redis.json().set(key, "$.blob", write_obj["blob"]) # type: ignore[misc, arg-type]
751723
else:
752724
# Create new key
753-
await self._redis.json().set(key, "$", write_obj)
725+
await self._redis.json().set(key, "$", write_obj) # type: ignore[misc]
754726
created_keys.append(key)
755727
else:
756728
# For non-upsert case, only set if key doesn't exist
757729
exists = await self._redis.exists(key)
758730
if not exists:
759-
await self._redis.json().set(key, "$", write_obj)
731+
await self._redis.json().set(key, "$", write_obj) # type: ignore[misc]
760732
created_keys.append(key)
761733

762734
# Apply TTL to newly created keys
@@ -788,20 +760,30 @@ async def aput_writes(
788760
exists = await self._redis.exists(key)
789761
if exists:
790762
# Update existing key
791-
await pipeline.json().set(
792-
key, "$.channel", write_obj["channel"]
763+
pipeline.json().set(
764+
key,
765+
"$.channel",
766+
write_obj["channel"], # type: ignore[arg-type]
767+
)
768+
pipeline.json().set(
769+
key,
770+
"$.type",
771+
write_obj["type"], # type: ignore[arg-type]
772+
)
773+
pipeline.json().set(
774+
key,
775+
"$.blob",
776+
write_obj["blob"], # type: ignore[arg-type]
793777
)
794-
await pipeline.json().set(key, "$.type", write_obj["type"])
795-
await pipeline.json().set(key, "$.blob", write_obj["blob"])
796778
else:
797779
# Create new key
798-
await pipeline.json().set(key, "$", write_obj)
780+
pipeline.json().set(key, "$", write_obj)
799781
created_keys.append(key)
800782
else:
801783
# For non-upsert case, only set if key doesn't exist
802784
exists = await self._redis.exists(key)
803785
if not exists:
804-
await pipeline.json().set(key, "$", write_obj)
786+
pipeline.json().set(key, "$", write_obj)
805787
created_keys.append(key)
806788

807789
# Execute all operations atomically

langgraph/checkpoint/redis/ashallow.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,6 @@
8686
]
8787

8888

89-
# func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
90-
async def _write_obj_tx(pipe: Pipeline, key: str, write_obj: dict[str, Any]) -> None:
91-
exists: int = await pipe.exists(key)
92-
if exists:
93-
await pipe.json().set(key, "$.channel", write_obj["channel"])
94-
await pipe.json().set(key, "$.type", write_obj["type"])
95-
await pipe.json().set(key, "$.blob", write_obj["blob"])
96-
else:
97-
await pipe.json().set(key, "$", write_obj)
98-
99-
10089
class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
10190
"""Async Redis implementation that only stores the most recent checkpoint."""
10291

@@ -240,7 +229,7 @@ async def aput(
240229
)
241230

242231
# Add checkpoint data to pipeline
243-
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
232+
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
244233

245234
# Before storing the new blobs, clean up old ones that won't be needed
246235
# - Get a list of all blob keys for this thread_id and checkpoint_ns
@@ -274,7 +263,7 @@ async def aput(
274263
continue
275264
else:
276265
# This is an old version, delete it
277-
await pipeline.delete(blob_key)
266+
pipeline.delete(blob_key)
278267

279268
# Store the new blob values
280269
blobs = self._dump_blobs(
@@ -287,7 +276,7 @@ async def aput(
287276
if blobs:
288277
# Add all blob data to pipeline
289278
for key, data in blobs:
290-
await pipeline.json().set(key, "$", data)
279+
pipeline.json().set(key, "$", data)
291280

292281
# Execute all operations atomically
293282
await pipeline.execute()
@@ -571,7 +560,7 @@ async def aput_writes(
571560

572561
# If the write is for a different checkpoint_id, delete it
573562
if key_checkpoint_id != checkpoint_id:
574-
await pipeline.delete(write_key)
563+
pipeline.delete(write_key)
575564

576565
# Add new writes to the pipeline
577566
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
@@ -589,17 +578,15 @@ async def aput_writes(
589578
exists = await self._redis.exists(key)
590579
if exists:
591580
# Update existing key
592-
await pipeline.json().set(
593-
key, "$.channel", write_obj["channel"]
594-
)
595-
await pipeline.json().set(key, "$.type", write_obj["type"])
596-
await pipeline.json().set(key, "$.blob", write_obj["blob"])
581+
pipeline.json().set(key, "$.channel", write_obj["channel"])
582+
pipeline.json().set(key, "$.type", write_obj["type"])
583+
pipeline.json().set(key, "$.blob", write_obj["blob"])
597584
else:
598585
# Create new key
599-
await pipeline.json().set(key, "$", write_obj)
586+
pipeline.json().set(key, "$", write_obj)
600587
else:
601588
# For shallow implementation, always set the full object
602-
await pipeline.json().set(key, "$", write_obj)
589+
pipeline.json().set(key, "$", write_obj)
603590

604591
# Execute all operations atomically
605592
await pipeline.execute()
@@ -722,7 +709,9 @@ async def _aload_pending_writes(
722709
(
723710
parsed_key["task_id"],
724711
parsed_key["idx"],
725-
): await self._redis.json().get(key)
712+
): await self._redis.json().get(
713+
key
714+
) # type: ignore[misc]
726715
for key, parsed_key in sorted(
727716
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
728717
)

langgraph/checkpoint/redis/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,9 @@ def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
383383
# NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing
384384
return serialized_metadata.decode().replace("\\u0000", "")
385385

386-
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
386+
def get_next_version( # type: ignore[override]
387+
self, current: Optional[str], channel: ChannelProtocol[Any, Any, Any]
388+
) -> str:
387389
"""Generate next version number."""
388390
if current is None:
389391
current_v = 0
@@ -420,7 +422,7 @@ def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]:
420422
return []
421423

422424
writes = []
423-
for write in result["writes"]:
425+
for write in result["writes"]: # type: ignore[call-overload]
424426
writes.append(
425427
(
426428
write["task_id"],

langgraph/store/redis/__init__.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def _batch_search_ops(
515515
if not isinstance(store_doc, dict):
516516
try:
517517
store_doc = json.loads(
518-
store_doc
518+
store_doc # type: ignore[arg-type]
519519
) # Attempt to parse if it's a JSON string
520520
except (json.JSONDecodeError, TypeError):
521521
logger.error(f"Failed to parse store_doc: {store_doc}")
@@ -578,16 +578,14 @@ def _batch_search_ops(
578578
if self.cluster_mode:
579579
for key in refresh_keys:
580580
ttl = self._redis.ttl(key)
581-
if ttl > 0: # type: ignore
581+
if ttl > 0:
582582
self._redis.expire(key, ttl_seconds)
583583
else:
584584
pipeline = self._redis.pipeline(transaction=True)
585585
for key in refresh_keys:
586586
# Only refresh TTL if the key exists and has a TTL
587587
ttl = self._redis.ttl(key)
588-
if (
589-
ttl > 0
590-
): # Only refresh if key exists and has TTL # type: ignore
588+
if ttl > 0: # Only refresh if key exists and has TTL
591589
pipeline.expire(key, ttl_seconds)
592590
if pipeline.command_stack:
593591
pipeline.execute()
@@ -645,16 +643,14 @@ def _batch_search_ops(
645643
if self.cluster_mode:
646644
for key in refresh_keys:
647645
ttl = self._redis.ttl(key)
648-
if ttl > 0: # type: ignore
646+
if ttl > 0:
649647
self._redis.expire(key, ttl_seconds)
650648
else:
651649
pipeline = self._redis.pipeline(transaction=True)
652650
for key in refresh_keys:
653651
# Only refresh TTL if the key exists and has a TTL
654652
ttl = self._redis.ttl(key)
655-
if (
656-
ttl > 0
657-
): # Only refresh if key exists and has TTL # type: ignore
653+
if ttl > 0: # Only refresh if key exists and has TTL
658654
pipeline.expire(key, ttl_seconds)
659655
if pipeline.command_stack:
660656
pipeline.execute()

langgraph/store/redis/aio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787

8888
# Set up store configuration
8989
self.index_config = index
90-
self.ttl_config = ttl # type: ignore
90+
self.ttl_config = ttl
9191

9292
if self.index_config:
9393
self.index_config = self.index_config.copy()
@@ -744,7 +744,7 @@ async def _batch_search_ops(
744744
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
745745
result_map[store_key] = doc
746746
# Fetch individually in cluster mode
747-
store_doc_item = await self._redis.json().get(store_key)
747+
store_doc_item = await self._redis.json().get(store_key) # type: ignore
748748
store_docs.append(store_doc_item)
749749
store_docs_raw = store_docs
750750
else:

0 commit comments

Comments
 (0)