langgraph-api 0.0.47__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- langgraph_api/__init__.py +1 -1
- langgraph_api/api/__init__.py +2 -2
- langgraph_api/api/assistants.py +3 -3
- langgraph_api/api/meta.py +9 -11
- langgraph_api/api/runs.py +3 -3
- langgraph_api/api/store.py +2 -2
- langgraph_api/api/threads.py +3 -3
- langgraph_api/cli.py +3 -1
- langgraph_api/config.py +3 -0
- langgraph_api/cron_scheduler.py +3 -3
- langgraph_api/graph.py +2 -2
- langgraph_api/js/remote.py +3 -3
- langgraph_api/metadata.py +7 -0
- langgraph_api/middleware/http_logger.py +19 -16
- langgraph_api/models/run.py +10 -1
- langgraph_api/queue_entrypoint.py +1 -1
- langgraph_api/server.py +2 -2
- langgraph_api/stream.py +3 -3
- langgraph_api/thread_ttl.py +2 -2
- langgraph_api/worker.py +3 -3
- {langgraph_api-0.0.47.dist-info → langgraph_api-0.1.0.dist-info}/METADATA +1 -1
- {langgraph_api-0.0.47.dist-info → langgraph_api-0.1.0.dist-info}/RECORD +26 -34
- langgraph_runtime/__init__.py +39 -0
- langgraph_api/lifespan.py +0 -74
- langgraph_storage/__init__.py +0 -0
- langgraph_storage/checkpoint.py +0 -123
- langgraph_storage/database.py +0 -200
- langgraph_storage/inmem_stream.py +0 -109
- langgraph_storage/ops.py +0 -2175
- langgraph_storage/queue.py +0 -186
- langgraph_storage/retry.py +0 -31
- langgraph_storage/store.py +0 -100
- {langgraph_api-0.0.47.dist-info → langgraph_api-0.1.0.dist-info}/LICENSE +0 -0
- {langgraph_api-0.0.47.dist-info → langgraph_api-0.1.0.dist-info}/WHEEL +0 -0
- {langgraph_api-0.0.47.dist-info → langgraph_api-0.1.0.dist-info}/entry_points.txt +0 -0
langgraph_storage/ops.py
DELETED
|
@@ -1,2175 +0,0 @@
|
|
|
1
|
-
"""Implementation of the LangGraph API using in-memory checkpointer & store."""
|
|
2
|
-
|
|
3
|
-
import asyncio
|
|
4
|
-
import base64
|
|
5
|
-
import copy
|
|
6
|
-
import json
|
|
7
|
-
import logging
|
|
8
|
-
import uuid
|
|
9
|
-
from collections import defaultdict
|
|
10
|
-
from collections.abc import AsyncIterator, Sequence
|
|
11
|
-
from contextlib import asynccontextmanager
|
|
12
|
-
from copy import deepcopy
|
|
13
|
-
from datetime import UTC, datetime, timedelta
|
|
14
|
-
from typing import Any, Literal, cast
|
|
15
|
-
from uuid import UUID, uuid4
|
|
16
|
-
|
|
17
|
-
import structlog
|
|
18
|
-
from langgraph.checkpoint.serde.jsonplus import _msgpack_ext_hook_to_json
|
|
19
|
-
from langgraph.pregel.debug import CheckpointPayload
|
|
20
|
-
from langgraph.pregel.types import StateSnapshot
|
|
21
|
-
from langgraph_sdk import Auth
|
|
22
|
-
from starlette.exceptions import HTTPException
|
|
23
|
-
|
|
24
|
-
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
|
|
25
|
-
from langgraph_api.auth.custom import handle_event
|
|
26
|
-
from langgraph_api.command import map_cmd
|
|
27
|
-
from langgraph_api.config import ThreadTTLConfig
|
|
28
|
-
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
29
|
-
from langgraph_api.graph import get_graph
|
|
30
|
-
from langgraph_api.schema import (
|
|
31
|
-
Assistant,
|
|
32
|
-
Checkpoint,
|
|
33
|
-
Config,
|
|
34
|
-
Cron,
|
|
35
|
-
IfNotExists,
|
|
36
|
-
MetadataInput,
|
|
37
|
-
MetadataValue,
|
|
38
|
-
MultitaskStrategy,
|
|
39
|
-
OnConflictBehavior,
|
|
40
|
-
QueueStats,
|
|
41
|
-
Run,
|
|
42
|
-
RunStatus,
|
|
43
|
-
StreamMode,
|
|
44
|
-
Thread,
|
|
45
|
-
ThreadStatus,
|
|
46
|
-
ThreadUpdateResponse,
|
|
47
|
-
)
|
|
48
|
-
from langgraph_api.serde import Fragment
|
|
49
|
-
from langgraph_api.utils import fetchone, get_auth_ctx
|
|
50
|
-
from langgraph_storage.checkpoint import Checkpointer
|
|
51
|
-
from langgraph_storage.database import InMemConnectionProto, connect
|
|
52
|
-
from langgraph_storage.inmem_stream import Message, get_stream_manager
|
|
53
|
-
|
|
54
|
-
logger = structlog.stdlib.get_logger(__name__)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
|
|
58
|
-
if isinstance(id_, str):
|
|
59
|
-
return uuid.UUID(id_)
|
|
60
|
-
if id_ is None:
|
|
61
|
-
return uuid4()
|
|
62
|
-
return id_
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class WrappedHTTPException(Exception):
|
|
66
|
-
def __init__(self, http_exception: HTTPException):
|
|
67
|
-
self.http_exception = http_exception
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
# Right now the whole API types as UUID but frequently passes a str
|
|
71
|
-
# We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
|
|
72
|
-
# which we leave as strings. This is because I'm too lazy to subclass fully
|
|
73
|
-
# and we use non-UUID examples in the OSS version
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class Authenticated:
|
|
77
|
-
resource: Literal["threads", "crons", "assistants"]
|
|
78
|
-
|
|
79
|
-
@classmethod
|
|
80
|
-
def _context(
|
|
81
|
-
cls,
|
|
82
|
-
ctx: Auth.types.BaseAuthContext | None,
|
|
83
|
-
action: Literal["create", "read", "update", "delete", "create_run"],
|
|
84
|
-
) -> Auth.types.AuthContext | None:
|
|
85
|
-
if not ctx:
|
|
86
|
-
return
|
|
87
|
-
return Auth.types.AuthContext(
|
|
88
|
-
user=ctx.user,
|
|
89
|
-
permissions=ctx.permissions,
|
|
90
|
-
resource=cls.resource,
|
|
91
|
-
action=action,
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
@classmethod
|
|
95
|
-
async def handle_event(
|
|
96
|
-
cls,
|
|
97
|
-
ctx: Auth.types.BaseAuthContext | None,
|
|
98
|
-
action: Literal["create", "read", "update", "delete", "search", "create_run"],
|
|
99
|
-
value: Any,
|
|
100
|
-
) -> Auth.types.FilterType | None:
|
|
101
|
-
ctx = ctx or get_auth_ctx()
|
|
102
|
-
if not ctx:
|
|
103
|
-
return
|
|
104
|
-
return await handle_event(cls._context(ctx, action), value)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class Assistants(Authenticated):
|
|
108
|
-
resource = "assistants"
|
|
109
|
-
|
|
110
|
-
@staticmethod
|
|
111
|
-
async def search(
|
|
112
|
-
conn: InMemConnectionProto,
|
|
113
|
-
*,
|
|
114
|
-
graph_id: str | None,
|
|
115
|
-
metadata: MetadataInput,
|
|
116
|
-
limit: int,
|
|
117
|
-
offset: int,
|
|
118
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
119
|
-
) -> AsyncIterator[Assistant]:
|
|
120
|
-
metadata = metadata if metadata is not None else {}
|
|
121
|
-
filters = await Assistants.handle_event(
|
|
122
|
-
ctx,
|
|
123
|
-
"search",
|
|
124
|
-
Auth.types.AssistantsSearch(
|
|
125
|
-
graph_id=graph_id, metadata=metadata, limit=limit, offset=offset
|
|
126
|
-
),
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
async def filter_and_yield() -> AsyncIterator[Assistant]:
|
|
130
|
-
assistants = conn.store["assistants"]
|
|
131
|
-
filtered_assistants = [
|
|
132
|
-
assistant
|
|
133
|
-
for assistant in assistants
|
|
134
|
-
if (not graph_id or assistant["graph_id"] == graph_id)
|
|
135
|
-
and (
|
|
136
|
-
not metadata or is_jsonb_contained(assistant["metadata"], metadata)
|
|
137
|
-
)
|
|
138
|
-
and (not filters or _check_filter_match(assistant["metadata"], filters))
|
|
139
|
-
]
|
|
140
|
-
filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
|
|
141
|
-
for assistant in filtered_assistants[offset : offset + limit]:
|
|
142
|
-
yield assistant
|
|
143
|
-
|
|
144
|
-
return filter_and_yield()
|
|
145
|
-
|
|
146
|
-
@staticmethod
|
|
147
|
-
async def get(
|
|
148
|
-
conn: InMemConnectionProto,
|
|
149
|
-
assistant_id: UUID,
|
|
150
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
151
|
-
) -> AsyncIterator[Assistant]:
|
|
152
|
-
"""Get an assistant by ID."""
|
|
153
|
-
assistant_id = _ensure_uuid(assistant_id)
|
|
154
|
-
filters = await Assistants.handle_event(
|
|
155
|
-
ctx,
|
|
156
|
-
"read",
|
|
157
|
-
Auth.types.AssistantsRead(assistant_id=assistant_id),
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
async def _yield_result():
|
|
161
|
-
for assistant in conn.store["assistants"]:
|
|
162
|
-
if assistant["assistant_id"] == assistant_id and (
|
|
163
|
-
not filters or _check_filter_match(assistant["metadata"], filters)
|
|
164
|
-
):
|
|
165
|
-
yield assistant
|
|
166
|
-
|
|
167
|
-
return _yield_result()
|
|
168
|
-
|
|
169
|
-
@staticmethod
|
|
170
|
-
async def put(
|
|
171
|
-
conn: InMemConnectionProto,
|
|
172
|
-
assistant_id: UUID,
|
|
173
|
-
*,
|
|
174
|
-
graph_id: str,
|
|
175
|
-
config: Config,
|
|
176
|
-
metadata: MetadataInput,
|
|
177
|
-
if_exists: OnConflictBehavior,
|
|
178
|
-
name: str,
|
|
179
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
180
|
-
description: str | None = None,
|
|
181
|
-
) -> AsyncIterator[Assistant]:
|
|
182
|
-
"""Insert an assistant."""
|
|
183
|
-
assistant_id = _ensure_uuid(assistant_id)
|
|
184
|
-
metadata = metadata if metadata is not None else {}
|
|
185
|
-
filters = await Assistants.handle_event(
|
|
186
|
-
ctx,
|
|
187
|
-
"create",
|
|
188
|
-
Auth.types.AssistantsCreate(
|
|
189
|
-
assistant_id=assistant_id,
|
|
190
|
-
graph_id=graph_id,
|
|
191
|
-
config=config,
|
|
192
|
-
metadata=metadata,
|
|
193
|
-
name=name,
|
|
194
|
-
),
|
|
195
|
-
)
|
|
196
|
-
existing_assistant = next(
|
|
197
|
-
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
198
|
-
None,
|
|
199
|
-
)
|
|
200
|
-
if existing_assistant:
|
|
201
|
-
if filters and not _check_filter_match(
|
|
202
|
-
existing_assistant["metadata"], filters
|
|
203
|
-
):
|
|
204
|
-
raise HTTPException(
|
|
205
|
-
status_code=409, detail=f"Assistant {assistant_id} already exists"
|
|
206
|
-
)
|
|
207
|
-
if if_exists == "raise":
|
|
208
|
-
raise HTTPException(
|
|
209
|
-
status_code=409, detail=f"Assistant {assistant_id} already exists"
|
|
210
|
-
)
|
|
211
|
-
elif if_exists == "do_nothing":
|
|
212
|
-
|
|
213
|
-
async def _yield_existing():
|
|
214
|
-
yield existing_assistant
|
|
215
|
-
|
|
216
|
-
return _yield_existing()
|
|
217
|
-
|
|
218
|
-
now = datetime.now(UTC)
|
|
219
|
-
new_assistant: Assistant = {
|
|
220
|
-
"assistant_id": assistant_id,
|
|
221
|
-
"graph_id": graph_id,
|
|
222
|
-
"config": config or {},
|
|
223
|
-
"metadata": metadata or {},
|
|
224
|
-
"name": name,
|
|
225
|
-
"created_at": now,
|
|
226
|
-
"updated_at": now,
|
|
227
|
-
"version": 1,
|
|
228
|
-
"description": description,
|
|
229
|
-
}
|
|
230
|
-
new_version = {
|
|
231
|
-
"assistant_id": assistant_id,
|
|
232
|
-
"version": 1,
|
|
233
|
-
"graph_id": graph_id,
|
|
234
|
-
"config": config or {},
|
|
235
|
-
"metadata": metadata or {},
|
|
236
|
-
"created_at": now,
|
|
237
|
-
"name": name,
|
|
238
|
-
}
|
|
239
|
-
conn.store["assistants"].append(new_assistant)
|
|
240
|
-
conn.store["assistant_versions"].append(new_version)
|
|
241
|
-
|
|
242
|
-
async def _yield_new():
|
|
243
|
-
yield new_assistant
|
|
244
|
-
|
|
245
|
-
return _yield_new()
|
|
246
|
-
|
|
247
|
-
@staticmethod
|
|
248
|
-
async def patch(
|
|
249
|
-
conn: InMemConnectionProto,
|
|
250
|
-
assistant_id: UUID,
|
|
251
|
-
*,
|
|
252
|
-
config: dict | None = None,
|
|
253
|
-
graph_id: str | None = None,
|
|
254
|
-
metadata: MetadataInput | None = None,
|
|
255
|
-
name: str | None = None,
|
|
256
|
-
description: str | None = None,
|
|
257
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
258
|
-
) -> AsyncIterator[Assistant]:
|
|
259
|
-
"""Update an assistant.
|
|
260
|
-
|
|
261
|
-
Args:
|
|
262
|
-
conn: The connection to the in-memory store.
|
|
263
|
-
assistant_id: The assistant ID.
|
|
264
|
-
graph_id: The graph ID.
|
|
265
|
-
config: The assistant config.
|
|
266
|
-
metadata: The assistant metadata.
|
|
267
|
-
name: The assistant name.
|
|
268
|
-
description: The assistant description.
|
|
269
|
-
ctx: The auth context.
|
|
270
|
-
|
|
271
|
-
Returns:
|
|
272
|
-
return the updated assistant model.
|
|
273
|
-
"""
|
|
274
|
-
assistant_id = _ensure_uuid(assistant_id)
|
|
275
|
-
metadata = metadata if metadata is not None else {}
|
|
276
|
-
filters = await Assistants.handle_event(
|
|
277
|
-
ctx,
|
|
278
|
-
"update",
|
|
279
|
-
Auth.types.AssistantsUpdate(
|
|
280
|
-
assistant_id=assistant_id,
|
|
281
|
-
graph_id=graph_id,
|
|
282
|
-
config=config,
|
|
283
|
-
metadata=metadata,
|
|
284
|
-
name=name,
|
|
285
|
-
),
|
|
286
|
-
)
|
|
287
|
-
assistant = next(
|
|
288
|
-
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
289
|
-
None,
|
|
290
|
-
)
|
|
291
|
-
if not assistant:
|
|
292
|
-
raise HTTPException(
|
|
293
|
-
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
294
|
-
)
|
|
295
|
-
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
296
|
-
raise HTTPException(
|
|
297
|
-
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
now = datetime.now(UTC)
|
|
301
|
-
new_version = (
|
|
302
|
-
max(
|
|
303
|
-
v["version"]
|
|
304
|
-
for v in conn.store["assistant_versions"]
|
|
305
|
-
if v["assistant_id"] == assistant_id
|
|
306
|
-
)
|
|
307
|
-
+ 1
|
|
308
|
-
if conn.store["assistant_versions"]
|
|
309
|
-
else 1
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
# Update assistant_versions table
|
|
313
|
-
if metadata:
|
|
314
|
-
metadata = {
|
|
315
|
-
**assistant["metadata"],
|
|
316
|
-
**metadata,
|
|
317
|
-
}
|
|
318
|
-
new_version_entry = {
|
|
319
|
-
"assistant_id": assistant_id,
|
|
320
|
-
"version": new_version,
|
|
321
|
-
"graph_id": graph_id if graph_id is not None else assistant["graph_id"],
|
|
322
|
-
"config": config if config is not None else assistant["config"],
|
|
323
|
-
"metadata": metadata if metadata is not None else assistant["metadata"],
|
|
324
|
-
"created_at": now,
|
|
325
|
-
"name": name if name is not None else assistant["name"],
|
|
326
|
-
"description": description
|
|
327
|
-
if description is not None
|
|
328
|
-
else assistant["description"],
|
|
329
|
-
}
|
|
330
|
-
conn.store["assistant_versions"].append(new_version_entry)
|
|
331
|
-
|
|
332
|
-
# Update assistants table
|
|
333
|
-
assistant.update(
|
|
334
|
-
{
|
|
335
|
-
"graph_id": new_version_entry["graph_id"],
|
|
336
|
-
"config": new_version_entry["config"],
|
|
337
|
-
"metadata": new_version_entry["metadata"],
|
|
338
|
-
"name": name if name is not None else assistant["name"],
|
|
339
|
-
"description": description
|
|
340
|
-
if description is not None
|
|
341
|
-
else assistant["description"],
|
|
342
|
-
"updated_at": now,
|
|
343
|
-
"version": new_version,
|
|
344
|
-
}
|
|
345
|
-
)
|
|
346
|
-
|
|
347
|
-
async def _yield_updated():
|
|
348
|
-
yield assistant
|
|
349
|
-
|
|
350
|
-
return _yield_updated()
|
|
351
|
-
|
|
352
|
-
@staticmethod
|
|
353
|
-
async def delete(
|
|
354
|
-
conn: InMemConnectionProto,
|
|
355
|
-
assistant_id: UUID,
|
|
356
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
357
|
-
) -> AsyncIterator[UUID]:
|
|
358
|
-
"""Delete an assistant by ID."""
|
|
359
|
-
assistant_id = _ensure_uuid(assistant_id)
|
|
360
|
-
filters = await Assistants.handle_event(
|
|
361
|
-
ctx,
|
|
362
|
-
"delete",
|
|
363
|
-
Auth.types.AssistantsDelete(
|
|
364
|
-
assistant_id=assistant_id,
|
|
365
|
-
),
|
|
366
|
-
)
|
|
367
|
-
assistant = next(
|
|
368
|
-
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
369
|
-
None,
|
|
370
|
-
)
|
|
371
|
-
|
|
372
|
-
if not assistant:
|
|
373
|
-
raise HTTPException(
|
|
374
|
-
status_code=404, detail=f"Assistant with ID {assistant_id} not found"
|
|
375
|
-
)
|
|
376
|
-
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
377
|
-
raise HTTPException(
|
|
378
|
-
status_code=404, detail=f"Assistant with ID {assistant_id} not found"
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
conn.store["assistants"] = [
|
|
382
|
-
a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
|
|
383
|
-
]
|
|
384
|
-
# Cascade delete assistant versions, crons, & runs on this assistant
|
|
385
|
-
conn.store["assistant_versions"] = [
|
|
386
|
-
v
|
|
387
|
-
for v in conn.store["assistant_versions"]
|
|
388
|
-
if v["assistant_id"] != assistant_id
|
|
389
|
-
]
|
|
390
|
-
retained = []
|
|
391
|
-
for run in conn.store["runs"]:
|
|
392
|
-
if run["assistant_id"] == assistant_id:
|
|
393
|
-
res = await Runs.delete(
|
|
394
|
-
conn, run["run_id"], thread_id=run["thread_id"], ctx=ctx
|
|
395
|
-
)
|
|
396
|
-
await anext(res)
|
|
397
|
-
else:
|
|
398
|
-
retained.append(run)
|
|
399
|
-
|
|
400
|
-
async def _yield_deleted():
|
|
401
|
-
yield assistant_id
|
|
402
|
-
|
|
403
|
-
return _yield_deleted()
|
|
404
|
-
|
|
405
|
-
@staticmethod
|
|
406
|
-
async def set_latest(
|
|
407
|
-
conn: InMemConnectionProto,
|
|
408
|
-
assistant_id: UUID,
|
|
409
|
-
version: int,
|
|
410
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
411
|
-
) -> AsyncIterator[Assistant]:
|
|
412
|
-
"""Change the version of an assistant."""
|
|
413
|
-
assistant_id = _ensure_uuid(assistant_id)
|
|
414
|
-
filters = await Assistants.handle_event(
|
|
415
|
-
ctx,
|
|
416
|
-
"update",
|
|
417
|
-
Auth.types.AssistantsUpdate(
|
|
418
|
-
assistant_id=assistant_id,
|
|
419
|
-
version=version,
|
|
420
|
-
),
|
|
421
|
-
)
|
|
422
|
-
assistant = next(
|
|
423
|
-
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
424
|
-
None,
|
|
425
|
-
)
|
|
426
|
-
if not assistant:
|
|
427
|
-
raise HTTPException(
|
|
428
|
-
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
429
|
-
)
|
|
430
|
-
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
431
|
-
raise HTTPException(
|
|
432
|
-
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
433
|
-
)
|
|
434
|
-
|
|
435
|
-
version_data = next(
|
|
436
|
-
(
|
|
437
|
-
v
|
|
438
|
-
for v in conn.store["assistant_versions"]
|
|
439
|
-
if v["assistant_id"] == assistant_id and v["version"] == version
|
|
440
|
-
),
|
|
441
|
-
None,
|
|
442
|
-
)
|
|
443
|
-
if not version_data:
|
|
444
|
-
raise HTTPException(
|
|
445
|
-
status_code=404,
|
|
446
|
-
detail=f"Version {version} not found for assistant {assistant_id}",
|
|
447
|
-
)
|
|
448
|
-
|
|
449
|
-
assistant.update(
|
|
450
|
-
{
|
|
451
|
-
"config": version_data["config"],
|
|
452
|
-
"metadata": version_data["metadata"],
|
|
453
|
-
"version": version_data["version"],
|
|
454
|
-
"updated_at": datetime.now(UTC),
|
|
455
|
-
}
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
async def _yield_updated():
|
|
459
|
-
yield assistant
|
|
460
|
-
|
|
461
|
-
return _yield_updated()
|
|
462
|
-
|
|
463
|
-
@staticmethod
|
|
464
|
-
async def get_versions(
|
|
465
|
-
conn: InMemConnectionProto,
|
|
466
|
-
assistant_id: UUID,
|
|
467
|
-
metadata: MetadataInput,
|
|
468
|
-
limit: int,
|
|
469
|
-
offset: int,
|
|
470
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
471
|
-
) -> AsyncIterator[Assistant]:
|
|
472
|
-
"""Get all versions of an assistant."""
|
|
473
|
-
assistant_id = _ensure_uuid(assistant_id)
|
|
474
|
-
filters = await Assistants.handle_event(
|
|
475
|
-
ctx,
|
|
476
|
-
"read",
|
|
477
|
-
Auth.types.AssistantsRead(assistant_id=assistant_id),
|
|
478
|
-
)
|
|
479
|
-
assistant = next(
|
|
480
|
-
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
481
|
-
None,
|
|
482
|
-
)
|
|
483
|
-
if not assistant:
|
|
484
|
-
raise HTTPException(
|
|
485
|
-
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
486
|
-
)
|
|
487
|
-
versions = [
|
|
488
|
-
v
|
|
489
|
-
for v in conn.store["assistant_versions"]
|
|
490
|
-
if v["assistant_id"] == assistant_id
|
|
491
|
-
and (not metadata or is_jsonb_contained(v["metadata"], metadata))
|
|
492
|
-
and (not filters or _check_filter_match(v["metadata"], filters))
|
|
493
|
-
]
|
|
494
|
-
|
|
495
|
-
# Previously, the name was not included in the assistant_versions table. So we should add them here.
|
|
496
|
-
for v in versions:
|
|
497
|
-
if "name" not in v:
|
|
498
|
-
v["name"] = assistant["name"]
|
|
499
|
-
|
|
500
|
-
versions.sort(key=lambda x: x["version"], reverse=True)
|
|
501
|
-
|
|
502
|
-
async def _yield_versions():
|
|
503
|
-
for version in versions[offset : offset + limit]:
|
|
504
|
-
yield version
|
|
505
|
-
|
|
506
|
-
return _yield_versions()
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
|
|
510
|
-
"""
|
|
511
|
-
Implements Postgres' @> (containment) operator for dictionaries.
|
|
512
|
-
Returns True if superset contains all key/value pairs from subset.
|
|
513
|
-
"""
|
|
514
|
-
for key, value in subset.items():
|
|
515
|
-
if key not in superset:
|
|
516
|
-
return False
|
|
517
|
-
if isinstance(value, dict) and isinstance(superset[key], dict):
|
|
518
|
-
if not is_jsonb_contained(superset[key], value):
|
|
519
|
-
return False
|
|
520
|
-
elif superset[key] != value:
|
|
521
|
-
return False
|
|
522
|
-
return True
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
def bytes_decoder(obj):
|
|
526
|
-
"""Custom JSON decoder that converts base64 back to bytes."""
|
|
527
|
-
if "__type__" in obj and obj["__type__"] == "bytes":
|
|
528
|
-
return base64.b64decode(obj["value"].encode("utf-8"))
|
|
529
|
-
return obj
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
def _replace_thread_id(data, new_thread_id, thread_id):
|
|
533
|
-
class BytesEncoder(json.JSONEncoder):
|
|
534
|
-
"""Custom JSON encoder that handles bytes by converting them to base64."""
|
|
535
|
-
|
|
536
|
-
def default(self, obj):
|
|
537
|
-
if isinstance(obj, bytes | bytearray):
|
|
538
|
-
return {
|
|
539
|
-
"__type__": "bytes",
|
|
540
|
-
"value": base64.b64encode(
|
|
541
|
-
obj.replace(
|
|
542
|
-
str(thread_id).encode(), str(new_thread_id).encode()
|
|
543
|
-
)
|
|
544
|
-
).decode("utf-8"),
|
|
545
|
-
}
|
|
546
|
-
|
|
547
|
-
return super().default(obj)
|
|
548
|
-
|
|
549
|
-
try:
|
|
550
|
-
json_str = json.dumps(data, cls=BytesEncoder, indent=2)
|
|
551
|
-
except Exception as e:
|
|
552
|
-
raise ValueError(data) from e
|
|
553
|
-
json_str = json_str.replace(str(thread_id), str(new_thread_id))
|
|
554
|
-
|
|
555
|
-
# Decoding back from JSON
|
|
556
|
-
d = json.loads(json_str, object_hook=bytes_decoder)
|
|
557
|
-
return d
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
class Threads(Authenticated):
|
|
561
|
-
resource = "threads"
|
|
562
|
-
|
|
563
|
-
@staticmethod
|
|
564
|
-
async def search(
|
|
565
|
-
conn: InMemConnectionProto,
|
|
566
|
-
*,
|
|
567
|
-
metadata: MetadataInput,
|
|
568
|
-
values: MetadataInput,
|
|
569
|
-
status: ThreadStatus | None,
|
|
570
|
-
limit: int,
|
|
571
|
-
offset: int,
|
|
572
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
573
|
-
) -> AsyncIterator[Thread]:
|
|
574
|
-
threads = conn.store["threads"]
|
|
575
|
-
filtered_threads: list[Thread] = []
|
|
576
|
-
metadata = metadata if metadata is not None else {}
|
|
577
|
-
values = values if values is not None else {}
|
|
578
|
-
filters = await Threads.handle_event(
|
|
579
|
-
ctx,
|
|
580
|
-
"search",
|
|
581
|
-
Auth.types.ThreadsSearch(
|
|
582
|
-
metadata=metadata,
|
|
583
|
-
values=values,
|
|
584
|
-
status=status,
|
|
585
|
-
limit=limit,
|
|
586
|
-
offset=offset,
|
|
587
|
-
),
|
|
588
|
-
)
|
|
589
|
-
|
|
590
|
-
# Apply filters
|
|
591
|
-
for thread in threads:
|
|
592
|
-
if filters and not _check_filter_match(thread["metadata"], filters):
|
|
593
|
-
continue
|
|
594
|
-
|
|
595
|
-
if metadata and not is_jsonb_contained(thread["metadata"], metadata):
|
|
596
|
-
continue
|
|
597
|
-
|
|
598
|
-
if (
|
|
599
|
-
values
|
|
600
|
-
and "values" in thread
|
|
601
|
-
and not is_jsonb_contained(thread["values"], values)
|
|
602
|
-
):
|
|
603
|
-
continue
|
|
604
|
-
|
|
605
|
-
if status and thread.get("status") != status:
|
|
606
|
-
continue
|
|
607
|
-
|
|
608
|
-
filtered_threads.append(thread)
|
|
609
|
-
|
|
610
|
-
# Sort by created_at in descending order
|
|
611
|
-
sorted_threads = sorted(
|
|
612
|
-
filtered_threads, key=lambda x: x["created_at"], reverse=True
|
|
613
|
-
)
|
|
614
|
-
|
|
615
|
-
# Apply limit and offset
|
|
616
|
-
paginated_threads = sorted_threads[offset : offset + limit]
|
|
617
|
-
|
|
618
|
-
async def thread_iterator() -> AsyncIterator[Thread]:
|
|
619
|
-
for thread in paginated_threads:
|
|
620
|
-
yield thread
|
|
621
|
-
|
|
622
|
-
return thread_iterator()
|
|
623
|
-
|
|
624
|
-
@staticmethod
|
|
625
|
-
async def _get_with_filters(
|
|
626
|
-
conn: InMemConnectionProto,
|
|
627
|
-
thread_id: UUID,
|
|
628
|
-
filters: Auth.types.FilterType | None,
|
|
629
|
-
) -> Thread | None:
|
|
630
|
-
thread_id = _ensure_uuid(thread_id)
|
|
631
|
-
matching_thread = next(
|
|
632
|
-
(
|
|
633
|
-
thread
|
|
634
|
-
for thread in conn.store["threads"]
|
|
635
|
-
if thread["thread_id"] == thread_id
|
|
636
|
-
),
|
|
637
|
-
None,
|
|
638
|
-
)
|
|
639
|
-
if not matching_thread or (
|
|
640
|
-
filters and not _check_filter_match(matching_thread["metadata"], filters)
|
|
641
|
-
):
|
|
642
|
-
return
|
|
643
|
-
|
|
644
|
-
return matching_thread
|
|
645
|
-
|
|
646
|
-
@staticmethod
|
|
647
|
-
async def _get(
|
|
648
|
-
conn: InMemConnectionProto,
|
|
649
|
-
thread_id: UUID,
|
|
650
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
651
|
-
) -> Thread | None:
|
|
652
|
-
"""Get a thread by ID."""
|
|
653
|
-
thread_id = _ensure_uuid(thread_id)
|
|
654
|
-
filters = await Threads.handle_event(
|
|
655
|
-
ctx,
|
|
656
|
-
"read",
|
|
657
|
-
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
658
|
-
)
|
|
659
|
-
return await Threads._get_with_filters(conn, thread_id, filters)
|
|
660
|
-
|
|
661
|
-
@staticmethod
|
|
662
|
-
async def get(
|
|
663
|
-
conn: InMemConnectionProto,
|
|
664
|
-
thread_id: UUID,
|
|
665
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
666
|
-
) -> AsyncIterator[Thread]:
|
|
667
|
-
"""Get a thread by ID."""
|
|
668
|
-
matching_thread = await Threads._get(conn, thread_id, ctx)
|
|
669
|
-
|
|
670
|
-
if not matching_thread:
|
|
671
|
-
raise HTTPException(
|
|
672
|
-
status_code=404, detail=f"Thread with ID {thread_id} not found"
|
|
673
|
-
)
|
|
674
|
-
|
|
675
|
-
async def _yield_result():
|
|
676
|
-
if matching_thread:
|
|
677
|
-
yield matching_thread
|
|
678
|
-
|
|
679
|
-
return _yield_result()
|
|
680
|
-
|
|
681
|
-
@staticmethod
|
|
682
|
-
async def put(
|
|
683
|
-
conn: InMemConnectionProto,
|
|
684
|
-
thread_id: UUID,
|
|
685
|
-
*,
|
|
686
|
-
metadata: MetadataInput,
|
|
687
|
-
if_exists: OnConflictBehavior,
|
|
688
|
-
ttl: ThreadTTLConfig | None = None,
|
|
689
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
690
|
-
) -> AsyncIterator[Thread]:
|
|
691
|
-
"""Insert or update a thread."""
|
|
692
|
-
thread_id = _ensure_uuid(thread_id)
|
|
693
|
-
if metadata is None:
|
|
694
|
-
metadata = {}
|
|
695
|
-
|
|
696
|
-
# Check if thread already exists
|
|
697
|
-
existing_thread = next(
|
|
698
|
-
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
699
|
-
)
|
|
700
|
-
filters = await Threads.handle_event(
|
|
701
|
-
ctx,
|
|
702
|
-
"create",
|
|
703
|
-
Auth.types.ThreadsCreate(
|
|
704
|
-
thread_id=thread_id, metadata=metadata, if_exists=if_exists
|
|
705
|
-
),
|
|
706
|
-
)
|
|
707
|
-
|
|
708
|
-
if existing_thread:
|
|
709
|
-
if filters and not _check_filter_match(
|
|
710
|
-
existing_thread["metadata"], filters
|
|
711
|
-
):
|
|
712
|
-
# Should we use a different status code here?
|
|
713
|
-
raise HTTPException(
|
|
714
|
-
status_code=409, detail=f"Thread with ID {thread_id} already exists"
|
|
715
|
-
)
|
|
716
|
-
if if_exists == "raise":
|
|
717
|
-
raise HTTPException(
|
|
718
|
-
status_code=409, detail=f"Thread with ID {thread_id} already exists"
|
|
719
|
-
)
|
|
720
|
-
elif if_exists == "do_nothing":
|
|
721
|
-
|
|
722
|
-
async def _yield_existing():
|
|
723
|
-
yield existing_thread
|
|
724
|
-
|
|
725
|
-
return _yield_existing()
|
|
726
|
-
# Create new thread
|
|
727
|
-
new_thread: Thread = {
|
|
728
|
-
"thread_id": thread_id,
|
|
729
|
-
"created_at": datetime.now(UTC),
|
|
730
|
-
"updated_at": datetime.now(UTC),
|
|
731
|
-
"metadata": copy.deepcopy(metadata),
|
|
732
|
-
"status": "idle",
|
|
733
|
-
"config": {},
|
|
734
|
-
"values": None,
|
|
735
|
-
}
|
|
736
|
-
|
|
737
|
-
# Add to store
|
|
738
|
-
conn.store["threads"].append(new_thread)
|
|
739
|
-
|
|
740
|
-
async def _yield_new():
|
|
741
|
-
yield new_thread
|
|
742
|
-
|
|
743
|
-
return _yield_new()
|
|
744
|
-
|
|
745
|
-
@staticmethod
|
|
746
|
-
async def patch(
|
|
747
|
-
conn: InMemConnectionProto,
|
|
748
|
-
thread_id: UUID,
|
|
749
|
-
*,
|
|
750
|
-
metadata: MetadataValue,
|
|
751
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
752
|
-
) -> AsyncIterator[Thread]:
|
|
753
|
-
"""Update a thread."""
|
|
754
|
-
thread_list = conn.store["threads"]
|
|
755
|
-
thread_idx = None
|
|
756
|
-
thread_id = _ensure_uuid(thread_id)
|
|
757
|
-
|
|
758
|
-
for idx, thread in enumerate(thread_list):
|
|
759
|
-
if thread["thread_id"] == thread_id:
|
|
760
|
-
thread_idx = idx
|
|
761
|
-
break
|
|
762
|
-
|
|
763
|
-
if thread_idx is not None:
|
|
764
|
-
filters = await Threads.handle_event(
|
|
765
|
-
ctx,
|
|
766
|
-
"update",
|
|
767
|
-
Auth.types.ThreadsUpdate(thread_id=thread_id, metadata=metadata),
|
|
768
|
-
)
|
|
769
|
-
if not filters or _check_filter_match(
|
|
770
|
-
thread_list[thread_idx]["metadata"], filters
|
|
771
|
-
):
|
|
772
|
-
thread = copy.deepcopy(thread_list[thread_idx])
|
|
773
|
-
thread["metadata"] = {**thread["metadata"], **metadata}
|
|
774
|
-
thread["updated_at"] = datetime.now(UTC)
|
|
775
|
-
thread_list[thread_idx] = thread
|
|
776
|
-
|
|
777
|
-
async def thread_iterator() -> AsyncIterator[Thread]:
|
|
778
|
-
yield thread
|
|
779
|
-
|
|
780
|
-
return thread_iterator()
|
|
781
|
-
|
|
782
|
-
async def empty_iterator() -> AsyncIterator[Thread]:
|
|
783
|
-
if False: # This ensures the iterator is empty
|
|
784
|
-
yield
|
|
785
|
-
|
|
786
|
-
return empty_iterator()
|
|
787
|
-
|
|
788
|
-
@staticmethod
|
|
789
|
-
async def set_status(
|
|
790
|
-
conn: InMemConnectionProto,
|
|
791
|
-
thread_id: UUID,
|
|
792
|
-
checkpoint: CheckpointPayload | None,
|
|
793
|
-
exception: BaseException | None,
|
|
794
|
-
# This does not accept the auth context since it's only used internally
|
|
795
|
-
) -> None:
|
|
796
|
-
"""Set the status of a thread."""
|
|
797
|
-
thread_id = _ensure_uuid(thread_id)
|
|
798
|
-
|
|
799
|
-
async def has_pending_runs(conn_: InMemConnectionProto, tid: UUID) -> bool:
|
|
800
|
-
"""Check if thread has any pending runs."""
|
|
801
|
-
return any(
|
|
802
|
-
run["status"] in ("pending", "running") and run["thread_id"] == tid
|
|
803
|
-
for run in conn_.store["runs"]
|
|
804
|
-
)
|
|
805
|
-
|
|
806
|
-
# Find the thread
|
|
807
|
-
thread = next(
|
|
808
|
-
(
|
|
809
|
-
thread
|
|
810
|
-
for thread in conn.store["threads"]
|
|
811
|
-
if thread["thread_id"] == thread_id
|
|
812
|
-
),
|
|
813
|
-
None,
|
|
814
|
-
)
|
|
815
|
-
|
|
816
|
-
if not thread:
|
|
817
|
-
raise HTTPException(
|
|
818
|
-
status_code=404, detail=f"Thread {thread_id} not found."
|
|
819
|
-
)
|
|
820
|
-
|
|
821
|
-
# Determine has_next from checkpoint
|
|
822
|
-
has_next = False if checkpoint is None else bool(checkpoint["next"])
|
|
823
|
-
|
|
824
|
-
# Determine base status
|
|
825
|
-
if exception:
|
|
826
|
-
status = "error"
|
|
827
|
-
elif has_next:
|
|
828
|
-
status = "interrupted"
|
|
829
|
-
else:
|
|
830
|
-
status = "idle"
|
|
831
|
-
|
|
832
|
-
# Check for pending runs and update to busy if found
|
|
833
|
-
if await has_pending_runs(conn, thread_id):
|
|
834
|
-
status = "busy"
|
|
835
|
-
|
|
836
|
-
# Update thread
|
|
837
|
-
thread.update(
|
|
838
|
-
{
|
|
839
|
-
"updated_at": datetime.now(UTC),
|
|
840
|
-
"values": checkpoint["values"] if checkpoint else None,
|
|
841
|
-
"status": status,
|
|
842
|
-
"interrupts": (
|
|
843
|
-
{
|
|
844
|
-
t["id"]: t["interrupts"]
|
|
845
|
-
for t in checkpoint["tasks"]
|
|
846
|
-
if t.get("interrupts")
|
|
847
|
-
}
|
|
848
|
-
if checkpoint
|
|
849
|
-
else {}
|
|
850
|
-
),
|
|
851
|
-
}
|
|
852
|
-
)
|
|
853
|
-
|
|
854
|
-
@staticmethod
|
|
855
|
-
async def delete(
|
|
856
|
-
conn: InMemConnectionProto,
|
|
857
|
-
thread_id: UUID,
|
|
858
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
859
|
-
) -> AsyncIterator[UUID]:
|
|
860
|
-
"""Delete a thread by ID and cascade delete all associated runs."""
|
|
861
|
-
thread_list = conn.store["threads"]
|
|
862
|
-
thread_idx = None
|
|
863
|
-
thread_id = _ensure_uuid(thread_id)
|
|
864
|
-
|
|
865
|
-
# Find the thread to delete
|
|
866
|
-
for idx, thread in enumerate(thread_list):
|
|
867
|
-
if thread["thread_id"] == thread_id:
|
|
868
|
-
thread_idx = idx
|
|
869
|
-
break
|
|
870
|
-
filters = await Threads.handle_event(
|
|
871
|
-
ctx,
|
|
872
|
-
"delete",
|
|
873
|
-
Auth.types.ThreadsDelete(thread_id=thread_id),
|
|
874
|
-
)
|
|
875
|
-
if (filters and not _check_filter_match(thread["metadata"], filters)) or (
|
|
876
|
-
thread_idx is None
|
|
877
|
-
):
|
|
878
|
-
raise HTTPException(
|
|
879
|
-
status_code=404, detail=f"Thread with ID {thread_id} not found"
|
|
880
|
-
)
|
|
881
|
-
# Cascade delete all runs associated with this thread
|
|
882
|
-
conn.store["runs"] = [
|
|
883
|
-
run for run in conn.store["runs"] if run["thread_id"] != thread_id
|
|
884
|
-
]
|
|
885
|
-
_delete_checkpoints_for_thread(thread_id, conn)
|
|
886
|
-
|
|
887
|
-
if thread_idx is not None:
|
|
888
|
-
# Remove the thread from the store
|
|
889
|
-
deleted_thread = thread_list.pop(thread_idx)
|
|
890
|
-
|
|
891
|
-
# Return an async iterator with the deleted thread_id
|
|
892
|
-
async def id_iterator() -> AsyncIterator[UUID]:
|
|
893
|
-
yield deleted_thread["thread_id"]
|
|
894
|
-
|
|
895
|
-
return id_iterator()
|
|
896
|
-
|
|
897
|
-
# If thread not found, return empty iterator
|
|
898
|
-
async def empty_iterator() -> AsyncIterator[UUID]:
|
|
899
|
-
if False: # This ensures the iterator is empty
|
|
900
|
-
yield
|
|
901
|
-
|
|
902
|
-
return empty_iterator()
|
|
903
|
-
|
|
904
|
-
@staticmethod
|
|
905
|
-
async def copy(
|
|
906
|
-
conn: InMemConnectionProto,
|
|
907
|
-
thread_id: UUID,
|
|
908
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
909
|
-
) -> AsyncIterator[Thread]:
|
|
910
|
-
"""Create a copy of an existing thread."""
|
|
911
|
-
thread_id = _ensure_uuid(thread_id)
|
|
912
|
-
new_thread_id = uuid4()
|
|
913
|
-
filters = await Threads.handle_event(
|
|
914
|
-
ctx,
|
|
915
|
-
"read",
|
|
916
|
-
Auth.types.ThreadsRead(
|
|
917
|
-
thread_id=new_thread_id,
|
|
918
|
-
),
|
|
919
|
-
)
|
|
920
|
-
async with conn.pipeline():
|
|
921
|
-
# Find the original thread in our store
|
|
922
|
-
original_thread = next(
|
|
923
|
-
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
924
|
-
)
|
|
925
|
-
|
|
926
|
-
if not original_thread:
|
|
927
|
-
return _empty_generator()
|
|
928
|
-
if filters and not _check_filter_match(
|
|
929
|
-
original_thread["metadata"], filters
|
|
930
|
-
):
|
|
931
|
-
return _empty_generator()
|
|
932
|
-
|
|
933
|
-
# Create new thread with copied metadata
|
|
934
|
-
new_thread: Thread = {
|
|
935
|
-
"thread_id": new_thread_id,
|
|
936
|
-
"created_at": datetime.now(tz=UTC),
|
|
937
|
-
"updated_at": datetime.now(tz=UTC),
|
|
938
|
-
"metadata": deepcopy(original_thread["metadata"]),
|
|
939
|
-
"status": "idle",
|
|
940
|
-
"config": {},
|
|
941
|
-
}
|
|
942
|
-
|
|
943
|
-
# Add new thread to store
|
|
944
|
-
conn.store["threads"].append(new_thread)
|
|
945
|
-
|
|
946
|
-
checkpointer = Checkpointer(conn)
|
|
947
|
-
copied_storage = _replace_thread_id(
|
|
948
|
-
checkpointer.storage[str(thread_id)], new_thread_id, thread_id
|
|
949
|
-
)
|
|
950
|
-
checkpointer.storage[str(new_thread_id)] = copied_storage
|
|
951
|
-
# Copy the writes over (if any)
|
|
952
|
-
outer_keys = []
|
|
953
|
-
for k in checkpointer.writes:
|
|
954
|
-
if k[0] == str(thread_id):
|
|
955
|
-
outer_keys.append(k)
|
|
956
|
-
for tid, checkpoint_ns, checkpoint_id in outer_keys:
|
|
957
|
-
mapped = {
|
|
958
|
-
k: _replace_thread_id(v, new_thread_id, thread_id)
|
|
959
|
-
for k, v in checkpointer.writes[
|
|
960
|
-
(str(tid), checkpoint_ns, checkpoint_id)
|
|
961
|
-
].items()
|
|
962
|
-
}
|
|
963
|
-
|
|
964
|
-
checkpointer.writes[
|
|
965
|
-
(str(new_thread_id), checkpoint_ns, checkpoint_id)
|
|
966
|
-
] = mapped
|
|
967
|
-
# Copy the blobs
|
|
968
|
-
for k in list(checkpointer.blobs):
|
|
969
|
-
if str(k[0]) == str(thread_id):
|
|
970
|
-
new_key = (str(new_thread_id), *k[1:])
|
|
971
|
-
checkpointer.blobs[new_key] = checkpointer.blobs[k]
|
|
972
|
-
|
|
973
|
-
async def row_generator() -> AsyncIterator[Thread]:
|
|
974
|
-
yield new_thread
|
|
975
|
-
|
|
976
|
-
return row_generator()
|
|
977
|
-
|
|
978
|
-
@staticmethod
|
|
979
|
-
async def sweep_ttl(
|
|
980
|
-
conn: InMemConnectionProto,
|
|
981
|
-
*,
|
|
982
|
-
limit: int | None = None,
|
|
983
|
-
batch_size: int = 100,
|
|
984
|
-
) -> tuple[int, int]:
|
|
985
|
-
# Not implemented for inmem server
|
|
986
|
-
return (0, 0)
|
|
987
|
-
|
|
988
|
-
class State(Authenticated):
|
|
989
|
-
# We will treat this like a runs resource for now.
|
|
990
|
-
resource = "threads"
|
|
991
|
-
|
|
992
|
-
@staticmethod
|
|
993
|
-
async def get(
|
|
994
|
-
conn: InMemConnectionProto,
|
|
995
|
-
config: Config,
|
|
996
|
-
subgraphs: bool = False,
|
|
997
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
998
|
-
) -> StateSnapshot:
|
|
999
|
-
"""Get state for a thread."""
|
|
1000
|
-
checkpointer = await asyncio.to_thread(
|
|
1001
|
-
Checkpointer, conn, unpack_hook=_msgpack_ext_hook_to_json
|
|
1002
|
-
)
|
|
1003
|
-
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
1004
|
-
# Auth will be applied here so no need to use filters downstream
|
|
1005
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1006
|
-
thread = await anext(thread_iter)
|
|
1007
|
-
checkpoint = await checkpointer.aget(config)
|
|
1008
|
-
|
|
1009
|
-
if not thread:
|
|
1010
|
-
return StateSnapshot(
|
|
1011
|
-
values={},
|
|
1012
|
-
next=[],
|
|
1013
|
-
config=None,
|
|
1014
|
-
metadata=None,
|
|
1015
|
-
created_at=None,
|
|
1016
|
-
parent_config=None,
|
|
1017
|
-
tasks=tuple(),
|
|
1018
|
-
)
|
|
1019
|
-
|
|
1020
|
-
metadata = thread.get("metadata", {})
|
|
1021
|
-
thread_config = thread.get("config", {})
|
|
1022
|
-
|
|
1023
|
-
if graph_id := metadata.get("graph_id"):
|
|
1024
|
-
# format latest checkpoint for response
|
|
1025
|
-
checkpointer.latest_iter = checkpoint
|
|
1026
|
-
async with get_graph(
|
|
1027
|
-
graph_id, thread_config, checkpointer=checkpointer
|
|
1028
|
-
) as graph:
|
|
1029
|
-
result = await graph.aget_state(config, subgraphs=subgraphs)
|
|
1030
|
-
if (
|
|
1031
|
-
result.metadata is not None
|
|
1032
|
-
and "checkpoint_ns" in result.metadata
|
|
1033
|
-
and result.metadata["checkpoint_ns"] == ""
|
|
1034
|
-
):
|
|
1035
|
-
result.metadata.pop("checkpoint_ns")
|
|
1036
|
-
return result
|
|
1037
|
-
else:
|
|
1038
|
-
return StateSnapshot(
|
|
1039
|
-
values={},
|
|
1040
|
-
next=[],
|
|
1041
|
-
config=None,
|
|
1042
|
-
metadata=None,
|
|
1043
|
-
created_at=None,
|
|
1044
|
-
parent_config=None,
|
|
1045
|
-
tasks=tuple(),
|
|
1046
|
-
)
|
|
1047
|
-
|
|
1048
|
-
@staticmethod
|
|
1049
|
-
async def post(
|
|
1050
|
-
conn: InMemConnectionProto,
|
|
1051
|
-
config: Config,
|
|
1052
|
-
values: Sequence[dict] | dict[str, Any] | None,
|
|
1053
|
-
as_node: str | None = None,
|
|
1054
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1055
|
-
) -> ThreadUpdateResponse:
|
|
1056
|
-
"""Add state to a thread."""
|
|
1057
|
-
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
1058
|
-
filters = await Threads.handle_event(
|
|
1059
|
-
ctx,
|
|
1060
|
-
"update",
|
|
1061
|
-
Auth.types.ThreadsUpdate(thread_id=thread_id),
|
|
1062
|
-
)
|
|
1063
|
-
|
|
1064
|
-
checkpointer = Checkpointer(conn)
|
|
1065
|
-
|
|
1066
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1067
|
-
thread = await fetchone(
|
|
1068
|
-
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
1069
|
-
)
|
|
1070
|
-
checkpoint = await checkpointer.aget(config)
|
|
1071
|
-
|
|
1072
|
-
if not thread:
|
|
1073
|
-
raise HTTPException(status_code=404, detail="Thread not found")
|
|
1074
|
-
if not _check_filter_match(thread["metadata"], filters):
|
|
1075
|
-
raise HTTPException(status_code=403, detail="Forbidden")
|
|
1076
|
-
|
|
1077
|
-
metadata = thread["metadata"]
|
|
1078
|
-
thread_config = thread["config"]
|
|
1079
|
-
|
|
1080
|
-
if graph_id := metadata.get("graph_id"):
|
|
1081
|
-
config["configurable"].setdefault("graph_id", graph_id)
|
|
1082
|
-
|
|
1083
|
-
checkpointer.latest_iter = checkpoint
|
|
1084
|
-
async with get_graph(
|
|
1085
|
-
graph_id, thread_config, checkpointer=checkpointer
|
|
1086
|
-
) as graph:
|
|
1087
|
-
update_config = config.copy()
|
|
1088
|
-
update_config["configurable"] = {
|
|
1089
|
-
**config["configurable"],
|
|
1090
|
-
"checkpoint_ns": config["configurable"].get(
|
|
1091
|
-
"checkpoint_ns", ""
|
|
1092
|
-
),
|
|
1093
|
-
}
|
|
1094
|
-
next_config = await graph.aupdate_state(
|
|
1095
|
-
update_config, values, as_node=as_node
|
|
1096
|
-
)
|
|
1097
|
-
|
|
1098
|
-
# Get current state
|
|
1099
|
-
state = await Threads.State.get(
|
|
1100
|
-
conn, config, subgraphs=False, ctx=ctx
|
|
1101
|
-
)
|
|
1102
|
-
# Update thread values
|
|
1103
|
-
for thread in conn.store["threads"]:
|
|
1104
|
-
if thread["thread_id"] == thread_id:
|
|
1105
|
-
thread["values"] = state.values
|
|
1106
|
-
break
|
|
1107
|
-
|
|
1108
|
-
return ThreadUpdateResponse(
|
|
1109
|
-
checkpoint=next_config["configurable"],
|
|
1110
|
-
# Including deprecated fields
|
|
1111
|
-
configurable=next_config["configurable"],
|
|
1112
|
-
checkpoint_id=next_config["configurable"]["checkpoint_id"],
|
|
1113
|
-
)
|
|
1114
|
-
else:
|
|
1115
|
-
raise HTTPException(status_code=400, detail="Thread has no graph ID.")
|
|
1116
|
-
|
|
1117
|
-
@staticmethod
|
|
1118
|
-
async def bulk(
|
|
1119
|
-
conn: InMemConnectionProto,
|
|
1120
|
-
*,
|
|
1121
|
-
config: Config,
|
|
1122
|
-
supersteps: Sequence[dict],
|
|
1123
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1124
|
-
) -> ThreadUpdateResponse:
|
|
1125
|
-
"""Update a thread with a batch of state updates."""
|
|
1126
|
-
|
|
1127
|
-
from langgraph.pregel.types import StateUpdate
|
|
1128
|
-
|
|
1129
|
-
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
1130
|
-
filters = await Threads.handle_event(
|
|
1131
|
-
ctx,
|
|
1132
|
-
"update",
|
|
1133
|
-
Auth.types.ThreadsUpdate(thread_id=thread_id),
|
|
1134
|
-
)
|
|
1135
|
-
|
|
1136
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1137
|
-
thread = await fetchone(
|
|
1138
|
-
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
1139
|
-
)
|
|
1140
|
-
|
|
1141
|
-
thread_config = thread["config"]
|
|
1142
|
-
metadata = thread["metadata"]
|
|
1143
|
-
|
|
1144
|
-
if not thread:
|
|
1145
|
-
raise HTTPException(status_code=404, detail="Thread not found")
|
|
1146
|
-
|
|
1147
|
-
if not _check_filter_match(metadata, filters):
|
|
1148
|
-
raise HTTPException(status_code=403, detail="Forbidden")
|
|
1149
|
-
|
|
1150
|
-
if graph_id := metadata.get("graph_id"):
|
|
1151
|
-
config["configurable"].setdefault("graph_id", graph_id)
|
|
1152
|
-
config["configurable"].setdefault("checkpoint_ns", "")
|
|
1153
|
-
|
|
1154
|
-
async with get_graph(
|
|
1155
|
-
graph_id, thread_config, checkpointer=Checkpointer(conn)
|
|
1156
|
-
) as graph:
|
|
1157
|
-
next_config = await graph.abulk_update_state(
|
|
1158
|
-
config,
|
|
1159
|
-
[
|
|
1160
|
-
[
|
|
1161
|
-
StateUpdate(
|
|
1162
|
-
map_cmd(update.get("command"))
|
|
1163
|
-
if update.get("command")
|
|
1164
|
-
else update.get("values"),
|
|
1165
|
-
update.get("as_node"),
|
|
1166
|
-
)
|
|
1167
|
-
for update in superstep.get("updates", [])
|
|
1168
|
-
]
|
|
1169
|
-
for superstep in supersteps
|
|
1170
|
-
],
|
|
1171
|
-
)
|
|
1172
|
-
|
|
1173
|
-
state = await Threads.State.get(
|
|
1174
|
-
conn, config, subgraphs=False, ctx=ctx
|
|
1175
|
-
)
|
|
1176
|
-
|
|
1177
|
-
# update thread values
|
|
1178
|
-
for thread in conn.store["threads"]:
|
|
1179
|
-
if thread["thread_id"] == thread_id:
|
|
1180
|
-
thread["values"] = state.values
|
|
1181
|
-
break
|
|
1182
|
-
|
|
1183
|
-
return ThreadUpdateResponse(
|
|
1184
|
-
checkpoint=next_config["configurable"],
|
|
1185
|
-
)
|
|
1186
|
-
else:
|
|
1187
|
-
raise HTTPException(status_code=400, detail="Thread has no graph ID")
|
|
1188
|
-
|
|
1189
|
-
@staticmethod
|
|
1190
|
-
async def list(
|
|
1191
|
-
conn: InMemConnectionProto,
|
|
1192
|
-
*,
|
|
1193
|
-
config: Config,
|
|
1194
|
-
limit: int = 10,
|
|
1195
|
-
before: str | Checkpoint | None = None,
|
|
1196
|
-
metadata: MetadataInput = None,
|
|
1197
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1198
|
-
) -> list[StateSnapshot]:
|
|
1199
|
-
"""Get the history of a thread."""
|
|
1200
|
-
|
|
1201
|
-
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
1202
|
-
thread = None
|
|
1203
|
-
filters = await Threads.handle_event(
|
|
1204
|
-
ctx,
|
|
1205
|
-
"read",
|
|
1206
|
-
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1207
|
-
)
|
|
1208
|
-
thread = await fetchone(
|
|
1209
|
-
await Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx)
|
|
1210
|
-
)
|
|
1211
|
-
|
|
1212
|
-
# Parse thread metadata and config
|
|
1213
|
-
thread_metadata = thread["metadata"]
|
|
1214
|
-
if not _check_filter_match(thread_metadata, filters):
|
|
1215
|
-
return []
|
|
1216
|
-
|
|
1217
|
-
thread_config = thread["config"]
|
|
1218
|
-
# If graph_id exists, get state history
|
|
1219
|
-
if graph_id := thread_metadata.get("graph_id"):
|
|
1220
|
-
async with get_graph(
|
|
1221
|
-
graph_id,
|
|
1222
|
-
thread_config,
|
|
1223
|
-
checkpointer=await asyncio.to_thread(
|
|
1224
|
-
Checkpointer, conn, unpack_hook=_msgpack_ext_hook_to_json
|
|
1225
|
-
),
|
|
1226
|
-
) as graph:
|
|
1227
|
-
# Convert before parameter if it's a string
|
|
1228
|
-
before_param = (
|
|
1229
|
-
{"configurable": {"checkpoint_id": before}}
|
|
1230
|
-
if isinstance(before, str)
|
|
1231
|
-
else before
|
|
1232
|
-
)
|
|
1233
|
-
|
|
1234
|
-
states = [
|
|
1235
|
-
state
|
|
1236
|
-
async for state in graph.aget_state_history(
|
|
1237
|
-
config, limit=limit, filter=metadata, before=before_param
|
|
1238
|
-
)
|
|
1239
|
-
]
|
|
1240
|
-
|
|
1241
|
-
return states
|
|
1242
|
-
|
|
1243
|
-
return []
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
RUN_LOCK = asyncio.Lock()
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
class Runs(Authenticated):
|
|
1250
|
-
resource = "threads"
|
|
1251
|
-
|
|
1252
|
-
@staticmethod
|
|
1253
|
-
async def stats(conn: InMemConnectionProto) -> QueueStats:
|
|
1254
|
-
"""Get stats about the queue."""
|
|
1255
|
-
pending_runs = [run for run in conn.store["runs"] if run["status"] == "pending"]
|
|
1256
|
-
running_runs = [run for run in conn.store["runs"] if run["status"] == "running"]
|
|
1257
|
-
|
|
1258
|
-
if not pending_runs and not running_runs:
|
|
1259
|
-
return {
|
|
1260
|
-
"n_pending": 0,
|
|
1261
|
-
"max_age_secs": None,
|
|
1262
|
-
"med_age_secs": None,
|
|
1263
|
-
"n_running": 0,
|
|
1264
|
-
}
|
|
1265
|
-
|
|
1266
|
-
# Get all creation timestamps
|
|
1267
|
-
created_times = [run.get("created_at") for run in (pending_runs + running_runs)]
|
|
1268
|
-
created_times = [
|
|
1269
|
-
t for t in created_times if t is not None
|
|
1270
|
-
] # Filter out None values
|
|
1271
|
-
|
|
1272
|
-
if not created_times:
|
|
1273
|
-
return {
|
|
1274
|
-
"n_pending": len(pending_runs),
|
|
1275
|
-
"n_running": len(running_runs),
|
|
1276
|
-
"max_age_secs": None,
|
|
1277
|
-
"med_age_secs": None,
|
|
1278
|
-
}
|
|
1279
|
-
|
|
1280
|
-
# Find oldest (max age)
|
|
1281
|
-
oldest_time = min(created_times) # Earliest timestamp = oldest run
|
|
1282
|
-
|
|
1283
|
-
# Find median age
|
|
1284
|
-
sorted_times = sorted(created_times)
|
|
1285
|
-
median_idx = len(sorted_times) // 2
|
|
1286
|
-
median_time = sorted_times[median_idx]
|
|
1287
|
-
|
|
1288
|
-
return {
|
|
1289
|
-
"n_pending": len(pending_runs),
|
|
1290
|
-
"n_running": len(running_runs),
|
|
1291
|
-
"max_age_secs": oldest_time,
|
|
1292
|
-
"med_age_secs": median_time,
|
|
1293
|
-
}
|
|
1294
|
-
|
|
1295
|
-
@staticmethod
|
|
1296
|
-
async def next(wait: bool, limit: int = 1) -> AsyncIterator[tuple[Run, int]]:
|
|
1297
|
-
"""Get the next run from the queue, and the attempt number.
|
|
1298
|
-
1 is the first attempt, 2 is the first retry, etc."""
|
|
1299
|
-
now = datetime.now(UTC)
|
|
1300
|
-
|
|
1301
|
-
if wait:
|
|
1302
|
-
await asyncio.sleep(0.5)
|
|
1303
|
-
else:
|
|
1304
|
-
await asyncio.sleep(0)
|
|
1305
|
-
|
|
1306
|
-
async with connect() as conn, RUN_LOCK:
|
|
1307
|
-
pending_runs = sorted(
|
|
1308
|
-
[
|
|
1309
|
-
run
|
|
1310
|
-
for run in conn.store["runs"]
|
|
1311
|
-
if run["status"] == "pending" and run.get("created_at", now) < now
|
|
1312
|
-
],
|
|
1313
|
-
key=lambda x: x.get("created_at", datetime.min),
|
|
1314
|
-
)
|
|
1315
|
-
|
|
1316
|
-
if not pending_runs:
|
|
1317
|
-
return
|
|
1318
|
-
|
|
1319
|
-
# Try to lock and get the first available run
|
|
1320
|
-
for _, run in zip(range(limit), pending_runs, strict=False):
|
|
1321
|
-
if run["status"] != "pending":
|
|
1322
|
-
continue
|
|
1323
|
-
|
|
1324
|
-
run_id = run["run_id"]
|
|
1325
|
-
thread_id = run["thread_id"]
|
|
1326
|
-
thread = next(
|
|
1327
|
-
(t for t in conn.store["threads"] if t["thread_id"] == thread_id),
|
|
1328
|
-
None,
|
|
1329
|
-
)
|
|
1330
|
-
|
|
1331
|
-
if thread is None:
|
|
1332
|
-
await logger.awarning(
|
|
1333
|
-
"Unexpected missing thread in Runs.next",
|
|
1334
|
-
thread_id=run["thread_id"],
|
|
1335
|
-
)
|
|
1336
|
-
continue
|
|
1337
|
-
|
|
1338
|
-
if run["status"] != "pending":
|
|
1339
|
-
continue
|
|
1340
|
-
|
|
1341
|
-
if any(
|
|
1342
|
-
run["status"] == "running"
|
|
1343
|
-
for run in conn.store["runs"]
|
|
1344
|
-
if run["thread_id"] == thread_id
|
|
1345
|
-
):
|
|
1346
|
-
continue
|
|
1347
|
-
# Increment attempt counter
|
|
1348
|
-
attempt = await conn.retry_counter.increment(run_id)
|
|
1349
|
-
# Set run as "running"
|
|
1350
|
-
run["status"] = "running"
|
|
1351
|
-
yield run, attempt
|
|
1352
|
-
|
|
1353
|
-
@asynccontextmanager
|
|
1354
|
-
@staticmethod
|
|
1355
|
-
async def enter(
|
|
1356
|
-
run_id: UUID, loop: asyncio.AbstractEventLoop
|
|
1357
|
-
) -> AsyncIterator[ValueEvent]:
|
|
1358
|
-
"""Enter a run, listen for cancellation while running, signal when done."
|
|
1359
|
-
This method should be called as a context manager by a worker executing a run.
|
|
1360
|
-
"""
|
|
1361
|
-
stream_manager = get_stream_manager()
|
|
1362
|
-
# Get queue for this run
|
|
1363
|
-
queue = await Runs.Stream.subscribe(run_id)
|
|
1364
|
-
|
|
1365
|
-
async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
|
|
1366
|
-
done = ValueEvent()
|
|
1367
|
-
tg.create_task(listen_for_cancellation(queue, run_id, done))
|
|
1368
|
-
|
|
1369
|
-
# Give done event to caller
|
|
1370
|
-
yield done
|
|
1371
|
-
# Signal done to all subscribers
|
|
1372
|
-
control_message = Message(
|
|
1373
|
-
topic=f"run:{run_id}:control".encode(), data=b"done"
|
|
1374
|
-
)
|
|
1375
|
-
|
|
1376
|
-
# Store the control message for late subscribers
|
|
1377
|
-
await stream_manager.put(run_id, control_message)
|
|
1378
|
-
stream_manager.control_queues[run_id].append(control_message)
|
|
1379
|
-
# Clean up this queue
|
|
1380
|
-
await stream_manager.remove_queue(run_id, queue)
|
|
1381
|
-
|
|
1382
|
-
@staticmethod
|
|
1383
|
-
async def sweep(conn: InMemConnectionProto) -> list[UUID]:
|
|
1384
|
-
"""Sweep runs that are no longer running"""
|
|
1385
|
-
return []
|
|
1386
|
-
|
|
1387
|
-
@staticmethod
|
|
1388
|
-
def _merge_jsonb(*objects: dict) -> dict:
|
|
1389
|
-
"""Mimics PostgreSQL's JSONB merge behavior"""
|
|
1390
|
-
result = {}
|
|
1391
|
-
for obj in objects:
|
|
1392
|
-
if obj is not None:
|
|
1393
|
-
result.update(copy.deepcopy(obj))
|
|
1394
|
-
return result
|
|
1395
|
-
|
|
1396
|
-
@staticmethod
|
|
1397
|
-
def _get_configurable(config: dict) -> dict:
|
|
1398
|
-
"""Extract configurable from config, mimicking PostgreSQL's coalesce"""
|
|
1399
|
-
return config.get("configurable", {})
|
|
1400
|
-
|
|
1401
|
-
@staticmethod
|
|
1402
|
-
async def put(
|
|
1403
|
-
conn: InMemConnectionProto,
|
|
1404
|
-
assistant_id: UUID,
|
|
1405
|
-
kwargs: dict,
|
|
1406
|
-
*,
|
|
1407
|
-
thread_id: UUID | None = None,
|
|
1408
|
-
user_id: str | None = None,
|
|
1409
|
-
run_id: UUID | None = None,
|
|
1410
|
-
status: RunStatus | None = "pending",
|
|
1411
|
-
metadata: MetadataInput,
|
|
1412
|
-
prevent_insert_if_inflight: bool,
|
|
1413
|
-
multitask_strategy: MultitaskStrategy = "reject",
|
|
1414
|
-
if_not_exists: IfNotExists = "reject",
|
|
1415
|
-
after_seconds: int = 0,
|
|
1416
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1417
|
-
) -> AsyncIterator[Run]:
|
|
1418
|
-
"""Create a run."""
|
|
1419
|
-
assistant_id = _ensure_uuid(assistant_id)
|
|
1420
|
-
assistant = next(
|
|
1421
|
-
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
1422
|
-
None,
|
|
1423
|
-
)
|
|
1424
|
-
|
|
1425
|
-
if not assistant:
|
|
1426
|
-
return _empty_generator()
|
|
1427
|
-
|
|
1428
|
-
thread_id = _ensure_uuid(thread_id) if thread_id else None
|
|
1429
|
-
run_id = _ensure_uuid(run_id) if run_id else None
|
|
1430
|
-
metadata = metadata if metadata is not None else {}
|
|
1431
|
-
config = kwargs.get("config", {})
|
|
1432
|
-
|
|
1433
|
-
# Handle thread creation/update
|
|
1434
|
-
existing_thread = next(
|
|
1435
|
-
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
1436
|
-
)
|
|
1437
|
-
filters = await Runs.handle_event(
|
|
1438
|
-
ctx,
|
|
1439
|
-
"create_run",
|
|
1440
|
-
Auth.types.RunsCreate(
|
|
1441
|
-
thread_id=thread_id,
|
|
1442
|
-
assistant_id=assistant_id,
|
|
1443
|
-
run_id=run_id,
|
|
1444
|
-
status=status,
|
|
1445
|
-
metadata=metadata,
|
|
1446
|
-
prevent_insert_if_inflight=prevent_insert_if_inflight,
|
|
1447
|
-
multitask_strategy=multitask_strategy,
|
|
1448
|
-
if_not_exists=if_not_exists,
|
|
1449
|
-
after_seconds=after_seconds,
|
|
1450
|
-
kwargs=kwargs,
|
|
1451
|
-
),
|
|
1452
|
-
)
|
|
1453
|
-
if existing_thread and filters:
|
|
1454
|
-
# Reject if the user doesn't own the thread
|
|
1455
|
-
if not _check_filter_match(existing_thread["metadata"], filters):
|
|
1456
|
-
return _empty_generator()
|
|
1457
|
-
|
|
1458
|
-
if not existing_thread and (thread_id is None or if_not_exists == "create"):
|
|
1459
|
-
# Create new thread
|
|
1460
|
-
if thread_id is None:
|
|
1461
|
-
thread_id = uuid4()
|
|
1462
|
-
thread = Thread(
|
|
1463
|
-
thread_id=thread_id,
|
|
1464
|
-
status="busy",
|
|
1465
|
-
metadata={
|
|
1466
|
-
"graph_id": assistant["graph_id"],
|
|
1467
|
-
"assistant_id": str(assistant_id),
|
|
1468
|
-
**metadata,
|
|
1469
|
-
},
|
|
1470
|
-
config=Runs._merge_jsonb(
|
|
1471
|
-
assistant["config"],
|
|
1472
|
-
config,
|
|
1473
|
-
{
|
|
1474
|
-
"configurable": Runs._merge_jsonb(
|
|
1475
|
-
Runs._get_configurable(assistant["config"]),
|
|
1476
|
-
Runs._get_configurable(config),
|
|
1477
|
-
)
|
|
1478
|
-
},
|
|
1479
|
-
),
|
|
1480
|
-
created_at=datetime.now(UTC),
|
|
1481
|
-
updated_at=datetime.now(UTC),
|
|
1482
|
-
values=b"",
|
|
1483
|
-
)
|
|
1484
|
-
await logger.ainfo("Creating thread", thread_id=thread_id)
|
|
1485
|
-
conn.store["threads"].append(thread)
|
|
1486
|
-
elif existing_thread:
|
|
1487
|
-
# Update existing thread
|
|
1488
|
-
if existing_thread["status"] != "busy":
|
|
1489
|
-
existing_thread["status"] = "busy"
|
|
1490
|
-
existing_thread["metadata"] = Runs._merge_jsonb(
|
|
1491
|
-
existing_thread["metadata"],
|
|
1492
|
-
{
|
|
1493
|
-
"graph_id": assistant["graph_id"],
|
|
1494
|
-
"assistant_id": str(assistant_id),
|
|
1495
|
-
},
|
|
1496
|
-
)
|
|
1497
|
-
existing_thread["config"] = Runs._merge_jsonb(
|
|
1498
|
-
assistant["config"],
|
|
1499
|
-
existing_thread["config"],
|
|
1500
|
-
config,
|
|
1501
|
-
{
|
|
1502
|
-
"configurable": Runs._merge_jsonb(
|
|
1503
|
-
Runs._get_configurable(assistant["config"]),
|
|
1504
|
-
Runs._get_configurable(existing_thread["config"]),
|
|
1505
|
-
Runs._get_configurable(config),
|
|
1506
|
-
)
|
|
1507
|
-
},
|
|
1508
|
-
)
|
|
1509
|
-
existing_thread["updated_at"] = datetime.now(UTC)
|
|
1510
|
-
else:
|
|
1511
|
-
return _empty_generator()
|
|
1512
|
-
|
|
1513
|
-
# Check for inflight runs if needed
|
|
1514
|
-
inflight_runs = [
|
|
1515
|
-
r
|
|
1516
|
-
for r in conn.store["runs"]
|
|
1517
|
-
if r["thread_id"] == thread_id and r["status"] in ("pending", "running")
|
|
1518
|
-
]
|
|
1519
|
-
if prevent_insert_if_inflight:
|
|
1520
|
-
if inflight_runs:
|
|
1521
|
-
|
|
1522
|
-
async def _return_inflight():
|
|
1523
|
-
for run in inflight_runs:
|
|
1524
|
-
yield run
|
|
1525
|
-
|
|
1526
|
-
return _return_inflight()
|
|
1527
|
-
|
|
1528
|
-
# Create new run
|
|
1529
|
-
configurable = Runs._merge_jsonb(
|
|
1530
|
-
Runs._get_configurable(assistant["config"]),
|
|
1531
|
-
(
|
|
1532
|
-
Runs._get_configurable(existing_thread["config"])
|
|
1533
|
-
if existing_thread
|
|
1534
|
-
else {}
|
|
1535
|
-
),
|
|
1536
|
-
Runs._get_configurable(config),
|
|
1537
|
-
{
|
|
1538
|
-
"run_id": str(run_id),
|
|
1539
|
-
"thread_id": str(thread_id),
|
|
1540
|
-
"graph_id": assistant["graph_id"],
|
|
1541
|
-
"assistant_id": str(assistant_id),
|
|
1542
|
-
"user_id": (
|
|
1543
|
-
config.get("configurable", {}).get("user_id")
|
|
1544
|
-
or (
|
|
1545
|
-
existing_thread["config"].get("configurable", {}).get("user_id")
|
|
1546
|
-
if existing_thread
|
|
1547
|
-
else None
|
|
1548
|
-
)
|
|
1549
|
-
or assistant["config"].get("configurable", {}).get("user_id")
|
|
1550
|
-
or user_id
|
|
1551
|
-
),
|
|
1552
|
-
},
|
|
1553
|
-
)
|
|
1554
|
-
merged_metadata = Runs._merge_jsonb(
|
|
1555
|
-
assistant["metadata"],
|
|
1556
|
-
existing_thread["metadata"] if existing_thread else {},
|
|
1557
|
-
metadata,
|
|
1558
|
-
)
|
|
1559
|
-
new_run = Run(
|
|
1560
|
-
run_id=run_id,
|
|
1561
|
-
thread_id=thread_id,
|
|
1562
|
-
assistant_id=assistant_id,
|
|
1563
|
-
metadata=merged_metadata,
|
|
1564
|
-
status=status,
|
|
1565
|
-
kwargs=Runs._merge_jsonb(
|
|
1566
|
-
kwargs,
|
|
1567
|
-
{
|
|
1568
|
-
"config": Runs._merge_jsonb(
|
|
1569
|
-
assistant["config"],
|
|
1570
|
-
config,
|
|
1571
|
-
{"configurable": configurable},
|
|
1572
|
-
{
|
|
1573
|
-
"metadata": merged_metadata,
|
|
1574
|
-
},
|
|
1575
|
-
)
|
|
1576
|
-
},
|
|
1577
|
-
),
|
|
1578
|
-
multitask_strategy=multitask_strategy,
|
|
1579
|
-
created_at=datetime.now(UTC) + timedelta(seconds=after_seconds),
|
|
1580
|
-
updated_at=datetime.now(UTC),
|
|
1581
|
-
)
|
|
1582
|
-
conn.store["runs"].append(new_run)
|
|
1583
|
-
|
|
1584
|
-
async def _yield_new():
|
|
1585
|
-
yield new_run
|
|
1586
|
-
for r in inflight_runs:
|
|
1587
|
-
yield r
|
|
1588
|
-
|
|
1589
|
-
return _yield_new()
|
|
1590
|
-
|
|
1591
|
-
@staticmethod
|
|
1592
|
-
async def get(
|
|
1593
|
-
conn: InMemConnectionProto,
|
|
1594
|
-
run_id: UUID,
|
|
1595
|
-
*,
|
|
1596
|
-
thread_id: UUID,
|
|
1597
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1598
|
-
) -> AsyncIterator[Run]:
|
|
1599
|
-
"""Get a run by ID."""
|
|
1600
|
-
|
|
1601
|
-
run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
|
|
1602
|
-
filters = await Runs.handle_event(
|
|
1603
|
-
ctx,
|
|
1604
|
-
"read",
|
|
1605
|
-
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1606
|
-
)
|
|
1607
|
-
|
|
1608
|
-
async def _yield_result():
|
|
1609
|
-
matching_run = None
|
|
1610
|
-
for run in conn.store["runs"]:
|
|
1611
|
-
if run["run_id"] == run_id and run["thread_id"] == thread_id:
|
|
1612
|
-
matching_run = run
|
|
1613
|
-
break
|
|
1614
|
-
if matching_run:
|
|
1615
|
-
if filters:
|
|
1616
|
-
thread = await Threads._get_with_filters(
|
|
1617
|
-
conn, matching_run["thread_id"], filters
|
|
1618
|
-
)
|
|
1619
|
-
if not thread:
|
|
1620
|
-
return
|
|
1621
|
-
yield matching_run
|
|
1622
|
-
|
|
1623
|
-
return _yield_result()
|
|
1624
|
-
|
|
1625
|
-
@staticmethod
|
|
1626
|
-
async def delete(
|
|
1627
|
-
conn: InMemConnectionProto,
|
|
1628
|
-
run_id: UUID,
|
|
1629
|
-
*,
|
|
1630
|
-
thread_id: UUID,
|
|
1631
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1632
|
-
) -> AsyncIterator[UUID]:
|
|
1633
|
-
"""Delete a run by ID."""
|
|
1634
|
-
run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
|
|
1635
|
-
filters = await Runs.handle_event(
|
|
1636
|
-
ctx,
|
|
1637
|
-
"delete",
|
|
1638
|
-
Auth.types.ThreadsDelete(run_id=run_id, thread_id=thread_id),
|
|
1639
|
-
)
|
|
1640
|
-
|
|
1641
|
-
if filters:
|
|
1642
|
-
thread = await Threads._get_with_filters(conn, thread_id, filters)
|
|
1643
|
-
if not thread:
|
|
1644
|
-
return _empty_generator()
|
|
1645
|
-
_delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
|
|
1646
|
-
found = False
|
|
1647
|
-
for i, run in enumerate(conn.store["runs"]):
|
|
1648
|
-
if run["run_id"] == run_id and run["thread_id"] == thread_id:
|
|
1649
|
-
del conn.store["runs"][i]
|
|
1650
|
-
found = True
|
|
1651
|
-
break
|
|
1652
|
-
if not found:
|
|
1653
|
-
raise HTTPException(status_code=404, detail="Run not found")
|
|
1654
|
-
|
|
1655
|
-
async def _yield_deleted():
|
|
1656
|
-
await logger.ainfo("Run deleted", run_id=run_id)
|
|
1657
|
-
yield run_id
|
|
1658
|
-
|
|
1659
|
-
return _yield_deleted()
|
|
1660
|
-
|
|
1661
|
-
@staticmethod
|
|
1662
|
-
async def join(
|
|
1663
|
-
run_id: UUID,
|
|
1664
|
-
*,
|
|
1665
|
-
thread_id: UUID,
|
|
1666
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1667
|
-
) -> Fragment:
|
|
1668
|
-
"""Wait for a run to complete. If already done, return immediately.
|
|
1669
|
-
|
|
1670
|
-
Returns:
|
|
1671
|
-
the final state of the run.
|
|
1672
|
-
"""
|
|
1673
|
-
async with connect() as conn:
|
|
1674
|
-
# Validate ownership
|
|
1675
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1676
|
-
await fetchone(thread_iter)
|
|
1677
|
-
last_chunk: bytes | None = None
|
|
1678
|
-
# wait for the run to complete
|
|
1679
|
-
# Rely on this join's auth
|
|
1680
|
-
async for mode, chunk in Runs.Stream.join(
|
|
1681
|
-
run_id, thread_id=thread_id, stream_mode="values", ctx=ctx, ignore_404=True
|
|
1682
|
-
):
|
|
1683
|
-
if mode == b"values":
|
|
1684
|
-
last_chunk = chunk
|
|
1685
|
-
# if we received a final chunk, return it
|
|
1686
|
-
if last_chunk is not None:
|
|
1687
|
-
# ie. if the run completed while we were waiting for it
|
|
1688
|
-
return Fragment(last_chunk)
|
|
1689
|
-
else:
|
|
1690
|
-
# otherwise, the run had already finished, so fetch the state from thread
|
|
1691
|
-
async with connect() as conn:
|
|
1692
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1693
|
-
thread = await fetchone(thread_iter)
|
|
1694
|
-
return thread["values"]
|
|
1695
|
-
|
|
1696
|
-
@staticmethod
|
|
1697
|
-
async def cancel(
|
|
1698
|
-
conn: InMemConnectionProto,
|
|
1699
|
-
run_ids: Sequence[UUID] | None = None,
|
|
1700
|
-
*,
|
|
1701
|
-
action: Literal["interrupt", "rollback"] = "interrupt",
|
|
1702
|
-
thread_id: UUID | None = None,
|
|
1703
|
-
status: Literal["pending", "running", "all"] | None = None,
|
|
1704
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1705
|
-
) -> None:
|
|
1706
|
-
"""
|
|
1707
|
-
Cancel runs in memory. Must provide either:
|
|
1708
|
-
1) thread_id + run_ids, or
|
|
1709
|
-
2) status in {"pending", "running", "all"}.
|
|
1710
|
-
|
|
1711
|
-
Steps:
|
|
1712
|
-
- Validate arguments (one usage pattern or the other).
|
|
1713
|
-
- Auth check: 'update' event via handle_event().
|
|
1714
|
-
- Gather runs matching either the (thread_id, run_ids) set or the given status.
|
|
1715
|
-
- For each run found:
|
|
1716
|
-
* Send a cancellation message through the stream manager.
|
|
1717
|
-
* If 'pending', set to 'interrupted' or delete (if action='rollback' and not actively queued).
|
|
1718
|
-
* If 'running', the worker will pick up the message.
|
|
1719
|
-
* Otherwise, log a warning for non-cancelable states.
|
|
1720
|
-
- 404 if no runs are found or authorized.
|
|
1721
|
-
"""
|
|
1722
|
-
# 1. Validate arguments
|
|
1723
|
-
if status is not None:
|
|
1724
|
-
# If status is set, user must NOT specify thread_id or run_ids
|
|
1725
|
-
if thread_id is not None or run_ids is not None:
|
|
1726
|
-
raise HTTPException(
|
|
1727
|
-
status_code=422,
|
|
1728
|
-
detail="Cannot specify 'thread_id' or 'run_ids' when using 'status'",
|
|
1729
|
-
)
|
|
1730
|
-
else:
|
|
1731
|
-
# If status is not set, user must specify both thread_id and run_ids
|
|
1732
|
-
if thread_id is None or run_ids is None:
|
|
1733
|
-
raise HTTPException(
|
|
1734
|
-
status_code=422,
|
|
1735
|
-
detail="Must provide either a status or both 'thread_id' and 'run_ids'",
|
|
1736
|
-
)
|
|
1737
|
-
|
|
1738
|
-
# Convert and normalize inputs
|
|
1739
|
-
if run_ids is not None:
|
|
1740
|
-
run_ids = [_ensure_uuid(rid) for rid in run_ids]
|
|
1741
|
-
if thread_id is not None:
|
|
1742
|
-
thread_id = _ensure_uuid(thread_id)
|
|
1743
|
-
|
|
1744
|
-
filters = await Runs.handle_event(
|
|
1745
|
-
ctx,
|
|
1746
|
-
"update",
|
|
1747
|
-
Auth.types.ThreadsUpdate(
|
|
1748
|
-
thread_id=thread_id, # type: ignore
|
|
1749
|
-
action=action,
|
|
1750
|
-
metadata={
|
|
1751
|
-
"run_ids": run_ids,
|
|
1752
|
-
"status": status,
|
|
1753
|
-
},
|
|
1754
|
-
),
|
|
1755
|
-
)
|
|
1756
|
-
|
|
1757
|
-
status_list: tuple[str, ...] = ()
|
|
1758
|
-
if status is not None:
|
|
1759
|
-
if status == "all":
|
|
1760
|
-
status_list = ("pending", "running")
|
|
1761
|
-
elif status in ("pending", "running"):
|
|
1762
|
-
status_list = (status,)
|
|
1763
|
-
else:
|
|
1764
|
-
raise ValueError(f"Unsupported status: {status}")
|
|
1765
|
-
|
|
1766
|
-
def is_run_match(r: dict) -> bool:
|
|
1767
|
-
"""
|
|
1768
|
-
Check whether a run in `conn.store["runs"]` meets the selection criteria.
|
|
1769
|
-
"""
|
|
1770
|
-
if status_list:
|
|
1771
|
-
return r["status"] in status_list
|
|
1772
|
-
else:
|
|
1773
|
-
return r["thread_id"] == thread_id and r["run_id"] in run_ids # type: ignore
|
|
1774
|
-
|
|
1775
|
-
candidate_runs = [r for r in conn.store["runs"] if is_run_match(r)]
|
|
1776
|
-
|
|
1777
|
-
if filters:
|
|
1778
|
-
# If a run is found but not authorized by the thread filters, skip it
|
|
1779
|
-
thread = (
|
|
1780
|
-
await Threads._get_with_filters(conn, thread_id, filters)
|
|
1781
|
-
if thread_id
|
|
1782
|
-
else None
|
|
1783
|
-
)
|
|
1784
|
-
# If there's no matching thread, no runs are authorized.
|
|
1785
|
-
if thread_id and not thread:
|
|
1786
|
-
candidate_runs = []
|
|
1787
|
-
# Otherwise, we might trust that `_get_with_filters` is the only constraint
|
|
1788
|
-
# on thread. If your filters also apply to runs, you might do more checks here.
|
|
1789
|
-
|
|
1790
|
-
if not candidate_runs:
|
|
1791
|
-
raise HTTPException(status_code=404, detail="No runs found to cancel.")
|
|
1792
|
-
|
|
1793
|
-
stream_manager = get_stream_manager()
|
|
1794
|
-
coros = []
|
|
1795
|
-
for run in candidate_runs:
|
|
1796
|
-
run_id = run["run_id"]
|
|
1797
|
-
control_message = Message(
|
|
1798
|
-
topic=f"run:{run_id}:control".encode(),
|
|
1799
|
-
data=action.encode(),
|
|
1800
|
-
)
|
|
1801
|
-
coros.append(stream_manager.put(run_id, control_message))
|
|
1802
|
-
|
|
1803
|
-
queues = stream_manager.get_queues(run_id)
|
|
1804
|
-
|
|
1805
|
-
if run["status"] in ("pending", "running"):
|
|
1806
|
-
if queues or action != "rollback":
|
|
1807
|
-
run["status"] = "interrupted"
|
|
1808
|
-
run["updated_at"] = datetime.now(tz=UTC)
|
|
1809
|
-
else:
|
|
1810
|
-
await logger.ainfo(
|
|
1811
|
-
"Eagerly deleting pending run with rollback action",
|
|
1812
|
-
run_id=str(run_id),
|
|
1813
|
-
status=run["status"],
|
|
1814
|
-
)
|
|
1815
|
-
coros.append(Runs.delete(conn, run_id, thread_id=run["thread_id"]))
|
|
1816
|
-
else:
|
|
1817
|
-
await logger.awarning(
|
|
1818
|
-
"Attempted to cancel non-pending run.",
|
|
1819
|
-
run_id=str(run_id),
|
|
1820
|
-
status=run["status"],
|
|
1821
|
-
)
|
|
1822
|
-
|
|
1823
|
-
if coros:
|
|
1824
|
-
await asyncio.gather(*coros)
|
|
1825
|
-
|
|
1826
|
-
await logger.ainfo(
|
|
1827
|
-
"Cancelled runs",
|
|
1828
|
-
run_ids=[str(r["run_id"]) for r in candidate_runs],
|
|
1829
|
-
thread_id=str(thread_id) if thread_id else None,
|
|
1830
|
-
status=status,
|
|
1831
|
-
action=action,
|
|
1832
|
-
)
|
|
1833
|
-
|
|
1834
|
-
@staticmethod
|
|
1835
|
-
async def search(
|
|
1836
|
-
conn: InMemConnectionProto,
|
|
1837
|
-
thread_id: UUID,
|
|
1838
|
-
*,
|
|
1839
|
-
limit: int = 10,
|
|
1840
|
-
offset: int = 0,
|
|
1841
|
-
metadata: MetadataInput,
|
|
1842
|
-
status: RunStatus | None = None,
|
|
1843
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1844
|
-
) -> AsyncIterator[Run]:
|
|
1845
|
-
"""List all runs by thread."""
|
|
1846
|
-
runs = conn.store["runs"]
|
|
1847
|
-
metadata = metadata if metadata is not None else {}
|
|
1848
|
-
thread_id = _ensure_uuid(thread_id)
|
|
1849
|
-
filters = await Runs.handle_event(
|
|
1850
|
-
ctx,
|
|
1851
|
-
"search",
|
|
1852
|
-
Auth.types.ThreadsSearch(thread_id=thread_id, metadata=metadata),
|
|
1853
|
-
)
|
|
1854
|
-
filtered_runs = [
|
|
1855
|
-
run
|
|
1856
|
-
for run in runs
|
|
1857
|
-
if run["thread_id"] == thread_id
|
|
1858
|
-
and is_jsonb_contained(run["metadata"], metadata)
|
|
1859
|
-
and (
|
|
1860
|
-
not filters
|
|
1861
|
-
or (await Threads._get_with_filters(conn, thread_id, filters))
|
|
1862
|
-
)
|
|
1863
|
-
and (status is None or run["status"] == status)
|
|
1864
|
-
]
|
|
1865
|
-
sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
|
|
1866
|
-
sliced_runs = sorted_runs[offset : offset + limit]
|
|
1867
|
-
|
|
1868
|
-
async def _return():
|
|
1869
|
-
for run in sliced_runs:
|
|
1870
|
-
yield run
|
|
1871
|
-
|
|
1872
|
-
return _return()
|
|
1873
|
-
|
|
1874
|
-
@staticmethod
|
|
1875
|
-
async def set_status(
|
|
1876
|
-
conn: InMemConnectionProto, run_id: UUID, status: RunStatus
|
|
1877
|
-
) -> None:
|
|
1878
|
-
"""Set the status of a run."""
|
|
1879
|
-
# Find the run in the store
|
|
1880
|
-
run_id = _ensure_uuid(run_id)
|
|
1881
|
-
run = next((run for run in conn.store["runs"] if run["run_id"] == run_id), None)
|
|
1882
|
-
|
|
1883
|
-
if run:
|
|
1884
|
-
# Update the status and updated_at timestamp
|
|
1885
|
-
run["status"] = status
|
|
1886
|
-
run["updated_at"] = datetime.now(tz=UTC)
|
|
1887
|
-
return run
|
|
1888
|
-
return None
|
|
1889
|
-
|
|
1890
|
-
class Stream:
|
|
1891
|
-
@staticmethod
|
|
1892
|
-
async def subscribe(
|
|
1893
|
-
run_id: UUID,
|
|
1894
|
-
*,
|
|
1895
|
-
stream_mode: "StreamMode | None" = None,
|
|
1896
|
-
) -> asyncio.Queue:
|
|
1897
|
-
"""Subscribe to the run stream, returning a queue."""
|
|
1898
|
-
stream_manager = get_stream_manager()
|
|
1899
|
-
queue = await stream_manager.add_queue(_ensure_uuid(run_id))
|
|
1900
|
-
|
|
1901
|
-
# If there's a control message already stored, send it to the new subscriber
|
|
1902
|
-
if control_messages := stream_manager.control_queues.get(run_id):
|
|
1903
|
-
for control_msg in control_messages:
|
|
1904
|
-
await queue.put(control_msg)
|
|
1905
|
-
return queue
|
|
1906
|
-
|
|
1907
|
-
@staticmethod
|
|
1908
|
-
async def join(
|
|
1909
|
-
run_id: UUID,
|
|
1910
|
-
*,
|
|
1911
|
-
thread_id: UUID,
|
|
1912
|
-
ignore_404: bool = False,
|
|
1913
|
-
cancel_on_disconnect: bool = False,
|
|
1914
|
-
stream_mode: "StreamMode | asyncio.Queue | None" = None,
|
|
1915
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1916
|
-
) -> AsyncIterator[tuple[bytes, bytes]]:
|
|
1917
|
-
"""Stream the run output."""
|
|
1918
|
-
log = logger.isEnabledFor(logging.DEBUG)
|
|
1919
|
-
queue = (
|
|
1920
|
-
stream_mode
|
|
1921
|
-
if isinstance(stream_mode, asyncio.Queue)
|
|
1922
|
-
else await Runs.Stream.subscribe(run_id, stream_mode=stream_mode)
|
|
1923
|
-
)
|
|
1924
|
-
|
|
1925
|
-
try:
|
|
1926
|
-
async with connect() as conn:
|
|
1927
|
-
filters = await Runs.handle_event(
|
|
1928
|
-
ctx,
|
|
1929
|
-
"read",
|
|
1930
|
-
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1931
|
-
)
|
|
1932
|
-
if filters:
|
|
1933
|
-
thread = await Threads._get_with_filters(
|
|
1934
|
-
cast(InMemConnectionProto, conn), thread_id, filters
|
|
1935
|
-
)
|
|
1936
|
-
if not thread:
|
|
1937
|
-
raise WrappedHTTPException(
|
|
1938
|
-
HTTPException(
|
|
1939
|
-
status_code=404, detail="Thread not found"
|
|
1940
|
-
)
|
|
1941
|
-
)
|
|
1942
|
-
channel_prefix = f"run:{run_id}:stream:"
|
|
1943
|
-
len_prefix = len(channel_prefix.encode())
|
|
1944
|
-
|
|
1945
|
-
while True:
|
|
1946
|
-
try:
|
|
1947
|
-
# Wait for messages with a timeout
|
|
1948
|
-
message = await asyncio.wait_for(queue.get(), timeout=0.5)
|
|
1949
|
-
topic, data = message.topic, message.data
|
|
1950
|
-
|
|
1951
|
-
if topic.decode() == f"run:{run_id}:control":
|
|
1952
|
-
if data == b"done":
|
|
1953
|
-
break
|
|
1954
|
-
else:
|
|
1955
|
-
# Extract mode from topic
|
|
1956
|
-
yield topic[len_prefix:], data
|
|
1957
|
-
if log:
|
|
1958
|
-
await logger.adebug(
|
|
1959
|
-
"Streamed run event",
|
|
1960
|
-
run_id=str(run_id),
|
|
1961
|
-
stream_mode=topic[len_prefix:],
|
|
1962
|
-
data=data,
|
|
1963
|
-
)
|
|
1964
|
-
except TimeoutError:
|
|
1965
|
-
# Check if the run is still pending
|
|
1966
|
-
run_iter = await Runs.get(
|
|
1967
|
-
conn, run_id, thread_id=thread_id, ctx=ctx
|
|
1968
|
-
)
|
|
1969
|
-
run = await anext(run_iter, None)
|
|
1970
|
-
|
|
1971
|
-
if ignore_404 and run is None:
|
|
1972
|
-
break
|
|
1973
|
-
elif run is None:
|
|
1974
|
-
yield (
|
|
1975
|
-
b"error",
|
|
1976
|
-
HTTPException(
|
|
1977
|
-
status_code=404, detail="Run not found"
|
|
1978
|
-
),
|
|
1979
|
-
)
|
|
1980
|
-
break
|
|
1981
|
-
elif run["status"] not in ("pending", "running"):
|
|
1982
|
-
break
|
|
1983
|
-
except WrappedHTTPException as e:
|
|
1984
|
-
raise e.http_exception from None
|
|
1985
|
-
except:
|
|
1986
|
-
if cancel_on_disconnect:
|
|
1987
|
-
create_task(cancel_run(thread_id, run_id))
|
|
1988
|
-
raise
|
|
1989
|
-
finally:
|
|
1990
|
-
stream_manager = get_stream_manager()
|
|
1991
|
-
await stream_manager.remove_queue(run_id, queue)
|
|
1992
|
-
|
|
1993
|
-
@staticmethod
|
|
1994
|
-
async def publish(
|
|
1995
|
-
run_id: UUID,
|
|
1996
|
-
event: str,
|
|
1997
|
-
message: bytes,
|
|
1998
|
-
) -> None:
|
|
1999
|
-
"""Publish a message to all subscribers of the run stream."""
|
|
2000
|
-
topic = f"run:{run_id}:stream:{event}".encode()
|
|
2001
|
-
|
|
2002
|
-
stream_manager = get_stream_manager()
|
|
2003
|
-
# Send to all queues subscribed to this run_id
|
|
2004
|
-
await stream_manager.put(run_id, Message(topic=topic, data=message))
|
|
2005
|
-
|
|
2006
|
-
|
|
2007
|
-
async def listen_for_cancellation(
|
|
2008
|
-
queue: asyncio.Queue, run_id: UUID, done: "ValueEvent"
|
|
2009
|
-
):
|
|
2010
|
-
"""Listen for cancellation messages and set the done event accordingly."""
|
|
2011
|
-
stream_manager = get_stream_manager()
|
|
2012
|
-
control_key = f"run:{run_id}:control"
|
|
2013
|
-
|
|
2014
|
-
if existing_queue := stream_manager.control_queues.get(run_id):
|
|
2015
|
-
for message in existing_queue:
|
|
2016
|
-
payload = message.data
|
|
2017
|
-
if payload == b"rollback":
|
|
2018
|
-
done.set(UserRollback())
|
|
2019
|
-
elif payload == b"interrupt":
|
|
2020
|
-
done.set(UserInterrupt())
|
|
2021
|
-
|
|
2022
|
-
while not done.is_set():
|
|
2023
|
-
try:
|
|
2024
|
-
# This task gets cancelled when Runs.enter exits anyway,
|
|
2025
|
-
# so we can have a pretty length timeout here
|
|
2026
|
-
message = await asyncio.wait_for(queue.get(), timeout=240)
|
|
2027
|
-
payload = message.data
|
|
2028
|
-
if payload == b"rollback":
|
|
2029
|
-
done.set(UserRollback())
|
|
2030
|
-
elif payload == b"interrupt":
|
|
2031
|
-
done.set(UserInterrupt())
|
|
2032
|
-
elif payload == b"done":
|
|
2033
|
-
done.set()
|
|
2034
|
-
break
|
|
2035
|
-
|
|
2036
|
-
# Store control messages for late subscribers
|
|
2037
|
-
if message.topic.decode() == control_key:
|
|
2038
|
-
stream_manager.control_queues[run_id].append(message)
|
|
2039
|
-
except TimeoutError:
|
|
2040
|
-
break
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
class Crons:
|
|
2044
|
-
@staticmethod
|
|
2045
|
-
async def put(
|
|
2046
|
-
conn: InMemConnectionProto,
|
|
2047
|
-
*,
|
|
2048
|
-
payload: dict,
|
|
2049
|
-
schedule: str,
|
|
2050
|
-
cron_id: UUID | None = None,
|
|
2051
|
-
thread_id: UUID | None = None,
|
|
2052
|
-
end_time: datetime | None = None,
|
|
2053
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2054
|
-
) -> AsyncIterator[Cron]:
|
|
2055
|
-
raise NotImplementedError
|
|
2056
|
-
|
|
2057
|
-
@staticmethod
|
|
2058
|
-
async def delete(
|
|
2059
|
-
conn: InMemConnectionProto,
|
|
2060
|
-
cron_id: UUID,
|
|
2061
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2062
|
-
) -> AsyncIterator[UUID]:
|
|
2063
|
-
raise NotImplementedError
|
|
2064
|
-
|
|
2065
|
-
@staticmethod
|
|
2066
|
-
async def next(
|
|
2067
|
-
conn: InMemConnectionProto,
|
|
2068
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2069
|
-
) -> AsyncIterator[Cron]:
|
|
2070
|
-
yield
|
|
2071
|
-
raise NotImplementedError("The in-mem server does not implement Crons.")
|
|
2072
|
-
|
|
2073
|
-
@staticmethod
|
|
2074
|
-
async def set_next_run_date(
|
|
2075
|
-
conn: InMemConnectionProto,
|
|
2076
|
-
cron_id: UUID,
|
|
2077
|
-
next_run_date: datetime,
|
|
2078
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2079
|
-
) -> None:
|
|
2080
|
-
raise NotImplementedError
|
|
2081
|
-
|
|
2082
|
-
@staticmethod
|
|
2083
|
-
async def search(
|
|
2084
|
-
conn: InMemConnectionProto,
|
|
2085
|
-
*,
|
|
2086
|
-
assistant_id: UUID | None,
|
|
2087
|
-
thread_id: UUID | None,
|
|
2088
|
-
limit: int,
|
|
2089
|
-
offset: int,
|
|
2090
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2091
|
-
) -> AsyncIterator[Cron]:
|
|
2092
|
-
raise NotImplementedError
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
async def cancel_run(
|
|
2096
|
-
thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
|
|
2097
|
-
) -> None:
|
|
2098
|
-
async with connect() as conn:
|
|
2099
|
-
await Runs.cancel(conn, [run_id], thread_id=thread_id, ctx=ctx)
|
|
2100
|
-
|
|
2101
|
-
|
|
2102
|
-
def _delete_checkpoints_for_thread(
|
|
2103
|
-
thread_id: str | UUID,
|
|
2104
|
-
conn: InMemConnectionProto,
|
|
2105
|
-
run_id: str | UUID | None = None,
|
|
2106
|
-
):
|
|
2107
|
-
checkpointer = Checkpointer(conn)
|
|
2108
|
-
thread_id = str(thread_id)
|
|
2109
|
-
if thread_id not in checkpointer.storage:
|
|
2110
|
-
return
|
|
2111
|
-
if run_id:
|
|
2112
|
-
# Look through metadata
|
|
2113
|
-
run_id = str(run_id)
|
|
2114
|
-
for checkpoint_ns, checkpoints in list(checkpointer.storage[thread_id].items()):
|
|
2115
|
-
for checkpoint_id, (_, metadata_b, _) in list(checkpoints.items()):
|
|
2116
|
-
metadata = checkpointer.serde.loads_typed(metadata_b)
|
|
2117
|
-
if metadata.get("run_id") == run_id:
|
|
2118
|
-
del checkpointer.storage[thread_id][checkpoint_ns][checkpoint_id]
|
|
2119
|
-
if not checkpointer.storage[thread_id][checkpoint_ns]:
|
|
2120
|
-
del checkpointer.storage[thread_id][checkpoint_ns]
|
|
2121
|
-
else:
|
|
2122
|
-
del checkpointer.storage[thread_id]
|
|
2123
|
-
# Keys are (thread_id, checkpoint_ns, checkpoint_id)
|
|
2124
|
-
checkpointer.writes = defaultdict(
|
|
2125
|
-
dict, {k: v for k, v in checkpointer.writes.items() if k[0] != thread_id}
|
|
2126
|
-
)
|
|
2127
|
-
|
|
2128
|
-
|
|
2129
|
-
def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -> bool:
|
|
2130
|
-
"""Check if metadata matches the filter conditions.
|
|
2131
|
-
|
|
2132
|
-
Args:
|
|
2133
|
-
metadata: The metadata to check
|
|
2134
|
-
filters: The filter conditions to apply
|
|
2135
|
-
|
|
2136
|
-
Returns:
|
|
2137
|
-
True if the metadata matches all filter conditions, False otherwise
|
|
2138
|
-
"""
|
|
2139
|
-
if not filters:
|
|
2140
|
-
return True
|
|
2141
|
-
|
|
2142
|
-
for key, value in filters.items():
|
|
2143
|
-
if isinstance(value, dict):
|
|
2144
|
-
op = next(iter(value))
|
|
2145
|
-
filter_value = value[op]
|
|
2146
|
-
|
|
2147
|
-
if op == "$eq":
|
|
2148
|
-
if key not in metadata or metadata[key] != filter_value:
|
|
2149
|
-
return False
|
|
2150
|
-
elif op == "$contains":
|
|
2151
|
-
if (
|
|
2152
|
-
key not in metadata
|
|
2153
|
-
or not isinstance(metadata[key], list)
|
|
2154
|
-
or filter_value not in metadata[key]
|
|
2155
|
-
):
|
|
2156
|
-
return False
|
|
2157
|
-
else:
|
|
2158
|
-
# Direct equality
|
|
2159
|
-
if key not in metadata or metadata[key] != value:
|
|
2160
|
-
return False
|
|
2161
|
-
|
|
2162
|
-
return True
|
|
2163
|
-
|
|
2164
|
-
|
|
2165
|
-
async def _empty_generator():
|
|
2166
|
-
if False:
|
|
2167
|
-
yield
|
|
2168
|
-
|
|
2169
|
-
|
|
2170
|
-
__all__ = [
|
|
2171
|
-
"Assistants",
|
|
2172
|
-
"Crons",
|
|
2173
|
-
"Runs",
|
|
2174
|
-
"Threads",
|
|
2175
|
-
]
|