langgraph-api 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- LICENSE +93 -0
- langgraph_api/__init__.py +0 -0
- langgraph_api/api/__init__.py +63 -0
- langgraph_api/api/assistants.py +326 -0
- langgraph_api/api/meta.py +71 -0
- langgraph_api/api/openapi.py +32 -0
- langgraph_api/api/runs.py +463 -0
- langgraph_api/api/store.py +116 -0
- langgraph_api/api/threads.py +263 -0
- langgraph_api/asyncio.py +201 -0
- langgraph_api/auth/__init__.py +0 -0
- langgraph_api/auth/langsmith/__init__.py +0 -0
- langgraph_api/auth/langsmith/backend.py +67 -0
- langgraph_api/auth/langsmith/client.py +145 -0
- langgraph_api/auth/middleware.py +41 -0
- langgraph_api/auth/noop.py +14 -0
- langgraph_api/cli.py +209 -0
- langgraph_api/config.py +70 -0
- langgraph_api/cron_scheduler.py +60 -0
- langgraph_api/errors.py +52 -0
- langgraph_api/graph.py +314 -0
- langgraph_api/http.py +168 -0
- langgraph_api/http_logger.py +89 -0
- langgraph_api/js/.gitignore +2 -0
- langgraph_api/js/build.mts +49 -0
- langgraph_api/js/client.mts +849 -0
- langgraph_api/js/global.d.ts +6 -0
- langgraph_api/js/package.json +33 -0
- langgraph_api/js/remote.py +673 -0
- langgraph_api/js/server_sent_events.py +126 -0
- langgraph_api/js/src/graph.mts +88 -0
- langgraph_api/js/src/hooks.mjs +12 -0
- langgraph_api/js/src/parser/parser.mts +443 -0
- langgraph_api/js/src/parser/parser.worker.mjs +12 -0
- langgraph_api/js/src/schema/types.mts +2136 -0
- langgraph_api/js/src/schema/types.template.mts +74 -0
- langgraph_api/js/src/utils/importMap.mts +85 -0
- langgraph_api/js/src/utils/pythonSchemas.mts +28 -0
- langgraph_api/js/src/utils/serde.mts +21 -0
- langgraph_api/js/tests/api.test.mts +1566 -0
- langgraph_api/js/tests/compose-postgres.yml +56 -0
- langgraph_api/js/tests/graphs/.gitignore +1 -0
- langgraph_api/js/tests/graphs/agent.mts +127 -0
- langgraph_api/js/tests/graphs/error.mts +17 -0
- langgraph_api/js/tests/graphs/langgraph.json +8 -0
- langgraph_api/js/tests/graphs/nested.mts +44 -0
- langgraph_api/js/tests/graphs/package.json +7 -0
- langgraph_api/js/tests/graphs/weather.mts +57 -0
- langgraph_api/js/tests/graphs/yarn.lock +159 -0
- langgraph_api/js/tests/parser.test.mts +870 -0
- langgraph_api/js/tests/utils.mts +17 -0
- langgraph_api/js/yarn.lock +1340 -0
- langgraph_api/lifespan.py +41 -0
- langgraph_api/logging.py +121 -0
- langgraph_api/metadata.py +101 -0
- langgraph_api/models/__init__.py +0 -0
- langgraph_api/models/run.py +229 -0
- langgraph_api/patch.py +42 -0
- langgraph_api/queue.py +245 -0
- langgraph_api/route.py +118 -0
- langgraph_api/schema.py +190 -0
- langgraph_api/serde.py +124 -0
- langgraph_api/server.py +48 -0
- langgraph_api/sse.py +118 -0
- langgraph_api/state.py +67 -0
- langgraph_api/stream.py +289 -0
- langgraph_api/utils.py +60 -0
- langgraph_api/validation.py +141 -0
- langgraph_api-0.0.1.dist-info/LICENSE +93 -0
- langgraph_api-0.0.1.dist-info/METADATA +26 -0
- langgraph_api-0.0.1.dist-info/RECORD +86 -0
- langgraph_api-0.0.1.dist-info/WHEEL +4 -0
- langgraph_api-0.0.1.dist-info/entry_points.txt +3 -0
- langgraph_license/__init__.py +0 -0
- langgraph_license/middleware.py +21 -0
- langgraph_license/validation.py +11 -0
- langgraph_storage/__init__.py +0 -0
- langgraph_storage/checkpoint.py +94 -0
- langgraph_storage/database.py +190 -0
- langgraph_storage/ops.py +1523 -0
- langgraph_storage/queue.py +108 -0
- langgraph_storage/retry.py +27 -0
- langgraph_storage/store.py +28 -0
- langgraph_storage/ttl_dict.py +54 -0
- logging.json +22 -0
- openapi.json +4304 -0
langgraph_storage/ops.py
ADDED
|
@@ -0,0 +1,1523 @@
|
|
|
1
|
+
"""Implementation of the LangGraph API using in-memory checkpointer & store."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import copy
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import uuid
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from collections.abc import AsyncIterator, Sequence
|
|
11
|
+
from contextlib import asynccontextmanager
|
|
12
|
+
from copy import deepcopy
|
|
13
|
+
from datetime import UTC, datetime, timedelta
|
|
14
|
+
from typing import Any, Literal
|
|
15
|
+
from uuid import UUID, uuid4
|
|
16
|
+
|
|
17
|
+
import structlog
|
|
18
|
+
from langgraph.pregel.debug import CheckpointPayload
|
|
19
|
+
from langgraph.pregel.types import StateSnapshot
|
|
20
|
+
from starlette.exceptions import HTTPException
|
|
21
|
+
|
|
22
|
+
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
|
|
23
|
+
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
24
|
+
from langgraph_api.graph import get_graph
|
|
25
|
+
from langgraph_api.schema import (
|
|
26
|
+
Assistant,
|
|
27
|
+
Checkpoint,
|
|
28
|
+
Config,
|
|
29
|
+
Cron,
|
|
30
|
+
IfNotExists,
|
|
31
|
+
MetadataInput,
|
|
32
|
+
MetadataValue,
|
|
33
|
+
MultitaskStrategy,
|
|
34
|
+
OnConflictBehavior,
|
|
35
|
+
QueueStats,
|
|
36
|
+
Run,
|
|
37
|
+
RunStatus,
|
|
38
|
+
StreamMode,
|
|
39
|
+
Thread,
|
|
40
|
+
ThreadStatus,
|
|
41
|
+
ThreadUpdateResponse,
|
|
42
|
+
)
|
|
43
|
+
from langgraph_api.serde import Fragment
|
|
44
|
+
from langgraph_api.utils import fetchone
|
|
45
|
+
from langgraph_storage.checkpoint import Checkpointer
|
|
46
|
+
from langgraph_storage.database import InMemConnectionProto, connect
|
|
47
|
+
from langgraph_storage.queue import Message, get_stream_manager
|
|
48
|
+
|
|
49
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
|
|
53
|
+
if isinstance(id_, str):
|
|
54
|
+
return uuid.UUID(id_)
|
|
55
|
+
if id_ is None:
|
|
56
|
+
return uuid4()
|
|
57
|
+
return id_
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Right now the whole API types as UUID but frequently passes a str
|
|
61
|
+
# We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
|
|
62
|
+
# which we leave as strings. This is because I'm too lazy to subclass fully
|
|
63
|
+
# and we use non-UUID examples in the OSS version
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Assistants:
|
|
67
|
+
@staticmethod
|
|
68
|
+
async def search(
|
|
69
|
+
conn: InMemConnectionProto,
|
|
70
|
+
*,
|
|
71
|
+
graph_id: str | None,
|
|
72
|
+
metadata: MetadataInput,
|
|
73
|
+
limit: int,
|
|
74
|
+
offset: int,
|
|
75
|
+
) -> AsyncIterator[Assistant]:
|
|
76
|
+
async def filter_and_yield() -> AsyncIterator[Assistant]:
|
|
77
|
+
assistants = conn.store["assistants"]
|
|
78
|
+
filtered_assistants = [
|
|
79
|
+
assistant
|
|
80
|
+
for assistant in assistants
|
|
81
|
+
if (not graph_id or assistant["graph_id"] == graph_id)
|
|
82
|
+
and (
|
|
83
|
+
not metadata or is_jsonb_contained(assistant["metadata"], metadata)
|
|
84
|
+
)
|
|
85
|
+
]
|
|
86
|
+
filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
|
|
87
|
+
for assistant in filtered_assistants[offset : offset + limit]:
|
|
88
|
+
yield assistant
|
|
89
|
+
|
|
90
|
+
return filter_and_yield()
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
async def get(
|
|
94
|
+
conn: InMemConnectionProto, assistant_id: UUID
|
|
95
|
+
) -> AsyncIterator[Assistant]:
|
|
96
|
+
"""Get an assistant by ID."""
|
|
97
|
+
assistant_id = _ensure_uuid(assistant_id)
|
|
98
|
+
|
|
99
|
+
async def _yield_result():
|
|
100
|
+
for assistant in conn.store["assistants"]:
|
|
101
|
+
if assistant["assistant_id"] == assistant_id:
|
|
102
|
+
yield assistant
|
|
103
|
+
|
|
104
|
+
return _yield_result()
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
async def put(
|
|
108
|
+
conn: InMemConnectionProto,
|
|
109
|
+
assistant_id: UUID,
|
|
110
|
+
*,
|
|
111
|
+
graph_id: str,
|
|
112
|
+
config: Config,
|
|
113
|
+
metadata: MetadataInput,
|
|
114
|
+
if_exists: OnConflictBehavior,
|
|
115
|
+
name: str,
|
|
116
|
+
) -> AsyncIterator[Assistant]:
|
|
117
|
+
"""Insert an assistant."""
|
|
118
|
+
assistant_id = _ensure_uuid(assistant_id)
|
|
119
|
+
existing_assistant = next(
|
|
120
|
+
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
121
|
+
None,
|
|
122
|
+
)
|
|
123
|
+
if existing_assistant:
|
|
124
|
+
if if_exists == "raise":
|
|
125
|
+
raise HTTPException(
|
|
126
|
+
status_code=409, detail=f"Assistant {assistant_id} already exists"
|
|
127
|
+
)
|
|
128
|
+
elif if_exists == "do_nothing":
|
|
129
|
+
|
|
130
|
+
async def _yield_existing():
|
|
131
|
+
yield existing_assistant
|
|
132
|
+
|
|
133
|
+
return _yield_existing()
|
|
134
|
+
|
|
135
|
+
now = datetime.now(UTC)
|
|
136
|
+
new_assistant: Assistant = {
|
|
137
|
+
"assistant_id": assistant_id,
|
|
138
|
+
"graph_id": graph_id,
|
|
139
|
+
"config": config or {},
|
|
140
|
+
"metadata": metadata or {},
|
|
141
|
+
"name": name,
|
|
142
|
+
"created_at": now,
|
|
143
|
+
"updated_at": now,
|
|
144
|
+
"version": 1,
|
|
145
|
+
}
|
|
146
|
+
new_version = {
|
|
147
|
+
"assistant_id": assistant_id,
|
|
148
|
+
"version": 1,
|
|
149
|
+
"graph_id": graph_id,
|
|
150
|
+
"config": config or {},
|
|
151
|
+
"metadata": metadata or {},
|
|
152
|
+
"created_at": now,
|
|
153
|
+
}
|
|
154
|
+
conn.store["assistants"].append(new_assistant)
|
|
155
|
+
conn.store["assistant_versions"].append(new_version)
|
|
156
|
+
|
|
157
|
+
async def _yield_new():
|
|
158
|
+
yield new_assistant
|
|
159
|
+
|
|
160
|
+
return _yield_new()
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
async def patch(
|
|
164
|
+
conn: InMemConnectionProto,
|
|
165
|
+
assistant_id: UUID,
|
|
166
|
+
*,
|
|
167
|
+
config: dict | None = None,
|
|
168
|
+
graph_id: str | None = None,
|
|
169
|
+
metadata: MetadataInput | None = None,
|
|
170
|
+
name: str | None = None,
|
|
171
|
+
) -> AsyncIterator[Assistant]:
|
|
172
|
+
"""Update an assistant.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
assistant_id: The assistant ID.
|
|
176
|
+
graph_id: The graph ID.
|
|
177
|
+
config: The assistant config.
|
|
178
|
+
metadata: The assistant metadata.
|
|
179
|
+
name: The assistant name.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
return the updated assistant model.
|
|
183
|
+
"""
|
|
184
|
+
assistant_id = _ensure_uuid(assistant_id)
|
|
185
|
+
assistant = next(
|
|
186
|
+
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
187
|
+
None,
|
|
188
|
+
)
|
|
189
|
+
if not assistant:
|
|
190
|
+
raise HTTPException(
|
|
191
|
+
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
now = datetime.now(UTC)
|
|
195
|
+
new_version = (
|
|
196
|
+
max(
|
|
197
|
+
v["version"]
|
|
198
|
+
for v in conn.store["assistant_versions"]
|
|
199
|
+
if v["assistant_id"] == assistant_id
|
|
200
|
+
)
|
|
201
|
+
+ 1
|
|
202
|
+
if conn.store["assistant_versions"]
|
|
203
|
+
else 1
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Update assistant_versions table
|
|
207
|
+
new_version_entry = {
|
|
208
|
+
"assistant_id": assistant_id,
|
|
209
|
+
"version": new_version,
|
|
210
|
+
"graph_id": graph_id if graph_id is not None else assistant["graph_id"],
|
|
211
|
+
"config": config if config is not None else assistant["config"],
|
|
212
|
+
"metadata": metadata if metadata is not None else assistant["metadata"],
|
|
213
|
+
"created_at": now,
|
|
214
|
+
}
|
|
215
|
+
conn.store["assistant_versions"].append(new_version_entry)
|
|
216
|
+
|
|
217
|
+
# Update assistants table
|
|
218
|
+
assistant.update(
|
|
219
|
+
{
|
|
220
|
+
"graph_id": new_version_entry["graph_id"],
|
|
221
|
+
"config": new_version_entry["config"],
|
|
222
|
+
"metadata": new_version_entry["metadata"],
|
|
223
|
+
"name": name if name is not None else assistant["name"],
|
|
224
|
+
"updated_at": now,
|
|
225
|
+
"version": new_version,
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
async def _yield_updated():
|
|
230
|
+
yield assistant
|
|
231
|
+
|
|
232
|
+
return _yield_updated()
|
|
233
|
+
|
|
234
|
+
@staticmethod
|
|
235
|
+
async def delete(
|
|
236
|
+
conn: InMemConnectionProto, assistant_id: UUID
|
|
237
|
+
) -> AsyncIterator[UUID]:
|
|
238
|
+
"""Delete an assistant by ID."""
|
|
239
|
+
assistant_id = _ensure_uuid(assistant_id)
|
|
240
|
+
conn.store["assistants"] = [
|
|
241
|
+
a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
|
|
242
|
+
]
|
|
243
|
+
# Cascade delete assistant versions, crons, & runs on this assistant
|
|
244
|
+
conn.store["assistant_versions"] = [
|
|
245
|
+
v
|
|
246
|
+
for v in conn.store["assistant_versions"]
|
|
247
|
+
if v["assistant_id"] != assistant_id
|
|
248
|
+
]
|
|
249
|
+
retained = []
|
|
250
|
+
for run in conn.store["runs"]:
|
|
251
|
+
if run["assistant_id"] == assistant_id:
|
|
252
|
+
res = await Runs.delete(conn, run["run_id"], thread_id=run["thread_id"])
|
|
253
|
+
await anext(res)
|
|
254
|
+
|
|
255
|
+
else:
|
|
256
|
+
retained.append(run)
|
|
257
|
+
|
|
258
|
+
async def _yield_deleted():
|
|
259
|
+
yield assistant_id
|
|
260
|
+
|
|
261
|
+
return _yield_deleted()
|
|
262
|
+
|
|
263
|
+
@staticmethod
|
|
264
|
+
async def set_latest(
|
|
265
|
+
conn: InMemConnectionProto, assistant_id: UUID, version: int
|
|
266
|
+
) -> AsyncIterator[Assistant]:
|
|
267
|
+
"""Change the version of an assistant."""
|
|
268
|
+
assistant_id = _ensure_uuid(assistant_id)
|
|
269
|
+
assistant = next(
|
|
270
|
+
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
271
|
+
None,
|
|
272
|
+
)
|
|
273
|
+
if not assistant:
|
|
274
|
+
raise HTTPException(
|
|
275
|
+
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
version_data = next(
|
|
279
|
+
(
|
|
280
|
+
v
|
|
281
|
+
for v in conn.store["assistant_versions"]
|
|
282
|
+
if v["assistant_id"] == assistant_id and v["version"] == version
|
|
283
|
+
),
|
|
284
|
+
None,
|
|
285
|
+
)
|
|
286
|
+
if not version_data:
|
|
287
|
+
raise HTTPException(
|
|
288
|
+
status_code=404,
|
|
289
|
+
detail=f"Version {version} not found for assistant {assistant_id}",
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
assistant.update(
|
|
293
|
+
{
|
|
294
|
+
"config": version_data["config"],
|
|
295
|
+
"metadata": version_data["metadata"],
|
|
296
|
+
"version": version_data["version"],
|
|
297
|
+
"updated_at": datetime.now(UTC),
|
|
298
|
+
}
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
async def _yield_updated():
|
|
302
|
+
yield assistant
|
|
303
|
+
|
|
304
|
+
return _yield_updated()
|
|
305
|
+
|
|
306
|
+
@staticmethod
|
|
307
|
+
async def get_versions(
|
|
308
|
+
conn: InMemConnectionProto,
|
|
309
|
+
assistant_id: UUID,
|
|
310
|
+
metadata: MetadataInput,
|
|
311
|
+
limit: int,
|
|
312
|
+
offset: int,
|
|
313
|
+
) -> AsyncIterator[Assistant]:
|
|
314
|
+
"""Get all versions of an assistant."""
|
|
315
|
+
assistant_id = _ensure_uuid(assistant_id)
|
|
316
|
+
versions = [
|
|
317
|
+
v
|
|
318
|
+
for v in conn.store["assistant_versions"]
|
|
319
|
+
if v["assistant_id"] == assistant_id
|
|
320
|
+
and (not metadata or is_jsonb_contained(v["metadata"], metadata))
|
|
321
|
+
]
|
|
322
|
+
versions.sort(key=lambda x: x["version"], reverse=True)
|
|
323
|
+
|
|
324
|
+
async def _yield_versions():
|
|
325
|
+
for version in versions[offset : offset + limit]:
|
|
326
|
+
yield version
|
|
327
|
+
|
|
328
|
+
return _yield_versions()
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
|
|
332
|
+
"""
|
|
333
|
+
Implements Postgres' @> (containment) operator for dictionaries.
|
|
334
|
+
Returns True if superset contains all key/value pairs from subset.
|
|
335
|
+
"""
|
|
336
|
+
for key, value in subset.items():
|
|
337
|
+
if key not in superset:
|
|
338
|
+
return False
|
|
339
|
+
if isinstance(value, dict) and isinstance(superset[key], dict):
|
|
340
|
+
if not is_jsonb_contained(superset[key], value):
|
|
341
|
+
return False
|
|
342
|
+
elif superset[key] != value:
|
|
343
|
+
return False
|
|
344
|
+
return True
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def bytes_decoder(obj):
|
|
348
|
+
"""Custom JSON decoder that converts base64 back to bytes."""
|
|
349
|
+
if "__type__" in obj and obj["__type__"] == "bytes":
|
|
350
|
+
return base64.b64decode(obj["value"].encode("utf-8"))
|
|
351
|
+
return obj
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _replace_thread_id(data, new_thread_id, thread_id):
|
|
355
|
+
class BytesEncoder(json.JSONEncoder):
|
|
356
|
+
"""Custom JSON encoder that handles bytes by converting them to base64."""
|
|
357
|
+
|
|
358
|
+
def default(self, obj):
|
|
359
|
+
if isinstance(obj, bytes | bytearray):
|
|
360
|
+
return {
|
|
361
|
+
"__type__": "bytes",
|
|
362
|
+
"value": base64.b64encode(
|
|
363
|
+
obj.replace(
|
|
364
|
+
str(thread_id).encode(), str(new_thread_id).encode()
|
|
365
|
+
)
|
|
366
|
+
).decode("utf-8"),
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
return super().default(obj)
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
json_str = json.dumps(data, cls=BytesEncoder, indent=2)
|
|
373
|
+
except Exception as e:
|
|
374
|
+
raise ValueError(data) from e
|
|
375
|
+
json_str = json_str.replace(str(thread_id), str(new_thread_id))
|
|
376
|
+
|
|
377
|
+
# Decoding back from JSON
|
|
378
|
+
d = json.loads(json_str, object_hook=bytes_decoder)
|
|
379
|
+
return d
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class Threads:
|
|
383
|
+
@staticmethod
|
|
384
|
+
async def search(
|
|
385
|
+
conn: InMemConnectionProto,
|
|
386
|
+
*,
|
|
387
|
+
metadata: MetadataInput,
|
|
388
|
+
values: MetadataInput,
|
|
389
|
+
status: ThreadStatus | None,
|
|
390
|
+
limit: int,
|
|
391
|
+
offset: int,
|
|
392
|
+
) -> AsyncIterator[Thread]:
|
|
393
|
+
threads = conn.store["threads"]
|
|
394
|
+
filtered_threads: list[Thread] = []
|
|
395
|
+
|
|
396
|
+
# Apply filters
|
|
397
|
+
for thread in threads:
|
|
398
|
+
matches = True
|
|
399
|
+
|
|
400
|
+
if metadata and not is_jsonb_contained(thread["metadata"], metadata):
|
|
401
|
+
matches = False
|
|
402
|
+
|
|
403
|
+
if (
|
|
404
|
+
values
|
|
405
|
+
and "values" in thread
|
|
406
|
+
and not is_jsonb_contained(thread["values"], values)
|
|
407
|
+
):
|
|
408
|
+
matches = False
|
|
409
|
+
|
|
410
|
+
if status and thread.get("status") != status:
|
|
411
|
+
matches = False
|
|
412
|
+
|
|
413
|
+
if matches:
|
|
414
|
+
filtered_threads.append(thread)
|
|
415
|
+
|
|
416
|
+
# Sort by created_at in descending order
|
|
417
|
+
sorted_threads = sorted(
|
|
418
|
+
filtered_threads, key=lambda x: x["created_at"], reverse=True
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Apply limit and offset
|
|
422
|
+
paginated_threads = sorted_threads[offset : offset + limit]
|
|
423
|
+
|
|
424
|
+
async def thread_iterator() -> AsyncIterator[Thread]:
|
|
425
|
+
for thread in paginated_threads:
|
|
426
|
+
yield thread
|
|
427
|
+
|
|
428
|
+
return thread_iterator()
|
|
429
|
+
|
|
430
|
+
@staticmethod
|
|
431
|
+
async def get(conn: InMemConnectionProto, thread_id: UUID) -> AsyncIterator[Thread]:
|
|
432
|
+
"""Get a thread by ID."""
|
|
433
|
+
thread_id = _ensure_uuid(thread_id)
|
|
434
|
+
matching_thread = next(
|
|
435
|
+
(
|
|
436
|
+
thread
|
|
437
|
+
for thread in conn.store["threads"]
|
|
438
|
+
if thread["thread_id"] == thread_id
|
|
439
|
+
),
|
|
440
|
+
None,
|
|
441
|
+
)
|
|
442
|
+
if not matching_thread:
|
|
443
|
+
raise HTTPException(
|
|
444
|
+
status_code=404, detail=f"Thread with ID {thread_id} not found"
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
async def _yield_result():
|
|
448
|
+
if matching_thread:
|
|
449
|
+
yield matching_thread
|
|
450
|
+
|
|
451
|
+
return _yield_result()
|
|
452
|
+
|
|
453
|
+
@staticmethod
|
|
454
|
+
async def put(
|
|
455
|
+
conn: InMemConnectionProto,
|
|
456
|
+
thread_id: UUID,
|
|
457
|
+
*,
|
|
458
|
+
metadata: MetadataInput,
|
|
459
|
+
if_exists: OnConflictBehavior,
|
|
460
|
+
) -> AsyncIterator[Thread]:
|
|
461
|
+
"""Insert or update a thread."""
|
|
462
|
+
thread_id = _ensure_uuid(thread_id)
|
|
463
|
+
if metadata is None:
|
|
464
|
+
metadata = {}
|
|
465
|
+
|
|
466
|
+
# Check if thread already exists
|
|
467
|
+
existing_thread = next(
|
|
468
|
+
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if existing_thread:
|
|
472
|
+
if if_exists == "raise":
|
|
473
|
+
raise HTTPException(
|
|
474
|
+
status_code=409, detail=f"Thread with ID {thread_id} already exists"
|
|
475
|
+
)
|
|
476
|
+
elif if_exists == "do_nothing":
|
|
477
|
+
|
|
478
|
+
async def _yield_existing():
|
|
479
|
+
yield existing_thread
|
|
480
|
+
|
|
481
|
+
return _yield_existing()
|
|
482
|
+
|
|
483
|
+
# Create new thread
|
|
484
|
+
new_thread: Thread = {
|
|
485
|
+
"thread_id": thread_id,
|
|
486
|
+
"created_at": datetime.now(UTC),
|
|
487
|
+
"updated_at": datetime.now(UTC),
|
|
488
|
+
"metadata": copy.deepcopy(metadata),
|
|
489
|
+
"status": "idle",
|
|
490
|
+
"config": {},
|
|
491
|
+
"values": None,
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
# Add to store
|
|
495
|
+
conn.store["threads"].append(new_thread)
|
|
496
|
+
|
|
497
|
+
async def _yield_new():
|
|
498
|
+
yield new_thread
|
|
499
|
+
|
|
500
|
+
return _yield_new()
|
|
501
|
+
|
|
502
|
+
@staticmethod
|
|
503
|
+
async def patch(
|
|
504
|
+
conn: InMemConnectionProto, thread_id: UUID, *, metadata: MetadataValue
|
|
505
|
+
) -> AsyncIterator[Thread]:
|
|
506
|
+
"""Update a thread."""
|
|
507
|
+
thread_list = conn.store["threads"]
|
|
508
|
+
thread_idx = None
|
|
509
|
+
thread_id = _ensure_uuid(thread_id)
|
|
510
|
+
|
|
511
|
+
for idx, thread in enumerate(thread_list):
|
|
512
|
+
if thread["thread_id"] == thread_id:
|
|
513
|
+
thread_idx = idx
|
|
514
|
+
break
|
|
515
|
+
|
|
516
|
+
if thread_idx is not None:
|
|
517
|
+
thread = copy.deepcopy(thread_list[thread_idx])
|
|
518
|
+
thread["metadata"] = {**thread["metadata"], **metadata}
|
|
519
|
+
thread["updated_at"] = datetime.now(UTC)
|
|
520
|
+
thread_list[thread_idx] = thread
|
|
521
|
+
|
|
522
|
+
async def thread_iterator() -> AsyncIterator[Thread]:
|
|
523
|
+
yield thread
|
|
524
|
+
|
|
525
|
+
return thread_iterator()
|
|
526
|
+
|
|
527
|
+
async def empty_iterator() -> AsyncIterator[Thread]:
|
|
528
|
+
if False: # This ensures the iterator is empty
|
|
529
|
+
yield
|
|
530
|
+
|
|
531
|
+
return empty_iterator()
|
|
532
|
+
|
|
533
|
+
@staticmethod
|
|
534
|
+
async def set_status(
|
|
535
|
+
conn: InMemConnectionProto,
|
|
536
|
+
thread_id: UUID,
|
|
537
|
+
checkpoint: CheckpointPayload | None,
|
|
538
|
+
exception: BaseException | None,
|
|
539
|
+
) -> None:
|
|
540
|
+
"""Set the status of a thread."""
|
|
541
|
+
thread_id = _ensure_uuid(thread_id)
|
|
542
|
+
|
|
543
|
+
async def has_pending_runs(conn_: InMemConnectionProto, tid: UUID) -> bool:
|
|
544
|
+
"""Check if thread has any pending runs."""
|
|
545
|
+
return any(
|
|
546
|
+
run["status"] == "pending" and run["thread_id"] == tid
|
|
547
|
+
for run in conn_.store["runs"]
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Find the thread
|
|
551
|
+
thread = next(
|
|
552
|
+
(
|
|
553
|
+
thread
|
|
554
|
+
for thread in conn.store["threads"]
|
|
555
|
+
if thread["thread_id"] == thread_id
|
|
556
|
+
),
|
|
557
|
+
None,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
if not thread:
|
|
561
|
+
raise HTTPException(
|
|
562
|
+
status_code=404, detail=f"Thread {thread_id} not found."
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Determine has_next from checkpoint
|
|
566
|
+
has_next = False if checkpoint is None else bool(checkpoint["next"])
|
|
567
|
+
|
|
568
|
+
# Determine base status
|
|
569
|
+
if exception:
|
|
570
|
+
status = "error"
|
|
571
|
+
elif has_next:
|
|
572
|
+
status = "interrupted"
|
|
573
|
+
else:
|
|
574
|
+
status = "idle"
|
|
575
|
+
|
|
576
|
+
# Check for pending runs and update to busy if found
|
|
577
|
+
if await has_pending_runs(conn, thread_id):
|
|
578
|
+
status = "busy"
|
|
579
|
+
|
|
580
|
+
# Update thread
|
|
581
|
+
thread.update(
|
|
582
|
+
{
|
|
583
|
+
"updated_at": datetime.now(UTC),
|
|
584
|
+
"values": checkpoint["values"] if checkpoint else None,
|
|
585
|
+
"status": status,
|
|
586
|
+
}
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
@staticmethod
|
|
590
|
+
async def delete(
|
|
591
|
+
conn: InMemConnectionProto, thread_id: UUID
|
|
592
|
+
) -> AsyncIterator[UUID]:
|
|
593
|
+
"""Delete a thread by ID and cascade delete all associated runs."""
|
|
594
|
+
thread_list = conn.store["threads"]
|
|
595
|
+
thread_idx = None
|
|
596
|
+
thread_id = _ensure_uuid(thread_id)
|
|
597
|
+
conn.locks.pop(thread_id, None)
|
|
598
|
+
|
|
599
|
+
# Find the thread to delete
|
|
600
|
+
for idx, thread in enumerate(thread_list):
|
|
601
|
+
if thread["thread_id"] == thread_id:
|
|
602
|
+
thread_idx = idx
|
|
603
|
+
break
|
|
604
|
+
# Cascade delete all runs associated with this thread
|
|
605
|
+
conn.store["runs"] = [
|
|
606
|
+
run for run in conn.store["runs"] if run["thread_id"] != thread_id
|
|
607
|
+
]
|
|
608
|
+
_delete_checkpoints_for_thread(thread_id, conn)
|
|
609
|
+
|
|
610
|
+
if thread_idx is not None:
|
|
611
|
+
# Remove the thread from the store
|
|
612
|
+
deleted_thread = thread_list.pop(thread_idx)
|
|
613
|
+
|
|
614
|
+
# Return an async iterator with the deleted thread_id
|
|
615
|
+
async def id_iterator() -> AsyncIterator[UUID]:
|
|
616
|
+
yield deleted_thread["thread_id"]
|
|
617
|
+
|
|
618
|
+
return id_iterator()
|
|
619
|
+
|
|
620
|
+
# If thread not found, return empty iterator
|
|
621
|
+
async def empty_iterator() -> AsyncIterator[UUID]:
|
|
622
|
+
if False: # This ensures the iterator is empty
|
|
623
|
+
yield
|
|
624
|
+
|
|
625
|
+
return empty_iterator()
|
|
626
|
+
|
|
627
|
+
@staticmethod
|
|
628
|
+
async def copy(
|
|
629
|
+
conn: InMemConnectionProto, thread_id: UUID
|
|
630
|
+
) -> AsyncIterator[Thread]:
|
|
631
|
+
"""Create a copy of an existing thread."""
|
|
632
|
+
thread_id = _ensure_uuid(thread_id)
|
|
633
|
+
new_thread_id = uuid4()
|
|
634
|
+
|
|
635
|
+
async with conn.pipeline():
|
|
636
|
+
# Find the original thread in our store
|
|
637
|
+
original_thread = next(
|
|
638
|
+
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
if not original_thread:
|
|
642
|
+
return
|
|
643
|
+
|
|
644
|
+
# Create new thread with copied metadata
|
|
645
|
+
new_thread: Thread = {
|
|
646
|
+
"thread_id": new_thread_id,
|
|
647
|
+
"created_at": datetime.now(tz=UTC),
|
|
648
|
+
"updated_at": datetime.now(tz=UTC),
|
|
649
|
+
"metadata": deepcopy(original_thread["metadata"]),
|
|
650
|
+
"status": "idle",
|
|
651
|
+
"config": {},
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
# Add new thread to store
|
|
655
|
+
conn.store["threads"].append(new_thread)
|
|
656
|
+
|
|
657
|
+
checkpointer = Checkpointer(conn)
|
|
658
|
+
copied_storage = _replace_thread_id(
|
|
659
|
+
checkpointer.storage[str(thread_id)], new_thread_id, thread_id
|
|
660
|
+
)
|
|
661
|
+
checkpointer.storage[str(new_thread_id)] = copied_storage
|
|
662
|
+
# Copy the writes over (if any)
|
|
663
|
+
outer_keys = []
|
|
664
|
+
for k in checkpointer.writes:
|
|
665
|
+
if k[0] == str(thread_id):
|
|
666
|
+
outer_keys.append(k)
|
|
667
|
+
for tid, checkpoint_ns, checkpoint_id in outer_keys:
|
|
668
|
+
mapped = {
|
|
669
|
+
k: _replace_thread_id(v, new_thread_id, thread_id)
|
|
670
|
+
for k, v in checkpointer.writes[
|
|
671
|
+
(str(tid), checkpoint_ns, checkpoint_id)
|
|
672
|
+
].items()
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
checkpointer.writes[
|
|
676
|
+
(str(new_thread_id), checkpoint_ns, checkpoint_id)
|
|
677
|
+
] = mapped
|
|
678
|
+
|
|
679
|
+
async def row_generator() -> AsyncIterator[Thread]:
|
|
680
|
+
yield new_thread
|
|
681
|
+
|
|
682
|
+
return row_generator()
|
|
683
|
+
|
|
684
|
+
class State:
|
|
685
|
+
@staticmethod
|
|
686
|
+
async def get(
|
|
687
|
+
conn: InMemConnectionProto, config: Config, subgraphs: bool = False
|
|
688
|
+
) -> StateSnapshot:
|
|
689
|
+
"""Get state for a thread."""
|
|
690
|
+
checkpointer = Checkpointer(conn)
|
|
691
|
+
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
692
|
+
thread_iter = await Threads.get(conn, thread_id)
|
|
693
|
+
thread = await anext(thread_iter)
|
|
694
|
+
checkpoint = await checkpointer.aget(config)
|
|
695
|
+
|
|
696
|
+
if not thread:
|
|
697
|
+
return StateSnapshot(
|
|
698
|
+
values={},
|
|
699
|
+
next=[],
|
|
700
|
+
config=None,
|
|
701
|
+
metadata=None,
|
|
702
|
+
created_at=None,
|
|
703
|
+
parent_config=None,
|
|
704
|
+
tasks=tuple(),
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
metadata = thread.get("metadata", {})
|
|
708
|
+
thread_config = thread.get("config", {})
|
|
709
|
+
|
|
710
|
+
if graph_id := metadata.get("graph_id"):
|
|
711
|
+
# format latest checkpoint for response
|
|
712
|
+
checkpointer.latest_iter = checkpoint
|
|
713
|
+
graph = get_graph(graph_id, thread_config, checkpointer=checkpointer)
|
|
714
|
+
result = await graph.aget_state(config, subgraphs=subgraphs)
|
|
715
|
+
if (
|
|
716
|
+
result.metadata is not None
|
|
717
|
+
and "checkpoint_ns" in result.metadata
|
|
718
|
+
and result.metadata["checkpoint_ns"] == ""
|
|
719
|
+
):
|
|
720
|
+
result.metadata.pop("checkpoint_ns")
|
|
721
|
+
return result
|
|
722
|
+
else:
|
|
723
|
+
return StateSnapshot(
|
|
724
|
+
values={},
|
|
725
|
+
next=[],
|
|
726
|
+
config=None,
|
|
727
|
+
metadata=None,
|
|
728
|
+
created_at=None,
|
|
729
|
+
parent_config=None,
|
|
730
|
+
tasks=tuple(),
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
@staticmethod
|
|
734
|
+
async def post(
|
|
735
|
+
conn: InMemConnectionProto,
|
|
736
|
+
config: Config,
|
|
737
|
+
values: Sequence[dict] | dict[str, Any] | None,
|
|
738
|
+
as_node: str | None = None,
|
|
739
|
+
) -> ThreadUpdateResponse:
|
|
740
|
+
"""Add state to a thread."""
|
|
741
|
+
|
|
742
|
+
checkpointer = Checkpointer(conn)
|
|
743
|
+
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
744
|
+
thread_iter = await Threads.get(conn, thread_id)
|
|
745
|
+
thread = await fetchone(
|
|
746
|
+
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
747
|
+
)
|
|
748
|
+
checkpoint = await checkpointer.aget(config)
|
|
749
|
+
|
|
750
|
+
if not thread:
|
|
751
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
752
|
+
|
|
753
|
+
metadata = thread["metadata"]
|
|
754
|
+
thread_config = thread["config"]
|
|
755
|
+
|
|
756
|
+
if graph_id := metadata.get("graph_id"):
|
|
757
|
+
config["configurable"].setdefault("graph_id", graph_id)
|
|
758
|
+
|
|
759
|
+
checkpointer.latest_iter = checkpoint
|
|
760
|
+
graph = get_graph(graph_id, thread_config, checkpointer=checkpointer)
|
|
761
|
+
update_config = config.copy()
|
|
762
|
+
update_config["configurable"] = {
|
|
763
|
+
**config["configurable"],
|
|
764
|
+
"checkpoint_ns": config["configurable"].get("checkpoint_ns", ""),
|
|
765
|
+
}
|
|
766
|
+
next_config = await graph.aupdate_state(
|
|
767
|
+
update_config, values, as_node=as_node
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
# Get current state
|
|
771
|
+
state = await Threads.State.get(conn, config, subgraphs=False)
|
|
772
|
+
# Update thread values
|
|
773
|
+
for thread in conn.store["threads"]:
|
|
774
|
+
if thread["thread_id"] == thread_id:
|
|
775
|
+
thread["values"] = state.values
|
|
776
|
+
break
|
|
777
|
+
|
|
778
|
+
return ThreadUpdateResponse(
|
|
779
|
+
checkpoint=next_config["configurable"],
|
|
780
|
+
# Including deprecated fields
|
|
781
|
+
configurable=next_config["configurable"],
|
|
782
|
+
checkpoint_id=next_config["configurable"]["checkpoint_id"],
|
|
783
|
+
)
|
|
784
|
+
else:
|
|
785
|
+
raise HTTPException(status_code=400, detail="Thread has no graph ID.")
|
|
786
|
+
|
|
787
|
+
@staticmethod
|
|
788
|
+
async def list(
|
|
789
|
+
conn: InMemConnectionProto,
|
|
790
|
+
*,
|
|
791
|
+
config: Config,
|
|
792
|
+
limit: int = 10,
|
|
793
|
+
before: str | Checkpoint | None = None,
|
|
794
|
+
metadata: MetadataInput = None,
|
|
795
|
+
) -> list[StateSnapshot]:
|
|
796
|
+
"""Get the history of a thread."""
|
|
797
|
+
|
|
798
|
+
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
799
|
+
thread = None
|
|
800
|
+
|
|
801
|
+
for t in conn.store["threads"]:
|
|
802
|
+
if t["thread_id"] == thread_id:
|
|
803
|
+
thread = t
|
|
804
|
+
break
|
|
805
|
+
|
|
806
|
+
if not thread:
|
|
807
|
+
return []
|
|
808
|
+
|
|
809
|
+
# Parse thread metadata and config
|
|
810
|
+
thread_metadata = thread["metadata"]
|
|
811
|
+
thread_config = thread["config"]
|
|
812
|
+
# If graph_id exists, get state history
|
|
813
|
+
if graph_id := thread_metadata.get("graph_id"):
|
|
814
|
+
graph = get_graph(
|
|
815
|
+
graph_id, thread_config, checkpointer=Checkpointer(conn)
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
# Convert before parameter if it's a string
|
|
819
|
+
before_param = (
|
|
820
|
+
{"configurable": {"checkpoint_id": before}}
|
|
821
|
+
if isinstance(before, str)
|
|
822
|
+
else before
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
states = [
|
|
826
|
+
state
|
|
827
|
+
async for state in graph.aget_state_history(
|
|
828
|
+
config, limit=limit, filter=metadata, before=before_param
|
|
829
|
+
)
|
|
830
|
+
]
|
|
831
|
+
|
|
832
|
+
return states
|
|
833
|
+
|
|
834
|
+
return []
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
class Runs:
|
|
838
|
+
@staticmethod
|
|
839
|
+
async def stats(conn: InMemConnectionProto) -> QueueStats:
|
|
840
|
+
"""Get stats about the queue."""
|
|
841
|
+
pending_runs = [run for run in conn.store["runs"] if run["status"] == "pending"]
|
|
842
|
+
|
|
843
|
+
if not pending_runs:
|
|
844
|
+
return {"n_pending": 0, "max_age_secs": None, "med_age_secs": None}
|
|
845
|
+
|
|
846
|
+
# Get all creation timestamps
|
|
847
|
+
created_times = [run.get("created_at") for run in pending_runs]
|
|
848
|
+
created_times = [
|
|
849
|
+
t for t in created_times if t is not None
|
|
850
|
+
] # Filter out None values
|
|
851
|
+
|
|
852
|
+
if not created_times:
|
|
853
|
+
return {
|
|
854
|
+
"n_pending": len(pending_runs),
|
|
855
|
+
"max_age_secs": None,
|
|
856
|
+
"med_age_secs": None,
|
|
857
|
+
}
|
|
858
|
+
|
|
859
|
+
# Find oldest (max age)
|
|
860
|
+
oldest_time = min(created_times) # Earliest timestamp = oldest run
|
|
861
|
+
|
|
862
|
+
# Find median age
|
|
863
|
+
sorted_times = sorted(created_times)
|
|
864
|
+
median_idx = len(sorted_times) // 2
|
|
865
|
+
median_time = sorted_times[median_idx]
|
|
866
|
+
|
|
867
|
+
return {
|
|
868
|
+
"n_pending": len(pending_runs),
|
|
869
|
+
"max_age_secs": oldest_time,
|
|
870
|
+
"med_age_secs": median_time,
|
|
871
|
+
}
|
|
872
|
+
|
|
873
|
+
@asynccontextmanager
|
|
874
|
+
@staticmethod
|
|
875
|
+
async def next(conn: InMemConnectionProto) -> AsyncIterator[tuple[Run, int] | None]:
|
|
876
|
+
"""Get the next run from the queue, and the attempt number.
|
|
877
|
+
1 is the first attempt, 2 is the first retry, etc."""
|
|
878
|
+
now = datetime.now(UTC)
|
|
879
|
+
|
|
880
|
+
pending_runs = sorted(
|
|
881
|
+
[
|
|
882
|
+
run
|
|
883
|
+
for run in conn.store["runs"]
|
|
884
|
+
if run["status"] == "pending" and run.get("created_at", now) < now
|
|
885
|
+
],
|
|
886
|
+
key=lambda x: x.get("created_at", datetime.min),
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
if not pending_runs:
|
|
890
|
+
yield None
|
|
891
|
+
return
|
|
892
|
+
|
|
893
|
+
# Try to lock and get the first available run
|
|
894
|
+
for run in pending_runs:
|
|
895
|
+
run_id = run["run_id"]
|
|
896
|
+
thread_id = run["thread_id"]
|
|
897
|
+
lock = conn.locks[thread_id]
|
|
898
|
+
acquired = lock.acquire(blocking=False)
|
|
899
|
+
if not acquired:
|
|
900
|
+
continue
|
|
901
|
+
try:
|
|
902
|
+
if run["status"] != "pending":
|
|
903
|
+
continue
|
|
904
|
+
|
|
905
|
+
thread = next(
|
|
906
|
+
(
|
|
907
|
+
t
|
|
908
|
+
for t in conn.store["threads"]
|
|
909
|
+
if t["thread_id"] == run["thread_id"]
|
|
910
|
+
),
|
|
911
|
+
None,
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
if thread is None:
|
|
915
|
+
await logger.awarning(
|
|
916
|
+
"Unexpected missing thread in Runs.next",
|
|
917
|
+
thread_id=run["thread_id"],
|
|
918
|
+
)
|
|
919
|
+
continue
|
|
920
|
+
|
|
921
|
+
# Increment attempt counter
|
|
922
|
+
attempt = await conn.retry_counter.increment(run_id)
|
|
923
|
+
enriched_run = {
|
|
924
|
+
**run,
|
|
925
|
+
"thread_created_at": thread.get("created_at", now),
|
|
926
|
+
}
|
|
927
|
+
yield enriched_run, attempt
|
|
928
|
+
finally:
|
|
929
|
+
lock.release()
|
|
930
|
+
return
|
|
931
|
+
yield None
|
|
932
|
+
|
|
933
|
+
@asynccontextmanager
|
|
934
|
+
@staticmethod
|
|
935
|
+
async def enter(run_id: UUID) -> AsyncIterator[ValueEvent]:
|
|
936
|
+
"""Enter a run, listen for cancellation while running, signal when done."
|
|
937
|
+
This method should be called as a context manager by a worker executing a run.
|
|
938
|
+
"""
|
|
939
|
+
stream_manager = get_stream_manager()
|
|
940
|
+
# Get queue for this run
|
|
941
|
+
queue = await Runs.Stream.subscribe(run_id)
|
|
942
|
+
|
|
943
|
+
async with SimpleTaskGroup(cancel=True) as tg:
|
|
944
|
+
done = ValueEvent()
|
|
945
|
+
tg.create_task(listen_for_cancellation(queue, run_id, done))
|
|
946
|
+
|
|
947
|
+
try:
|
|
948
|
+
# Give done event to caller
|
|
949
|
+
yield done
|
|
950
|
+
finally:
|
|
951
|
+
# Signal done to all subscribers
|
|
952
|
+
control_message = Message(
|
|
953
|
+
topic=f"run:{run_id}:control".encode(), data=b"done"
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
# Store the control message for late subscribers
|
|
957
|
+
await stream_manager.put(run_id, control_message)
|
|
958
|
+
stream_manager.control_queues[run_id].append(control_message)
|
|
959
|
+
# Clean up this queue
|
|
960
|
+
await stream_manager.remove_queue(run_id, queue)
|
|
961
|
+
|
|
962
|
+
@staticmethod
|
|
963
|
+
def _merge_jsonb(*objects: dict) -> dict:
|
|
964
|
+
"""Mimics PostgreSQL's JSONB merge behavior"""
|
|
965
|
+
result = {}
|
|
966
|
+
for obj in objects:
|
|
967
|
+
if obj is not None:
|
|
968
|
+
result.update(copy.deepcopy(obj))
|
|
969
|
+
return result
|
|
970
|
+
|
|
971
|
+
@staticmethod
|
|
972
|
+
def _get_configurable(config: dict) -> dict:
|
|
973
|
+
"""Extract configurable from config, mimicking PostgreSQL's coalesce"""
|
|
974
|
+
return config.get("configurable", {})
|
|
975
|
+
|
|
976
|
+
@staticmethod
|
|
977
|
+
async def put(
|
|
978
|
+
conn: InMemConnectionProto,
|
|
979
|
+
assistant_id: UUID,
|
|
980
|
+
kwargs: dict,
|
|
981
|
+
*,
|
|
982
|
+
thread_id: UUID | None = None,
|
|
983
|
+
user_id: str | None = None,
|
|
984
|
+
run_id: UUID | None = None,
|
|
985
|
+
status: RunStatus | None = "pending",
|
|
986
|
+
metadata: MetadataInput,
|
|
987
|
+
prevent_insert_if_inflight: bool,
|
|
988
|
+
multitask_strategy: MultitaskStrategy = "reject",
|
|
989
|
+
if_not_exists: IfNotExists = "reject",
|
|
990
|
+
after_seconds: int = 0,
|
|
991
|
+
) -> AsyncIterator[Run]:
|
|
992
|
+
"""Create a run."""
|
|
993
|
+
assistant_id = _ensure_uuid(assistant_id)
|
|
994
|
+
assistant = next(
|
|
995
|
+
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
996
|
+
None,
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
async def empty_generator():
|
|
1000
|
+
if False:
|
|
1001
|
+
yield
|
|
1002
|
+
|
|
1003
|
+
if not assistant:
|
|
1004
|
+
return empty_generator()
|
|
1005
|
+
|
|
1006
|
+
thread_id = _ensure_uuid(thread_id) if thread_id else None
|
|
1007
|
+
run_id = _ensure_uuid(run_id) if run_id else None
|
|
1008
|
+
metadata = metadata or {}
|
|
1009
|
+
config = kwargs.get("config", {})
|
|
1010
|
+
|
|
1011
|
+
# Handle thread creation/update
|
|
1012
|
+
existing_thread = next(
|
|
1013
|
+
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
if not existing_thread and (thread_id is None or if_not_exists == "create"):
|
|
1017
|
+
# Create new thread
|
|
1018
|
+
if thread_id is None:
|
|
1019
|
+
thread_id = uuid4()
|
|
1020
|
+
thread = Thread(
|
|
1021
|
+
thread_id=thread_id,
|
|
1022
|
+
status="busy",
|
|
1023
|
+
metadata={"graph_id": assistant["graph_id"]},
|
|
1024
|
+
config=Runs._merge_jsonb(
|
|
1025
|
+
assistant["config"],
|
|
1026
|
+
config,
|
|
1027
|
+
{
|
|
1028
|
+
"configurable": Runs._merge_jsonb(
|
|
1029
|
+
Runs._get_configurable(assistant["config"]),
|
|
1030
|
+
Runs._get_configurable(config),
|
|
1031
|
+
)
|
|
1032
|
+
},
|
|
1033
|
+
),
|
|
1034
|
+
created_at=datetime.now(UTC),
|
|
1035
|
+
updated_at=datetime.now(UTC),
|
|
1036
|
+
)
|
|
1037
|
+
conn.store["threads"].append(thread)
|
|
1038
|
+
elif existing_thread:
|
|
1039
|
+
# Update existing thread
|
|
1040
|
+
if existing_thread["status"] != "busy":
|
|
1041
|
+
existing_thread["status"] = "busy"
|
|
1042
|
+
existing_thread["metadata"] = Runs._merge_jsonb(
|
|
1043
|
+
existing_thread["metadata"], {"graph_id": assistant["graph_id"]}
|
|
1044
|
+
)
|
|
1045
|
+
existing_thread["config"] = Runs._merge_jsonb(
|
|
1046
|
+
assistant["config"],
|
|
1047
|
+
existing_thread["config"],
|
|
1048
|
+
config,
|
|
1049
|
+
{
|
|
1050
|
+
"configurable": Runs._merge_jsonb(
|
|
1051
|
+
Runs._get_configurable(assistant["config"]),
|
|
1052
|
+
Runs._get_configurable(existing_thread["config"]),
|
|
1053
|
+
Runs._get_configurable(config),
|
|
1054
|
+
)
|
|
1055
|
+
},
|
|
1056
|
+
)
|
|
1057
|
+
existing_thread["updated_at"] = datetime.now(UTC)
|
|
1058
|
+
else:
|
|
1059
|
+
return empty_generator()
|
|
1060
|
+
|
|
1061
|
+
# Check for inflight runs if needed
|
|
1062
|
+
inflight_runs = [
|
|
1063
|
+
r
|
|
1064
|
+
for r in conn.store["runs"]
|
|
1065
|
+
if r["thread_id"] == thread_id and r["status"] == "pending"
|
|
1066
|
+
]
|
|
1067
|
+
if prevent_insert_if_inflight:
|
|
1068
|
+
if inflight_runs:
|
|
1069
|
+
|
|
1070
|
+
async def _return_inflight():
|
|
1071
|
+
for run in inflight_runs:
|
|
1072
|
+
yield run
|
|
1073
|
+
|
|
1074
|
+
return _return_inflight()
|
|
1075
|
+
|
|
1076
|
+
# Create new run
|
|
1077
|
+
configurable = Runs._merge_jsonb(
|
|
1078
|
+
Runs._get_configurable(assistant["config"]),
|
|
1079
|
+
Runs._get_configurable(config),
|
|
1080
|
+
{
|
|
1081
|
+
"run_id": str(run_id),
|
|
1082
|
+
"thread_id": str(thread_id),
|
|
1083
|
+
"graph_id": assistant["graph_id"],
|
|
1084
|
+
"assistant_id": str(assistant_id),
|
|
1085
|
+
"user_id": (
|
|
1086
|
+
config.get("configurable", {}).get("user_id")
|
|
1087
|
+
or assistant["config"].get("configurable", {}).get("user_id")
|
|
1088
|
+
or user_id
|
|
1089
|
+
),
|
|
1090
|
+
},
|
|
1091
|
+
)
|
|
1092
|
+
merged_metadata = Runs._merge_jsonb(
|
|
1093
|
+
assistant["metadata"],
|
|
1094
|
+
existing_thread["metadata"] if existing_thread else {},
|
|
1095
|
+
metadata,
|
|
1096
|
+
)
|
|
1097
|
+
new_run = Run(
|
|
1098
|
+
run_id=run_id,
|
|
1099
|
+
thread_id=thread_id,
|
|
1100
|
+
assistant_id=assistant_id,
|
|
1101
|
+
metadata=merged_metadata,
|
|
1102
|
+
status=status,
|
|
1103
|
+
kwargs=Runs._merge_jsonb(
|
|
1104
|
+
kwargs,
|
|
1105
|
+
{
|
|
1106
|
+
"config": Runs._merge_jsonb(
|
|
1107
|
+
assistant["config"],
|
|
1108
|
+
config,
|
|
1109
|
+
{"configurable": configurable},
|
|
1110
|
+
{
|
|
1111
|
+
"metadata": merged_metadata,
|
|
1112
|
+
},
|
|
1113
|
+
)
|
|
1114
|
+
},
|
|
1115
|
+
),
|
|
1116
|
+
multitask_strategy=multitask_strategy,
|
|
1117
|
+
created_at=datetime.now(UTC) + timedelta(seconds=after_seconds),
|
|
1118
|
+
updated_at=datetime.now(UTC),
|
|
1119
|
+
)
|
|
1120
|
+
conn.store["runs"].append(new_run)
|
|
1121
|
+
|
|
1122
|
+
async def _yield_new():
|
|
1123
|
+
yield new_run
|
|
1124
|
+
for r in inflight_runs:
|
|
1125
|
+
yield r
|
|
1126
|
+
|
|
1127
|
+
return _yield_new()
|
|
1128
|
+
|
|
1129
|
+
@staticmethod
|
|
1130
|
+
async def get(
|
|
1131
|
+
conn: InMemConnectionProto, run_id: UUID, *, thread_id: UUID
|
|
1132
|
+
) -> AsyncIterator[Run]:
|
|
1133
|
+
"""Get a run by ID."""
|
|
1134
|
+
|
|
1135
|
+
run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
|
|
1136
|
+
|
|
1137
|
+
async def _yield_result():
|
|
1138
|
+
matching_run = None
|
|
1139
|
+
for run in conn.store["runs"]:
|
|
1140
|
+
if run["run_id"] == run_id and run["thread_id"] == thread_id:
|
|
1141
|
+
matching_run = run
|
|
1142
|
+
break
|
|
1143
|
+
if matching_run:
|
|
1144
|
+
yield matching_run
|
|
1145
|
+
|
|
1146
|
+
return _yield_result()
|
|
1147
|
+
|
|
1148
|
+
@staticmethod
|
|
1149
|
+
async def delete(
|
|
1150
|
+
conn: InMemConnectionProto, run_id: UUID, *, thread_id: UUID
|
|
1151
|
+
) -> AsyncIterator[UUID]:
|
|
1152
|
+
"""Delete a run by ID."""
|
|
1153
|
+
run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
|
|
1154
|
+
_delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
|
|
1155
|
+
found = False
|
|
1156
|
+
for i, run in enumerate(conn.store["runs"]):
|
|
1157
|
+
if run["run_id"] == run_id and run["thread_id"] == thread_id:
|
|
1158
|
+
del conn.store["runs"][i]
|
|
1159
|
+
found = True
|
|
1160
|
+
break
|
|
1161
|
+
if not found:
|
|
1162
|
+
raise HTTPException(status_code=404, detail="Run not found")
|
|
1163
|
+
|
|
1164
|
+
async def _yield_deleted():
|
|
1165
|
+
yield run_id
|
|
1166
|
+
|
|
1167
|
+
return _yield_deleted()
|
|
1168
|
+
|
|
1169
|
+
@staticmethod
|
|
1170
|
+
async def join(
|
|
1171
|
+
run_id: UUID,
|
|
1172
|
+
*,
|
|
1173
|
+
thread_id: UUID,
|
|
1174
|
+
) -> Fragment:
|
|
1175
|
+
"""Wait for a run to complete. If already done, return immediately.
|
|
1176
|
+
|
|
1177
|
+
Returns:
|
|
1178
|
+
the final state of the run.
|
|
1179
|
+
"""
|
|
1180
|
+
last_chunk: bytes | None = None
|
|
1181
|
+
# wait for the run to complete
|
|
1182
|
+
async for mode, chunk in Runs.Stream.join(
|
|
1183
|
+
run_id, thread_id=thread_id, stream_mode="values"
|
|
1184
|
+
):
|
|
1185
|
+
if mode == b"values":
|
|
1186
|
+
last_chunk = chunk
|
|
1187
|
+
# if we received a final chunk, return it
|
|
1188
|
+
if last_chunk is not None:
|
|
1189
|
+
# ie. if the run completed while we were waiting for it
|
|
1190
|
+
return Fragment(last_chunk)
|
|
1191
|
+
else:
|
|
1192
|
+
# otherwise, the run had already finished, so fetch the state from thread
|
|
1193
|
+
async with connect() as conn:
|
|
1194
|
+
thread_iter = await Threads.get(conn, thread_id)
|
|
1195
|
+
thread = await fetchone(thread_iter)
|
|
1196
|
+
return thread["values"]
|
|
1197
|
+
|
|
1198
|
+
@staticmethod
|
|
1199
|
+
async def cancel(
|
|
1200
|
+
conn: InMemConnectionProto,
|
|
1201
|
+
run_ids: Sequence[UUID],
|
|
1202
|
+
*,
|
|
1203
|
+
action: Literal["interrupt", "rollback"] = "interrupt",
|
|
1204
|
+
thread_id: UUID,
|
|
1205
|
+
) -> None:
|
|
1206
|
+
"""Cancel a run."""
|
|
1207
|
+
# Cancellation tries to take two actions, to cover runs in different states:
|
|
1208
|
+
# - For any run, send a cancellation message through the stream manager
|
|
1209
|
+
# - For queued runs not yet picked up by a worker, update their status if interrupt,
|
|
1210
|
+
# delete if rollback.
|
|
1211
|
+
# - For runs currently being worked on, the worker will handle cancellation
|
|
1212
|
+
# - For runs in any other state, we raise a 404
|
|
1213
|
+
run_ids = [_ensure_uuid(run_id) for run_id in run_ids]
|
|
1214
|
+
thread_id = _ensure_uuid(thread_id)
|
|
1215
|
+
|
|
1216
|
+
stream_manager = get_stream_manager()
|
|
1217
|
+
found_runs = []
|
|
1218
|
+
coros = []
|
|
1219
|
+
for run_id in run_ids:
|
|
1220
|
+
run = next(
|
|
1221
|
+
(
|
|
1222
|
+
r
|
|
1223
|
+
for r in conn.store["runs"]
|
|
1224
|
+
if r["run_id"] == run_id and r["thread_id"] == thread_id
|
|
1225
|
+
),
|
|
1226
|
+
None,
|
|
1227
|
+
)
|
|
1228
|
+
if run:
|
|
1229
|
+
found_runs.append(run)
|
|
1230
|
+
# Send cancellation message through stream manager
|
|
1231
|
+
control_message = Message(
|
|
1232
|
+
topic=f"run:{run_id}:control".encode(),
|
|
1233
|
+
data=action.encode(),
|
|
1234
|
+
)
|
|
1235
|
+
queues = stream_manager.get_queues(run_id)
|
|
1236
|
+
coros.append(stream_manager.put(run_id, control_message))
|
|
1237
|
+
|
|
1238
|
+
# Update status for pending runs
|
|
1239
|
+
if run["status"] == "pending":
|
|
1240
|
+
if queues or action != "rollback":
|
|
1241
|
+
run["status"] = "interrupted"
|
|
1242
|
+
run["updated_at"] = datetime.now(tz=UTC)
|
|
1243
|
+
else:
|
|
1244
|
+
await logger.ainfo(
|
|
1245
|
+
"Eagerly deleting unscheduled run with rollback action",
|
|
1246
|
+
run_id=run_id,
|
|
1247
|
+
thread_id=thread_id,
|
|
1248
|
+
)
|
|
1249
|
+
coros.append(Runs.delete(conn, run_id, thread_id=thread_id))
|
|
1250
|
+
|
|
1251
|
+
else:
|
|
1252
|
+
await logger.awarning(
|
|
1253
|
+
"Attempted to cancel non-pending run.",
|
|
1254
|
+
run_id=run_id,
|
|
1255
|
+
status=run["status"],
|
|
1256
|
+
)
|
|
1257
|
+
if coros:
|
|
1258
|
+
await asyncio.gather(*coros)
|
|
1259
|
+
if len(found_runs) == len(run_ids):
|
|
1260
|
+
await logger.ainfo(
|
|
1261
|
+
"Cancelled runs",
|
|
1262
|
+
run_ids=[str(run_id) for run_id in run_ids],
|
|
1263
|
+
thread_id=str(thread_id),
|
|
1264
|
+
action=action,
|
|
1265
|
+
)
|
|
1266
|
+
else:
|
|
1267
|
+
raise HTTPException(status_code=404, detail="Run not found")
|
|
1268
|
+
|
|
1269
|
+
@staticmethod
|
|
1270
|
+
async def search(
|
|
1271
|
+
conn: InMemConnectionProto,
|
|
1272
|
+
thread_id: UUID,
|
|
1273
|
+
*,
|
|
1274
|
+
limit: int = 10,
|
|
1275
|
+
offset: int = 0,
|
|
1276
|
+
metadata: MetadataInput,
|
|
1277
|
+
) -> AsyncIterator[Run]:
|
|
1278
|
+
"""List all runs by thread."""
|
|
1279
|
+
runs = conn.store["runs"]
|
|
1280
|
+
metadata = metadata or {}
|
|
1281
|
+
thread_id = _ensure_uuid(thread_id)
|
|
1282
|
+
filtered_runs = [
|
|
1283
|
+
run
|
|
1284
|
+
for run in runs
|
|
1285
|
+
if run["thread_id"] == thread_id
|
|
1286
|
+
and is_jsonb_contained(run["metadata"], metadata)
|
|
1287
|
+
]
|
|
1288
|
+
sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
|
|
1289
|
+
sliced_runs = sorted_runs[offset : offset + limit]
|
|
1290
|
+
|
|
1291
|
+
async def _return():
|
|
1292
|
+
for run in sliced_runs:
|
|
1293
|
+
yield run
|
|
1294
|
+
|
|
1295
|
+
return _return()
|
|
1296
|
+
|
|
1297
|
+
@staticmethod
|
|
1298
|
+
async def set_status(
|
|
1299
|
+
conn: InMemConnectionProto, run_id: UUID, status: RunStatus
|
|
1300
|
+
) -> None:
|
|
1301
|
+
"""Set the status of a run."""
|
|
1302
|
+
# Find the run in the store
|
|
1303
|
+
run_id = _ensure_uuid(run_id)
|
|
1304
|
+
run = next((run for run in conn.store["runs"] if run["run_id"] == run_id), None)
|
|
1305
|
+
|
|
1306
|
+
if run:
|
|
1307
|
+
# Update the status and updated_at timestamp
|
|
1308
|
+
run["status"] = status
|
|
1309
|
+
run["updated_at"] = datetime.now(tz=UTC)
|
|
1310
|
+
return run
|
|
1311
|
+
return None
|
|
1312
|
+
|
|
1313
|
+
class Stream:
|
|
1314
|
+
@staticmethod
|
|
1315
|
+
async def subscribe(
|
|
1316
|
+
run_id: UUID,
|
|
1317
|
+
*,
|
|
1318
|
+
stream_mode: "StreamMode | None" = None,
|
|
1319
|
+
) -> asyncio.Queue:
|
|
1320
|
+
"""Subscribe to the run stream, returning a queue."""
|
|
1321
|
+
stream_manager = get_stream_manager()
|
|
1322
|
+
queue = await stream_manager.add_queue(_ensure_uuid(run_id))
|
|
1323
|
+
|
|
1324
|
+
# If there's a control message already stored, send it to the new subscriber
|
|
1325
|
+
if control_messages := stream_manager.control_queues.get(run_id):
|
|
1326
|
+
for control_msg in control_messages:
|
|
1327
|
+
await queue.put(control_msg)
|
|
1328
|
+
return queue
|
|
1329
|
+
|
|
1330
|
+
@staticmethod
|
|
1331
|
+
async def join(
|
|
1332
|
+
run_id: UUID,
|
|
1333
|
+
*,
|
|
1334
|
+
thread_id: UUID,
|
|
1335
|
+
ignore_404: bool = False,
|
|
1336
|
+
cancel_on_disconnect: bool = False,
|
|
1337
|
+
stream_mode: "StreamMode | asyncio.Queue | None" = None,
|
|
1338
|
+
) -> AsyncIterator[tuple[bytes, bytes]]:
|
|
1339
|
+
"""Stream the run output."""
|
|
1340
|
+
log = logger.isEnabledFor(logging.DEBUG)
|
|
1341
|
+
queue = (
|
|
1342
|
+
stream_mode
|
|
1343
|
+
if isinstance(stream_mode, asyncio.Queue)
|
|
1344
|
+
else await Runs.Stream.subscribe(run_id)
|
|
1345
|
+
)
|
|
1346
|
+
|
|
1347
|
+
try:
|
|
1348
|
+
async with connect() as conn:
|
|
1349
|
+
channel_prefix = f"run:{run_id}:stream:"
|
|
1350
|
+
len_prefix = len(channel_prefix.encode())
|
|
1351
|
+
|
|
1352
|
+
while True:
|
|
1353
|
+
try:
|
|
1354
|
+
# Wait for messages with a timeout
|
|
1355
|
+
message = await asyncio.wait_for(queue.get(), timeout=0.5)
|
|
1356
|
+
topic, data = message.topic, message.data
|
|
1357
|
+
|
|
1358
|
+
if topic.decode() == f"run:{run_id}:control":
|
|
1359
|
+
if data == b"done":
|
|
1360
|
+
break
|
|
1361
|
+
else:
|
|
1362
|
+
# Extract mode from topic
|
|
1363
|
+
yield topic[len_prefix:], data
|
|
1364
|
+
if log:
|
|
1365
|
+
await logger.adebug(
|
|
1366
|
+
"Streamed run event",
|
|
1367
|
+
run_id=str(run_id),
|
|
1368
|
+
stream_mode=topic[len_prefix:],
|
|
1369
|
+
data=data,
|
|
1370
|
+
)
|
|
1371
|
+
except TimeoutError:
|
|
1372
|
+
# Check if the run is still pending
|
|
1373
|
+
run_iter = await Runs.get(conn, run_id, thread_id=thread_id)
|
|
1374
|
+
run = await anext(run_iter, None)
|
|
1375
|
+
|
|
1376
|
+
if ignore_404 and run is None:
|
|
1377
|
+
break
|
|
1378
|
+
elif run is None:
|
|
1379
|
+
yield (
|
|
1380
|
+
b"error",
|
|
1381
|
+
HTTPException(
|
|
1382
|
+
status_code=404, detail="Run not found"
|
|
1383
|
+
),
|
|
1384
|
+
)
|
|
1385
|
+
break
|
|
1386
|
+
elif run["status"] != "pending":
|
|
1387
|
+
break
|
|
1388
|
+
except:
|
|
1389
|
+
if cancel_on_disconnect:
|
|
1390
|
+
create_task(cancel_run(thread_id, run_id))
|
|
1391
|
+
raise
|
|
1392
|
+
finally:
|
|
1393
|
+
stream_manager = get_stream_manager()
|
|
1394
|
+
await stream_manager.remove_queue(run_id, queue)
|
|
1395
|
+
|
|
1396
|
+
@staticmethod
|
|
1397
|
+
async def publish(
|
|
1398
|
+
run_id: UUID,
|
|
1399
|
+
event: str,
|
|
1400
|
+
message: bytes,
|
|
1401
|
+
) -> None:
|
|
1402
|
+
"""Publish a message to all subscribers of the run stream."""
|
|
1403
|
+
topic = f"run:{run_id}:stream:{event}".encode()
|
|
1404
|
+
|
|
1405
|
+
stream_manager = get_stream_manager()
|
|
1406
|
+
# Send to all queues subscribed to this run_id
|
|
1407
|
+
await stream_manager.put(run_id, Message(topic=topic, data=message))
|
|
1408
|
+
|
|
1409
|
+
|
|
1410
|
+
async def listen_for_cancellation(
|
|
1411
|
+
queue: asyncio.Queue, run_id: UUID, done: "ValueEvent"
|
|
1412
|
+
):
|
|
1413
|
+
"""Listen for cancellation messages and set the done event accordingly."""
|
|
1414
|
+
stream_manager = get_stream_manager()
|
|
1415
|
+
control_key = f"run:{run_id}:control"
|
|
1416
|
+
|
|
1417
|
+
if existing_queue := stream_manager.control_queues.get(run_id):
|
|
1418
|
+
for message in existing_queue:
|
|
1419
|
+
payload = message.data
|
|
1420
|
+
if payload == b"rollback":
|
|
1421
|
+
done.set(UserRollback())
|
|
1422
|
+
elif payload == b"interrupt":
|
|
1423
|
+
done.set(UserInterrupt())
|
|
1424
|
+
|
|
1425
|
+
while not done.is_set():
|
|
1426
|
+
try:
|
|
1427
|
+
# This task gets cancelled when Runs.enter exits anyway,
|
|
1428
|
+
# so we can have a pretty length timeout here
|
|
1429
|
+
message = await asyncio.wait_for(queue.get(), timeout=240)
|
|
1430
|
+
payload = message.data
|
|
1431
|
+
if payload == b"rollback":
|
|
1432
|
+
done.set(UserRollback())
|
|
1433
|
+
elif payload == b"interrupt":
|
|
1434
|
+
done.set(UserInterrupt())
|
|
1435
|
+
elif payload == b"done":
|
|
1436
|
+
done.set()
|
|
1437
|
+
break
|
|
1438
|
+
|
|
1439
|
+
# Store control messages for late subscribers
|
|
1440
|
+
if message.topic.decode() == control_key:
|
|
1441
|
+
stream_manager.control_queues[run_id].append(message)
|
|
1442
|
+
except TimeoutError:
|
|
1443
|
+
break
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
class Crons:
|
|
1447
|
+
@staticmethod
|
|
1448
|
+
async def put(
|
|
1449
|
+
conn: InMemConnectionProto,
|
|
1450
|
+
*,
|
|
1451
|
+
payload: dict,
|
|
1452
|
+
schedule: str,
|
|
1453
|
+
cron_id: UUID | None = None,
|
|
1454
|
+
thread_id: UUID | None = None,
|
|
1455
|
+
user_id: str | None = None,
|
|
1456
|
+
end_time: datetime | None = None,
|
|
1457
|
+
) -> AsyncIterator[Cron]:
|
|
1458
|
+
raise NotImplementedError
|
|
1459
|
+
|
|
1460
|
+
@staticmethod
|
|
1461
|
+
async def delete(conn: InMemConnectionProto, cron_id: UUID) -> AsyncIterator[UUID]:
|
|
1462
|
+
raise NotImplementedError
|
|
1463
|
+
|
|
1464
|
+
@staticmethod
|
|
1465
|
+
async def next(conn: InMemConnectionProto) -> AsyncIterator[Cron]:
|
|
1466
|
+
raise NotImplementedError
|
|
1467
|
+
|
|
1468
|
+
@staticmethod
|
|
1469
|
+
async def set_next_run_date(
|
|
1470
|
+
conn: InMemConnectionProto, cron_id: UUID, next_run_date: datetime
|
|
1471
|
+
) -> None:
|
|
1472
|
+
raise NotImplementedError
|
|
1473
|
+
|
|
1474
|
+
@staticmethod
|
|
1475
|
+
async def search(
|
|
1476
|
+
conn: InMemConnectionProto,
|
|
1477
|
+
*,
|
|
1478
|
+
assistant_id: UUID | None,
|
|
1479
|
+
thread_id: UUID | None,
|
|
1480
|
+
limit: int,
|
|
1481
|
+
offset: int,
|
|
1482
|
+
) -> AsyncIterator[Cron]:
|
|
1483
|
+
raise NotImplementedError
|
|
1484
|
+
|
|
1485
|
+
|
|
1486
|
+
async def cancel_run(thread_id: UUID, run_id: UUID) -> None:
|
|
1487
|
+
async with connect() as conn:
|
|
1488
|
+
await Runs.cancel(conn, [run_id], thread_id=thread_id)
|
|
1489
|
+
|
|
1490
|
+
|
|
1491
|
+
def _delete_checkpoints_for_thread(
|
|
1492
|
+
thread_id: str | UUID,
|
|
1493
|
+
conn: InMemConnectionProto,
|
|
1494
|
+
run_id: str | UUID | None = None,
|
|
1495
|
+
):
|
|
1496
|
+
checkpointer = Checkpointer(conn)
|
|
1497
|
+
thread_id = str(thread_id)
|
|
1498
|
+
if thread_id not in checkpointer.storage:
|
|
1499
|
+
return
|
|
1500
|
+
if run_id:
|
|
1501
|
+
# Look through metadata
|
|
1502
|
+
run_id = str(run_id)
|
|
1503
|
+
for checkpoint_ns, checkpoints in list(checkpointer.storage[thread_id].items()):
|
|
1504
|
+
for checkpoint_id, (_, metadata_b, _) in list(checkpoints.items()):
|
|
1505
|
+
metadata = checkpointer.serde.loads_typed(metadata_b)
|
|
1506
|
+
if metadata.get("run_id") == run_id:
|
|
1507
|
+
del checkpointer.storage[thread_id][checkpoint_ns][checkpoint_id]
|
|
1508
|
+
if not checkpointer.storage[thread_id][checkpoint_ns]:
|
|
1509
|
+
del checkpointer.storage[thread_id][checkpoint_ns]
|
|
1510
|
+
else:
|
|
1511
|
+
del checkpointer.storage[thread_id]
|
|
1512
|
+
# Keys are (thread_id, checkpoint_ns, checkpoint_id)
|
|
1513
|
+
checkpointer.writes = defaultdict(
|
|
1514
|
+
dict, {k: v for k, v in checkpointer.writes.items() if k[0] != thread_id}
|
|
1515
|
+
)
|
|
1516
|
+
|
|
1517
|
+
|
|
1518
|
+
__all__ = [
|
|
1519
|
+
"Assistants",
|
|
1520
|
+
"Crons",
|
|
1521
|
+
"Runs",
|
|
1522
|
+
"Threads",
|
|
1523
|
+
]
|