amsdal_langgraph 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amsdal_langgraph/Third-Party Materials - AMSDAL Dependencies - License Notices.md +30 -0
- amsdal_langgraph/__about__.py +1 -0
- amsdal_langgraph/__init__.py +0 -0
- amsdal_langgraph/app.py +8 -0
- amsdal_langgraph/checkpoint.py +616 -0
- amsdal_langgraph/migrations/0000_initial.py +66 -0
- amsdal_langgraph/models/__init__.py +0 -0
- amsdal_langgraph/models/checkpoint.py +24 -0
- amsdal_langgraph/models/checkpoint_writes.py +24 -0
- amsdal_langgraph/py.typed +0 -0
- amsdal_langgraph/utils.py +67 -0
- amsdal_langgraph-0.1.0.dist-info/METADATA +292 -0
- amsdal_langgraph-0.1.0.dist-info/RECORD +14 -0
- amsdal_langgraph-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Third-Party Materials - AMSDAL Dependencies - License Notices
|
|
2
|
+
|
|
3
|
+
## **LangGraph v0.6.8 or later**
|
|
4
|
+
|
|
5
|
+
### [https://github.com/langchain-ai/langgraph](https://github.com/langchain-ai/langgraph)
|
|
6
|
+
|
|
7
|
+
### **MIT License**
|
|
8
|
+
|
|
9
|
+
MIT License
|
|
10
|
+
|
|
11
|
+
Copyright (c) 2024 LangChain, Inc.
|
|
12
|
+
|
|
13
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
14
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
15
|
+
in the Software without restriction, including without limitation the rights
|
|
16
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
17
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
18
|
+
furnished to do so, subject to the following conditions:
|
|
19
|
+
|
|
20
|
+
The above copyright notice and this permission notice shall be included in all
|
|
21
|
+
copies or substantial portions of the Software.
|
|
22
|
+
|
|
23
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
24
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
25
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
26
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
27
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
28
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
29
|
+
SOFTWARE.
|
|
30
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.1.0'
|
|
File without changes
|
amsdal_langgraph/app.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing import cast
|
|
7
|
+
|
|
8
|
+
from langchain_core.runnables import RunnableConfig
|
|
9
|
+
from langgraph.checkpoint.base import WRITES_IDX_MAP
|
|
10
|
+
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
11
|
+
from langgraph.checkpoint.base import ChannelVersions
|
|
12
|
+
from langgraph.checkpoint.base import Checkpoint
|
|
13
|
+
from langgraph.checkpoint.base import CheckpointMetadata
|
|
14
|
+
from langgraph.checkpoint.base import CheckpointTuple
|
|
15
|
+
from langgraph.checkpoint.base import get_checkpoint_id
|
|
16
|
+
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
17
|
+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
18
|
+
|
|
19
|
+
from .models.checkpoint import Checkpoint as CheckpointModel
|
|
20
|
+
from .models.checkpoint_writes import CheckpointWrites as CheckpointWritesModel
|
|
21
|
+
from .utils import get_checkpoint_metadata
|
|
22
|
+
from .utils import search_where
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AmsdalCheckpointSaver(BaseCheckpointSaver[str]):
|
|
26
|
+
"""AMSDAL-based checkpoint saver for LangGraph workflows."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, *, serde: SerializerProtocol | None = None):
|
|
29
|
+
"""Initialize the AMSDAL checkpoint saver.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
serde: Optional serializer protocol for checkpoint serialization.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(serde=serde)
|
|
35
|
+
self.jsonplus_serde = JsonPlusSerializer()
|
|
36
|
+
|
|
37
|
+
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
38
|
+
"""Get a checkpoint tuple synchronously.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
config: The runnable configuration.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The checkpoint tuple or None if not found.
|
|
45
|
+
"""
|
|
46
|
+
checkpoint_ns = config['configurable'].get('checkpoint_ns', '')
|
|
47
|
+
|
|
48
|
+
# Find the specific checkpoint or latest one
|
|
49
|
+
if checkpoint_id := get_checkpoint_id(config):
|
|
50
|
+
checkpoint_obj = (
|
|
51
|
+
CheckpointModel.objects.filter(
|
|
52
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
53
|
+
checkpoint_ns=checkpoint_ns,
|
|
54
|
+
checkpoint_id=checkpoint_id,
|
|
55
|
+
)
|
|
56
|
+
.first()
|
|
57
|
+
.execute()
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
checkpoint_obj = (
|
|
61
|
+
CheckpointModel.objects.filter(
|
|
62
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
63
|
+
checkpoint_ns=checkpoint_ns,
|
|
64
|
+
)
|
|
65
|
+
.order_by('-checkpoint_id')
|
|
66
|
+
.first()
|
|
67
|
+
.execute()
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if not checkpoint_obj:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
(
|
|
74
|
+
thread_id,
|
|
75
|
+
checkpoint_id,
|
|
76
|
+
parent_checkpoint_id,
|
|
77
|
+
_type,
|
|
78
|
+
checkpoint,
|
|
79
|
+
metadata,
|
|
80
|
+
) = (
|
|
81
|
+
checkpoint_obj.thread_id,
|
|
82
|
+
checkpoint_obj.checkpoint_id,
|
|
83
|
+
checkpoint_obj.parent_checkpoint_id,
|
|
84
|
+
checkpoint_obj.type,
|
|
85
|
+
checkpoint_obj.checkpoint,
|
|
86
|
+
checkpoint_obj.meta,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if not get_checkpoint_id(config):
|
|
90
|
+
config = {
|
|
91
|
+
'configurable': {
|
|
92
|
+
'thread_id': thread_id,
|
|
93
|
+
'checkpoint_ns': checkpoint_ns,
|
|
94
|
+
'checkpoint_id': checkpoint_id,
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Get pending writes for this checkpoint
|
|
99
|
+
writes = (
|
|
100
|
+
CheckpointWritesModel.objects.filter(
|
|
101
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
102
|
+
checkpoint_ns=checkpoint_ns,
|
|
103
|
+
checkpoint_id=str(config['configurable']['checkpoint_id']),
|
|
104
|
+
)
|
|
105
|
+
.order_by('task_id', 'idx')
|
|
106
|
+
.execute()
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# deserialize the checkpoint and metadata
|
|
110
|
+
return CheckpointTuple(
|
|
111
|
+
config,
|
|
112
|
+
self.serde.loads_typed((_type, checkpoint)),
|
|
113
|
+
cast(
|
|
114
|
+
CheckpointMetadata,
|
|
115
|
+
metadata if metadata is not None else {},
|
|
116
|
+
),
|
|
117
|
+
(
|
|
118
|
+
{
|
|
119
|
+
'configurable': {
|
|
120
|
+
'thread_id': thread_id,
|
|
121
|
+
'checkpoint_ns': checkpoint_ns,
|
|
122
|
+
'checkpoint_id': parent_checkpoint_id,
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
if parent_checkpoint_id
|
|
126
|
+
else None
|
|
127
|
+
),
|
|
128
|
+
[(write.task_id, write.channel, self.serde.loads_typed((write.type, write.value))) for write in writes],
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def list(
|
|
132
|
+
self,
|
|
133
|
+
config: RunnableConfig | None,
|
|
134
|
+
*,
|
|
135
|
+
filter: dict[str, Any] | None = None, # noqa: A002, ARG002
|
|
136
|
+
before: RunnableConfig | None = None,
|
|
137
|
+
limit: int | None = None,
|
|
138
|
+
) -> Iterator[CheckpointTuple]:
|
|
139
|
+
"""List checkpoints synchronously.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
config: The runnable configuration.
|
|
143
|
+
filter: Optional filter criteria.
|
|
144
|
+
before: Optional before configuration.
|
|
145
|
+
limit: Optional limit on results.
|
|
146
|
+
|
|
147
|
+
Yields:
|
|
148
|
+
Checkpoint tuples.
|
|
149
|
+
"""
|
|
150
|
+
query = CheckpointModel.objects.all().order_by('-checkpoint_id')
|
|
151
|
+
where = search_where(config, filter, before)
|
|
152
|
+
|
|
153
|
+
if where:
|
|
154
|
+
query = query.filter(where)
|
|
155
|
+
|
|
156
|
+
if limit:
|
|
157
|
+
checkpoints = query[0:limit].execute()
|
|
158
|
+
else:
|
|
159
|
+
checkpoints = query.execute()
|
|
160
|
+
|
|
161
|
+
for checkpoint_obj in checkpoints:
|
|
162
|
+
writes = (
|
|
163
|
+
CheckpointWritesModel.objects.filter(
|
|
164
|
+
thread_id=checkpoint_obj.thread_id,
|
|
165
|
+
checkpoint_ns=checkpoint_obj.checkpoint_ns,
|
|
166
|
+
checkpoint_id=checkpoint_obj.checkpoint_id,
|
|
167
|
+
)
|
|
168
|
+
.order_by('task_id', 'idx')
|
|
169
|
+
.execute()
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
yield CheckpointTuple(
|
|
173
|
+
{
|
|
174
|
+
'configurable': {
|
|
175
|
+
'thread_id': checkpoint_obj.thread_id,
|
|
176
|
+
'checkpoint_ns': checkpoint_obj.checkpoint_ns,
|
|
177
|
+
'checkpoint_id': checkpoint_obj.checkpoint_id,
|
|
178
|
+
}
|
|
179
|
+
},
|
|
180
|
+
self.serde.loads_typed((checkpoint_obj.type, checkpoint_obj.checkpoint)),
|
|
181
|
+
cast(
|
|
182
|
+
CheckpointMetadata,
|
|
183
|
+
checkpoint_obj.meta if checkpoint_obj.meta is not None else {},
|
|
184
|
+
),
|
|
185
|
+
(
|
|
186
|
+
{
|
|
187
|
+
'configurable': {
|
|
188
|
+
'thread_id': checkpoint_obj.thread_id,
|
|
189
|
+
'checkpoint_ns': checkpoint_obj.checkpoint_ns,
|
|
190
|
+
'checkpoint_id': checkpoint_obj.parent_checkpoint_id,
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
if checkpoint_obj.parent_checkpoint_id
|
|
194
|
+
else None
|
|
195
|
+
),
|
|
196
|
+
[(write.task_id, write.channel, self.serde.loads_typed((write.type, write.value))) for write in writes],
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def put(
|
|
200
|
+
self,
|
|
201
|
+
config: RunnableConfig,
|
|
202
|
+
checkpoint: Checkpoint,
|
|
203
|
+
metadata: CheckpointMetadata,
|
|
204
|
+
new_versions: ChannelVersions, # noqa: ARG002
|
|
205
|
+
) -> RunnableConfig:
|
|
206
|
+
"""Put a checkpoint synchronously.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
config: The runnable configuration.
|
|
210
|
+
checkpoint: The checkpoint data.
|
|
211
|
+
metadata: The checkpoint metadata.
|
|
212
|
+
new_versions: The new channel versions.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
The updated runnable configuration.
|
|
216
|
+
"""
|
|
217
|
+
thread_id = config['configurable']['thread_id']
|
|
218
|
+
checkpoint_ns = config['configurable']['checkpoint_ns']
|
|
219
|
+
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
|
|
220
|
+
|
|
221
|
+
# Create or update checkpoint
|
|
222
|
+
checkpoint_obj = CheckpointModel(
|
|
223
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
224
|
+
checkpoint_ns=checkpoint_ns,
|
|
225
|
+
checkpoint_id=checkpoint['id'],
|
|
226
|
+
parent_checkpoint_id=config['configurable'].get('checkpoint_id'),
|
|
227
|
+
type=type_,
|
|
228
|
+
checkpoint=serialized_checkpoint,
|
|
229
|
+
meta=get_checkpoint_metadata(config, metadata), # type: ignore[arg-type]
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if (
|
|
233
|
+
CheckpointModel.objects.filter(
|
|
234
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
235
|
+
checkpoint_ns=checkpoint_ns,
|
|
236
|
+
checkpoint_id=checkpoint['id'],
|
|
237
|
+
)
|
|
238
|
+
.count()
|
|
239
|
+
.execute()
|
|
240
|
+
):
|
|
241
|
+
checkpoint_obj.save()
|
|
242
|
+
else:
|
|
243
|
+
checkpoint_obj.save(force_insert=True)
|
|
244
|
+
|
|
245
|
+
return {
|
|
246
|
+
'configurable': {
|
|
247
|
+
'thread_id': thread_id,
|
|
248
|
+
'checkpoint_ns': checkpoint_ns,
|
|
249
|
+
'checkpoint_id': checkpoint['id'],
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
def put_writes(
|
|
254
|
+
self,
|
|
255
|
+
config: RunnableConfig,
|
|
256
|
+
writes: Sequence[tuple[str, Any]],
|
|
257
|
+
task_id: str,
|
|
258
|
+
task_path: str = '', # noqa: ARG002
|
|
259
|
+
) -> None:
|
|
260
|
+
"""Put writes synchronously.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
config: The runnable configuration.
|
|
264
|
+
writes: The writes to store.
|
|
265
|
+
task_id: The task ID.
|
|
266
|
+
task_path: The task path.
|
|
267
|
+
"""
|
|
268
|
+
_replace = all(w[0] in WRITES_IDX_MAP for w in writes)
|
|
269
|
+
|
|
270
|
+
for idx, (channel, value) in enumerate(writes):
|
|
271
|
+
type_name, serialized_value = self.serde.dumps_typed(value)
|
|
272
|
+
write_obj = CheckpointWritesModel(
|
|
273
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
274
|
+
checkpoint_ns=str(config['configurable']['checkpoint_ns']),
|
|
275
|
+
checkpoint_id=str(config['configurable']['checkpoint_id']),
|
|
276
|
+
task_id=task_id,
|
|
277
|
+
idx=WRITES_IDX_MAP.get(channel, idx),
|
|
278
|
+
channel=channel,
|
|
279
|
+
type=type_name,
|
|
280
|
+
value=serialized_value,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if (
|
|
284
|
+
CheckpointWritesModel.objects.filter(
|
|
285
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
286
|
+
checkpoint_ns=str(config['configurable']['checkpoint_ns']),
|
|
287
|
+
checkpoint_id=str(config['configurable']['checkpoint_id']),
|
|
288
|
+
task_id=task_id,
|
|
289
|
+
idx=WRITES_IDX_MAP.get(channel, idx),
|
|
290
|
+
)
|
|
291
|
+
.count()
|
|
292
|
+
.execute()
|
|
293
|
+
):
|
|
294
|
+
if _replace:
|
|
295
|
+
write_obj.save()
|
|
296
|
+
else:
|
|
297
|
+
write_obj.save(force_insert=True)
|
|
298
|
+
|
|
299
|
+
def delete_thread(self, thread_id: str) -> None:
|
|
300
|
+
"""Delete a thread synchronously.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
thread_id: The thread ID to delete.
|
|
304
|
+
"""
|
|
305
|
+
# Delete all checkpoints for this thread
|
|
306
|
+
checkpoints = CheckpointModel.objects.filter(thread_id=str(thread_id)).execute()
|
|
307
|
+
|
|
308
|
+
for checkpoint in checkpoints:
|
|
309
|
+
checkpoint.delete()
|
|
310
|
+
|
|
311
|
+
# Delete all writes for this thread
|
|
312
|
+
writes = CheckpointWritesModel.objects.filter(thread_id=str(thread_id)).execute()
|
|
313
|
+
|
|
314
|
+
for write in writes:
|
|
315
|
+
write.delete()
|
|
316
|
+
|
|
317
|
+
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
318
|
+
"""Get a checkpoint tuple asynchronously.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
config: The runnable configuration.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
The checkpoint tuple or None if not found.
|
|
325
|
+
"""
|
|
326
|
+
checkpoint_ns = config['configurable'].get('checkpoint_ns', '')
|
|
327
|
+
|
|
328
|
+
# Find the specific checkpoint or latest one
|
|
329
|
+
if checkpoint_id := get_checkpoint_id(config):
|
|
330
|
+
checkpoint_obj = (
|
|
331
|
+
await CheckpointModel.objects.filter(
|
|
332
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
333
|
+
checkpoint_ns=checkpoint_ns,
|
|
334
|
+
checkpoint_id=checkpoint_id,
|
|
335
|
+
)
|
|
336
|
+
.first()
|
|
337
|
+
.aexecute()
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
checkpoint_obj = await (
|
|
341
|
+
CheckpointModel.objects.filter(
|
|
342
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
343
|
+
checkpoint_ns=checkpoint_ns,
|
|
344
|
+
)
|
|
345
|
+
.order_by('-checkpoint_id')
|
|
346
|
+
.first()
|
|
347
|
+
.aexecute()
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
if not checkpoint_obj:
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
(
|
|
354
|
+
thread_id,
|
|
355
|
+
checkpoint_id,
|
|
356
|
+
parent_checkpoint_id,
|
|
357
|
+
_type,
|
|
358
|
+
checkpoint,
|
|
359
|
+
metadata,
|
|
360
|
+
) = (
|
|
361
|
+
checkpoint_obj.thread_id,
|
|
362
|
+
checkpoint_obj.checkpoint_id,
|
|
363
|
+
checkpoint_obj.parent_checkpoint_id,
|
|
364
|
+
checkpoint_obj.type,
|
|
365
|
+
checkpoint_obj.checkpoint,
|
|
366
|
+
checkpoint_obj.meta,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
if not get_checkpoint_id(config):
|
|
370
|
+
config = {
|
|
371
|
+
'configurable': {
|
|
372
|
+
'thread_id': thread_id,
|
|
373
|
+
'checkpoint_ns': checkpoint_ns,
|
|
374
|
+
'checkpoint_id': checkpoint_id,
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
# Get pending writes for this checkpoint
|
|
379
|
+
writes = await (
|
|
380
|
+
CheckpointWritesModel.objects.filter(
|
|
381
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
382
|
+
checkpoint_ns=checkpoint_ns,
|
|
383
|
+
checkpoint_id=str(config['configurable']['checkpoint_id']),
|
|
384
|
+
)
|
|
385
|
+
.order_by('task_id', 'idx')
|
|
386
|
+
.aexecute()
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# deserialize the checkpoint and metadata
|
|
390
|
+
return CheckpointTuple(
|
|
391
|
+
config,
|
|
392
|
+
self.serde.loads_typed((_type, checkpoint)),
|
|
393
|
+
cast(
|
|
394
|
+
CheckpointMetadata,
|
|
395
|
+
metadata if metadata is not None else {},
|
|
396
|
+
),
|
|
397
|
+
(
|
|
398
|
+
{
|
|
399
|
+
'configurable': {
|
|
400
|
+
'thread_id': thread_id,
|
|
401
|
+
'checkpoint_ns': checkpoint_ns,
|
|
402
|
+
'checkpoint_id': parent_checkpoint_id,
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
if parent_checkpoint_id
|
|
406
|
+
else None
|
|
407
|
+
),
|
|
408
|
+
[(write.task_id, write.channel, self.serde.loads_typed((write.type, write.value))) for write in writes],
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
async def alist(
|
|
412
|
+
self,
|
|
413
|
+
config: RunnableConfig | None,
|
|
414
|
+
*,
|
|
415
|
+
filter: dict[str, Any] | None = None, # noqa: A002, ARG002
|
|
416
|
+
before: RunnableConfig | None = None,
|
|
417
|
+
limit: int | None = None,
|
|
418
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
419
|
+
"""List checkpoints asynchronously.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
config: The runnable configuration.
|
|
423
|
+
filter: Optional filter criteria.
|
|
424
|
+
before: Optional before configuration.
|
|
425
|
+
limit: Optional limit on results.
|
|
426
|
+
|
|
427
|
+
Yields:
|
|
428
|
+
Checkpoint tuples.
|
|
429
|
+
"""
|
|
430
|
+
query = CheckpointModel.objects.all().order_by('-checkpoint_id')
|
|
431
|
+
where = search_where(config, filter, before)
|
|
432
|
+
|
|
433
|
+
if where:
|
|
434
|
+
query = query.filter(where)
|
|
435
|
+
|
|
436
|
+
if limit:
|
|
437
|
+
checkpoints = await query[:limit].aexecute()
|
|
438
|
+
else:
|
|
439
|
+
checkpoints = await query.aexecute()
|
|
440
|
+
|
|
441
|
+
for checkpoint_obj in checkpoints:
|
|
442
|
+
writes = await (
|
|
443
|
+
CheckpointWritesModel.objects.filter(
|
|
444
|
+
thread_id=checkpoint_obj.thread_id,
|
|
445
|
+
checkpoint_ns=checkpoint_obj.checkpoint_ns,
|
|
446
|
+
checkpoint_id=checkpoint_obj.checkpoint_id,
|
|
447
|
+
)
|
|
448
|
+
.order_by('task_id', 'idx')
|
|
449
|
+
.aexecute()
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
yield CheckpointTuple(
|
|
453
|
+
{
|
|
454
|
+
'configurable': {
|
|
455
|
+
'thread_id': checkpoint_obj.thread_id,
|
|
456
|
+
'checkpoint_ns': checkpoint_obj.checkpoint_ns,
|
|
457
|
+
'checkpoint_id': checkpoint_obj.checkpoint_id,
|
|
458
|
+
}
|
|
459
|
+
},
|
|
460
|
+
self.serde.loads_typed((checkpoint_obj.type, checkpoint_obj.checkpoint)),
|
|
461
|
+
cast(
|
|
462
|
+
CheckpointMetadata,
|
|
463
|
+
checkpoint_obj.meta if checkpoint_obj.meta is not None else {},
|
|
464
|
+
),
|
|
465
|
+
(
|
|
466
|
+
{
|
|
467
|
+
'configurable': {
|
|
468
|
+
'thread_id': checkpoint_obj.thread_id,
|
|
469
|
+
'checkpoint_ns': checkpoint_obj.checkpoint_ns,
|
|
470
|
+
'checkpoint_id': checkpoint_obj.parent_checkpoint_id,
|
|
471
|
+
}
|
|
472
|
+
}
|
|
473
|
+
if checkpoint_obj.parent_checkpoint_id
|
|
474
|
+
else None
|
|
475
|
+
),
|
|
476
|
+
[(write.task_id, write.channel, self.serde.loads_typed((write.type, write.value))) for write in writes],
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
async def aput(
|
|
480
|
+
self,
|
|
481
|
+
config: RunnableConfig,
|
|
482
|
+
checkpoint: Checkpoint,
|
|
483
|
+
metadata: CheckpointMetadata,
|
|
484
|
+
new_versions: ChannelVersions, # noqa: ARG002
|
|
485
|
+
) -> RunnableConfig:
|
|
486
|
+
"""Put a checkpoint asynchronously.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
config: The runnable configuration.
|
|
490
|
+
checkpoint: The checkpoint data.
|
|
491
|
+
metadata: The checkpoint metadata.
|
|
492
|
+
new_versions: The new channel versions.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
The updated runnable configuration.
|
|
496
|
+
"""
|
|
497
|
+
thread_id = config['configurable']['thread_id']
|
|
498
|
+
checkpoint_ns = config['configurable']['checkpoint_ns']
|
|
499
|
+
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
|
|
500
|
+
|
|
501
|
+
# Create or update checkpoint
|
|
502
|
+
checkpoint_obj = CheckpointModel(
|
|
503
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
504
|
+
checkpoint_ns=checkpoint_ns,
|
|
505
|
+
checkpoint_id=checkpoint['id'],
|
|
506
|
+
parent_checkpoint_id=config['configurable'].get('checkpoint_id'),
|
|
507
|
+
type=type_,
|
|
508
|
+
checkpoint=serialized_checkpoint,
|
|
509
|
+
meta=get_checkpoint_metadata(config, metadata), # type: ignore[arg-type]
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
if (
|
|
513
|
+
await CheckpointModel.objects.filter(
|
|
514
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
515
|
+
checkpoint_ns=checkpoint_ns,
|
|
516
|
+
checkpoint_id=checkpoint['id'],
|
|
517
|
+
)
|
|
518
|
+
.count()
|
|
519
|
+
.aexecute()
|
|
520
|
+
):
|
|
521
|
+
await checkpoint_obj.asave() # type: ignore[misc]
|
|
522
|
+
else:
|
|
523
|
+
await checkpoint_obj.asave(force_insert=True) # type: ignore[misc]
|
|
524
|
+
|
|
525
|
+
return {
|
|
526
|
+
'configurable': {
|
|
527
|
+
'thread_id': thread_id,
|
|
528
|
+
'checkpoint_ns': checkpoint_ns,
|
|
529
|
+
'checkpoint_id': checkpoint['id'],
|
|
530
|
+
}
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
async def aput_writes(
|
|
534
|
+
self,
|
|
535
|
+
config: RunnableConfig,
|
|
536
|
+
writes: Sequence[tuple[str, Any]],
|
|
537
|
+
task_id: str,
|
|
538
|
+
task_path: str = '', # noqa: ARG002
|
|
539
|
+
) -> None:
|
|
540
|
+
"""Put writes asynchronously.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
config: The runnable configuration.
|
|
544
|
+
writes: The writes to store.
|
|
545
|
+
task_id: The task ID.
|
|
546
|
+
task_path: The task path.
|
|
547
|
+
"""
|
|
548
|
+
_replace = all(w[0] in WRITES_IDX_MAP for w in writes)
|
|
549
|
+
|
|
550
|
+
for idx, (channel, value) in enumerate(writes):
|
|
551
|
+
type_name, serialized_value = self.serde.dumps_typed(value)
|
|
552
|
+
write_obj = CheckpointWritesModel(
|
|
553
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
554
|
+
checkpoint_ns=str(config['configurable']['checkpoint_ns']),
|
|
555
|
+
checkpoint_id=str(config['configurable']['checkpoint_id']),
|
|
556
|
+
task_id=task_id,
|
|
557
|
+
idx=WRITES_IDX_MAP.get(channel, idx),
|
|
558
|
+
channel=channel,
|
|
559
|
+
type=type_name,
|
|
560
|
+
value=serialized_value,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
if await (
|
|
564
|
+
CheckpointWritesModel.objects.filter(
|
|
565
|
+
thread_id=str(config['configurable']['thread_id']),
|
|
566
|
+
checkpoint_ns=str(config['configurable']['checkpoint_ns']),
|
|
567
|
+
checkpoint_id=str(config['configurable']['checkpoint_id']),
|
|
568
|
+
task_id=task_id,
|
|
569
|
+
idx=WRITES_IDX_MAP.get(channel, idx),
|
|
570
|
+
)
|
|
571
|
+
.count()
|
|
572
|
+
.aexecute()
|
|
573
|
+
):
|
|
574
|
+
if _replace:
|
|
575
|
+
await write_obj.asave() # type: ignore[misc]
|
|
576
|
+
else:
|
|
577
|
+
await write_obj.asave(force_insert=True) # type: ignore[misc]
|
|
578
|
+
|
|
579
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
580
|
+
"""Delete a thread asynchronously.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
thread_id: The thread ID to delete.
|
|
584
|
+
"""
|
|
585
|
+
# Delete all checkpoints for this thread
|
|
586
|
+
checkpoints = await CheckpointModel.objects.filter(thread_id=str(thread_id)).aexecute()
|
|
587
|
+
|
|
588
|
+
for checkpoint in checkpoints:
|
|
589
|
+
await checkpoint.adelete()
|
|
590
|
+
|
|
591
|
+
# Delete all writes for this thread
|
|
592
|
+
writes = await CheckpointWritesModel.objects.filter(thread_id=str(thread_id)).aexecute()
|
|
593
|
+
|
|
594
|
+
for write in writes:
|
|
595
|
+
await write.adelete()
|
|
596
|
+
|
|
597
|
+
def get_next_version(self, current: str | None, channel: None) -> str: # noqa: ARG002
|
|
598
|
+
"""Generate the next version ID for a channel.
|
|
599
|
+
|
|
600
|
+
This method creates a new version identifier for a channel based on its current version.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
current (Optional[str]): The current version identifier of the channel.
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
str: The next version identifier, which is guaranteed to be monotonically increasing.
|
|
607
|
+
"""
|
|
608
|
+
if current is None:
|
|
609
|
+
current_v = 0
|
|
610
|
+
elif isinstance(current, int):
|
|
611
|
+
current_v = current
|
|
612
|
+
else:
|
|
613
|
+
current_v = int(current.split('.')[0])
|
|
614
|
+
next_v = current_v + 1
|
|
615
|
+
next_h = random.random() # noqa: S311
|
|
616
|
+
return f'{next_v:032}.{next_h:016}'
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from amsdal_models.migration import migrations
|
|
2
|
+
from amsdal_utils.models.enums import ModuleType
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Migration(migrations.Migration):
|
|
6
|
+
operations: list[migrations.Operation] = [
|
|
7
|
+
migrations.CreateClass(
|
|
8
|
+
module_type=ModuleType.CONTRIB,
|
|
9
|
+
class_name="Checkpoint",
|
|
10
|
+
new_schema={
|
|
11
|
+
"title": "Checkpoint",
|
|
12
|
+
"required": ["thread_id", "checkpoint_id", "checkpoint"],
|
|
13
|
+
"properties": {
|
|
14
|
+
"created_at": {"type": "datetime", "title": "Created At", "format": "date-time"},
|
|
15
|
+
"updated_at": {"type": "datetime", "title": "Updated At", "format": "date-time"},
|
|
16
|
+
"thread_id": {"type": "string", "title": "Thread ID"},
|
|
17
|
+
"checkpoint_ns": {"type": "string", "default": "", "title": "Checkpoint Namespace"},
|
|
18
|
+
"checkpoint_id": {"type": "string", "title": "Checkpoint ID"},
|
|
19
|
+
"parent_checkpoint_id": {"type": "string", "title": "Parent Checkpoint ID"},
|
|
20
|
+
"type": {"type": "string", "title": "Type"},
|
|
21
|
+
"checkpoint": {"type": "binary", "title": "Checkpoint Data"},
|
|
22
|
+
"meta": {
|
|
23
|
+
"type": "dictionary",
|
|
24
|
+
"items": {"key": {"type": "string"}, "value": {"type": "anything"}},
|
|
25
|
+
"title": "Metadata",
|
|
26
|
+
},
|
|
27
|
+
},
|
|
28
|
+
"custom_code": "import datetime\n\n\nasync def apre_create(self) -> None:\n self.created_at = datetime.datetime.now(tz=datetime.UTC)\n await super().apre_create()\n\nasync def apre_update(self) -> None:\n self.updated_at = datetime.datetime.now(tz=datetime.UTC)\n if not self.created_at:\n _metadata = await self.aget_metadata()\n self.created_at = datetime.datetime.fromtimestamp(_metadata.created_at / 1000, tz=datetime.UTC)\n await super().apre_update()\n\ndef pre_create(self) -> None:\n self.created_at = datetime.datetime.now(tz=datetime.UTC)\n super().pre_create()\n\ndef pre_update(self) -> None:\n self.updated_at = datetime.datetime.now(tz=datetime.UTC)\n if not self.created_at:\n _metadata = self.get_metadata()\n self.created_at = datetime.datetime.fromtimestamp(_metadata.created_at / 1000, tz=datetime.UTC)\n super().pre_update()",
|
|
29
|
+
"storage_metadata": {
|
|
30
|
+
"table_name": "checkpoints",
|
|
31
|
+
"db_fields": {},
|
|
32
|
+
"primary_key": ["thread_id", "checkpoint_ns", "checkpoint_id"],
|
|
33
|
+
"foreign_keys": {},
|
|
34
|
+
},
|
|
35
|
+
"description": "AMSDAL model for storing LangGraph checkpoints.",
|
|
36
|
+
},
|
|
37
|
+
),
|
|
38
|
+
migrations.CreateClass(
|
|
39
|
+
module_type=ModuleType.CONTRIB,
|
|
40
|
+
class_name="CheckpointWrites",
|
|
41
|
+
new_schema={
|
|
42
|
+
"title": "CheckpointWrites",
|
|
43
|
+
"required": ["thread_id", "checkpoint_id", "task_id", "idx", "channel"],
|
|
44
|
+
"properties": {
|
|
45
|
+
"created_at": {"type": "datetime", "title": "Created At", "format": "date-time"},
|
|
46
|
+
"updated_at": {"type": "datetime", "title": "Updated At", "format": "date-time"},
|
|
47
|
+
"thread_id": {"type": "string", "title": "Thread ID"},
|
|
48
|
+
"checkpoint_ns": {"type": "string", "default": "", "title": "Checkpoint Namespace"},
|
|
49
|
+
"checkpoint_id": {"type": "string", "title": "Checkpoint ID"},
|
|
50
|
+
"task_id": {"type": "string", "title": "Task ID"},
|
|
51
|
+
"idx": {"type": "integer", "title": "Index"},
|
|
52
|
+
"channel": {"type": "string", "title": "Channel"},
|
|
53
|
+
"type": {"type": "string", "title": "Type"},
|
|
54
|
+
"value": {"type": "binary", "title": "Value"},
|
|
55
|
+
},
|
|
56
|
+
"custom_code": "import datetime\n\n\nasync def apre_create(self) -> None:\n self.created_at = datetime.datetime.now(tz=datetime.UTC)\n await super().apre_create()\n\nasync def apre_update(self) -> None:\n self.updated_at = datetime.datetime.now(tz=datetime.UTC)\n if not self.created_at:\n _metadata = await self.aget_metadata()\n self.created_at = datetime.datetime.fromtimestamp(_metadata.created_at / 1000, tz=datetime.UTC)\n await super().apre_update()\n\ndef pre_create(self) -> None:\n self.created_at = datetime.datetime.now(tz=datetime.UTC)\n super().pre_create()\n\ndef pre_update(self) -> None:\n self.updated_at = datetime.datetime.now(tz=datetime.UTC)\n if not self.created_at:\n _metadata = self.get_metadata()\n self.created_at = datetime.datetime.fromtimestamp(_metadata.created_at / 1000, tz=datetime.UTC)\n super().pre_update()",
|
|
57
|
+
"storage_metadata": {
|
|
58
|
+
"table_name": "checkpoint_writes",
|
|
59
|
+
"db_fields": {},
|
|
60
|
+
"primary_key": ["thread_id", "checkpoint_ns", "checkpoint_id", "task_id", "idx"],
|
|
61
|
+
"foreign_keys": {},
|
|
62
|
+
},
|
|
63
|
+
"description": "AMSDAL model for storing checkpoint write operations.",
|
|
64
|
+
},
|
|
65
|
+
),
|
|
66
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from typing import ClassVar
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from amsdal.models.mixins import TimestampMixin
|
|
6
|
+
from amsdal_models.classes.model import Model
|
|
7
|
+
from amsdal_utils.models.enums import ModuleType
|
|
8
|
+
from pydantic.fields import Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Checkpoint(TimestampMixin, Model):
|
|
12
|
+
"""AMSDAL model for storing LangGraph checkpoints."""
|
|
13
|
+
|
|
14
|
+
__module_type__: ClassVar[ModuleType] = ModuleType.CONTRIB
|
|
15
|
+
__table_name__ = 'checkpoints'
|
|
16
|
+
__primary_key__: ClassVar[list[str]] = ['thread_id', 'checkpoint_ns', 'checkpoint_id']
|
|
17
|
+
|
|
18
|
+
thread_id: str = Field(..., title='Thread ID')
|
|
19
|
+
checkpoint_ns: str = Field(default='', title='Checkpoint Namespace')
|
|
20
|
+
checkpoint_id: str = Field(..., title='Checkpoint ID')
|
|
21
|
+
parent_checkpoint_id: Optional[str] = Field(default=None, title='Parent Checkpoint ID')
|
|
22
|
+
type: Optional[str] = Field(default=None, title='Type')
|
|
23
|
+
checkpoint: bytes = Field(..., title='Checkpoint Data')
|
|
24
|
+
meta: dict[str, Any | None] | None = Field(..., title='Metadata')
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import ClassVar
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from amsdal.models.mixins import TimestampMixin
|
|
5
|
+
from amsdal_models.classes.model import Model
|
|
6
|
+
from amsdal_utils.models.enums import ModuleType
|
|
7
|
+
from pydantic.fields import Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CheckpointWrites(TimestampMixin, Model):
|
|
11
|
+
"""AMSDAL model for storing checkpoint write operations."""
|
|
12
|
+
|
|
13
|
+
__module_type__: ClassVar[ModuleType] = ModuleType.CONTRIB
|
|
14
|
+
__table_name__ = 'checkpoint_writes'
|
|
15
|
+
__primary_key__: ClassVar[list[str]] = ['thread_id', 'checkpoint_ns', 'checkpoint_id', 'task_id', 'idx']
|
|
16
|
+
|
|
17
|
+
thread_id: str = Field(..., title='Thread ID')
|
|
18
|
+
checkpoint_ns: str = Field(default='', title='Checkpoint Namespace')
|
|
19
|
+
checkpoint_id: str = Field(..., title='Checkpoint ID')
|
|
20
|
+
task_id: str = Field(..., title='Task ID')
|
|
21
|
+
idx: int = Field(..., title='Index')
|
|
22
|
+
channel: str = Field(..., title='Channel')
|
|
23
|
+
type: Optional[str] = Field(default=None, title='Type')
|
|
24
|
+
value: Optional[bytes] = Field(default=None, title='Value')
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from amsdal_utils.query.utils import Q
|
|
5
|
+
from langchain_core.runnables import RunnableConfig
|
|
6
|
+
from langgraph.checkpoint.base import EXCLUDED_METADATA_KEYS
|
|
7
|
+
from langgraph.checkpoint.base import CheckpointMetadata
|
|
8
|
+
from langgraph.checkpoint.base import get_checkpoint_id
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def search_where(
|
|
12
|
+
config: RunnableConfig | None,
|
|
13
|
+
filter: dict[str, Any] | None, # noqa: A002
|
|
14
|
+
before: RunnableConfig | None = None,
|
|
15
|
+
) -> Optional[Q]:
|
|
16
|
+
conditions = None
|
|
17
|
+
|
|
18
|
+
# construct predicate for config filter
|
|
19
|
+
if config is not None:
|
|
20
|
+
conditions = Q(thread_id=config['configurable']['thread_id'])
|
|
21
|
+
|
|
22
|
+
checkpoint_ns = config['configurable'].get('checkpoint_ns')
|
|
23
|
+
if checkpoint_ns is not None:
|
|
24
|
+
conditions &= Q(checkpoint_ns=checkpoint_ns)
|
|
25
|
+
|
|
26
|
+
if checkpoint_id := get_checkpoint_id(config):
|
|
27
|
+
conditions &= Q(checkpoint_id=checkpoint_id)
|
|
28
|
+
|
|
29
|
+
# construct predicate for metadata filter
|
|
30
|
+
if filter:
|
|
31
|
+
for query_key, query_value in filter.items():
|
|
32
|
+
condition = Q(**{f'meta__{query_key}': query_value})
|
|
33
|
+
|
|
34
|
+
if not conditions:
|
|
35
|
+
conditions = condition
|
|
36
|
+
else:
|
|
37
|
+
conditions &= condition
|
|
38
|
+
|
|
39
|
+
# construct predicate for `before`
|
|
40
|
+
if before is not None:
|
|
41
|
+
condition = Q(checkpoint_id__lt=get_checkpoint_id(before))
|
|
42
|
+
|
|
43
|
+
if not conditions:
|
|
44
|
+
conditions = condition
|
|
45
|
+
else:
|
|
46
|
+
conditions &= condition
|
|
47
|
+
|
|
48
|
+
return conditions
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_checkpoint_metadata(config: RunnableConfig, metadata: CheckpointMetadata) -> CheckpointMetadata:
|
|
52
|
+
"""Get checkpoint metadata in a backwards-compatible manner."""
|
|
53
|
+
metadata = { # type: ignore[assignment]
|
|
54
|
+
k: v.replace('\u0000', '')
|
|
55
|
+
if isinstance(v, str) else v for k, v in metadata.items()
|
|
56
|
+
}
|
|
57
|
+
for obj in (config.get('metadata'), config.get('configurable')):
|
|
58
|
+
if not obj:
|
|
59
|
+
continue
|
|
60
|
+
for key, v in obj.items():
|
|
61
|
+
if key in metadata or key in EXCLUDED_METADATA_KEYS or key.startswith('__'):
|
|
62
|
+
continue
|
|
63
|
+
elif isinstance(v, str):
|
|
64
|
+
metadata[key] = v.replace('\u0000', '') # type: ignore[literal-required]
|
|
65
|
+
elif isinstance(v, (int, bool, float)):
|
|
66
|
+
metadata[key] = v # type: ignore[literal-required]
|
|
67
|
+
return metadata
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: amsdal_langgraph
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: amsdal_langgraph plugin for AMSDAL Framework
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Requires-Dist: amsdal[cli]>=0.5.14
|
|
7
|
+
Requires-Dist: langgraph>=0.6.8
|
|
8
|
+
Provides-Extra: openai
|
|
9
|
+
Requires-Dist: langchain[openai]; extra == 'openai'
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# AMSDAL Workflow
|
|
13
|
+
|
|
14
|
+
[](https://www.python.org/downloads/)
|
|
15
|
+
[](LICENSE)
|
|
16
|
+
|
|
17
|
+
A LangGraph checkpoint persistence plugin for the [AMSDAL Framework](https://github.com/amsdal/amsdal). This plugin enables persistent and recoverable LangGraph workflows by storing checkpoint state in AMSDAL-managed databases.
|
|
18
|
+
|
|
19
|
+
## Features
|
|
20
|
+
|
|
21
|
+
- **Persistent Workflow State**: Store LangGraph checkpoints in any AMSDAL-supported database (SQLite, PostgreSQL, etc.)
|
|
22
|
+
- **Dual Mode Support**: Both synchronous and asynchronous operations
|
|
23
|
+
- **Thread-based Organization**: Manage multiple workflow threads with checkpoint namespacing
|
|
24
|
+
- **Drop-in Replacement**: Compatible with LangGraph's `BaseCheckpointSaver` interface
|
|
25
|
+
- **Production Ready**: Built on the robust AMSDAL ORM with comprehensive testing
|
|
26
|
+
|
|
27
|
+
## Installation
|
|
28
|
+
|
|
29
|
+
Install via pip:
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
pip install amsdal-workflow
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
Or with optional dependencies:
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
# With OpenAI support
|
|
39
|
+
pip install amsdal-workflow[openai]
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## Quick Start
|
|
43
|
+
|
|
44
|
+
### Basic Usage
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from langgraph.graph import StateGraph
|
|
48
|
+
from amsdal_langgraph.checkpoint import AmsdalCheckpointSaver
|
|
49
|
+
|
|
50
|
+
# Initialize the checkpoint saver
|
|
51
|
+
checkpointer = AmsdalCheckpointSaver()
|
|
52
|
+
|
|
53
|
+
# Create your LangGraph workflow
|
|
54
|
+
workflow = StateGraph(...)
|
|
55
|
+
# ... define your workflow nodes and edges ...
|
|
56
|
+
|
|
57
|
+
# Compile with checkpoint support
|
|
58
|
+
app = workflow.compile(checkpointer=checkpointer)
|
|
59
|
+
|
|
60
|
+
# Run with persistence
|
|
61
|
+
config = {'configurable': {'thread_id': 'user-123'}}
|
|
62
|
+
result = app.invoke(input_data, config=config)
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
### Async Usage
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
from amsdal_langgraph.checkpoint import AmsdalCheckpointSaver
|
|
69
|
+
|
|
70
|
+
# Same checkpointer works for async
|
|
71
|
+
checkpointer = AmsdalCheckpointSaver()
|
|
72
|
+
|
|
73
|
+
# Compile async workflow
|
|
74
|
+
app = workflow.compile(checkpointer=checkpointer)
|
|
75
|
+
|
|
76
|
+
# Run async with persistence
|
|
77
|
+
config = {'configurable': {'thread_id': 'user-123'}}
|
|
78
|
+
result = await app.ainvoke(input_data, config=config)
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### Advanced Configuration
|
|
82
|
+
|
|
83
|
+
```python
|
|
84
|
+
from langchain_core.runnables import RunnableConfig
|
|
85
|
+
from amsdal_langgraph.checkpoint import AmsdalCheckpointSaver
|
|
86
|
+
|
|
87
|
+
checkpointer = AmsdalCheckpointSaver()
|
|
88
|
+
|
|
89
|
+
# Configuration with checkpoint namespace
|
|
90
|
+
config: RunnableConfig = {
|
|
91
|
+
'configurable': {
|
|
92
|
+
'thread_id': 'conversation-456',
|
|
93
|
+
'checkpoint_ns': 'production', # Optional namespace
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
# Run workflow
|
|
98
|
+
result = app.invoke(input_data, config=config)
|
|
99
|
+
|
|
100
|
+
# Resume from checkpoint
|
|
101
|
+
checkpoint_tuple = checkpointer.get_tuple(config)
|
|
102
|
+
if checkpoint_tuple:
|
|
103
|
+
# Continue from last checkpoint
|
|
104
|
+
result = app.invoke(input_data, config=config)
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
## Architecture
|
|
108
|
+
|
|
109
|
+
### Core Components
|
|
110
|
+
|
|
111
|
+
- **AmsdalCheckpointSaver**: Main class implementing LangGraph's `BaseCheckpointSaver` protocol
|
|
112
|
+
- **Checkpoint Model**: Stores checkpoint snapshots with metadata
|
|
113
|
+
- **CheckpointWrites Model**: Stores pending write operations for each checkpoint
|
|
114
|
+
|
|
115
|
+
### Data Models
|
|
116
|
+
|
|
117
|
+
#### Checkpoint
|
|
118
|
+
Stores workflow state snapshots:
|
|
119
|
+
- `thread_id`: Workflow thread identifier
|
|
120
|
+
- `checkpoint_ns`: Optional namespace for organization
|
|
121
|
+
- `checkpoint_id`: Unique checkpoint identifier
|
|
122
|
+
- `parent_checkpoint_id`: Reference to parent checkpoint
|
|
123
|
+
- `checkpoint`: Serialized checkpoint data
|
|
124
|
+
- `meta`: Checkpoint metadata
|
|
125
|
+
|
|
126
|
+
#### CheckpointWrites
|
|
127
|
+
Stores pending writes associated with checkpoints:
|
|
128
|
+
- `thread_id`, `checkpoint_ns`, `checkpoint_id`: Links to checkpoint
|
|
129
|
+
- `task_id`: Task identifier
|
|
130
|
+
- `idx`: Write operation index
|
|
131
|
+
- `channel`: Channel name
|
|
132
|
+
- `value`: Serialized write value
|
|
133
|
+
|
|
134
|
+
## API Reference
|
|
135
|
+
|
|
136
|
+
### AmsdalCheckpointSaver
|
|
137
|
+
|
|
138
|
+
#### Methods
|
|
139
|
+
|
|
140
|
+
##### Synchronous Methods
|
|
141
|
+
|
|
142
|
+
- `get_tuple(config: RunnableConfig) -> CheckpointTuple | None`
|
|
143
|
+
- Retrieve a checkpoint tuple by configuration
|
|
144
|
+
|
|
145
|
+
- `list(config: RunnableConfig | None, *, filter: dict | None = None, before: RunnableConfig | None = None, limit: int | None = None) -> Iterator[CheckpointTuple]`
|
|
146
|
+
- List checkpoints with optional filtering
|
|
147
|
+
|
|
148
|
+
- `put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions) -> RunnableConfig`
|
|
149
|
+
- Store a new checkpoint
|
|
150
|
+
|
|
151
|
+
- `put_writes(config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str) -> None`
|
|
152
|
+
- Store pending writes for a checkpoint
|
|
153
|
+
|
|
154
|
+
- `delete_thread(thread_id: str) -> None`
|
|
155
|
+
- Delete all checkpoints and writes for a thread
|
|
156
|
+
|
|
157
|
+
##### Asynchronous Methods
|
|
158
|
+
|
|
159
|
+
All synchronous methods have async equivalents prefixed with `a`:
|
|
160
|
+
- `aget_tuple(...)`
|
|
161
|
+
- `alist(...)`
|
|
162
|
+
- `aput(...)`
|
|
163
|
+
- `aput_writes(...)`
|
|
164
|
+
- `adelete_thread(...)`
|
|
165
|
+
|
|
166
|
+
## Configuration
|
|
167
|
+
|
|
168
|
+
### Database Setup
|
|
169
|
+
|
|
170
|
+
AMSDAL Workflow uses your existing AMSDAL configuration. Ensure you have configured your database connection:
|
|
171
|
+
|
|
172
|
+
```python
|
|
173
|
+
from amsdal.manager import AmsdalManager
|
|
174
|
+
|
|
175
|
+
# Initialize AMSDAL
|
|
176
|
+
manager = AmsdalManager()
|
|
177
|
+
manager.setup()
|
|
178
|
+
|
|
179
|
+
# Now use AmsdalCheckpointSaver
|
|
180
|
+
checkpointer = AmsdalCheckpointSaver()
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
### Migration
|
|
184
|
+
|
|
185
|
+
The plugin includes migration files for creating the required tables. Run migrations before first use:
|
|
186
|
+
|
|
187
|
+
```bash
|
|
188
|
+
amsdal migrate
|
|
189
|
+
```
|
|
190
|
+
|
|
191
|
+
## Development
|
|
192
|
+
|
|
193
|
+
### Setup Development Environment
|
|
194
|
+
|
|
195
|
+
```bash
|
|
196
|
+
# Clone the repository
|
|
197
|
+
git clone https://github.com/amsdal/amsdal-workflow.git
|
|
198
|
+
cd amsdal-workflow
|
|
199
|
+
|
|
200
|
+
# Install dependencies
|
|
201
|
+
hatch run sync
|
|
202
|
+
```
|
|
203
|
+
|
|
204
|
+
### Running Tests
|
|
205
|
+
|
|
206
|
+
```bash
|
|
207
|
+
# Run all tests
|
|
208
|
+
hatch run test
|
|
209
|
+
|
|
210
|
+
# Run with coverage
|
|
211
|
+
hatch run cov
|
|
212
|
+
|
|
213
|
+
# Run specific test file
|
|
214
|
+
hatch run test tests/test_checkpoint.py
|
|
215
|
+
|
|
216
|
+
# Run with verbose output
|
|
217
|
+
hatch run test -v
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
### Code Quality
|
|
221
|
+
|
|
222
|
+
```bash
|
|
223
|
+
# Format code
|
|
224
|
+
hatch run fmt
|
|
225
|
+
|
|
226
|
+
# Check code style
|
|
227
|
+
hatch run style
|
|
228
|
+
|
|
229
|
+
# Run type checking
|
|
230
|
+
hatch run typing
|
|
231
|
+
|
|
232
|
+
# Run all checks
|
|
233
|
+
hatch run all
|
|
234
|
+
```
|
|
235
|
+
|
|
236
|
+
### Project Structure
|
|
237
|
+
|
|
238
|
+
```
|
|
239
|
+
amsdal_langgraph/
|
|
240
|
+
├── amsdal_langgraph/ # Main package
|
|
241
|
+
│ ├── __init__.py
|
|
242
|
+
│ ├── checkpoint.py # AmsdalCheckpointSaver implementation
|
|
243
|
+
│ ├── utils.py # Utility functions
|
|
244
|
+
│ ├── models/ # Data models
|
|
245
|
+
│ │ ├── checkpoint.py # Checkpoint model
|
|
246
|
+
│ │ └── checkpoint_writes.py # CheckpointWrites model
|
|
247
|
+
│ └── migrations/ # Database migrations
|
|
248
|
+
├── tests/ # Test suite
|
|
249
|
+
│ ├── conftest.py # Test fixtures
|
|
250
|
+
│ └── test_checkpoint.py # Checkpoint tests
|
|
251
|
+
├── pyproject.toml # Project configuration
|
|
252
|
+
├── README.md # This file
|
|
253
|
+
```
|
|
254
|
+
|
|
255
|
+
## Contributing
|
|
256
|
+
|
|
257
|
+
Contributions are welcome! Please follow these steps:
|
|
258
|
+
|
|
259
|
+
1. Fork the repository
|
|
260
|
+
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
|
|
261
|
+
3. Make your changes
|
|
262
|
+
4. Run tests and code quality checks
|
|
263
|
+
5. Commit your changes (`git commit -m 'Add amazing feature'`)
|
|
264
|
+
6. Push to the branch (`git push origin feature/amazing-feature`)
|
|
265
|
+
7. Open a Pull Request
|
|
266
|
+
|
|
267
|
+
### Code Standards
|
|
268
|
+
|
|
269
|
+
- Python 3.11+ required
|
|
270
|
+
- Follow PEP 8 style guide (enforced by Ruff)
|
|
271
|
+
- Add type hints to all functions
|
|
272
|
+
- Write tests for new features
|
|
273
|
+
- Maintain test coverage above 90%
|
|
274
|
+
|
|
275
|
+
## License
|
|
276
|
+
|
|
277
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
278
|
+
|
|
279
|
+
## Acknowledgments
|
|
280
|
+
|
|
281
|
+
- Built on [LangGraph](https://github.com/langchain-ai/langgraph) by LangChain
|
|
282
|
+
- Powered by [AMSDAL Framework](https://github.com/amsdal/amsdal)
|
|
283
|
+
|
|
284
|
+
## Support
|
|
285
|
+
|
|
286
|
+
- **Documentation**: [AMSDAL Docs](https://docs.amsdal.com)
|
|
287
|
+
- **Issues**: [GitHub Issues](https://github.com/amsdal/amsdal-workflow/issues)
|
|
288
|
+
- **Discussions**: [GitHub Discussions](https://github.com/amsdal/amsdal-workflow/discussions)
|
|
289
|
+
|
|
290
|
+
## Changelog
|
|
291
|
+
|
|
292
|
+
See [CHANGELOG.md](CHANGELOG.md) for a list of changes in each release.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
amsdal_langgraph/Third-Party Materials - AMSDAL Dependencies - License Notices.md,sha256=y5L5esHBNdYnlmq34lyHCSiH_3N7M6SH7TFC92gFZ40,1285
|
|
2
|
+
amsdal_langgraph/__about__.py,sha256=IMjkMO3twhQzluVTo8Z6rE7Eg-9U79_LGKMcsWLKBkY,22
|
|
3
|
+
amsdal_langgraph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
+
amsdal_langgraph/app.py,sha256=x2Hv4cyHUsiXIv6atviu20jiHGZiw6w0rFKhRv7wlXs,203
|
|
5
|
+
amsdal_langgraph/checkpoint.py,sha256=AYuhIHdNRYr6QLWq25hq3RmP-MorOTX_rmuwQEOn0lg,21638
|
|
6
|
+
amsdal_langgraph/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
+
amsdal_langgraph/utils.py,sha256=4IMu8_JMgmaprXL0MkFcxKIT2mgGKCp2lIccvzgg7HA,2355
|
|
8
|
+
amsdal_langgraph/migrations/0000_initial.py,sha256=Ku40eJwP9VFO5oqE9U5LHYA_2IU-Up-Vxyxf5wDkM-Q,5198
|
|
9
|
+
amsdal_langgraph/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
amsdal_langgraph/models/checkpoint.py,sha256=vgxzmbdVs4vYKxodHklIzHMvDge6xH9QyfewgNctmHc,1014
|
|
11
|
+
amsdal_langgraph/models/checkpoint_writes.py,sha256=sIBzKZZHjz5-HRjHblMSIA7eZwHHR-LeK9uCeFRGhlc,1004
|
|
12
|
+
amsdal_langgraph-0.1.0.dist-info/METADATA,sha256=4nHshcbd7OH4wHSwauxApg3xtfd70uodhMyI0FlyoGo,7872
|
|
13
|
+
amsdal_langgraph-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
14
|
+
amsdal_langgraph-0.1.0.dist-info/RECORD,,
|