langgraph-cosmosdb-checkpointer 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ from .cosmosdb import (
2
+ AsyncCosmosDBSaver,
3
+ CosmosDBSaver,
4
+ CosmosDBSettings,
5
+ )
6
+
7
+ __all__ = [
8
+ "CosmosDBSettings",
9
+ "CosmosDBSaver",
10
+ "AsyncCosmosDBSaver",
11
+ ]
@@ -0,0 +1,1074 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import hashlib
6
+ import json
7
+ import random
8
+ import zlib
9
+ from contextlib import AbstractAsyncContextManager, AbstractContextManager
10
+ from dataclasses import dataclass
11
+ from typing import Any, AsyncIterator, Iterator, Mapping, Sequence
12
+
13
+ from azure.core import MatchConditions
14
+ from azure.core.credentials import TokenCredential
15
+ from azure.cosmos import CosmosClient, PartitionKey
16
+ from azure.cosmos.container import ContainerProxy
17
+ from azure.cosmos.database import DatabaseProxy
18
+ from azure.cosmos.exceptions import (
19
+ CosmosAccessConditionFailedError,
20
+ CosmosHttpResponseError,
21
+ CosmosResourceExistsError,
22
+ CosmosResourceNotFoundError,
23
+ )
24
+ from langchain_core.runnables import RunnableConfig
25
+ from langgraph.checkpoint.base import (
26
+ WRITES_IDX_MAP,
27
+ BaseCheckpointSaver,
28
+ ChannelVersions,
29
+ Checkpoint,
30
+ CheckpointMetadata,
31
+ CheckpointTuple,
32
+ SerializerProtocol,
33
+ get_checkpoint_id,
34
+ get_checkpoint_metadata,
35
+ )
36
+ from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
37
+
38
+ MAX_COSMOS_ITEM_BYTES = 2_000_000
39
+ DEFAULT_INLINE_PAYLOAD_LIMIT_BYTES = 1_500_000
40
+ DEFAULT_CHUNK_SIZE_BYTES = 512_000
41
+ DEFAULT_COMPRESS_ABOVE_BYTES = 16_384
42
+
43
+
44
+ @dataclass(slots=True)
45
+ class CosmosDBSettings:
46
+ database_name: str = "langgraph"
47
+ checkpoint_container_name: str = "checkpoints"
48
+ writes_container_name: str = "writes"
49
+ blob_container_name: str = "checkpoint_blobs"
50
+ create_containers: bool = True
51
+ retry_total: int | None = None
52
+ retry_backoff_max: int | None = None
53
+ retry_fixed_interval: int | None = None
54
+ inline_payload_limit_bytes: int = DEFAULT_INLINE_PAYLOAD_LIMIT_BYTES
55
+ chunk_size_bytes: int = DEFAULT_CHUNK_SIZE_BYTES
56
+ compress_above_bytes: int = DEFAULT_COMPRESS_ABOVE_BYTES
57
+ no_response_on_write: bool = False
58
+
59
+
60
+ class CosmosDBSaver(BaseCheckpointSaver[str], AbstractContextManager):
61
+ def __init__(
62
+ self,
63
+ checkpoint_container: ContainerProxy,
64
+ writes_container: ContainerProxy,
65
+ blob_container: ContainerProxy | None,
66
+ *,
67
+ settings: CosmosDBSettings | None = None,
68
+ serde: SerializerProtocol | None = None,
69
+ client: CosmosClient | None = None,
70
+ owns_client: bool = False,
71
+ ) -> None:
72
+ super().__init__(serde=serde or JsonPlusSerializer())
73
+ self._checkpoints = checkpoint_container
74
+ self._writes = writes_container
75
+ self._blobs = blob_container
76
+ self._settings = settings or CosmosDBSettings()
77
+ self._client = client
78
+ self._owns_client = owns_client
79
+
80
+ @classmethod
81
+ def from_connection_string(
82
+ cls,
83
+ connection_string: str,
84
+ *,
85
+ settings: CosmosDBSettings | None = None,
86
+ credential: str | TokenCredential | Mapping[str, Any] | None = None,
87
+ serde: SerializerProtocol | None = None,
88
+ ) -> "CosmosDBSaver":
89
+ cfg = settings or CosmosDBSettings()
90
+ client_kwargs = _client_kwargs(cfg)
91
+ client = CosmosClient.from_connection_string(
92
+ connection_string,
93
+ credential=credential,
94
+ **client_kwargs,
95
+ )
96
+ checkpoint_container, writes_container, blob_container = _resolve_containers_sync(
97
+ client=client,
98
+ settings=cfg,
99
+ )
100
+ return cls(
101
+ checkpoint_container=checkpoint_container,
102
+ writes_container=writes_container,
103
+ blob_container=blob_container,
104
+ settings=cfg,
105
+ serde=serde,
106
+ client=client,
107
+ owns_client=True,
108
+ )
109
+
110
+ @classmethod
111
+ def from_endpoint(
112
+ cls,
113
+ endpoint: str,
114
+ *,
115
+ credential: str | TokenCredential | Mapping[str, Any] | None = None,
116
+ settings: CosmosDBSettings | None = None,
117
+ serde: SerializerProtocol | None = None,
118
+ ) -> "CosmosDBSaver":
119
+ cfg = settings or CosmosDBSettings()
120
+ credential_value = credential
121
+ if credential_value is None:
122
+ from azure.identity import DefaultAzureCredential
123
+
124
+ credential_value = DefaultAzureCredential()
125
+ client = CosmosClient(endpoint, credential=credential_value, **_client_kwargs(cfg))
126
+ checkpoint_container, writes_container, blob_container = _resolve_containers_sync(
127
+ client=client,
128
+ settings=cfg,
129
+ )
130
+ return cls(
131
+ checkpoint_container=checkpoint_container,
132
+ writes_container=writes_container,
133
+ blob_container=blob_container,
134
+ settings=cfg,
135
+ serde=serde,
136
+ client=client,
137
+ owns_client=True,
138
+ )
139
+
140
+ def __enter__(self) -> "CosmosDBSaver":
141
+ return self
142
+
143
+ def __exit__(self, exc_type, exc, tb) -> bool | None:
144
+ self.close()
145
+ return None
146
+
147
+ def close(self) -> None:
148
+ if self._owns_client and self._client is not None:
149
+ self._client.__exit__(None, None, None)
150
+ self._client = None
151
+
152
+ def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
153
+ thread_id, checkpoint_ns, checkpoint_id = _normalize_config(config)
154
+ if checkpoint_id:
155
+ doc_id = _checkpoint_doc_id(checkpoint_ns, checkpoint_id)
156
+ try:
157
+ doc = self._checkpoints.read_item(item=doc_id, partition_key=thread_id)
158
+ except CosmosResourceNotFoundError:
159
+ return None
160
+ return self._checkpoint_tuple_from_doc(
161
+ doc=doc,
162
+ thread_id=thread_id,
163
+ checkpoint_ns=checkpoint_ns,
164
+ )
165
+
166
+ rows = list(
167
+ self._checkpoints.query_items(
168
+ query=(
169
+ "SELECT TOP 1 * FROM c "
170
+ "WHERE c.doc_type=@doc_type AND c.thread_id=@thread_id AND c.checkpoint_ns=@checkpoint_ns "
171
+ "ORDER BY c.checkpoint_id DESC"
172
+ ),
173
+ parameters=[
174
+ {"name": "@doc_type", "value": "checkpoint"},
175
+ {"name": "@thread_id", "value": thread_id},
176
+ {"name": "@checkpoint_ns", "value": checkpoint_ns},
177
+ ],
178
+ partition_key=thread_id,
179
+ )
180
+ )
181
+ if not rows:
182
+ return None
183
+ return self._checkpoint_tuple_from_doc(
184
+ doc=rows[0],
185
+ thread_id=thread_id,
186
+ checkpoint_ns=checkpoint_ns,
187
+ )
188
+
189
+ def list(
190
+ self,
191
+ config: RunnableConfig | None,
192
+ *,
193
+ filter: dict[str, Any] | None = None,
194
+ before: RunnableConfig | None = None,
195
+ limit: int | None = None,
196
+ ) -> Iterator[CheckpointTuple]:
197
+ if config:
198
+ thread_id, checkpoint_ns, checkpoint_id = _normalize_config(config)
199
+ else:
200
+ thread_id, checkpoint_ns, checkpoint_id = None, None, None
201
+
202
+ before_checkpoint_id = get_checkpoint_id(before) if before else None
203
+
204
+ query = "SELECT * FROM c WHERE c.doc_type=@doc_type"
205
+ parameters: list[dict[str, Any]] = [{"name": "@doc_type", "value": "checkpoint"}]
206
+ if thread_id is not None:
207
+ query += " AND c.thread_id=@thread_id"
208
+ parameters.append({"name": "@thread_id", "value": thread_id})
209
+ if checkpoint_ns is not None:
210
+ query += " AND c.checkpoint_ns=@checkpoint_ns"
211
+ parameters.append({"name": "@checkpoint_ns", "value": checkpoint_ns})
212
+ if checkpoint_id is not None:
213
+ query += " AND c.checkpoint_id=@checkpoint_id"
214
+ parameters.append({"name": "@checkpoint_id", "value": checkpoint_id})
215
+ if before_checkpoint_id is not None:
216
+ query += " AND c.checkpoint_id < @before_checkpoint_id"
217
+ parameters.append(
218
+ {"name": "@before_checkpoint_id", "value": before_checkpoint_id}
219
+ )
220
+ query += " ORDER BY c.checkpoint_id DESC"
221
+
222
+ query_kwargs: dict[str, Any] = {
223
+ "query": query,
224
+ "parameters": parameters,
225
+ }
226
+ if thread_id is None:
227
+ query_kwargs["enable_cross_partition_query"] = True
228
+ else:
229
+ query_kwargs["partition_key"] = thread_id
230
+
231
+ yielded = 0
232
+ for row in self._checkpoints.query_items(**query_kwargs):
233
+ tuple_value = self._checkpoint_tuple_from_doc(
234
+ doc=row,
235
+ thread_id=row["thread_id"],
236
+ checkpoint_ns=row["checkpoint_ns"],
237
+ )
238
+ if filter and not all(
239
+ tuple_value.metadata.get(k) == v for k, v in filter.items()
240
+ ):
241
+ continue
242
+ yield tuple_value
243
+ yielded += 1
244
+ if limit is not None and yielded >= limit:
245
+ return
246
+
247
+ def put(
248
+ self,
249
+ config: RunnableConfig,
250
+ checkpoint: Checkpoint,
251
+ metadata: CheckpointMetadata,
252
+ new_versions: ChannelVersions,
253
+ ) -> RunnableConfig:
254
+ thread_id, checkpoint_ns, parent_checkpoint_id = _normalize_config(config)
255
+ checkpoint_core = checkpoint.copy()
256
+ channel_values = checkpoint_core.pop("channel_values", {})
257
+
258
+ for channel, version in new_versions.items():
259
+ self._store_channel_blob(
260
+ thread_id=thread_id,
261
+ checkpoint_ns=checkpoint_ns,
262
+ channel=channel,
263
+ version=version,
264
+ value=channel_values.get(channel),
265
+ )
266
+
267
+ checkpoint_type, checkpoint_payload, checkpoint_hash = self._serialize_typed(
268
+ value=checkpoint_core,
269
+ thread_id=thread_id,
270
+ payload_scope=f"checkpoint:{checkpoint_ns}:{checkpoint['id']}",
271
+ )
272
+ metadata_dict = get_checkpoint_metadata(config, metadata)
273
+ metadata_type, metadata_payload, metadata_hash = self._serialize_typed(
274
+ value=metadata_dict,
275
+ thread_id=thread_id,
276
+ payload_scope=f"metadata:{checkpoint_ns}:{checkpoint['id']}",
277
+ )
278
+
279
+ doc = {
280
+ "id": _checkpoint_doc_id(checkpoint_ns, checkpoint["id"]),
281
+ "doc_type": "checkpoint",
282
+ "thread_id": thread_id,
283
+ "checkpoint_ns": checkpoint_ns,
284
+ "checkpoint_id": checkpoint["id"],
285
+ "parent_checkpoint_id": parent_checkpoint_id,
286
+ "checkpoint_type": checkpoint_type,
287
+ "checkpoint_payload": checkpoint_payload,
288
+ "checkpoint_hash": checkpoint_hash,
289
+ "metadata_type": metadata_type,
290
+ "metadata_payload": metadata_payload,
291
+ "metadata_hash": metadata_hash,
292
+ "created_at": checkpoint.get("ts"),
293
+ "run_id": metadata_dict.get("run_id"),
294
+ }
295
+ _assert_item_size(doc, scope=f"checkpoint:{checkpoint['id']}")
296
+
297
+ try:
298
+ self._checkpoints.create_item(doc)
299
+ except CosmosResourceExistsError:
300
+ existing = self._checkpoints.read_item(
301
+ item=doc["id"], partition_key=thread_id
302
+ )
303
+ if not _checkpoint_doc_equivalent(existing, doc):
304
+ raise ValueError(
305
+ "Checkpoint conflict for deterministic checkpoint ID."
306
+ ) from None
307
+
308
+ return {
309
+ "configurable": {
310
+ "thread_id": thread_id,
311
+ "checkpoint_ns": checkpoint_ns,
312
+ "checkpoint_id": checkpoint["id"],
313
+ }
314
+ }
315
+
316
+ def put_writes(
317
+ self,
318
+ config: RunnableConfig,
319
+ writes: Sequence[tuple[str, Any]],
320
+ task_id: str,
321
+ task_path: str = "",
322
+ ) -> None:
323
+ thread_id, checkpoint_ns, checkpoint_id = _normalize_config(config)
324
+ if checkpoint_id is None:
325
+ raise ValueError("put_writes requires checkpoint_id in config.")
326
+
327
+ for idx, (channel, value) in enumerate(writes):
328
+ mapped_idx = WRITES_IDX_MAP.get(channel, idx)
329
+ write_type, write_payload, write_hash = self._serialize_typed(
330
+ value=value,
331
+ thread_id=thread_id,
332
+ payload_scope=(
333
+ f"write:{checkpoint_ns}:{checkpoint_id}:{task_id}:{mapped_idx}"
334
+ ),
335
+ )
336
+ doc = {
337
+ "id": _write_doc_id(
338
+ checkpoint_ns=checkpoint_ns,
339
+ checkpoint_id=checkpoint_id,
340
+ task_id=task_id,
341
+ idx=mapped_idx,
342
+ ),
343
+ "doc_type": "write",
344
+ "thread_id": thread_id,
345
+ "checkpoint_ns": checkpoint_ns,
346
+ "checkpoint_id": checkpoint_id,
347
+ "task_id": task_id,
348
+ "task_path": task_path,
349
+ "idx": mapped_idx,
350
+ "channel": channel,
351
+ "value_type": write_type,
352
+ "value_payload": write_payload,
353
+ "value_hash": write_hash,
354
+ }
355
+ _assert_item_size(
356
+ doc, scope=f"write:{checkpoint_id}:{task_id}:{mapped_idx}"
357
+ )
358
+
359
+ if mapped_idx >= 0:
360
+ self._create_or_validate_write(doc)
361
+ else:
362
+ self._replace_special_write_with_etag(doc)
363
+
364
+ def delete_thread(self, thread_id: str) -> None:
365
+ query = "SELECT c.id FROM c WHERE c.thread_id=@thread_id"
366
+ params = [{"name": "@thread_id", "value": thread_id}]
367
+ for container in (self._checkpoints, self._writes, self._blobs):
368
+ if container is None:
369
+ continue
370
+ ids = list(
371
+ container.query_items(
372
+ query=query,
373
+ parameters=params,
374
+ partition_key=thread_id,
375
+ )
376
+ )
377
+ for row in ids:
378
+ container.delete_item(item=row["id"], partition_key=thread_id)
379
+
380
+ async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
381
+ return await asyncio.to_thread(self.get_tuple, config)
382
+
383
+ async def alist(
384
+ self,
385
+ config: RunnableConfig | None,
386
+ *,
387
+ filter: dict[str, Any] | None = None,
388
+ before: RunnableConfig | None = None,
389
+ limit: int | None = None,
390
+ ) -> AsyncIterator[CheckpointTuple]:
391
+ rows = await asyncio.to_thread(
392
+ lambda: list(self.list(config, filter=filter, before=before, limit=limit))
393
+ )
394
+ for row in rows:
395
+ yield row
396
+
397
+ async def aput(
398
+ self,
399
+ config: RunnableConfig,
400
+ checkpoint: Checkpoint,
401
+ metadata: CheckpointMetadata,
402
+ new_versions: ChannelVersions,
403
+ ) -> RunnableConfig:
404
+ return await asyncio.to_thread(
405
+ self.put, config, checkpoint, metadata, new_versions
406
+ )
407
+
408
+ async def aput_writes(
409
+ self,
410
+ config: RunnableConfig,
411
+ writes: Sequence[tuple[str, Any]],
412
+ task_id: str,
413
+ task_path: str = "",
414
+ ) -> None:
415
+ await asyncio.to_thread(self.put_writes, config, writes, task_id, task_path)
416
+
417
+ async def adelete_thread(self, thread_id: str) -> None:
418
+ await asyncio.to_thread(self.delete_thread, thread_id)
419
+
420
+ def get_next_version(self, current: str | None, channel: None) -> str:
421
+ if current is None:
422
+ current_v = 0
423
+ elif isinstance(current, int):
424
+ current_v = current
425
+ else:
426
+ current_v = int(str(current).split(".")[0])
427
+ return f"{current_v + 1:032}.{random.random():016}"
428
+
429
+ def _checkpoint_tuple_from_doc(
430
+ self,
431
+ *,
432
+ doc: Mapping[str, Any],
433
+ thread_id: str,
434
+ checkpoint_ns: str,
435
+ ) -> CheckpointTuple:
436
+ checkpoint_core = self._deserialize_typed(
437
+ thread_id=thread_id,
438
+ payload_type=doc["checkpoint_type"],
439
+ payload=doc["checkpoint_payload"],
440
+ )
441
+ metadata = self._deserialize_typed(
442
+ thread_id=thread_id,
443
+ payload_type=doc["metadata_type"],
444
+ payload=doc["metadata_payload"],
445
+ )
446
+ if not isinstance(checkpoint_core, Mapping):
447
+ raise TypeError("Decoded checkpoint payload is not a mapping.")
448
+ if not isinstance(metadata, Mapping):
449
+ raise TypeError("Decoded checkpoint metadata is not a mapping.")
450
+
451
+ channel_values = self._load_channel_values(
452
+ thread_id=thread_id,
453
+ checkpoint_ns=checkpoint_ns,
454
+ channel_versions=checkpoint_core.get("channel_versions", {}),
455
+ )
456
+ pending_writes = self._load_pending_writes(
457
+ thread_id=thread_id,
458
+ checkpoint_ns=checkpoint_ns,
459
+ checkpoint_id=doc["checkpoint_id"],
460
+ )
461
+ checkpoint = dict(checkpoint_core)
462
+ checkpoint["channel_values"] = channel_values
463
+
464
+ parent_checkpoint_id = doc.get("parent_checkpoint_id")
465
+ parent_config = None
466
+ if parent_checkpoint_id:
467
+ parent_config = {
468
+ "configurable": {
469
+ "thread_id": thread_id,
470
+ "checkpoint_ns": checkpoint_ns,
471
+ "checkpoint_id": parent_checkpoint_id,
472
+ }
473
+ }
474
+
475
+ return CheckpointTuple(
476
+ config={
477
+ "configurable": {
478
+ "thread_id": thread_id,
479
+ "checkpoint_ns": checkpoint_ns,
480
+ "checkpoint_id": doc["checkpoint_id"],
481
+ }
482
+ },
483
+ checkpoint=checkpoint, # type: ignore[arg-type]
484
+ metadata=dict(metadata), # type: ignore[arg-type]
485
+ parent_config=parent_config,
486
+ pending_writes=pending_writes,
487
+ )
488
+
489
+ def _load_pending_writes(
490
+ self,
491
+ *,
492
+ thread_id: str,
493
+ checkpoint_ns: str,
494
+ checkpoint_id: str,
495
+ ) -> list[tuple[str, str, Any]]:
496
+ query = (
497
+ "SELECT * FROM c WHERE c.doc_type=@doc_type AND c.thread_id=@thread_id "
498
+ "AND c.checkpoint_ns=@checkpoint_ns AND c.checkpoint_id=@checkpoint_id "
499
+ "ORDER BY c.task_id ASC, c.idx ASC"
500
+ )
501
+ params = [
502
+ {"name": "@doc_type", "value": "write"},
503
+ {"name": "@thread_id", "value": thread_id},
504
+ {"name": "@checkpoint_ns", "value": checkpoint_ns},
505
+ {"name": "@checkpoint_id", "value": checkpoint_id},
506
+ ]
507
+ rows = self._writes.query_items(
508
+ query=query,
509
+ parameters=params,
510
+ partition_key=thread_id,
511
+ )
512
+ output: list[tuple[str, str, Any]] = []
513
+ for row in rows:
514
+ value = self._deserialize_typed(
515
+ thread_id=thread_id,
516
+ payload_type=row["value_type"],
517
+ payload=row["value_payload"],
518
+ )
519
+ output.append((row["task_id"], row["channel"], value))
520
+ return output
521
+
522
+ def _create_or_validate_write(self, doc: dict[str, Any]) -> None:
523
+ try:
524
+ self._writes.create_item(doc)
525
+ return
526
+ except CosmosResourceExistsError:
527
+ existing = self._writes.read_item(item=doc["id"], partition_key=doc["thread_id"])
528
+ if not _write_doc_equivalent(existing, doc):
529
+ raise ValueError(
530
+ "Detected conflicting duplicate write for deterministic write identity."
531
+ )
532
+
533
+ def _replace_special_write_with_etag(self, doc: dict[str, Any]) -> None:
534
+ try:
535
+ existing = self._writes.read_item(item=doc["id"], partition_key=doc["thread_id"])
536
+ except CosmosResourceNotFoundError:
537
+ self._writes.create_item(doc)
538
+ return
539
+
540
+ if _write_doc_equivalent(existing, doc):
541
+ return
542
+
543
+ try:
544
+ self._writes.replace_item(
545
+ item=doc["id"],
546
+ body=doc,
547
+ etag=existing.get("_etag"),
548
+ match_condition=MatchConditions.IfNotModified,
549
+ )
550
+ except (CosmosAccessConditionFailedError, CosmosHttpResponseError) as exc:
551
+ if getattr(exc, "status_code", None) != 412:
552
+ raise
553
+ latest = self._writes.read_item(item=doc["id"], partition_key=doc["thread_id"])
554
+ if not _write_doc_equivalent(latest, doc):
555
+ raise ValueError(
556
+ "Concurrent update conflict while replacing special write."
557
+ ) from None
558
+
559
+ def _store_channel_blob(
560
+ self,
561
+ *,
562
+ thread_id: str,
563
+ checkpoint_ns: str,
564
+ channel: str,
565
+ version: str | int | float,
566
+ value: Any,
567
+ ) -> None:
568
+ if self._blobs is None:
569
+ if value is not None:
570
+ raise ValueError("Blob container is required to persist channel values.")
571
+ return
572
+
573
+ doc_id = _channel_blob_doc_id(checkpoint_ns, channel, version)
574
+ if value is None:
575
+ doc = {
576
+ "id": doc_id,
577
+ "doc_type": "channel_blob",
578
+ "thread_id": thread_id,
579
+ "checkpoint_ns": checkpoint_ns,
580
+ "channel": channel,
581
+ "version": _stringify_version(version),
582
+ "value_type": "empty",
583
+ "value_payload": {"mode": "empty"},
584
+ "value_hash": "empty",
585
+ }
586
+ else:
587
+ value_type, value_payload, value_hash = self._serialize_typed(
588
+ value=value,
589
+ thread_id=thread_id,
590
+ payload_scope=f"channel:{checkpoint_ns}:{channel}:{_stringify_version(version)}",
591
+ )
592
+ doc = {
593
+ "id": doc_id,
594
+ "doc_type": "channel_blob",
595
+ "thread_id": thread_id,
596
+ "checkpoint_ns": checkpoint_ns,
597
+ "channel": channel,
598
+ "version": _stringify_version(version),
599
+ "value_type": value_type,
600
+ "value_payload": value_payload,
601
+ "value_hash": value_hash,
602
+ }
603
+ _assert_item_size(doc, scope=f"channel_blob:{channel}")
604
+
605
+ try:
606
+ self._blobs.create_item(doc)
607
+ except CosmosResourceExistsError:
608
+ existing = self._blobs.read_item(item=doc_id, partition_key=thread_id)
609
+ if not _channel_blob_equivalent(existing, doc):
610
+ raise ValueError(
611
+ "Conflicting channel blob for deterministic channel version identity."
612
+ ) from None
613
+
614
+ def _load_channel_values(
615
+ self,
616
+ *,
617
+ thread_id: str,
618
+ checkpoint_ns: str,
619
+ channel_versions: Mapping[str, Any],
620
+ ) -> dict[str, Any]:
621
+ channel_values: dict[str, Any] = {}
622
+ if self._blobs is None:
623
+ return channel_values
624
+
625
+ for channel, version in channel_versions.items():
626
+ doc_id = _channel_blob_doc_id(checkpoint_ns, channel, version)
627
+ try:
628
+ row = self._blobs.read_item(item=doc_id, partition_key=thread_id)
629
+ except CosmosResourceNotFoundError:
630
+ continue
631
+ if row.get("value_type") == "empty":
632
+ continue
633
+ channel_values[channel] = self._deserialize_typed(
634
+ thread_id=thread_id,
635
+ payload_type=row["value_type"],
636
+ payload=row["value_payload"],
637
+ )
638
+ return channel_values
639
+
640
+ def _serialize_typed(
641
+ self,
642
+ *,
643
+ value: Any,
644
+ thread_id: str,
645
+ payload_scope: str,
646
+ ) -> tuple[str, dict[str, Any], str]:
647
+ payload_type, raw = self.serde.dumps_typed(value)
648
+ payload, raw_hash = self._encode_payload(
649
+ raw=raw,
650
+ thread_id=thread_id,
651
+ payload_scope=payload_scope,
652
+ )
653
+ return payload_type, payload, raw_hash
654
+
655
+ def _deserialize_typed(
656
+ self,
657
+ *,
658
+ thread_id: str,
659
+ payload_type: str,
660
+ payload: Mapping[str, Any],
661
+ ) -> Any:
662
+ raw = self._decode_payload(thread_id=thread_id, payload=payload)
663
+ return self.serde.loads_typed((payload_type, raw))
664
+
665
+ def _encode_payload(
666
+ self,
667
+ *,
668
+ raw: bytes,
669
+ thread_id: str,
670
+ payload_scope: str,
671
+ ) -> tuple[dict[str, Any], str]:
672
+ raw_hash = hashlib.sha256(raw).hexdigest()
673
+ data = raw
674
+ compressed = False
675
+ if len(raw) >= self._settings.compress_above_bytes:
676
+ compressed_candidate = zlib.compress(raw)
677
+ if len(compressed_candidate) < len(raw):
678
+ data = compressed_candidate
679
+ compressed = True
680
+
681
+ if len(data) <= self._settings.inline_payload_limit_bytes:
682
+ return (
683
+ {
684
+ "mode": "inline",
685
+ "compressed": compressed,
686
+ "raw_hash": raw_hash,
687
+ "size_bytes": len(raw),
688
+ "data": _b64encode(data),
689
+ },
690
+ raw_hash,
691
+ )
692
+
693
+ if self._blobs is None:
694
+ raise ValueError(
695
+ "Payload exceeds inline payload limit and no blob container is configured."
696
+ )
697
+
698
+ chunk_ids: list[str] = []
699
+ for idx in range(0, len(data), self._settings.chunk_size_bytes):
700
+ chunk = data[idx : idx + self._settings.chunk_size_bytes]
701
+ chunk_id = _chunk_doc_id(payload_scope, idx // self._settings.chunk_size_bytes)
702
+ chunk_doc = {
703
+ "id": chunk_id,
704
+ "doc_type": "payload_chunk",
705
+ "thread_id": thread_id,
706
+ "scope_hash": _hash_token(payload_scope),
707
+ "chunk_index": idx // self._settings.chunk_size_bytes,
708
+ "payload": _b64encode(chunk),
709
+ "payload_hash": hashlib.sha256(chunk).hexdigest(),
710
+ }
711
+ _assert_item_size(chunk_doc, scope=f"chunk:{payload_scope}")
712
+ try:
713
+ self._blobs.create_item(chunk_doc)
714
+ except CosmosResourceExistsError:
715
+ existing = self._blobs.read_item(item=chunk_id, partition_key=thread_id)
716
+ if not _chunk_doc_equivalent(existing, chunk_doc):
717
+ raise ValueError("Payload chunk conflict for deterministic chunk ID.") from None
718
+ chunk_ids.append(chunk_id)
719
+
720
+ return (
721
+ {
722
+ "mode": "chunks",
723
+ "compressed": compressed,
724
+ "raw_hash": raw_hash,
725
+ "size_bytes": len(raw),
726
+ "chunks": chunk_ids,
727
+ },
728
+ raw_hash,
729
+ )
730
+
731
+ def _decode_payload(self, *, thread_id: str, payload: Mapping[str, Any]) -> bytes:
732
+ mode = payload.get("mode")
733
+ if mode == "empty":
734
+ return b""
735
+ if mode == "inline":
736
+ raw = _b64decode(payload["data"])
737
+ elif mode == "chunks":
738
+ if self._blobs is None:
739
+ raise ValueError("Blob container is required to read chunked payloads.")
740
+ chunks = payload.get("chunks", [])
741
+ pieces: list[bytes] = []
742
+ for chunk_id in chunks:
743
+ row = self._blobs.read_item(item=chunk_id, partition_key=thread_id)
744
+ pieces.append(_b64decode(row["payload"]))
745
+ raw = b"".join(pieces)
746
+ else:
747
+ raise ValueError(f"Unknown payload mode: {mode}")
748
+
749
+ if payload.get("compressed"):
750
+ raw = zlib.decompress(raw)
751
+ if payload.get("raw_hash") and hashlib.sha256(raw).hexdigest() != payload["raw_hash"]:
752
+ raise ValueError("Payload hash mismatch detected.")
753
+ return raw
754
+
755
+
756
+ class AsyncCosmosDBSaver(BaseCheckpointSaver[str], AbstractAsyncContextManager, AbstractContextManager):
757
+ def __init__(self, saver: CosmosDBSaver) -> None:
758
+ super().__init__(serde=saver.serde)
759
+ self._saver = saver
760
+
761
+ @classmethod
762
+ def from_connection_string(
763
+ cls,
764
+ connection_string: str,
765
+ *,
766
+ settings: CosmosDBSettings | None = None,
767
+ credential: str | TokenCredential | Mapping[str, Any] | None = None,
768
+ serde: SerializerProtocol | None = None,
769
+ ) -> "AsyncCosmosDBSaver":
770
+ return cls(
771
+ CosmosDBSaver.from_connection_string(
772
+ connection_string,
773
+ settings=settings,
774
+ credential=credential,
775
+ serde=serde,
776
+ )
777
+ )
778
+
779
+ @classmethod
780
+ def from_endpoint(
781
+ cls,
782
+ endpoint: str,
783
+ *,
784
+ credential: str | TokenCredential | Mapping[str, Any] | None = None,
785
+ settings: CosmosDBSettings | None = None,
786
+ serde: SerializerProtocol | None = None,
787
+ ) -> "AsyncCosmosDBSaver":
788
+ return cls(
789
+ CosmosDBSaver.from_endpoint(
790
+ endpoint,
791
+ credential=credential,
792
+ settings=settings,
793
+ serde=serde,
794
+ )
795
+ )
796
+
797
+ def __enter__(self) -> "AsyncCosmosDBSaver":
798
+ return self
799
+
800
+ def __exit__(self, exc_type, exc, tb) -> bool | None:
801
+ self._saver.close()
802
+ return None
803
+
804
+ async def __aenter__(self) -> "AsyncCosmosDBSaver":
805
+ return self
806
+
807
+ async def __aexit__(self, exc_type, exc, tb) -> bool | None:
808
+ await asyncio.to_thread(self._saver.close)
809
+ return None
810
+
811
+ def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
812
+ return self._saver.get_tuple(config)
813
+
814
+ def list(
815
+ self,
816
+ config: RunnableConfig | None,
817
+ *,
818
+ filter: dict[str, Any] | None = None,
819
+ before: RunnableConfig | None = None,
820
+ limit: int | None = None,
821
+ ) -> Iterator[CheckpointTuple]:
822
+ return self._saver.list(config, filter=filter, before=before, limit=limit)
823
+
824
+ def put(
825
+ self,
826
+ config: RunnableConfig,
827
+ checkpoint: Checkpoint,
828
+ metadata: CheckpointMetadata,
829
+ new_versions: ChannelVersions,
830
+ ) -> RunnableConfig:
831
+ return self._saver.put(config, checkpoint, metadata, new_versions)
832
+
833
+ def put_writes(
834
+ self,
835
+ config: RunnableConfig,
836
+ writes: Sequence[tuple[str, Any]],
837
+ task_id: str,
838
+ task_path: str = "",
839
+ ) -> None:
840
+ self._saver.put_writes(config, writes, task_id, task_path)
841
+
842
+ def delete_thread(self, thread_id: str) -> None:
843
+ self._saver.delete_thread(thread_id)
844
+
845
+ async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
846
+ return await asyncio.to_thread(self._saver.get_tuple, config)
847
+
848
+ async def alist(
849
+ self,
850
+ config: RunnableConfig | None,
851
+ *,
852
+ filter: dict[str, Any] | None = None,
853
+ before: RunnableConfig | None = None,
854
+ limit: int | None = None,
855
+ ) -> AsyncIterator[CheckpointTuple]:
856
+ rows = await asyncio.to_thread(
857
+ lambda: list(self._saver.list(config, filter=filter, before=before, limit=limit))
858
+ )
859
+ for row in rows:
860
+ yield row
861
+
862
+ async def aput(
863
+ self,
864
+ config: RunnableConfig,
865
+ checkpoint: Checkpoint,
866
+ metadata: CheckpointMetadata,
867
+ new_versions: ChannelVersions,
868
+ ) -> RunnableConfig:
869
+ return await asyncio.to_thread(
870
+ self._saver.put, config, checkpoint, metadata, new_versions
871
+ )
872
+
873
+ async def aput_writes(
874
+ self,
875
+ config: RunnableConfig,
876
+ writes: Sequence[tuple[str, Any]],
877
+ task_id: str,
878
+ task_path: str = "",
879
+ ) -> None:
880
+ await asyncio.to_thread(self._saver.put_writes, config, writes, task_id, task_path)
881
+
882
+ async def adelete_thread(self, thread_id: str) -> None:
883
+ await asyncio.to_thread(self._saver.delete_thread, thread_id)
884
+
885
+
886
+ def _client_kwargs(settings: CosmosDBSettings) -> dict[str, Any]:
887
+ kwargs: dict[str, Any] = {
888
+ "no_response_on_write": settings.no_response_on_write,
889
+ }
890
+ if settings.retry_total is not None:
891
+ kwargs["retry_total"] = settings.retry_total
892
+ if settings.retry_backoff_max is not None:
893
+ kwargs["retry_backoff_max"] = settings.retry_backoff_max
894
+ if settings.retry_fixed_interval is not None:
895
+ kwargs["retry_fixed_interval"] = settings.retry_fixed_interval
896
+ return kwargs
897
+
898
+
899
+ def _resolve_containers_sync(
900
+ *,
901
+ client: CosmosClient,
902
+ settings: CosmosDBSettings,
903
+ ) -> tuple[ContainerProxy, ContainerProxy, ContainerProxy | None]:
904
+ database: DatabaseProxy
905
+ if settings.create_containers:
906
+ database = client.create_database_if_not_exists(id=settings.database_name)
907
+ else:
908
+ database = client.get_database_client(settings.database_name)
909
+
910
+ if settings.create_containers:
911
+ checkpoints = database.create_container_if_not_exists(
912
+ id=settings.checkpoint_container_name,
913
+ partition_key=PartitionKey(path="/thread_id"),
914
+ indexing_policy=_checkpoint_indexing_policy(),
915
+ )
916
+ writes = database.create_container_if_not_exists(
917
+ id=settings.writes_container_name,
918
+ partition_key=PartitionKey(path="/thread_id"),
919
+ indexing_policy=_writes_indexing_policy(),
920
+ )
921
+ blobs = database.create_container_if_not_exists(
922
+ id=settings.blob_container_name,
923
+ partition_key=PartitionKey(path="/thread_id"),
924
+ indexing_policy=_blob_indexing_policy(),
925
+ )
926
+ else:
927
+ checkpoints = database.get_container_client(settings.checkpoint_container_name)
928
+ writes = database.get_container_client(settings.writes_container_name)
929
+ blobs = (
930
+ database.get_container_client(settings.blob_container_name)
931
+ if settings.blob_container_name
932
+ else None
933
+ )
934
+ return checkpoints, writes, blobs
935
+
936
+
937
+ def _normalize_config(config: RunnableConfig) -> tuple[str, str, str | None]:
938
+ configurable = config.get("configurable") or {}
939
+ if "thread_ts" in configurable:
940
+ raise ValueError(
941
+ "Legacy config key thread_ts is not supported. Use thread_id/checkpoint_ns/checkpoint_id."
942
+ )
943
+ thread_id = configurable.get("thread_id")
944
+ if not thread_id:
945
+ raise ValueError("config.configurable.thread_id is required.")
946
+ checkpoint_ns = configurable.get("checkpoint_ns", "")
947
+ checkpoint_id = get_checkpoint_id(config)
948
+ return str(thread_id), str(checkpoint_ns), str(checkpoint_id) if checkpoint_id else None
949
+
950
+
951
+ def _checkpoint_doc_equivalent(existing: Mapping[str, Any], new_doc: Mapping[str, Any]) -> bool:
952
+ return (
953
+ existing.get("checkpoint_hash") == new_doc.get("checkpoint_hash")
954
+ and existing.get("metadata_hash") == new_doc.get("metadata_hash")
955
+ and existing.get("parent_checkpoint_id") == new_doc.get("parent_checkpoint_id")
956
+ )
957
+
958
+
959
+ def _write_doc_equivalent(existing: Mapping[str, Any], new_doc: Mapping[str, Any]) -> bool:
960
+ return (
961
+ existing.get("channel") == new_doc.get("channel")
962
+ and existing.get("value_hash") == new_doc.get("value_hash")
963
+ and existing.get("task_id") == new_doc.get("task_id")
964
+ )
965
+
966
+
967
+ def _channel_blob_equivalent(existing: Mapping[str, Any], new_doc: Mapping[str, Any]) -> bool:
968
+ return (
969
+ existing.get("value_type") == new_doc.get("value_type")
970
+ and existing.get("value_hash") == new_doc.get("value_hash")
971
+ )
972
+
973
+
974
+ def _chunk_doc_equivalent(existing: Mapping[str, Any], new_doc: Mapping[str, Any]) -> bool:
975
+ return existing.get("payload_hash") == new_doc.get("payload_hash")
976
+
977
+
978
+ def _checkpoint_doc_id(checkpoint_ns: str, checkpoint_id: str) -> str:
979
+ return f"ckpt|{_ns_token(checkpoint_ns)}|{checkpoint_id}"
980
+
981
+
982
+ def _write_doc_id(
983
+ *,
984
+ checkpoint_ns: str,
985
+ checkpoint_id: str,
986
+ task_id: str,
987
+ idx: int,
988
+ ) -> str:
989
+ return f"w|{_ns_token(checkpoint_ns)}|{checkpoint_id}|{_hash_token(task_id)}|{idx}"
990
+
991
+
992
+ def _channel_blob_doc_id(
993
+ checkpoint_ns: str,
994
+ channel: str,
995
+ version: str | int | float,
996
+ ) -> str:
997
+ return (
998
+ f"b|{_ns_token(checkpoint_ns)}|{_hash_token(channel)}|"
999
+ f"{_hash_token(_stringify_version(version))}"
1000
+ )
1001
+
1002
+
1003
+ def _chunk_doc_id(scope: str, chunk_idx: int) -> str:
1004
+ return f"chunk|{_hash_token(scope)}|{chunk_idx:06d}"
1005
+
1006
+
1007
+ def _ns_token(checkpoint_ns: str) -> str:
1008
+ return "root" if not checkpoint_ns else _hash_token(checkpoint_ns)
1009
+
1010
+
1011
+ def _hash_token(value: str) -> str:
1012
+ return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16]
1013
+
1014
+
1015
+ def _stringify_version(version: str | int | float) -> str:
1016
+ if isinstance(version, str):
1017
+ return version
1018
+ return json.dumps(version, sort_keys=True, separators=(",", ":"))
1019
+
1020
+
1021
+ def _assert_item_size(item: Mapping[str, Any], *, scope: str) -> None:
1022
+ item_bytes = len(
1023
+ json.dumps(item, separators=(",", ":"), sort_keys=True, default=str).encode("utf-8")
1024
+ )
1025
+ if item_bytes > MAX_COSMOS_ITEM_BYTES:
1026
+ raise ValueError(
1027
+ f"Cosmos item for {scope} exceeds 2MB limit ({item_bytes} bytes)."
1028
+ )
1029
+
1030
+
1031
+ def _b64encode(raw: bytes) -> str:
1032
+ return base64.b64encode(raw).decode("ascii")
1033
+
1034
+
1035
+ def _b64decode(data: str) -> bytes:
1036
+ return base64.b64decode(data.encode("ascii"))
1037
+
1038
+
1039
+ def _checkpoint_indexing_policy() -> dict[str, Any]:
1040
+ return {
1041
+ "indexingMode": "consistent",
1042
+ "automatic": True,
1043
+ "includedPaths": [{"path": "/*"}],
1044
+ "excludedPaths": [
1045
+ {"path": "/checkpoint_payload/*"},
1046
+ {"path": "/metadata_payload/*"},
1047
+ ],
1048
+ }
1049
+
1050
+
1051
+ def _writes_indexing_policy() -> dict[str, Any]:
1052
+ return {
1053
+ "indexingMode": "consistent",
1054
+ "automatic": True,
1055
+ "includedPaths": [{"path": "/*"}],
1056
+ "excludedPaths": [{"path": "/value_payload/*"}],
1057
+ }
1058
+
1059
+
1060
+ def _blob_indexing_policy() -> dict[str, Any]:
1061
+ return {
1062
+ "indexingMode": "consistent",
1063
+ "automatic": True,
1064
+ "includedPaths": [
1065
+ {"path": "/thread_id/?"},
1066
+ {"path": "/doc_type/?"},
1067
+ {"path": "/scope_hash/?"},
1068
+ {"path": "/chunk_index/?"},
1069
+ {"path": "/channel/?"},
1070
+ {"path": "/version/?"},
1071
+ {"path": "/id/?"},
1072
+ ],
1073
+ "excludedPaths": [{"path": "/*"}],
1074
+ }
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.4
2
+ Name: langgraph-cosmosdb-checkpointer
3
+ Version: 0.0.0
4
+ Summary: State-of-the-art Azure Cosmos DB checkpointer for LangGraph
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: azure-cosmos>=4.14.6
8
+ Requires-Dist: azure-identity>=1.25.2
9
+ Requires-Dist: langchain>=1.2.10
10
+ Requires-Dist: langgraph>=1.0.9
11
+ Provides-Extra: dev
12
+ Requires-Dist: pytest>=9.0.2; extra == "dev"
13
+ Requires-Dist: pytest-asyncio>=1.3.0; extra == "dev"
14
+ Requires-Dist: ruff>=0.9.0; extra == "dev"
15
+
16
+ # langgraph-cosmosdb-checkpointer
17
+
18
+ State-of-the-art Azure Cosmos DB checkpointer for LangGraph.
@@ -0,0 +1,6 @@
1
+ langgraph_cosmosdb_checkpointer/__init__.py,sha256=8SkbUhhy3suSteK4Ha_7J-O48Tc7RMJQ49N7DrPVveo,177
2
+ langgraph_cosmosdb_checkpointer/cosmosdb.py,sha256=kBZ3Qklg10MBlmY1WBtSCjB-bLPy3LBanHsoCBSbbX0,37835
3
+ langgraph_cosmosdb_checkpointer-0.0.0.dist-info/METADATA,sha256=cIXspDHPg3XV2qyPYLDGhWJGPs8VxpGJKA-MhdqUmIw,605
4
+ langgraph_cosmosdb_checkpointer-0.0.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
5
+ langgraph_cosmosdb_checkpointer-0.0.0.dist-info/top_level.txt,sha256=JF8lejbl_uxDEJyPoTnQQwi7JWWXdd2vdLvmPhE-GyQ,32
6
+ langgraph_cosmosdb_checkpointer-0.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ langgraph_cosmosdb_checkpointer