langgraph-cosmosdb-checkpointer 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_cosmosdb_checkpointer/__init__.py +11 -0
- langgraph_cosmosdb_checkpointer/cosmosdb.py +1074 -0
- langgraph_cosmosdb_checkpointer-0.0.0.dist-info/METADATA +18 -0
- langgraph_cosmosdb_checkpointer-0.0.0.dist-info/RECORD +6 -0
- langgraph_cosmosdb_checkpointer-0.0.0.dist-info/WHEEL +5 -0
- langgraph_cosmosdb_checkpointer-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1074 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import random
|
|
8
|
+
import zlib
|
|
9
|
+
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any, AsyncIterator, Iterator, Mapping, Sequence
|
|
12
|
+
|
|
13
|
+
from azure.core import MatchConditions
|
|
14
|
+
from azure.core.credentials import TokenCredential
|
|
15
|
+
from azure.cosmos import CosmosClient, PartitionKey
|
|
16
|
+
from azure.cosmos.container import ContainerProxy
|
|
17
|
+
from azure.cosmos.database import DatabaseProxy
|
|
18
|
+
from azure.cosmos.exceptions import (
|
|
19
|
+
CosmosAccessConditionFailedError,
|
|
20
|
+
CosmosHttpResponseError,
|
|
21
|
+
CosmosResourceExistsError,
|
|
22
|
+
CosmosResourceNotFoundError,
|
|
23
|
+
)
|
|
24
|
+
from langchain_core.runnables import RunnableConfig
|
|
25
|
+
from langgraph.checkpoint.base import (
|
|
26
|
+
WRITES_IDX_MAP,
|
|
27
|
+
BaseCheckpointSaver,
|
|
28
|
+
ChannelVersions,
|
|
29
|
+
Checkpoint,
|
|
30
|
+
CheckpointMetadata,
|
|
31
|
+
CheckpointTuple,
|
|
32
|
+
SerializerProtocol,
|
|
33
|
+
get_checkpoint_id,
|
|
34
|
+
get_checkpoint_metadata,
|
|
35
|
+
)
|
|
36
|
+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
37
|
+
|
|
38
|
+
MAX_COSMOS_ITEM_BYTES = 2_000_000
|
|
39
|
+
DEFAULT_INLINE_PAYLOAD_LIMIT_BYTES = 1_500_000
|
|
40
|
+
DEFAULT_CHUNK_SIZE_BYTES = 512_000
|
|
41
|
+
DEFAULT_COMPRESS_ABOVE_BYTES = 16_384
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass(slots=True)
|
|
45
|
+
class CosmosDBSettings:
|
|
46
|
+
database_name: str = "langgraph"
|
|
47
|
+
checkpoint_container_name: str = "checkpoints"
|
|
48
|
+
writes_container_name: str = "writes"
|
|
49
|
+
blob_container_name: str = "checkpoint_blobs"
|
|
50
|
+
create_containers: bool = True
|
|
51
|
+
retry_total: int | None = None
|
|
52
|
+
retry_backoff_max: int | None = None
|
|
53
|
+
retry_fixed_interval: int | None = None
|
|
54
|
+
inline_payload_limit_bytes: int = DEFAULT_INLINE_PAYLOAD_LIMIT_BYTES
|
|
55
|
+
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE_BYTES
|
|
56
|
+
compress_above_bytes: int = DEFAULT_COMPRESS_ABOVE_BYTES
|
|
57
|
+
no_response_on_write: bool = False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class CosmosDBSaver(BaseCheckpointSaver[str], AbstractContextManager):
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
checkpoint_container: ContainerProxy,
|
|
64
|
+
writes_container: ContainerProxy,
|
|
65
|
+
blob_container: ContainerProxy | None,
|
|
66
|
+
*,
|
|
67
|
+
settings: CosmosDBSettings | None = None,
|
|
68
|
+
serde: SerializerProtocol | None = None,
|
|
69
|
+
client: CosmosClient | None = None,
|
|
70
|
+
owns_client: bool = False,
|
|
71
|
+
) -> None:
|
|
72
|
+
super().__init__(serde=serde or JsonPlusSerializer())
|
|
73
|
+
self._checkpoints = checkpoint_container
|
|
74
|
+
self._writes = writes_container
|
|
75
|
+
self._blobs = blob_container
|
|
76
|
+
self._settings = settings or CosmosDBSettings()
|
|
77
|
+
self._client = client
|
|
78
|
+
self._owns_client = owns_client
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_connection_string(
|
|
82
|
+
cls,
|
|
83
|
+
connection_string: str,
|
|
84
|
+
*,
|
|
85
|
+
settings: CosmosDBSettings | None = None,
|
|
86
|
+
credential: str | TokenCredential | Mapping[str, Any] | None = None,
|
|
87
|
+
serde: SerializerProtocol | None = None,
|
|
88
|
+
) -> "CosmosDBSaver":
|
|
89
|
+
cfg = settings or CosmosDBSettings()
|
|
90
|
+
client_kwargs = _client_kwargs(cfg)
|
|
91
|
+
client = CosmosClient.from_connection_string(
|
|
92
|
+
connection_string,
|
|
93
|
+
credential=credential,
|
|
94
|
+
**client_kwargs,
|
|
95
|
+
)
|
|
96
|
+
checkpoint_container, writes_container, blob_container = _resolve_containers_sync(
|
|
97
|
+
client=client,
|
|
98
|
+
settings=cfg,
|
|
99
|
+
)
|
|
100
|
+
return cls(
|
|
101
|
+
checkpoint_container=checkpoint_container,
|
|
102
|
+
writes_container=writes_container,
|
|
103
|
+
blob_container=blob_container,
|
|
104
|
+
settings=cfg,
|
|
105
|
+
serde=serde,
|
|
106
|
+
client=client,
|
|
107
|
+
owns_client=True,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def from_endpoint(
|
|
112
|
+
cls,
|
|
113
|
+
endpoint: str,
|
|
114
|
+
*,
|
|
115
|
+
credential: str | TokenCredential | Mapping[str, Any] | None = None,
|
|
116
|
+
settings: CosmosDBSettings | None = None,
|
|
117
|
+
serde: SerializerProtocol | None = None,
|
|
118
|
+
) -> "CosmosDBSaver":
|
|
119
|
+
cfg = settings or CosmosDBSettings()
|
|
120
|
+
credential_value = credential
|
|
121
|
+
if credential_value is None:
|
|
122
|
+
from azure.identity import DefaultAzureCredential
|
|
123
|
+
|
|
124
|
+
credential_value = DefaultAzureCredential()
|
|
125
|
+
client = CosmosClient(endpoint, credential=credential_value, **_client_kwargs(cfg))
|
|
126
|
+
checkpoint_container, writes_container, blob_container = _resolve_containers_sync(
|
|
127
|
+
client=client,
|
|
128
|
+
settings=cfg,
|
|
129
|
+
)
|
|
130
|
+
return cls(
|
|
131
|
+
checkpoint_container=checkpoint_container,
|
|
132
|
+
writes_container=writes_container,
|
|
133
|
+
blob_container=blob_container,
|
|
134
|
+
settings=cfg,
|
|
135
|
+
serde=serde,
|
|
136
|
+
client=client,
|
|
137
|
+
owns_client=True,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def __enter__(self) -> "CosmosDBSaver":
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
def __exit__(self, exc_type, exc, tb) -> bool | None:
|
|
144
|
+
self.close()
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
def close(self) -> None:
|
|
148
|
+
if self._owns_client and self._client is not None:
|
|
149
|
+
self._client.__exit__(None, None, None)
|
|
150
|
+
self._client = None
|
|
151
|
+
|
|
152
|
+
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
153
|
+
thread_id, checkpoint_ns, checkpoint_id = _normalize_config(config)
|
|
154
|
+
if checkpoint_id:
|
|
155
|
+
doc_id = _checkpoint_doc_id(checkpoint_ns, checkpoint_id)
|
|
156
|
+
try:
|
|
157
|
+
doc = self._checkpoints.read_item(item=doc_id, partition_key=thread_id)
|
|
158
|
+
except CosmosResourceNotFoundError:
|
|
159
|
+
return None
|
|
160
|
+
return self._checkpoint_tuple_from_doc(
|
|
161
|
+
doc=doc,
|
|
162
|
+
thread_id=thread_id,
|
|
163
|
+
checkpoint_ns=checkpoint_ns,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
rows = list(
|
|
167
|
+
self._checkpoints.query_items(
|
|
168
|
+
query=(
|
|
169
|
+
"SELECT TOP 1 * FROM c "
|
|
170
|
+
"WHERE c.doc_type=@doc_type AND c.thread_id=@thread_id AND c.checkpoint_ns=@checkpoint_ns "
|
|
171
|
+
"ORDER BY c.checkpoint_id DESC"
|
|
172
|
+
),
|
|
173
|
+
parameters=[
|
|
174
|
+
{"name": "@doc_type", "value": "checkpoint"},
|
|
175
|
+
{"name": "@thread_id", "value": thread_id},
|
|
176
|
+
{"name": "@checkpoint_ns", "value": checkpoint_ns},
|
|
177
|
+
],
|
|
178
|
+
partition_key=thread_id,
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
if not rows:
|
|
182
|
+
return None
|
|
183
|
+
return self._checkpoint_tuple_from_doc(
|
|
184
|
+
doc=rows[0],
|
|
185
|
+
thread_id=thread_id,
|
|
186
|
+
checkpoint_ns=checkpoint_ns,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def list(
|
|
190
|
+
self,
|
|
191
|
+
config: RunnableConfig | None,
|
|
192
|
+
*,
|
|
193
|
+
filter: dict[str, Any] | None = None,
|
|
194
|
+
before: RunnableConfig | None = None,
|
|
195
|
+
limit: int | None = None,
|
|
196
|
+
) -> Iterator[CheckpointTuple]:
|
|
197
|
+
if config:
|
|
198
|
+
thread_id, checkpoint_ns, checkpoint_id = _normalize_config(config)
|
|
199
|
+
else:
|
|
200
|
+
thread_id, checkpoint_ns, checkpoint_id = None, None, None
|
|
201
|
+
|
|
202
|
+
before_checkpoint_id = get_checkpoint_id(before) if before else None
|
|
203
|
+
|
|
204
|
+
query = "SELECT * FROM c WHERE c.doc_type=@doc_type"
|
|
205
|
+
parameters: list[dict[str, Any]] = [{"name": "@doc_type", "value": "checkpoint"}]
|
|
206
|
+
if thread_id is not None:
|
|
207
|
+
query += " AND c.thread_id=@thread_id"
|
|
208
|
+
parameters.append({"name": "@thread_id", "value": thread_id})
|
|
209
|
+
if checkpoint_ns is not None:
|
|
210
|
+
query += " AND c.checkpoint_ns=@checkpoint_ns"
|
|
211
|
+
parameters.append({"name": "@checkpoint_ns", "value": checkpoint_ns})
|
|
212
|
+
if checkpoint_id is not None:
|
|
213
|
+
query += " AND c.checkpoint_id=@checkpoint_id"
|
|
214
|
+
parameters.append({"name": "@checkpoint_id", "value": checkpoint_id})
|
|
215
|
+
if before_checkpoint_id is not None:
|
|
216
|
+
query += " AND c.checkpoint_id < @before_checkpoint_id"
|
|
217
|
+
parameters.append(
|
|
218
|
+
{"name": "@before_checkpoint_id", "value": before_checkpoint_id}
|
|
219
|
+
)
|
|
220
|
+
query += " ORDER BY c.checkpoint_id DESC"
|
|
221
|
+
|
|
222
|
+
query_kwargs: dict[str, Any] = {
|
|
223
|
+
"query": query,
|
|
224
|
+
"parameters": parameters,
|
|
225
|
+
}
|
|
226
|
+
if thread_id is None:
|
|
227
|
+
query_kwargs["enable_cross_partition_query"] = True
|
|
228
|
+
else:
|
|
229
|
+
query_kwargs["partition_key"] = thread_id
|
|
230
|
+
|
|
231
|
+
yielded = 0
|
|
232
|
+
for row in self._checkpoints.query_items(**query_kwargs):
|
|
233
|
+
tuple_value = self._checkpoint_tuple_from_doc(
|
|
234
|
+
doc=row,
|
|
235
|
+
thread_id=row["thread_id"],
|
|
236
|
+
checkpoint_ns=row["checkpoint_ns"],
|
|
237
|
+
)
|
|
238
|
+
if filter and not all(
|
|
239
|
+
tuple_value.metadata.get(k) == v for k, v in filter.items()
|
|
240
|
+
):
|
|
241
|
+
continue
|
|
242
|
+
yield tuple_value
|
|
243
|
+
yielded += 1
|
|
244
|
+
if limit is not None and yielded >= limit:
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
def put(
|
|
248
|
+
self,
|
|
249
|
+
config: RunnableConfig,
|
|
250
|
+
checkpoint: Checkpoint,
|
|
251
|
+
metadata: CheckpointMetadata,
|
|
252
|
+
new_versions: ChannelVersions,
|
|
253
|
+
) -> RunnableConfig:
|
|
254
|
+
thread_id, checkpoint_ns, parent_checkpoint_id = _normalize_config(config)
|
|
255
|
+
checkpoint_core = checkpoint.copy()
|
|
256
|
+
channel_values = checkpoint_core.pop("channel_values", {})
|
|
257
|
+
|
|
258
|
+
for channel, version in new_versions.items():
|
|
259
|
+
self._store_channel_blob(
|
|
260
|
+
thread_id=thread_id,
|
|
261
|
+
checkpoint_ns=checkpoint_ns,
|
|
262
|
+
channel=channel,
|
|
263
|
+
version=version,
|
|
264
|
+
value=channel_values.get(channel),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
checkpoint_type, checkpoint_payload, checkpoint_hash = self._serialize_typed(
|
|
268
|
+
value=checkpoint_core,
|
|
269
|
+
thread_id=thread_id,
|
|
270
|
+
payload_scope=f"checkpoint:{checkpoint_ns}:{checkpoint['id']}",
|
|
271
|
+
)
|
|
272
|
+
metadata_dict = get_checkpoint_metadata(config, metadata)
|
|
273
|
+
metadata_type, metadata_payload, metadata_hash = self._serialize_typed(
|
|
274
|
+
value=metadata_dict,
|
|
275
|
+
thread_id=thread_id,
|
|
276
|
+
payload_scope=f"metadata:{checkpoint_ns}:{checkpoint['id']}",
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
doc = {
|
|
280
|
+
"id": _checkpoint_doc_id(checkpoint_ns, checkpoint["id"]),
|
|
281
|
+
"doc_type": "checkpoint",
|
|
282
|
+
"thread_id": thread_id,
|
|
283
|
+
"checkpoint_ns": checkpoint_ns,
|
|
284
|
+
"checkpoint_id": checkpoint["id"],
|
|
285
|
+
"parent_checkpoint_id": parent_checkpoint_id,
|
|
286
|
+
"checkpoint_type": checkpoint_type,
|
|
287
|
+
"checkpoint_payload": checkpoint_payload,
|
|
288
|
+
"checkpoint_hash": checkpoint_hash,
|
|
289
|
+
"metadata_type": metadata_type,
|
|
290
|
+
"metadata_payload": metadata_payload,
|
|
291
|
+
"metadata_hash": metadata_hash,
|
|
292
|
+
"created_at": checkpoint.get("ts"),
|
|
293
|
+
"run_id": metadata_dict.get("run_id"),
|
|
294
|
+
}
|
|
295
|
+
_assert_item_size(doc, scope=f"checkpoint:{checkpoint['id']}")
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
self._checkpoints.create_item(doc)
|
|
299
|
+
except CosmosResourceExistsError:
|
|
300
|
+
existing = self._checkpoints.read_item(
|
|
301
|
+
item=doc["id"], partition_key=thread_id
|
|
302
|
+
)
|
|
303
|
+
if not _checkpoint_doc_equivalent(existing, doc):
|
|
304
|
+
raise ValueError(
|
|
305
|
+
"Checkpoint conflict for deterministic checkpoint ID."
|
|
306
|
+
) from None
|
|
307
|
+
|
|
308
|
+
return {
|
|
309
|
+
"configurable": {
|
|
310
|
+
"thread_id": thread_id,
|
|
311
|
+
"checkpoint_ns": checkpoint_ns,
|
|
312
|
+
"checkpoint_id": checkpoint["id"],
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
def put_writes(
|
|
317
|
+
self,
|
|
318
|
+
config: RunnableConfig,
|
|
319
|
+
writes: Sequence[tuple[str, Any]],
|
|
320
|
+
task_id: str,
|
|
321
|
+
task_path: str = "",
|
|
322
|
+
) -> None:
|
|
323
|
+
thread_id, checkpoint_ns, checkpoint_id = _normalize_config(config)
|
|
324
|
+
if checkpoint_id is None:
|
|
325
|
+
raise ValueError("put_writes requires checkpoint_id in config.")
|
|
326
|
+
|
|
327
|
+
for idx, (channel, value) in enumerate(writes):
|
|
328
|
+
mapped_idx = WRITES_IDX_MAP.get(channel, idx)
|
|
329
|
+
write_type, write_payload, write_hash = self._serialize_typed(
|
|
330
|
+
value=value,
|
|
331
|
+
thread_id=thread_id,
|
|
332
|
+
payload_scope=(
|
|
333
|
+
f"write:{checkpoint_ns}:{checkpoint_id}:{task_id}:{mapped_idx}"
|
|
334
|
+
),
|
|
335
|
+
)
|
|
336
|
+
doc = {
|
|
337
|
+
"id": _write_doc_id(
|
|
338
|
+
checkpoint_ns=checkpoint_ns,
|
|
339
|
+
checkpoint_id=checkpoint_id,
|
|
340
|
+
task_id=task_id,
|
|
341
|
+
idx=mapped_idx,
|
|
342
|
+
),
|
|
343
|
+
"doc_type": "write",
|
|
344
|
+
"thread_id": thread_id,
|
|
345
|
+
"checkpoint_ns": checkpoint_ns,
|
|
346
|
+
"checkpoint_id": checkpoint_id,
|
|
347
|
+
"task_id": task_id,
|
|
348
|
+
"task_path": task_path,
|
|
349
|
+
"idx": mapped_idx,
|
|
350
|
+
"channel": channel,
|
|
351
|
+
"value_type": write_type,
|
|
352
|
+
"value_payload": write_payload,
|
|
353
|
+
"value_hash": write_hash,
|
|
354
|
+
}
|
|
355
|
+
_assert_item_size(
|
|
356
|
+
doc, scope=f"write:{checkpoint_id}:{task_id}:{mapped_idx}"
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if mapped_idx >= 0:
|
|
360
|
+
self._create_or_validate_write(doc)
|
|
361
|
+
else:
|
|
362
|
+
self._replace_special_write_with_etag(doc)
|
|
363
|
+
|
|
364
|
+
def delete_thread(self, thread_id: str) -> None:
|
|
365
|
+
query = "SELECT c.id FROM c WHERE c.thread_id=@thread_id"
|
|
366
|
+
params = [{"name": "@thread_id", "value": thread_id}]
|
|
367
|
+
for container in (self._checkpoints, self._writes, self._blobs):
|
|
368
|
+
if container is None:
|
|
369
|
+
continue
|
|
370
|
+
ids = list(
|
|
371
|
+
container.query_items(
|
|
372
|
+
query=query,
|
|
373
|
+
parameters=params,
|
|
374
|
+
partition_key=thread_id,
|
|
375
|
+
)
|
|
376
|
+
)
|
|
377
|
+
for row in ids:
|
|
378
|
+
container.delete_item(item=row["id"], partition_key=thread_id)
|
|
379
|
+
|
|
380
|
+
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
381
|
+
return await asyncio.to_thread(self.get_tuple, config)
|
|
382
|
+
|
|
383
|
+
async def alist(
|
|
384
|
+
self,
|
|
385
|
+
config: RunnableConfig | None,
|
|
386
|
+
*,
|
|
387
|
+
filter: dict[str, Any] | None = None,
|
|
388
|
+
before: RunnableConfig | None = None,
|
|
389
|
+
limit: int | None = None,
|
|
390
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
391
|
+
rows = await asyncio.to_thread(
|
|
392
|
+
lambda: list(self.list(config, filter=filter, before=before, limit=limit))
|
|
393
|
+
)
|
|
394
|
+
for row in rows:
|
|
395
|
+
yield row
|
|
396
|
+
|
|
397
|
+
async def aput(
|
|
398
|
+
self,
|
|
399
|
+
config: RunnableConfig,
|
|
400
|
+
checkpoint: Checkpoint,
|
|
401
|
+
metadata: CheckpointMetadata,
|
|
402
|
+
new_versions: ChannelVersions,
|
|
403
|
+
) -> RunnableConfig:
|
|
404
|
+
return await asyncio.to_thread(
|
|
405
|
+
self.put, config, checkpoint, metadata, new_versions
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
async def aput_writes(
|
|
409
|
+
self,
|
|
410
|
+
config: RunnableConfig,
|
|
411
|
+
writes: Sequence[tuple[str, Any]],
|
|
412
|
+
task_id: str,
|
|
413
|
+
task_path: str = "",
|
|
414
|
+
) -> None:
|
|
415
|
+
await asyncio.to_thread(self.put_writes, config, writes, task_id, task_path)
|
|
416
|
+
|
|
417
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
418
|
+
await asyncio.to_thread(self.delete_thread, thread_id)
|
|
419
|
+
|
|
420
|
+
def get_next_version(self, current: str | None, channel: None) -> str:
|
|
421
|
+
if current is None:
|
|
422
|
+
current_v = 0
|
|
423
|
+
elif isinstance(current, int):
|
|
424
|
+
current_v = current
|
|
425
|
+
else:
|
|
426
|
+
current_v = int(str(current).split(".")[0])
|
|
427
|
+
return f"{current_v + 1:032}.{random.random():016}"
|
|
428
|
+
|
|
429
|
+
def _checkpoint_tuple_from_doc(
|
|
430
|
+
self,
|
|
431
|
+
*,
|
|
432
|
+
doc: Mapping[str, Any],
|
|
433
|
+
thread_id: str,
|
|
434
|
+
checkpoint_ns: str,
|
|
435
|
+
) -> CheckpointTuple:
|
|
436
|
+
checkpoint_core = self._deserialize_typed(
|
|
437
|
+
thread_id=thread_id,
|
|
438
|
+
payload_type=doc["checkpoint_type"],
|
|
439
|
+
payload=doc["checkpoint_payload"],
|
|
440
|
+
)
|
|
441
|
+
metadata = self._deserialize_typed(
|
|
442
|
+
thread_id=thread_id,
|
|
443
|
+
payload_type=doc["metadata_type"],
|
|
444
|
+
payload=doc["metadata_payload"],
|
|
445
|
+
)
|
|
446
|
+
if not isinstance(checkpoint_core, Mapping):
|
|
447
|
+
raise TypeError("Decoded checkpoint payload is not a mapping.")
|
|
448
|
+
if not isinstance(metadata, Mapping):
|
|
449
|
+
raise TypeError("Decoded checkpoint metadata is not a mapping.")
|
|
450
|
+
|
|
451
|
+
channel_values = self._load_channel_values(
|
|
452
|
+
thread_id=thread_id,
|
|
453
|
+
checkpoint_ns=checkpoint_ns,
|
|
454
|
+
channel_versions=checkpoint_core.get("channel_versions", {}),
|
|
455
|
+
)
|
|
456
|
+
pending_writes = self._load_pending_writes(
|
|
457
|
+
thread_id=thread_id,
|
|
458
|
+
checkpoint_ns=checkpoint_ns,
|
|
459
|
+
checkpoint_id=doc["checkpoint_id"],
|
|
460
|
+
)
|
|
461
|
+
checkpoint = dict(checkpoint_core)
|
|
462
|
+
checkpoint["channel_values"] = channel_values
|
|
463
|
+
|
|
464
|
+
parent_checkpoint_id = doc.get("parent_checkpoint_id")
|
|
465
|
+
parent_config = None
|
|
466
|
+
if parent_checkpoint_id:
|
|
467
|
+
parent_config = {
|
|
468
|
+
"configurable": {
|
|
469
|
+
"thread_id": thread_id,
|
|
470
|
+
"checkpoint_ns": checkpoint_ns,
|
|
471
|
+
"checkpoint_id": parent_checkpoint_id,
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
return CheckpointTuple(
|
|
476
|
+
config={
|
|
477
|
+
"configurable": {
|
|
478
|
+
"thread_id": thread_id,
|
|
479
|
+
"checkpoint_ns": checkpoint_ns,
|
|
480
|
+
"checkpoint_id": doc["checkpoint_id"],
|
|
481
|
+
}
|
|
482
|
+
},
|
|
483
|
+
checkpoint=checkpoint, # type: ignore[arg-type]
|
|
484
|
+
metadata=dict(metadata), # type: ignore[arg-type]
|
|
485
|
+
parent_config=parent_config,
|
|
486
|
+
pending_writes=pending_writes,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
def _load_pending_writes(
|
|
490
|
+
self,
|
|
491
|
+
*,
|
|
492
|
+
thread_id: str,
|
|
493
|
+
checkpoint_ns: str,
|
|
494
|
+
checkpoint_id: str,
|
|
495
|
+
) -> list[tuple[str, str, Any]]:
|
|
496
|
+
query = (
|
|
497
|
+
"SELECT * FROM c WHERE c.doc_type=@doc_type AND c.thread_id=@thread_id "
|
|
498
|
+
"AND c.checkpoint_ns=@checkpoint_ns AND c.checkpoint_id=@checkpoint_id "
|
|
499
|
+
"ORDER BY c.task_id ASC, c.idx ASC"
|
|
500
|
+
)
|
|
501
|
+
params = [
|
|
502
|
+
{"name": "@doc_type", "value": "write"},
|
|
503
|
+
{"name": "@thread_id", "value": thread_id},
|
|
504
|
+
{"name": "@checkpoint_ns", "value": checkpoint_ns},
|
|
505
|
+
{"name": "@checkpoint_id", "value": checkpoint_id},
|
|
506
|
+
]
|
|
507
|
+
rows = self._writes.query_items(
|
|
508
|
+
query=query,
|
|
509
|
+
parameters=params,
|
|
510
|
+
partition_key=thread_id,
|
|
511
|
+
)
|
|
512
|
+
output: list[tuple[str, str, Any]] = []
|
|
513
|
+
for row in rows:
|
|
514
|
+
value = self._deserialize_typed(
|
|
515
|
+
thread_id=thread_id,
|
|
516
|
+
payload_type=row["value_type"],
|
|
517
|
+
payload=row["value_payload"],
|
|
518
|
+
)
|
|
519
|
+
output.append((row["task_id"], row["channel"], value))
|
|
520
|
+
return output
|
|
521
|
+
|
|
522
|
+
def _create_or_validate_write(self, doc: dict[str, Any]) -> None:
|
|
523
|
+
try:
|
|
524
|
+
self._writes.create_item(doc)
|
|
525
|
+
return
|
|
526
|
+
except CosmosResourceExistsError:
|
|
527
|
+
existing = self._writes.read_item(item=doc["id"], partition_key=doc["thread_id"])
|
|
528
|
+
if not _write_doc_equivalent(existing, doc):
|
|
529
|
+
raise ValueError(
|
|
530
|
+
"Detected conflicting duplicate write for deterministic write identity."
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
def _replace_special_write_with_etag(self, doc: dict[str, Any]) -> None:
|
|
534
|
+
try:
|
|
535
|
+
existing = self._writes.read_item(item=doc["id"], partition_key=doc["thread_id"])
|
|
536
|
+
except CosmosResourceNotFoundError:
|
|
537
|
+
self._writes.create_item(doc)
|
|
538
|
+
return
|
|
539
|
+
|
|
540
|
+
if _write_doc_equivalent(existing, doc):
|
|
541
|
+
return
|
|
542
|
+
|
|
543
|
+
try:
|
|
544
|
+
self._writes.replace_item(
|
|
545
|
+
item=doc["id"],
|
|
546
|
+
body=doc,
|
|
547
|
+
etag=existing.get("_etag"),
|
|
548
|
+
match_condition=MatchConditions.IfNotModified,
|
|
549
|
+
)
|
|
550
|
+
except (CosmosAccessConditionFailedError, CosmosHttpResponseError) as exc:
|
|
551
|
+
if getattr(exc, "status_code", None) != 412:
|
|
552
|
+
raise
|
|
553
|
+
latest = self._writes.read_item(item=doc["id"], partition_key=doc["thread_id"])
|
|
554
|
+
if not _write_doc_equivalent(latest, doc):
|
|
555
|
+
raise ValueError(
|
|
556
|
+
"Concurrent update conflict while replacing special write."
|
|
557
|
+
) from None
|
|
558
|
+
|
|
559
|
+
def _store_channel_blob(
|
|
560
|
+
self,
|
|
561
|
+
*,
|
|
562
|
+
thread_id: str,
|
|
563
|
+
checkpoint_ns: str,
|
|
564
|
+
channel: str,
|
|
565
|
+
version: str | int | float,
|
|
566
|
+
value: Any,
|
|
567
|
+
) -> None:
|
|
568
|
+
if self._blobs is None:
|
|
569
|
+
if value is not None:
|
|
570
|
+
raise ValueError("Blob container is required to persist channel values.")
|
|
571
|
+
return
|
|
572
|
+
|
|
573
|
+
doc_id = _channel_blob_doc_id(checkpoint_ns, channel, version)
|
|
574
|
+
if value is None:
|
|
575
|
+
doc = {
|
|
576
|
+
"id": doc_id,
|
|
577
|
+
"doc_type": "channel_blob",
|
|
578
|
+
"thread_id": thread_id,
|
|
579
|
+
"checkpoint_ns": checkpoint_ns,
|
|
580
|
+
"channel": channel,
|
|
581
|
+
"version": _stringify_version(version),
|
|
582
|
+
"value_type": "empty",
|
|
583
|
+
"value_payload": {"mode": "empty"},
|
|
584
|
+
"value_hash": "empty",
|
|
585
|
+
}
|
|
586
|
+
else:
|
|
587
|
+
value_type, value_payload, value_hash = self._serialize_typed(
|
|
588
|
+
value=value,
|
|
589
|
+
thread_id=thread_id,
|
|
590
|
+
payload_scope=f"channel:{checkpoint_ns}:{channel}:{_stringify_version(version)}",
|
|
591
|
+
)
|
|
592
|
+
doc = {
|
|
593
|
+
"id": doc_id,
|
|
594
|
+
"doc_type": "channel_blob",
|
|
595
|
+
"thread_id": thread_id,
|
|
596
|
+
"checkpoint_ns": checkpoint_ns,
|
|
597
|
+
"channel": channel,
|
|
598
|
+
"version": _stringify_version(version),
|
|
599
|
+
"value_type": value_type,
|
|
600
|
+
"value_payload": value_payload,
|
|
601
|
+
"value_hash": value_hash,
|
|
602
|
+
}
|
|
603
|
+
_assert_item_size(doc, scope=f"channel_blob:{channel}")
|
|
604
|
+
|
|
605
|
+
try:
|
|
606
|
+
self._blobs.create_item(doc)
|
|
607
|
+
except CosmosResourceExistsError:
|
|
608
|
+
existing = self._blobs.read_item(item=doc_id, partition_key=thread_id)
|
|
609
|
+
if not _channel_blob_equivalent(existing, doc):
|
|
610
|
+
raise ValueError(
|
|
611
|
+
"Conflicting channel blob for deterministic channel version identity."
|
|
612
|
+
) from None
|
|
613
|
+
|
|
614
|
+
def _load_channel_values(
|
|
615
|
+
self,
|
|
616
|
+
*,
|
|
617
|
+
thread_id: str,
|
|
618
|
+
checkpoint_ns: str,
|
|
619
|
+
channel_versions: Mapping[str, Any],
|
|
620
|
+
) -> dict[str, Any]:
|
|
621
|
+
channel_values: dict[str, Any] = {}
|
|
622
|
+
if self._blobs is None:
|
|
623
|
+
return channel_values
|
|
624
|
+
|
|
625
|
+
for channel, version in channel_versions.items():
|
|
626
|
+
doc_id = _channel_blob_doc_id(checkpoint_ns, channel, version)
|
|
627
|
+
try:
|
|
628
|
+
row = self._blobs.read_item(item=doc_id, partition_key=thread_id)
|
|
629
|
+
except CosmosResourceNotFoundError:
|
|
630
|
+
continue
|
|
631
|
+
if row.get("value_type") == "empty":
|
|
632
|
+
continue
|
|
633
|
+
channel_values[channel] = self._deserialize_typed(
|
|
634
|
+
thread_id=thread_id,
|
|
635
|
+
payload_type=row["value_type"],
|
|
636
|
+
payload=row["value_payload"],
|
|
637
|
+
)
|
|
638
|
+
return channel_values
|
|
639
|
+
|
|
640
|
+
def _serialize_typed(
|
|
641
|
+
self,
|
|
642
|
+
*,
|
|
643
|
+
value: Any,
|
|
644
|
+
thread_id: str,
|
|
645
|
+
payload_scope: str,
|
|
646
|
+
) -> tuple[str, dict[str, Any], str]:
|
|
647
|
+
payload_type, raw = self.serde.dumps_typed(value)
|
|
648
|
+
payload, raw_hash = self._encode_payload(
|
|
649
|
+
raw=raw,
|
|
650
|
+
thread_id=thread_id,
|
|
651
|
+
payload_scope=payload_scope,
|
|
652
|
+
)
|
|
653
|
+
return payload_type, payload, raw_hash
|
|
654
|
+
|
|
655
|
+
def _deserialize_typed(
|
|
656
|
+
self,
|
|
657
|
+
*,
|
|
658
|
+
thread_id: str,
|
|
659
|
+
payload_type: str,
|
|
660
|
+
payload: Mapping[str, Any],
|
|
661
|
+
) -> Any:
|
|
662
|
+
raw = self._decode_payload(thread_id=thread_id, payload=payload)
|
|
663
|
+
return self.serde.loads_typed((payload_type, raw))
|
|
664
|
+
|
|
665
|
+
def _encode_payload(
|
|
666
|
+
self,
|
|
667
|
+
*,
|
|
668
|
+
raw: bytes,
|
|
669
|
+
thread_id: str,
|
|
670
|
+
payload_scope: str,
|
|
671
|
+
) -> tuple[dict[str, Any], str]:
|
|
672
|
+
raw_hash = hashlib.sha256(raw).hexdigest()
|
|
673
|
+
data = raw
|
|
674
|
+
compressed = False
|
|
675
|
+
if len(raw) >= self._settings.compress_above_bytes:
|
|
676
|
+
compressed_candidate = zlib.compress(raw)
|
|
677
|
+
if len(compressed_candidate) < len(raw):
|
|
678
|
+
data = compressed_candidate
|
|
679
|
+
compressed = True
|
|
680
|
+
|
|
681
|
+
if len(data) <= self._settings.inline_payload_limit_bytes:
|
|
682
|
+
return (
|
|
683
|
+
{
|
|
684
|
+
"mode": "inline",
|
|
685
|
+
"compressed": compressed,
|
|
686
|
+
"raw_hash": raw_hash,
|
|
687
|
+
"size_bytes": len(raw),
|
|
688
|
+
"data": _b64encode(data),
|
|
689
|
+
},
|
|
690
|
+
raw_hash,
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
if self._blobs is None:
|
|
694
|
+
raise ValueError(
|
|
695
|
+
"Payload exceeds inline payload limit and no blob container is configured."
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
chunk_ids: list[str] = []
|
|
699
|
+
for idx in range(0, len(data), self._settings.chunk_size_bytes):
|
|
700
|
+
chunk = data[idx : idx + self._settings.chunk_size_bytes]
|
|
701
|
+
chunk_id = _chunk_doc_id(payload_scope, idx // self._settings.chunk_size_bytes)
|
|
702
|
+
chunk_doc = {
|
|
703
|
+
"id": chunk_id,
|
|
704
|
+
"doc_type": "payload_chunk",
|
|
705
|
+
"thread_id": thread_id,
|
|
706
|
+
"scope_hash": _hash_token(payload_scope),
|
|
707
|
+
"chunk_index": idx // self._settings.chunk_size_bytes,
|
|
708
|
+
"payload": _b64encode(chunk),
|
|
709
|
+
"payload_hash": hashlib.sha256(chunk).hexdigest(),
|
|
710
|
+
}
|
|
711
|
+
_assert_item_size(chunk_doc, scope=f"chunk:{payload_scope}")
|
|
712
|
+
try:
|
|
713
|
+
self._blobs.create_item(chunk_doc)
|
|
714
|
+
except CosmosResourceExistsError:
|
|
715
|
+
existing = self._blobs.read_item(item=chunk_id, partition_key=thread_id)
|
|
716
|
+
if not _chunk_doc_equivalent(existing, chunk_doc):
|
|
717
|
+
raise ValueError("Payload chunk conflict for deterministic chunk ID.") from None
|
|
718
|
+
chunk_ids.append(chunk_id)
|
|
719
|
+
|
|
720
|
+
return (
|
|
721
|
+
{
|
|
722
|
+
"mode": "chunks",
|
|
723
|
+
"compressed": compressed,
|
|
724
|
+
"raw_hash": raw_hash,
|
|
725
|
+
"size_bytes": len(raw),
|
|
726
|
+
"chunks": chunk_ids,
|
|
727
|
+
},
|
|
728
|
+
raw_hash,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
def _decode_payload(self, *, thread_id: str, payload: Mapping[str, Any]) -> bytes:
|
|
732
|
+
mode = payload.get("mode")
|
|
733
|
+
if mode == "empty":
|
|
734
|
+
return b""
|
|
735
|
+
if mode == "inline":
|
|
736
|
+
raw = _b64decode(payload["data"])
|
|
737
|
+
elif mode == "chunks":
|
|
738
|
+
if self._blobs is None:
|
|
739
|
+
raise ValueError("Blob container is required to read chunked payloads.")
|
|
740
|
+
chunks = payload.get("chunks", [])
|
|
741
|
+
pieces: list[bytes] = []
|
|
742
|
+
for chunk_id in chunks:
|
|
743
|
+
row = self._blobs.read_item(item=chunk_id, partition_key=thread_id)
|
|
744
|
+
pieces.append(_b64decode(row["payload"]))
|
|
745
|
+
raw = b"".join(pieces)
|
|
746
|
+
else:
|
|
747
|
+
raise ValueError(f"Unknown payload mode: {mode}")
|
|
748
|
+
|
|
749
|
+
if payload.get("compressed"):
|
|
750
|
+
raw = zlib.decompress(raw)
|
|
751
|
+
if payload.get("raw_hash") and hashlib.sha256(raw).hexdigest() != payload["raw_hash"]:
|
|
752
|
+
raise ValueError("Payload hash mismatch detected.")
|
|
753
|
+
return raw
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
class AsyncCosmosDBSaver(BaseCheckpointSaver[str], AbstractAsyncContextManager, AbstractContextManager):
|
|
757
|
+
def __init__(self, saver: CosmosDBSaver) -> None:
|
|
758
|
+
super().__init__(serde=saver.serde)
|
|
759
|
+
self._saver = saver
|
|
760
|
+
|
|
761
|
+
@classmethod
|
|
762
|
+
def from_connection_string(
|
|
763
|
+
cls,
|
|
764
|
+
connection_string: str,
|
|
765
|
+
*,
|
|
766
|
+
settings: CosmosDBSettings | None = None,
|
|
767
|
+
credential: str | TokenCredential | Mapping[str, Any] | None = None,
|
|
768
|
+
serde: SerializerProtocol | None = None,
|
|
769
|
+
) -> "AsyncCosmosDBSaver":
|
|
770
|
+
return cls(
|
|
771
|
+
CosmosDBSaver.from_connection_string(
|
|
772
|
+
connection_string,
|
|
773
|
+
settings=settings,
|
|
774
|
+
credential=credential,
|
|
775
|
+
serde=serde,
|
|
776
|
+
)
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
@classmethod
|
|
780
|
+
def from_endpoint(
|
|
781
|
+
cls,
|
|
782
|
+
endpoint: str,
|
|
783
|
+
*,
|
|
784
|
+
credential: str | TokenCredential | Mapping[str, Any] | None = None,
|
|
785
|
+
settings: CosmosDBSettings | None = None,
|
|
786
|
+
serde: SerializerProtocol | None = None,
|
|
787
|
+
) -> "AsyncCosmosDBSaver":
|
|
788
|
+
return cls(
|
|
789
|
+
CosmosDBSaver.from_endpoint(
|
|
790
|
+
endpoint,
|
|
791
|
+
credential=credential,
|
|
792
|
+
settings=settings,
|
|
793
|
+
serde=serde,
|
|
794
|
+
)
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
def __enter__(self) -> "AsyncCosmosDBSaver":
|
|
798
|
+
return self
|
|
799
|
+
|
|
800
|
+
def __exit__(self, exc_type, exc, tb) -> bool | None:
|
|
801
|
+
self._saver.close()
|
|
802
|
+
return None
|
|
803
|
+
|
|
804
|
+
async def __aenter__(self) -> "AsyncCosmosDBSaver":
|
|
805
|
+
return self
|
|
806
|
+
|
|
807
|
+
async def __aexit__(self, exc_type, exc, tb) -> bool | None:
|
|
808
|
+
await asyncio.to_thread(self._saver.close)
|
|
809
|
+
return None
|
|
810
|
+
|
|
811
|
+
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
812
|
+
return self._saver.get_tuple(config)
|
|
813
|
+
|
|
814
|
+
def list(
|
|
815
|
+
self,
|
|
816
|
+
config: RunnableConfig | None,
|
|
817
|
+
*,
|
|
818
|
+
filter: dict[str, Any] | None = None,
|
|
819
|
+
before: RunnableConfig | None = None,
|
|
820
|
+
limit: int | None = None,
|
|
821
|
+
) -> Iterator[CheckpointTuple]:
|
|
822
|
+
return self._saver.list(config, filter=filter, before=before, limit=limit)
|
|
823
|
+
|
|
824
|
+
def put(
|
|
825
|
+
self,
|
|
826
|
+
config: RunnableConfig,
|
|
827
|
+
checkpoint: Checkpoint,
|
|
828
|
+
metadata: CheckpointMetadata,
|
|
829
|
+
new_versions: ChannelVersions,
|
|
830
|
+
) -> RunnableConfig:
|
|
831
|
+
return self._saver.put(config, checkpoint, metadata, new_versions)
|
|
832
|
+
|
|
833
|
+
def put_writes(
|
|
834
|
+
self,
|
|
835
|
+
config: RunnableConfig,
|
|
836
|
+
writes: Sequence[tuple[str, Any]],
|
|
837
|
+
task_id: str,
|
|
838
|
+
task_path: str = "",
|
|
839
|
+
) -> None:
|
|
840
|
+
self._saver.put_writes(config, writes, task_id, task_path)
|
|
841
|
+
|
|
842
|
+
def delete_thread(self, thread_id: str) -> None:
|
|
843
|
+
self._saver.delete_thread(thread_id)
|
|
844
|
+
|
|
845
|
+
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
846
|
+
return await asyncio.to_thread(self._saver.get_tuple, config)
|
|
847
|
+
|
|
848
|
+
async def alist(
|
|
849
|
+
self,
|
|
850
|
+
config: RunnableConfig | None,
|
|
851
|
+
*,
|
|
852
|
+
filter: dict[str, Any] | None = None,
|
|
853
|
+
before: RunnableConfig | None = None,
|
|
854
|
+
limit: int | None = None,
|
|
855
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
856
|
+
rows = await asyncio.to_thread(
|
|
857
|
+
lambda: list(self._saver.list(config, filter=filter, before=before, limit=limit))
|
|
858
|
+
)
|
|
859
|
+
for row in rows:
|
|
860
|
+
yield row
|
|
861
|
+
|
|
862
|
+
async def aput(
|
|
863
|
+
self,
|
|
864
|
+
config: RunnableConfig,
|
|
865
|
+
checkpoint: Checkpoint,
|
|
866
|
+
metadata: CheckpointMetadata,
|
|
867
|
+
new_versions: ChannelVersions,
|
|
868
|
+
) -> RunnableConfig:
|
|
869
|
+
return await asyncio.to_thread(
|
|
870
|
+
self._saver.put, config, checkpoint, metadata, new_versions
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
async def aput_writes(
|
|
874
|
+
self,
|
|
875
|
+
config: RunnableConfig,
|
|
876
|
+
writes: Sequence[tuple[str, Any]],
|
|
877
|
+
task_id: str,
|
|
878
|
+
task_path: str = "",
|
|
879
|
+
) -> None:
|
|
880
|
+
await asyncio.to_thread(self._saver.put_writes, config, writes, task_id, task_path)
|
|
881
|
+
|
|
882
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
883
|
+
await asyncio.to_thread(self._saver.delete_thread, thread_id)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def _client_kwargs(settings: CosmosDBSettings) -> dict[str, Any]:
|
|
887
|
+
kwargs: dict[str, Any] = {
|
|
888
|
+
"no_response_on_write": settings.no_response_on_write,
|
|
889
|
+
}
|
|
890
|
+
if settings.retry_total is not None:
|
|
891
|
+
kwargs["retry_total"] = settings.retry_total
|
|
892
|
+
if settings.retry_backoff_max is not None:
|
|
893
|
+
kwargs["retry_backoff_max"] = settings.retry_backoff_max
|
|
894
|
+
if settings.retry_fixed_interval is not None:
|
|
895
|
+
kwargs["retry_fixed_interval"] = settings.retry_fixed_interval
|
|
896
|
+
return kwargs
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def _resolve_containers_sync(
|
|
900
|
+
*,
|
|
901
|
+
client: CosmosClient,
|
|
902
|
+
settings: CosmosDBSettings,
|
|
903
|
+
) -> tuple[ContainerProxy, ContainerProxy, ContainerProxy | None]:
|
|
904
|
+
database: DatabaseProxy
|
|
905
|
+
if settings.create_containers:
|
|
906
|
+
database = client.create_database_if_not_exists(id=settings.database_name)
|
|
907
|
+
else:
|
|
908
|
+
database = client.get_database_client(settings.database_name)
|
|
909
|
+
|
|
910
|
+
if settings.create_containers:
|
|
911
|
+
checkpoints = database.create_container_if_not_exists(
|
|
912
|
+
id=settings.checkpoint_container_name,
|
|
913
|
+
partition_key=PartitionKey(path="/thread_id"),
|
|
914
|
+
indexing_policy=_checkpoint_indexing_policy(),
|
|
915
|
+
)
|
|
916
|
+
writes = database.create_container_if_not_exists(
|
|
917
|
+
id=settings.writes_container_name,
|
|
918
|
+
partition_key=PartitionKey(path="/thread_id"),
|
|
919
|
+
indexing_policy=_writes_indexing_policy(),
|
|
920
|
+
)
|
|
921
|
+
blobs = database.create_container_if_not_exists(
|
|
922
|
+
id=settings.blob_container_name,
|
|
923
|
+
partition_key=PartitionKey(path="/thread_id"),
|
|
924
|
+
indexing_policy=_blob_indexing_policy(),
|
|
925
|
+
)
|
|
926
|
+
else:
|
|
927
|
+
checkpoints = database.get_container_client(settings.checkpoint_container_name)
|
|
928
|
+
writes = database.get_container_client(settings.writes_container_name)
|
|
929
|
+
blobs = (
|
|
930
|
+
database.get_container_client(settings.blob_container_name)
|
|
931
|
+
if settings.blob_container_name
|
|
932
|
+
else None
|
|
933
|
+
)
|
|
934
|
+
return checkpoints, writes, blobs
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
def _normalize_config(config: RunnableConfig) -> tuple[str, str, str | None]:
|
|
938
|
+
configurable = config.get("configurable") or {}
|
|
939
|
+
if "thread_ts" in configurable:
|
|
940
|
+
raise ValueError(
|
|
941
|
+
"Legacy config key thread_ts is not supported. Use thread_id/checkpoint_ns/checkpoint_id."
|
|
942
|
+
)
|
|
943
|
+
thread_id = configurable.get("thread_id")
|
|
944
|
+
if not thread_id:
|
|
945
|
+
raise ValueError("config.configurable.thread_id is required.")
|
|
946
|
+
checkpoint_ns = configurable.get("checkpoint_ns", "")
|
|
947
|
+
checkpoint_id = get_checkpoint_id(config)
|
|
948
|
+
return str(thread_id), str(checkpoint_ns), str(checkpoint_id) if checkpoint_id else None
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
def _checkpoint_doc_equivalent(existing: Mapping[str, Any], new_doc: Mapping[str, Any]) -> bool:
|
|
952
|
+
return (
|
|
953
|
+
existing.get("checkpoint_hash") == new_doc.get("checkpoint_hash")
|
|
954
|
+
and existing.get("metadata_hash") == new_doc.get("metadata_hash")
|
|
955
|
+
and existing.get("parent_checkpoint_id") == new_doc.get("parent_checkpoint_id")
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
def _write_doc_equivalent(existing: Mapping[str, Any], new_doc: Mapping[str, Any]) -> bool:
|
|
960
|
+
return (
|
|
961
|
+
existing.get("channel") == new_doc.get("channel")
|
|
962
|
+
and existing.get("value_hash") == new_doc.get("value_hash")
|
|
963
|
+
and existing.get("task_id") == new_doc.get("task_id")
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
|
|
967
|
+
def _channel_blob_equivalent(existing: Mapping[str, Any], new_doc: Mapping[str, Any]) -> bool:
|
|
968
|
+
return (
|
|
969
|
+
existing.get("value_type") == new_doc.get("value_type")
|
|
970
|
+
and existing.get("value_hash") == new_doc.get("value_hash")
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def _chunk_doc_equivalent(existing: Mapping[str, Any], new_doc: Mapping[str, Any]) -> bool:
|
|
975
|
+
return existing.get("payload_hash") == new_doc.get("payload_hash")
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def _checkpoint_doc_id(checkpoint_ns: str, checkpoint_id: str) -> str:
|
|
979
|
+
return f"ckpt|{_ns_token(checkpoint_ns)}|{checkpoint_id}"
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def _write_doc_id(
|
|
983
|
+
*,
|
|
984
|
+
checkpoint_ns: str,
|
|
985
|
+
checkpoint_id: str,
|
|
986
|
+
task_id: str,
|
|
987
|
+
idx: int,
|
|
988
|
+
) -> str:
|
|
989
|
+
return f"w|{_ns_token(checkpoint_ns)}|{checkpoint_id}|{_hash_token(task_id)}|{idx}"
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
def _channel_blob_doc_id(
|
|
993
|
+
checkpoint_ns: str,
|
|
994
|
+
channel: str,
|
|
995
|
+
version: str | int | float,
|
|
996
|
+
) -> str:
|
|
997
|
+
return (
|
|
998
|
+
f"b|{_ns_token(checkpoint_ns)}|{_hash_token(channel)}|"
|
|
999
|
+
f"{_hash_token(_stringify_version(version))}"
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
def _chunk_doc_id(scope: str, chunk_idx: int) -> str:
|
|
1004
|
+
return f"chunk|{_hash_token(scope)}|{chunk_idx:06d}"
|
|
1005
|
+
|
|
1006
|
+
|
|
1007
|
+
def _ns_token(checkpoint_ns: str) -> str:
|
|
1008
|
+
return "root" if not checkpoint_ns else _hash_token(checkpoint_ns)
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def _hash_token(value: str) -> str:
|
|
1012
|
+
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16]
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
def _stringify_version(version: str | int | float) -> str:
|
|
1016
|
+
if isinstance(version, str):
|
|
1017
|
+
return version
|
|
1018
|
+
return json.dumps(version, sort_keys=True, separators=(",", ":"))
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def _assert_item_size(item: Mapping[str, Any], *, scope: str) -> None:
|
|
1022
|
+
item_bytes = len(
|
|
1023
|
+
json.dumps(item, separators=(",", ":"), sort_keys=True, default=str).encode("utf-8")
|
|
1024
|
+
)
|
|
1025
|
+
if item_bytes > MAX_COSMOS_ITEM_BYTES:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"Cosmos item for {scope} exceeds 2MB limit ({item_bytes} bytes)."
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
def _b64encode(raw: bytes) -> str:
|
|
1032
|
+
return base64.b64encode(raw).decode("ascii")
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
def _b64decode(data: str) -> bytes:
|
|
1036
|
+
return base64.b64decode(data.encode("ascii"))
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def _checkpoint_indexing_policy() -> dict[str, Any]:
|
|
1040
|
+
return {
|
|
1041
|
+
"indexingMode": "consistent",
|
|
1042
|
+
"automatic": True,
|
|
1043
|
+
"includedPaths": [{"path": "/*"}],
|
|
1044
|
+
"excludedPaths": [
|
|
1045
|
+
{"path": "/checkpoint_payload/*"},
|
|
1046
|
+
{"path": "/metadata_payload/*"},
|
|
1047
|
+
],
|
|
1048
|
+
}
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
def _writes_indexing_policy() -> dict[str, Any]:
|
|
1052
|
+
return {
|
|
1053
|
+
"indexingMode": "consistent",
|
|
1054
|
+
"automatic": True,
|
|
1055
|
+
"includedPaths": [{"path": "/*"}],
|
|
1056
|
+
"excludedPaths": [{"path": "/value_payload/*"}],
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
|
|
1060
|
+
def _blob_indexing_policy() -> dict[str, Any]:
|
|
1061
|
+
return {
|
|
1062
|
+
"indexingMode": "consistent",
|
|
1063
|
+
"automatic": True,
|
|
1064
|
+
"includedPaths": [
|
|
1065
|
+
{"path": "/thread_id/?"},
|
|
1066
|
+
{"path": "/doc_type/?"},
|
|
1067
|
+
{"path": "/scope_hash/?"},
|
|
1068
|
+
{"path": "/chunk_index/?"},
|
|
1069
|
+
{"path": "/channel/?"},
|
|
1070
|
+
{"path": "/version/?"},
|
|
1071
|
+
{"path": "/id/?"},
|
|
1072
|
+
],
|
|
1073
|
+
"excludedPaths": [{"path": "/*"}],
|
|
1074
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: langgraph-cosmosdb-checkpointer
|
|
3
|
+
Version: 0.0.0
|
|
4
|
+
Summary: State-of-the-art Azure Cosmos DB checkpointer for LangGraph
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: azure-cosmos>=4.14.6
|
|
8
|
+
Requires-Dist: azure-identity>=1.25.2
|
|
9
|
+
Requires-Dist: langchain>=1.2.10
|
|
10
|
+
Requires-Dist: langgraph>=1.0.9
|
|
11
|
+
Provides-Extra: dev
|
|
12
|
+
Requires-Dist: pytest>=9.0.2; extra == "dev"
|
|
13
|
+
Requires-Dist: pytest-asyncio>=1.3.0; extra == "dev"
|
|
14
|
+
Requires-Dist: ruff>=0.9.0; extra == "dev"
|
|
15
|
+
|
|
16
|
+
# langgraph-cosmosdb-checkpointer
|
|
17
|
+
|
|
18
|
+
State-of-the-art Azure Cosmos DB checkpointer for LangGraph.
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
langgraph_cosmosdb_checkpointer/__init__.py,sha256=8SkbUhhy3suSteK4Ha_7J-O48Tc7RMJQ49N7DrPVveo,177
|
|
2
|
+
langgraph_cosmosdb_checkpointer/cosmosdb.py,sha256=kBZ3Qklg10MBlmY1WBtSCjB-bLPy3LBanHsoCBSbbX0,37835
|
|
3
|
+
langgraph_cosmosdb_checkpointer-0.0.0.dist-info/METADATA,sha256=cIXspDHPg3XV2qyPYLDGhWJGPs8VxpGJKA-MhdqUmIw,605
|
|
4
|
+
langgraph_cosmosdb_checkpointer-0.0.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
|
|
5
|
+
langgraph_cosmosdb_checkpointer-0.0.0.dist-info/top_level.txt,sha256=JF8lejbl_uxDEJyPoTnQQwi7JWWXdd2vdLvmPhE-GyQ,32
|
|
6
|
+
langgraph_cosmosdb_checkpointer-0.0.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
langgraph_cosmosdb_checkpointer
|