langgraph-api 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- langgraph_api/api/openapi.py +38 -1
- langgraph_api/api/runs.py +70 -36
- langgraph_api/auth/custom.py +520 -0
- langgraph_api/auth/middleware.py +8 -3
- langgraph_api/cli.py +47 -1
- langgraph_api/config.py +24 -0
- langgraph_api/cron_scheduler.py +32 -27
- langgraph_api/graph.py +0 -11
- langgraph_api/models/run.py +77 -19
- langgraph_api/route.py +2 -0
- langgraph_api/utils.py +32 -0
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.9.dist-info}/METADATA +2 -2
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.9.dist-info}/RECORD +19 -18
- langgraph_storage/checkpoint.py +16 -0
- langgraph_storage/database.py +17 -1
- langgraph_storage/ops.py +495 -69
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.9.dist-info}/LICENSE +0 -0
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.9.dist-info}/WHEEL +0 -0
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.9.dist-info}/entry_points.txt +0 -0
langgraph_storage/ops.py
CHANGED
|
@@ -11,15 +11,17 @@ from collections.abc import AsyncIterator, Sequence
|
|
|
11
11
|
from contextlib import asynccontextmanager
|
|
12
12
|
from copy import deepcopy
|
|
13
13
|
from datetime import UTC, datetime, timedelta
|
|
14
|
-
from typing import Any, Literal
|
|
14
|
+
from typing import Any, Literal, cast
|
|
15
15
|
from uuid import UUID, uuid4
|
|
16
16
|
|
|
17
17
|
import structlog
|
|
18
18
|
from langgraph.pregel.debug import CheckpointPayload
|
|
19
19
|
from langgraph.pregel.types import StateSnapshot
|
|
20
|
+
from langgraph_sdk import Auth
|
|
20
21
|
from starlette.exceptions import HTTPException
|
|
21
22
|
|
|
22
23
|
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
|
|
24
|
+
from langgraph_api.auth.custom import handle_event
|
|
23
25
|
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
24
26
|
from langgraph_api.graph import get_graph
|
|
25
27
|
from langgraph_api.schema import (
|
|
@@ -41,7 +43,7 @@ from langgraph_api.schema import (
|
|
|
41
43
|
ThreadUpdateResponse,
|
|
42
44
|
)
|
|
43
45
|
from langgraph_api.serde import Fragment
|
|
44
|
-
from langgraph_api.utils import fetchone
|
|
46
|
+
from langgraph_api.utils import fetchone, get_auth_ctx
|
|
45
47
|
from langgraph_storage.checkpoint import Checkpointer
|
|
46
48
|
from langgraph_storage.database import InMemConnectionProto, connect
|
|
47
49
|
from langgraph_storage.queue import Message, get_stream_manager
|
|
@@ -57,13 +59,51 @@ def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
|
|
|
57
59
|
return id_
|
|
58
60
|
|
|
59
61
|
|
|
62
|
+
class WrappedHTTPException(Exception):
|
|
63
|
+
def __init__(self, http_exception: HTTPException):
|
|
64
|
+
self.http_exception = http_exception
|
|
65
|
+
|
|
66
|
+
|
|
60
67
|
# Right now the whole API types as UUID but frequently passes a str
|
|
61
68
|
# We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
|
|
62
69
|
# which we leave as strings. This is because I'm too lazy to subclass fully
|
|
63
70
|
# and we use non-UUID examples in the OSS version
|
|
64
71
|
|
|
65
72
|
|
|
66
|
-
class
|
|
73
|
+
class Authenticated:
|
|
74
|
+
resource: Literal["threads", "crons", "assistants"]
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def _context(
|
|
78
|
+
cls,
|
|
79
|
+
ctx: Auth.types.BaseAuthContext | None,
|
|
80
|
+
action: Literal["create", "read", "update", "delete", "create_run"],
|
|
81
|
+
) -> Auth.types.AuthContext | None:
|
|
82
|
+
if not ctx:
|
|
83
|
+
return
|
|
84
|
+
return Auth.types.AuthContext(
|
|
85
|
+
user=ctx.user,
|
|
86
|
+
scopes=ctx.scopes,
|
|
87
|
+
resource=cls.resource,
|
|
88
|
+
action=action,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
async def handle_event(
|
|
93
|
+
cls,
|
|
94
|
+
ctx: Auth.types.BaseAuthContext | None,
|
|
95
|
+
action: Literal["create", "read", "update", "delete", "search"],
|
|
96
|
+
value: Any,
|
|
97
|
+
) -> Auth.types.FilterType | None:
|
|
98
|
+
ctx = ctx or get_auth_ctx()
|
|
99
|
+
if not ctx:
|
|
100
|
+
return
|
|
101
|
+
return await handle_event(cls._context(ctx, action), value)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Assistants(Authenticated):
|
|
105
|
+
resource = "assistants"
|
|
106
|
+
|
|
67
107
|
@staticmethod
|
|
68
108
|
async def search(
|
|
69
109
|
conn: InMemConnectionProto,
|
|
@@ -72,7 +112,17 @@ class Assistants:
|
|
|
72
112
|
metadata: MetadataInput,
|
|
73
113
|
limit: int,
|
|
74
114
|
offset: int,
|
|
115
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
75
116
|
) -> AsyncIterator[Assistant]:
|
|
117
|
+
metadata = metadata if metadata is not None else {}
|
|
118
|
+
filters = await Assistants.handle_event(
|
|
119
|
+
ctx,
|
|
120
|
+
"search",
|
|
121
|
+
Auth.types.AssistantsSearch(
|
|
122
|
+
graph_id=graph_id, metadata=metadata, limit=limit, offset=offset
|
|
123
|
+
),
|
|
124
|
+
)
|
|
125
|
+
|
|
76
126
|
async def filter_and_yield() -> AsyncIterator[Assistant]:
|
|
77
127
|
assistants = conn.store["assistants"]
|
|
78
128
|
filtered_assistants = [
|
|
@@ -82,6 +132,7 @@ class Assistants:
|
|
|
82
132
|
and (
|
|
83
133
|
not metadata or is_jsonb_contained(assistant["metadata"], metadata)
|
|
84
134
|
)
|
|
135
|
+
and (not filters or _check_filter_match(assistant["metadata"], filters))
|
|
85
136
|
]
|
|
86
137
|
filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
|
|
87
138
|
for assistant in filtered_assistants[offset : offset + limit]:
|
|
@@ -91,14 +142,23 @@ class Assistants:
|
|
|
91
142
|
|
|
92
143
|
@staticmethod
|
|
93
144
|
async def get(
|
|
94
|
-
conn: InMemConnectionProto,
|
|
145
|
+
conn: InMemConnectionProto,
|
|
146
|
+
assistant_id: UUID,
|
|
147
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
95
148
|
) -> AsyncIterator[Assistant]:
|
|
96
149
|
"""Get an assistant by ID."""
|
|
97
150
|
assistant_id = _ensure_uuid(assistant_id)
|
|
151
|
+
filters = await Assistants.handle_event(
|
|
152
|
+
ctx,
|
|
153
|
+
"read",
|
|
154
|
+
Auth.types.AssistantsRead(assistant_id=assistant_id),
|
|
155
|
+
)
|
|
98
156
|
|
|
99
157
|
async def _yield_result():
|
|
100
158
|
for assistant in conn.store["assistants"]:
|
|
101
|
-
if assistant["assistant_id"] == assistant_id
|
|
159
|
+
if assistant["assistant_id"] == assistant_id and (
|
|
160
|
+
not filters or _check_filter_match(assistant["metadata"], filters)
|
|
161
|
+
):
|
|
102
162
|
yield assistant
|
|
103
163
|
|
|
104
164
|
return _yield_result()
|
|
@@ -113,14 +173,33 @@ class Assistants:
|
|
|
113
173
|
metadata: MetadataInput,
|
|
114
174
|
if_exists: OnConflictBehavior,
|
|
115
175
|
name: str,
|
|
176
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
116
177
|
) -> AsyncIterator[Assistant]:
|
|
117
178
|
"""Insert an assistant."""
|
|
118
179
|
assistant_id = _ensure_uuid(assistant_id)
|
|
180
|
+
metadata = metadata if metadata is not None else {}
|
|
181
|
+
filters = await Assistants.handle_event(
|
|
182
|
+
ctx,
|
|
183
|
+
"create",
|
|
184
|
+
Auth.types.AssistantsCreate(
|
|
185
|
+
assistant_id=assistant_id,
|
|
186
|
+
graph_id=graph_id,
|
|
187
|
+
config=config,
|
|
188
|
+
metadata=metadata,
|
|
189
|
+
name=name,
|
|
190
|
+
),
|
|
191
|
+
)
|
|
119
192
|
existing_assistant = next(
|
|
120
193
|
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
121
194
|
None,
|
|
122
195
|
)
|
|
123
196
|
if existing_assistant:
|
|
197
|
+
if filters and not _check_filter_match(
|
|
198
|
+
existing_assistant["metadata"], filters
|
|
199
|
+
):
|
|
200
|
+
raise HTTPException(
|
|
201
|
+
status_code=409, detail=f"Assistant {assistant_id} already exists"
|
|
202
|
+
)
|
|
124
203
|
if if_exists == "raise":
|
|
125
204
|
raise HTTPException(
|
|
126
205
|
status_code=409, detail=f"Assistant {assistant_id} already exists"
|
|
@@ -168,6 +247,7 @@ class Assistants:
|
|
|
168
247
|
graph_id: str | None = None,
|
|
169
248
|
metadata: MetadataInput | None = None,
|
|
170
249
|
name: str | None = None,
|
|
250
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
171
251
|
) -> AsyncIterator[Assistant]:
|
|
172
252
|
"""Update an assistant.
|
|
173
253
|
|
|
@@ -182,6 +262,18 @@ class Assistants:
|
|
|
182
262
|
return the updated assistant model.
|
|
183
263
|
"""
|
|
184
264
|
assistant_id = _ensure_uuid(assistant_id)
|
|
265
|
+
metadata = metadata if metadata is not None else {}
|
|
266
|
+
filters = await Assistants.handle_event(
|
|
267
|
+
ctx,
|
|
268
|
+
"update",
|
|
269
|
+
Auth.types.AssistantsUpdate(
|
|
270
|
+
assistant_id=assistant_id,
|
|
271
|
+
graph_id=graph_id,
|
|
272
|
+
config=config,
|
|
273
|
+
metadata=metadata,
|
|
274
|
+
name=name,
|
|
275
|
+
),
|
|
276
|
+
)
|
|
185
277
|
assistant = next(
|
|
186
278
|
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
187
279
|
None,
|
|
@@ -190,6 +282,10 @@ class Assistants:
|
|
|
190
282
|
raise HTTPException(
|
|
191
283
|
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
192
284
|
)
|
|
285
|
+
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
286
|
+
raise HTTPException(
|
|
287
|
+
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
288
|
+
)
|
|
193
289
|
|
|
194
290
|
now = datetime.now(UTC)
|
|
195
291
|
new_version = (
|
|
@@ -204,6 +300,11 @@ class Assistants:
|
|
|
204
300
|
)
|
|
205
301
|
|
|
206
302
|
# Update assistant_versions table
|
|
303
|
+
if metadata:
|
|
304
|
+
metadata = {
|
|
305
|
+
**assistant["metadata"],
|
|
306
|
+
**metadata,
|
|
307
|
+
}
|
|
207
308
|
new_version_entry = {
|
|
208
309
|
"assistant_id": assistant_id,
|
|
209
310
|
"version": new_version,
|
|
@@ -233,10 +334,33 @@ class Assistants:
|
|
|
233
334
|
|
|
234
335
|
@staticmethod
|
|
235
336
|
async def delete(
|
|
236
|
-
conn: InMemConnectionProto,
|
|
337
|
+
conn: InMemConnectionProto,
|
|
338
|
+
assistant_id: UUID,
|
|
339
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
237
340
|
) -> AsyncIterator[UUID]:
|
|
238
341
|
"""Delete an assistant by ID."""
|
|
239
342
|
assistant_id = _ensure_uuid(assistant_id)
|
|
343
|
+
filters = await Assistants.handle_event(
|
|
344
|
+
ctx,
|
|
345
|
+
"delete",
|
|
346
|
+
Auth.types.AssistantsDelete(
|
|
347
|
+
assistant_id=assistant_id,
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
assistant = next(
|
|
351
|
+
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
352
|
+
None,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
if not assistant:
|
|
356
|
+
raise HTTPException(
|
|
357
|
+
status_code=404, detail=f"Assistant with ID {assistant_id} not found"
|
|
358
|
+
)
|
|
359
|
+
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
360
|
+
raise HTTPException(
|
|
361
|
+
status_code=404, detail=f"Assistant with ID {assistant_id} not found"
|
|
362
|
+
)
|
|
363
|
+
|
|
240
364
|
conn.store["assistants"] = [
|
|
241
365
|
a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
|
|
242
366
|
]
|
|
@@ -249,9 +373,10 @@ class Assistants:
|
|
|
249
373
|
retained = []
|
|
250
374
|
for run in conn.store["runs"]:
|
|
251
375
|
if run["assistant_id"] == assistant_id:
|
|
252
|
-
res = await Runs.delete(
|
|
376
|
+
res = await Runs.delete(
|
|
377
|
+
conn, run["run_id"], thread_id=run["thread_id"], ctx=ctx
|
|
378
|
+
)
|
|
253
379
|
await anext(res)
|
|
254
|
-
|
|
255
380
|
else:
|
|
256
381
|
retained.append(run)
|
|
257
382
|
|
|
@@ -262,10 +387,21 @@ class Assistants:
|
|
|
262
387
|
|
|
263
388
|
@staticmethod
|
|
264
389
|
async def set_latest(
|
|
265
|
-
conn: InMemConnectionProto,
|
|
390
|
+
conn: InMemConnectionProto,
|
|
391
|
+
assistant_id: UUID,
|
|
392
|
+
version: int,
|
|
393
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
266
394
|
) -> AsyncIterator[Assistant]:
|
|
267
395
|
"""Change the version of an assistant."""
|
|
268
396
|
assistant_id = _ensure_uuid(assistant_id)
|
|
397
|
+
filters = await Assistants.handle_event(
|
|
398
|
+
ctx,
|
|
399
|
+
"update",
|
|
400
|
+
Auth.types.AssistantsUpdate(
|
|
401
|
+
assistant_id=assistant_id,
|
|
402
|
+
version=version,
|
|
403
|
+
),
|
|
404
|
+
)
|
|
269
405
|
assistant = next(
|
|
270
406
|
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
271
407
|
None,
|
|
@@ -274,6 +410,10 @@ class Assistants:
|
|
|
274
410
|
raise HTTPException(
|
|
275
411
|
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
276
412
|
)
|
|
413
|
+
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
414
|
+
raise HTTPException(
|
|
415
|
+
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
416
|
+
)
|
|
277
417
|
|
|
278
418
|
version_data = next(
|
|
279
419
|
(
|
|
@@ -310,14 +450,21 @@ class Assistants:
|
|
|
310
450
|
metadata: MetadataInput,
|
|
311
451
|
limit: int,
|
|
312
452
|
offset: int,
|
|
453
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
313
454
|
) -> AsyncIterator[Assistant]:
|
|
314
455
|
"""Get all versions of an assistant."""
|
|
315
456
|
assistant_id = _ensure_uuid(assistant_id)
|
|
457
|
+
filters = await Assistants.handle_event(
|
|
458
|
+
ctx,
|
|
459
|
+
"read",
|
|
460
|
+
Auth.types.AssistantsRead(assistant_id=assistant_id),
|
|
461
|
+
)
|
|
316
462
|
versions = [
|
|
317
463
|
v
|
|
318
464
|
for v in conn.store["assistant_versions"]
|
|
319
465
|
if v["assistant_id"] == assistant_id
|
|
320
466
|
and (not metadata or is_jsonb_contained(v["metadata"], metadata))
|
|
467
|
+
and (not filters or _check_filter_match(v["metadata"], filters))
|
|
321
468
|
]
|
|
322
469
|
versions.sort(key=lambda x: x["version"], reverse=True)
|
|
323
470
|
|
|
@@ -379,7 +526,9 @@ def _replace_thread_id(data, new_thread_id, thread_id):
|
|
|
379
526
|
return d
|
|
380
527
|
|
|
381
528
|
|
|
382
|
-
class Threads:
|
|
529
|
+
class Threads(Authenticated):
|
|
530
|
+
resource = "threads"
|
|
531
|
+
|
|
383
532
|
@staticmethod
|
|
384
533
|
async def search(
|
|
385
534
|
conn: InMemConnectionProto,
|
|
@@ -389,29 +538,43 @@ class Threads:
|
|
|
389
538
|
status: ThreadStatus | None,
|
|
390
539
|
limit: int,
|
|
391
540
|
offset: int,
|
|
541
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
392
542
|
) -> AsyncIterator[Thread]:
|
|
393
543
|
threads = conn.store["threads"]
|
|
394
544
|
filtered_threads: list[Thread] = []
|
|
545
|
+
metadata = metadata if metadata is not None else {}
|
|
546
|
+
values = values if values is not None else {}
|
|
547
|
+
filters = await Threads.handle_event(
|
|
548
|
+
ctx,
|
|
549
|
+
"search",
|
|
550
|
+
Auth.types.ThreadsSearch(
|
|
551
|
+
metadata=metadata,
|
|
552
|
+
values=values,
|
|
553
|
+
status=status,
|
|
554
|
+
limit=limit,
|
|
555
|
+
offset=offset,
|
|
556
|
+
),
|
|
557
|
+
)
|
|
395
558
|
|
|
396
559
|
# Apply filters
|
|
397
560
|
for thread in threads:
|
|
398
|
-
|
|
561
|
+
if filters and not _check_filter_match(thread["metadata"], filters):
|
|
562
|
+
continue
|
|
399
563
|
|
|
400
564
|
if metadata and not is_jsonb_contained(thread["metadata"], metadata):
|
|
401
|
-
|
|
565
|
+
continue
|
|
402
566
|
|
|
403
567
|
if (
|
|
404
568
|
values
|
|
405
569
|
and "values" in thread
|
|
406
570
|
and not is_jsonb_contained(thread["values"], values)
|
|
407
571
|
):
|
|
408
|
-
|
|
572
|
+
continue
|
|
409
573
|
|
|
410
574
|
if status and thread.get("status") != status:
|
|
411
|
-
|
|
575
|
+
continue
|
|
412
576
|
|
|
413
|
-
|
|
414
|
-
filtered_threads.append(thread)
|
|
577
|
+
filtered_threads.append(thread)
|
|
415
578
|
|
|
416
579
|
# Sort by created_at in descending order
|
|
417
580
|
sorted_threads = sorted(
|
|
@@ -428,8 +591,11 @@ class Threads:
|
|
|
428
591
|
return thread_iterator()
|
|
429
592
|
|
|
430
593
|
@staticmethod
|
|
431
|
-
async def
|
|
432
|
-
|
|
594
|
+
async def _get_with_filters(
|
|
595
|
+
conn: InMemConnectionProto,
|
|
596
|
+
thread_id: UUID,
|
|
597
|
+
filters: Auth.types.FilterType | None,
|
|
598
|
+
) -> Thread | None:
|
|
433
599
|
thread_id = _ensure_uuid(thread_id)
|
|
434
600
|
matching_thread = next(
|
|
435
601
|
(
|
|
@@ -439,6 +605,37 @@ class Threads:
|
|
|
439
605
|
),
|
|
440
606
|
None,
|
|
441
607
|
)
|
|
608
|
+
if not matching_thread or (
|
|
609
|
+
filters and not _check_filter_match(matching_thread["metadata"], filters)
|
|
610
|
+
):
|
|
611
|
+
return
|
|
612
|
+
|
|
613
|
+
return matching_thread
|
|
614
|
+
|
|
615
|
+
@staticmethod
|
|
616
|
+
async def _get(
|
|
617
|
+
conn: InMemConnectionProto,
|
|
618
|
+
thread_id: UUID,
|
|
619
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
620
|
+
) -> Thread | None:
|
|
621
|
+
"""Get a thread by ID."""
|
|
622
|
+
thread_id = _ensure_uuid(thread_id)
|
|
623
|
+
filters = await Threads.handle_event(
|
|
624
|
+
ctx,
|
|
625
|
+
"read",
|
|
626
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
627
|
+
)
|
|
628
|
+
return await Threads._get_with_filters(conn, thread_id, filters)
|
|
629
|
+
|
|
630
|
+
@staticmethod
|
|
631
|
+
async def get(
|
|
632
|
+
conn: InMemConnectionProto,
|
|
633
|
+
thread_id: UUID,
|
|
634
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
635
|
+
) -> AsyncIterator[Thread]:
|
|
636
|
+
"""Get a thread by ID."""
|
|
637
|
+
matching_thread = await Threads._get(conn, thread_id, ctx)
|
|
638
|
+
|
|
442
639
|
if not matching_thread:
|
|
443
640
|
raise HTTPException(
|
|
444
641
|
status_code=404, detail=f"Thread with ID {thread_id} not found"
|
|
@@ -457,6 +654,7 @@ class Threads:
|
|
|
457
654
|
*,
|
|
458
655
|
metadata: MetadataInput,
|
|
459
656
|
if_exists: OnConflictBehavior,
|
|
657
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
460
658
|
) -> AsyncIterator[Thread]:
|
|
461
659
|
"""Insert or update a thread."""
|
|
462
660
|
thread_id = _ensure_uuid(thread_id)
|
|
@@ -467,8 +665,22 @@ class Threads:
|
|
|
467
665
|
existing_thread = next(
|
|
468
666
|
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
469
667
|
)
|
|
668
|
+
filters = await Threads.handle_event(
|
|
669
|
+
ctx,
|
|
670
|
+
"create",
|
|
671
|
+
Auth.types.ThreadsCreate(
|
|
672
|
+
thread_id=thread_id, metadata=metadata, if_exists=if_exists
|
|
673
|
+
),
|
|
674
|
+
)
|
|
470
675
|
|
|
471
676
|
if existing_thread:
|
|
677
|
+
if filters and not _check_filter_match(
|
|
678
|
+
existing_thread["metadata"], filters
|
|
679
|
+
):
|
|
680
|
+
# Should we use a different status code here?
|
|
681
|
+
raise HTTPException(
|
|
682
|
+
status_code=409, detail=f"Thread with ID {thread_id} already exists"
|
|
683
|
+
)
|
|
472
684
|
if if_exists == "raise":
|
|
473
685
|
raise HTTPException(
|
|
474
686
|
status_code=409, detail=f"Thread with ID {thread_id} already exists"
|
|
@@ -479,7 +691,6 @@ class Threads:
|
|
|
479
691
|
yield existing_thread
|
|
480
692
|
|
|
481
693
|
return _yield_existing()
|
|
482
|
-
|
|
483
694
|
# Create new thread
|
|
484
695
|
new_thread: Thread = {
|
|
485
696
|
"thread_id": thread_id,
|
|
@@ -501,7 +712,11 @@ class Threads:
|
|
|
501
712
|
|
|
502
713
|
@staticmethod
|
|
503
714
|
async def patch(
|
|
504
|
-
conn: InMemConnectionProto,
|
|
715
|
+
conn: InMemConnectionProto,
|
|
716
|
+
thread_id: UUID,
|
|
717
|
+
*,
|
|
718
|
+
metadata: MetadataValue,
|
|
719
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
505
720
|
) -> AsyncIterator[Thread]:
|
|
506
721
|
"""Update a thread."""
|
|
507
722
|
thread_list = conn.store["threads"]
|
|
@@ -514,15 +729,23 @@ class Threads:
|
|
|
514
729
|
break
|
|
515
730
|
|
|
516
731
|
if thread_idx is not None:
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
732
|
+
filters = await Threads.handle_event(
|
|
733
|
+
ctx,
|
|
734
|
+
"update",
|
|
735
|
+
Auth.types.ThreadsUpdate(thread_id=thread_id, metadata=metadata),
|
|
736
|
+
)
|
|
737
|
+
if not filters or _check_filter_match(
|
|
738
|
+
thread_list[thread_idx]["metadata"], filters
|
|
739
|
+
):
|
|
740
|
+
thread = copy.deepcopy(thread_list[thread_idx])
|
|
741
|
+
thread["metadata"] = {**thread["metadata"], **metadata}
|
|
742
|
+
thread["updated_at"] = datetime.now(UTC)
|
|
743
|
+
thread_list[thread_idx] = thread
|
|
521
744
|
|
|
522
|
-
|
|
523
|
-
|
|
745
|
+
async def thread_iterator() -> AsyncIterator[Thread]:
|
|
746
|
+
yield thread
|
|
524
747
|
|
|
525
|
-
|
|
748
|
+
return thread_iterator()
|
|
526
749
|
|
|
527
750
|
async def empty_iterator() -> AsyncIterator[Thread]:
|
|
528
751
|
if False: # This ensures the iterator is empty
|
|
@@ -536,6 +759,7 @@ class Threads:
|
|
|
536
759
|
thread_id: UUID,
|
|
537
760
|
checkpoint: CheckpointPayload | None,
|
|
538
761
|
exception: BaseException | None,
|
|
762
|
+
# This does not accept the auth context since it's only used internally
|
|
539
763
|
) -> None:
|
|
540
764
|
"""Set the status of a thread."""
|
|
541
765
|
thread_id = _ensure_uuid(thread_id)
|
|
@@ -597,19 +821,34 @@ class Threads:
|
|
|
597
821
|
|
|
598
822
|
@staticmethod
|
|
599
823
|
async def delete(
|
|
600
|
-
conn: InMemConnectionProto,
|
|
824
|
+
conn: InMemConnectionProto,
|
|
825
|
+
thread_id: UUID,
|
|
826
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
601
827
|
) -> AsyncIterator[UUID]:
|
|
602
828
|
"""Delete a thread by ID and cascade delete all associated runs."""
|
|
603
829
|
thread_list = conn.store["threads"]
|
|
604
830
|
thread_idx = None
|
|
605
831
|
thread_id = _ensure_uuid(thread_id)
|
|
606
|
-
conn.locks.pop(thread_id, None)
|
|
607
832
|
|
|
608
833
|
# Find the thread to delete
|
|
609
834
|
for idx, thread in enumerate(thread_list):
|
|
610
835
|
if thread["thread_id"] == thread_id:
|
|
611
836
|
thread_idx = idx
|
|
612
837
|
break
|
|
838
|
+
filters = await Threads.handle_event(
|
|
839
|
+
ctx,
|
|
840
|
+
"delete",
|
|
841
|
+
Auth.types.ThreadsDelete(thread_id=thread_id),
|
|
842
|
+
)
|
|
843
|
+
if (filters and not _check_filter_match(thread["metadata"], filters)) or (
|
|
844
|
+
thread_idx is None
|
|
845
|
+
):
|
|
846
|
+
raise HTTPException(
|
|
847
|
+
status_code=404, detail=f"Thread with ID {thread_id} not found"
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
# Delete the thread
|
|
851
|
+
conn.locks.pop(thread_id, None)
|
|
613
852
|
# Cascade delete all runs associated with this thread
|
|
614
853
|
conn.store["runs"] = [
|
|
615
854
|
run for run in conn.store["runs"] if run["thread_id"] != thread_id
|
|
@@ -635,12 +874,20 @@ class Threads:
|
|
|
635
874
|
|
|
636
875
|
@staticmethod
|
|
637
876
|
async def copy(
|
|
638
|
-
conn: InMemConnectionProto,
|
|
877
|
+
conn: InMemConnectionProto,
|
|
878
|
+
thread_id: UUID,
|
|
879
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
639
880
|
) -> AsyncIterator[Thread]:
|
|
640
881
|
"""Create a copy of an existing thread."""
|
|
641
882
|
thread_id = _ensure_uuid(thread_id)
|
|
642
883
|
new_thread_id = uuid4()
|
|
643
|
-
|
|
884
|
+
filters = await Threads.handle_event(
|
|
885
|
+
ctx,
|
|
886
|
+
"read",
|
|
887
|
+
Auth.types.ThreadsRead(
|
|
888
|
+
thread_id=new_thread_id,
|
|
889
|
+
),
|
|
890
|
+
)
|
|
644
891
|
async with conn.pipeline():
|
|
645
892
|
# Find the original thread in our store
|
|
646
893
|
original_thread = next(
|
|
@@ -648,7 +895,11 @@ class Threads:
|
|
|
648
895
|
)
|
|
649
896
|
|
|
650
897
|
if not original_thread:
|
|
651
|
-
return
|
|
898
|
+
return _empty_generator()
|
|
899
|
+
if filters and not _check_filter_match(
|
|
900
|
+
original_thread["metadata"], filters
|
|
901
|
+
):
|
|
902
|
+
return _empty_generator()
|
|
652
903
|
|
|
653
904
|
# Create new thread with copied metadata
|
|
654
905
|
new_thread: Thread = {
|
|
@@ -690,15 +941,22 @@ class Threads:
|
|
|
690
941
|
|
|
691
942
|
return row_generator()
|
|
692
943
|
|
|
693
|
-
class State:
|
|
944
|
+
class State(Authenticated):
|
|
945
|
+
# We will treat this like a runs resource for now.
|
|
946
|
+
resource = "threads"
|
|
947
|
+
|
|
694
948
|
@staticmethod
|
|
695
949
|
async def get(
|
|
696
|
-
conn: InMemConnectionProto,
|
|
950
|
+
conn: InMemConnectionProto,
|
|
951
|
+
config: Config,
|
|
952
|
+
subgraphs: bool = False,
|
|
953
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
697
954
|
) -> StateSnapshot:
|
|
698
955
|
"""Get state for a thread."""
|
|
699
956
|
checkpointer = Checkpointer(conn)
|
|
700
957
|
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
701
|
-
|
|
958
|
+
# Auth will be applied here so no need to use filters downstream
|
|
959
|
+
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
702
960
|
thread = await anext(thread_iter)
|
|
703
961
|
checkpoint = await checkpointer.aget(config)
|
|
704
962
|
|
|
@@ -747,12 +1005,19 @@ class Threads:
|
|
|
747
1005
|
config: Config,
|
|
748
1006
|
values: Sequence[dict] | dict[str, Any] | None,
|
|
749
1007
|
as_node: str | None = None,
|
|
1008
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
750
1009
|
) -> ThreadUpdateResponse:
|
|
751
1010
|
"""Add state to a thread."""
|
|
1011
|
+
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
1012
|
+
filters = await Threads.handle_event(
|
|
1013
|
+
ctx,
|
|
1014
|
+
"update",
|
|
1015
|
+
Auth.types.ThreadsUpdate(thread_id=thread_id),
|
|
1016
|
+
)
|
|
752
1017
|
|
|
753
1018
|
checkpointer = Checkpointer(conn)
|
|
754
|
-
|
|
755
|
-
thread_iter = await Threads.get(conn, thread_id)
|
|
1019
|
+
|
|
1020
|
+
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
756
1021
|
thread = await fetchone(
|
|
757
1022
|
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
758
1023
|
)
|
|
@@ -760,6 +1025,8 @@ class Threads:
|
|
|
760
1025
|
|
|
761
1026
|
if not thread:
|
|
762
1027
|
raise HTTPException(status_code=404, detail="Thread not found")
|
|
1028
|
+
if not _check_filter_match(thread["metadata"], filters):
|
|
1029
|
+
raise HTTPException(status_code=403, detail="Forbidden")
|
|
763
1030
|
|
|
764
1031
|
metadata = thread["metadata"]
|
|
765
1032
|
thread_config = thread["config"]
|
|
@@ -781,7 +1048,7 @@ class Threads:
|
|
|
781
1048
|
)
|
|
782
1049
|
|
|
783
1050
|
# Get current state
|
|
784
|
-
state = await Threads.State.get(conn, config, subgraphs=False)
|
|
1051
|
+
state = await Threads.State.get(conn, config, subgraphs=False, ctx=ctx)
|
|
785
1052
|
# Update thread values
|
|
786
1053
|
for thread in conn.store["threads"]:
|
|
787
1054
|
if thread["thread_id"] == thread_id:
|
|
@@ -805,22 +1072,26 @@ class Threads:
|
|
|
805
1072
|
limit: int = 10,
|
|
806
1073
|
before: str | Checkpoint | None = None,
|
|
807
1074
|
metadata: MetadataInput = None,
|
|
1075
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
808
1076
|
) -> list[StateSnapshot]:
|
|
809
1077
|
"""Get the history of a thread."""
|
|
810
1078
|
|
|
811
1079
|
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
812
1080
|
thread = None
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
1081
|
+
filters = await Threads.handle_event(
|
|
1082
|
+
ctx,
|
|
1083
|
+
"read",
|
|
1084
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1085
|
+
)
|
|
1086
|
+
thread = await fetchone(
|
|
1087
|
+
await Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx)
|
|
1088
|
+
)
|
|
821
1089
|
|
|
822
1090
|
# Parse thread metadata and config
|
|
823
1091
|
thread_metadata = thread["metadata"]
|
|
1092
|
+
if not _check_filter_match(thread_metadata, filters):
|
|
1093
|
+
return []
|
|
1094
|
+
|
|
824
1095
|
thread_config = thread["config"]
|
|
825
1096
|
# If graph_id exists, get state history
|
|
826
1097
|
if graph_id := thread_metadata.get("graph_id"):
|
|
@@ -847,7 +1118,9 @@ class Threads:
|
|
|
847
1118
|
return []
|
|
848
1119
|
|
|
849
1120
|
|
|
850
|
-
class Runs:
|
|
1121
|
+
class Runs(Authenticated):
|
|
1122
|
+
resource = "threads"
|
|
1123
|
+
|
|
851
1124
|
@staticmethod
|
|
852
1125
|
async def stats(conn: InMemConnectionProto) -> QueueStats:
|
|
853
1126
|
"""Get stats about the queue."""
|
|
@@ -1001,6 +1274,7 @@ class Runs:
|
|
|
1001
1274
|
multitask_strategy: MultitaskStrategy = "reject",
|
|
1002
1275
|
if_not_exists: IfNotExists = "reject",
|
|
1003
1276
|
after_seconds: int = 0,
|
|
1277
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1004
1278
|
) -> AsyncIterator[Run]:
|
|
1005
1279
|
"""Create a run."""
|
|
1006
1280
|
assistant_id = _ensure_uuid(assistant_id)
|
|
@@ -1009,22 +1283,38 @@ class Runs:
|
|
|
1009
1283
|
None,
|
|
1010
1284
|
)
|
|
1011
1285
|
|
|
1012
|
-
async def empty_generator():
|
|
1013
|
-
if False:
|
|
1014
|
-
yield
|
|
1015
|
-
|
|
1016
1286
|
if not assistant:
|
|
1017
|
-
return
|
|
1287
|
+
return _empty_generator()
|
|
1018
1288
|
|
|
1019
1289
|
thread_id = _ensure_uuid(thread_id) if thread_id else None
|
|
1020
1290
|
run_id = _ensure_uuid(run_id) if run_id else None
|
|
1021
|
-
metadata = metadata
|
|
1291
|
+
metadata = metadata if metadata is not None else {}
|
|
1022
1292
|
config = kwargs.get("config", {})
|
|
1023
1293
|
|
|
1024
1294
|
# Handle thread creation/update
|
|
1025
1295
|
existing_thread = next(
|
|
1026
1296
|
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
1027
1297
|
)
|
|
1298
|
+
filters = await Runs.handle_event(
|
|
1299
|
+
ctx,
|
|
1300
|
+
"create",
|
|
1301
|
+
Auth.types.RunsCreate(
|
|
1302
|
+
thread_id=thread_id,
|
|
1303
|
+
assistant_id=assistant_id,
|
|
1304
|
+
run_id=run_id,
|
|
1305
|
+
status=status,
|
|
1306
|
+
metadata=metadata,
|
|
1307
|
+
prevent_insert_if_inflight=prevent_insert_if_inflight,
|
|
1308
|
+
multitask_strategy=multitask_strategy,
|
|
1309
|
+
if_not_exists=if_not_exists,
|
|
1310
|
+
after_seconds=after_seconds,
|
|
1311
|
+
kwargs=kwargs,
|
|
1312
|
+
),
|
|
1313
|
+
)
|
|
1314
|
+
if existing_thread and filters:
|
|
1315
|
+
# Reject if the user doesn't own the thread
|
|
1316
|
+
if not _check_filter_match(existing_thread["metadata"], filters):
|
|
1317
|
+
return _empty_generator()
|
|
1028
1318
|
|
|
1029
1319
|
if not existing_thread and (thread_id is None or if_not_exists == "create"):
|
|
1030
1320
|
# Create new thread
|
|
@@ -1069,7 +1359,7 @@ class Runs:
|
|
|
1069
1359
|
)
|
|
1070
1360
|
existing_thread["updated_at"] = datetime.now(UTC)
|
|
1071
1361
|
else:
|
|
1072
|
-
return
|
|
1362
|
+
return _empty_generator()
|
|
1073
1363
|
|
|
1074
1364
|
# Check for inflight runs if needed
|
|
1075
1365
|
inflight_runs = [
|
|
@@ -1089,9 +1379,11 @@ class Runs:
|
|
|
1089
1379
|
# Create new run
|
|
1090
1380
|
configurable = Runs._merge_jsonb(
|
|
1091
1381
|
Runs._get_configurable(assistant["config"]),
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1382
|
+
(
|
|
1383
|
+
Runs._get_configurable(existing_thread["config"])
|
|
1384
|
+
if existing_thread
|
|
1385
|
+
else {}
|
|
1386
|
+
),
|
|
1095
1387
|
Runs._get_configurable(config),
|
|
1096
1388
|
{
|
|
1097
1389
|
"run_id": str(run_id),
|
|
@@ -1149,11 +1441,20 @@ class Runs:
|
|
|
1149
1441
|
|
|
1150
1442
|
@staticmethod
|
|
1151
1443
|
async def get(
|
|
1152
|
-
conn: InMemConnectionProto,
|
|
1444
|
+
conn: InMemConnectionProto,
|
|
1445
|
+
run_id: UUID,
|
|
1446
|
+
*,
|
|
1447
|
+
thread_id: UUID,
|
|
1448
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1153
1449
|
) -> AsyncIterator[Run]:
|
|
1154
1450
|
"""Get a run by ID."""
|
|
1155
1451
|
|
|
1156
1452
|
run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
|
|
1453
|
+
filters = await Runs.handle_event(
|
|
1454
|
+
ctx,
|
|
1455
|
+
"read",
|
|
1456
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1457
|
+
)
|
|
1157
1458
|
|
|
1158
1459
|
async def _yield_result():
|
|
1159
1460
|
matching_run = None
|
|
@@ -1162,16 +1463,36 @@ class Runs:
|
|
|
1162
1463
|
matching_run = run
|
|
1163
1464
|
break
|
|
1164
1465
|
if matching_run:
|
|
1466
|
+
if filters:
|
|
1467
|
+
thread = await Threads._get_with_filters(
|
|
1468
|
+
conn, matching_run["thread_id"], filters
|
|
1469
|
+
)
|
|
1470
|
+
if not thread:
|
|
1471
|
+
return
|
|
1165
1472
|
yield matching_run
|
|
1166
1473
|
|
|
1167
1474
|
return _yield_result()
|
|
1168
1475
|
|
|
1169
1476
|
@staticmethod
|
|
1170
1477
|
async def delete(
|
|
1171
|
-
conn: InMemConnectionProto,
|
|
1478
|
+
conn: InMemConnectionProto,
|
|
1479
|
+
run_id: UUID,
|
|
1480
|
+
*,
|
|
1481
|
+
thread_id: UUID,
|
|
1482
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1172
1483
|
) -> AsyncIterator[UUID]:
|
|
1173
1484
|
"""Delete a run by ID."""
|
|
1174
1485
|
run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
|
|
1486
|
+
filters = await Runs.handle_event(
|
|
1487
|
+
ctx,
|
|
1488
|
+
"delete",
|
|
1489
|
+
Auth.types.ThreadsDelete(run_id=run_id, thread_id=thread_id),
|
|
1490
|
+
)
|
|
1491
|
+
|
|
1492
|
+
if filters:
|
|
1493
|
+
thread = await Threads._get_with_filters(conn, thread_id, filters)
|
|
1494
|
+
if not thread:
|
|
1495
|
+
return _empty_generator()
|
|
1175
1496
|
_delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
|
|
1176
1497
|
found = False
|
|
1177
1498
|
for i, run in enumerate(conn.store["runs"]):
|
|
@@ -1192,16 +1513,22 @@ class Runs:
|
|
|
1192
1513
|
run_id: UUID,
|
|
1193
1514
|
*,
|
|
1194
1515
|
thread_id: UUID,
|
|
1516
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1195
1517
|
) -> Fragment:
|
|
1196
1518
|
"""Wait for a run to complete. If already done, return immediately.
|
|
1197
1519
|
|
|
1198
1520
|
Returns:
|
|
1199
1521
|
the final state of the run.
|
|
1200
1522
|
"""
|
|
1523
|
+
async with connect() as conn:
|
|
1524
|
+
# Validate ownership
|
|
1525
|
+
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1526
|
+
await fetchone(thread_iter)
|
|
1201
1527
|
last_chunk: bytes | None = None
|
|
1202
1528
|
# wait for the run to complete
|
|
1529
|
+
# Rely on this join's auth
|
|
1203
1530
|
async for mode, chunk in Runs.Stream.join(
|
|
1204
|
-
run_id, thread_id=thread_id, stream_mode="values"
|
|
1531
|
+
run_id, thread_id=thread_id, stream_mode="values", ctx=ctx
|
|
1205
1532
|
):
|
|
1206
1533
|
if mode == b"values":
|
|
1207
1534
|
last_chunk = chunk
|
|
@@ -1212,7 +1539,7 @@ class Runs:
|
|
|
1212
1539
|
else:
|
|
1213
1540
|
# otherwise, the run had already finished, so fetch the state from thread
|
|
1214
1541
|
async with connect() as conn:
|
|
1215
|
-
thread_iter = await Threads.get(conn, thread_id)
|
|
1542
|
+
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1216
1543
|
thread = await fetchone(thread_iter)
|
|
1217
1544
|
return thread["values"]
|
|
1218
1545
|
|
|
@@ -1223,8 +1550,10 @@ class Runs:
|
|
|
1223
1550
|
*,
|
|
1224
1551
|
action: Literal["interrupt", "rollback"] = "interrupt",
|
|
1225
1552
|
thread_id: UUID,
|
|
1553
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1226
1554
|
) -> None:
|
|
1227
1555
|
"""Cancel a run."""
|
|
1556
|
+
# Authwise, this invokes the runs.update handler
|
|
1228
1557
|
# Cancellation tries to take two actions, to cover runs in different states:
|
|
1229
1558
|
# - For any run, send a cancellation message through the stream manager
|
|
1230
1559
|
# - For queued runs not yet picked up by a worker, update their status if interrupt,
|
|
@@ -1233,6 +1562,15 @@ class Runs:
|
|
|
1233
1562
|
# - For runs in any other state, we raise a 404
|
|
1234
1563
|
run_ids = [_ensure_uuid(run_id) for run_id in run_ids]
|
|
1235
1564
|
thread_id = _ensure_uuid(thread_id)
|
|
1565
|
+
filters = await Runs.handle_event(
|
|
1566
|
+
ctx,
|
|
1567
|
+
"update",
|
|
1568
|
+
Auth.types.ThreadsUpdate(
|
|
1569
|
+
thread_id=thread_id,
|
|
1570
|
+
action=action,
|
|
1571
|
+
metadata={"run_ids": run_ids},
|
|
1572
|
+
),
|
|
1573
|
+
)
|
|
1236
1574
|
|
|
1237
1575
|
stream_manager = get_stream_manager()
|
|
1238
1576
|
found_runs = []
|
|
@@ -1247,6 +1585,10 @@ class Runs:
|
|
|
1247
1585
|
None,
|
|
1248
1586
|
)
|
|
1249
1587
|
if run:
|
|
1588
|
+
if filters:
|
|
1589
|
+
thread = await Threads._get_with_filters(conn, thread_id, filters)
|
|
1590
|
+
if not thread:
|
|
1591
|
+
continue
|
|
1250
1592
|
found_runs.append(run)
|
|
1251
1593
|
# Send cancellation message through stream manager
|
|
1252
1594
|
control_message = Message(
|
|
@@ -1296,16 +1638,26 @@ class Runs:
|
|
|
1296
1638
|
offset: int = 0,
|
|
1297
1639
|
metadata: MetadataInput,
|
|
1298
1640
|
status: RunStatus | None = None,
|
|
1641
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1299
1642
|
) -> AsyncIterator[Run]:
|
|
1300
1643
|
"""List all runs by thread."""
|
|
1301
1644
|
runs = conn.store["runs"]
|
|
1302
|
-
metadata = metadata
|
|
1645
|
+
metadata = metadata if metadata is not None else {}
|
|
1303
1646
|
thread_id = _ensure_uuid(thread_id)
|
|
1647
|
+
filters = await Runs.handle_event(
|
|
1648
|
+
ctx,
|
|
1649
|
+
"search",
|
|
1650
|
+
Auth.types.ThreadsSearch(thread_id=thread_id, metadata=metadata),
|
|
1651
|
+
)
|
|
1304
1652
|
filtered_runs = [
|
|
1305
1653
|
run
|
|
1306
1654
|
for run in runs
|
|
1307
1655
|
if run["thread_id"] == thread_id
|
|
1308
1656
|
and is_jsonb_contained(run["metadata"], metadata)
|
|
1657
|
+
and (
|
|
1658
|
+
not filters
|
|
1659
|
+
or (await Threads._get_with_filters(conn, thread_id, filters))
|
|
1660
|
+
)
|
|
1309
1661
|
and (status is None or run["status"] == status)
|
|
1310
1662
|
]
|
|
1311
1663
|
sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
|
|
@@ -1358,6 +1710,7 @@ class Runs:
|
|
|
1358
1710
|
ignore_404: bool = False,
|
|
1359
1711
|
cancel_on_disconnect: bool = False,
|
|
1360
1712
|
stream_mode: "StreamMode | asyncio.Queue | None" = None,
|
|
1713
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1361
1714
|
) -> AsyncIterator[tuple[bytes, bytes]]:
|
|
1362
1715
|
"""Stream the run output."""
|
|
1363
1716
|
log = logger.isEnabledFor(logging.DEBUG)
|
|
@@ -1369,6 +1722,21 @@ class Runs:
|
|
|
1369
1722
|
|
|
1370
1723
|
try:
|
|
1371
1724
|
async with connect() as conn:
|
|
1725
|
+
filters = await Runs.handle_event(
|
|
1726
|
+
ctx,
|
|
1727
|
+
"read",
|
|
1728
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1729
|
+
)
|
|
1730
|
+
if filters:
|
|
1731
|
+
thread = await Threads._get_with_filters(
|
|
1732
|
+
cast(InMemConnectionProto, conn), thread_id, filters
|
|
1733
|
+
)
|
|
1734
|
+
if not thread:
|
|
1735
|
+
raise WrappedHTTPException(
|
|
1736
|
+
HTTPException(
|
|
1737
|
+
status_code=404, detail="Thread not found"
|
|
1738
|
+
)
|
|
1739
|
+
)
|
|
1372
1740
|
channel_prefix = f"run:{run_id}:stream:"
|
|
1373
1741
|
len_prefix = len(channel_prefix.encode())
|
|
1374
1742
|
|
|
@@ -1393,7 +1761,9 @@ class Runs:
|
|
|
1393
1761
|
)
|
|
1394
1762
|
except TimeoutError:
|
|
1395
1763
|
# Check if the run is still pending
|
|
1396
|
-
run_iter = await Runs.get(
|
|
1764
|
+
run_iter = await Runs.get(
|
|
1765
|
+
conn, run_id, thread_id=thread_id, ctx=ctx
|
|
1766
|
+
)
|
|
1397
1767
|
run = await anext(run_iter, None)
|
|
1398
1768
|
|
|
1399
1769
|
if ignore_404 and run is None:
|
|
@@ -1408,6 +1778,8 @@ class Runs:
|
|
|
1408
1778
|
break
|
|
1409
1779
|
elif run["status"] != "pending":
|
|
1410
1780
|
break
|
|
1781
|
+
except WrappedHTTPException as e:
|
|
1782
|
+
raise e.http_exception from None
|
|
1411
1783
|
except:
|
|
1412
1784
|
if cancel_on_disconnect:
|
|
1413
1785
|
create_task(cancel_run(thread_id, run_id))
|
|
@@ -1475,22 +1847,32 @@ class Crons:
|
|
|
1475
1847
|
schedule: str,
|
|
1476
1848
|
cron_id: UUID | None = None,
|
|
1477
1849
|
thread_id: UUID | None = None,
|
|
1478
|
-
user_id: str | None = None,
|
|
1479
1850
|
end_time: datetime | None = None,
|
|
1851
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1480
1852
|
) -> AsyncIterator[Cron]:
|
|
1481
1853
|
raise NotImplementedError
|
|
1482
1854
|
|
|
1483
1855
|
@staticmethod
|
|
1484
|
-
async def delete(
|
|
1856
|
+
async def delete(
|
|
1857
|
+
conn: InMemConnectionProto,
|
|
1858
|
+
cron_id: UUID,
|
|
1859
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1860
|
+
) -> AsyncIterator[UUID]:
|
|
1485
1861
|
raise NotImplementedError
|
|
1486
1862
|
|
|
1487
1863
|
@staticmethod
|
|
1488
|
-
async def next(
|
|
1864
|
+
async def next(
|
|
1865
|
+
conn: InMemConnectionProto,
|
|
1866
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1867
|
+
) -> AsyncIterator[Cron]:
|
|
1489
1868
|
raise NotImplementedError
|
|
1490
1869
|
|
|
1491
1870
|
@staticmethod
|
|
1492
1871
|
async def set_next_run_date(
|
|
1493
|
-
conn: InMemConnectionProto,
|
|
1872
|
+
conn: InMemConnectionProto,
|
|
1873
|
+
cron_id: UUID,
|
|
1874
|
+
next_run_date: datetime,
|
|
1875
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1494
1876
|
) -> None:
|
|
1495
1877
|
raise NotImplementedError
|
|
1496
1878
|
|
|
@@ -1502,13 +1884,16 @@ class Crons:
|
|
|
1502
1884
|
thread_id: UUID | None,
|
|
1503
1885
|
limit: int,
|
|
1504
1886
|
offset: int,
|
|
1887
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1505
1888
|
) -> AsyncIterator[Cron]:
|
|
1506
1889
|
raise NotImplementedError
|
|
1507
1890
|
|
|
1508
1891
|
|
|
1509
|
-
async def cancel_run(
|
|
1892
|
+
async def cancel_run(
|
|
1893
|
+
thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
|
|
1894
|
+
) -> None:
|
|
1510
1895
|
async with connect() as conn:
|
|
1511
|
-
await Runs.cancel(conn, [run_id], thread_id=thread_id)
|
|
1896
|
+
await Runs.cancel(conn, [run_id], thread_id=thread_id, ctx=ctx)
|
|
1512
1897
|
|
|
1513
1898
|
|
|
1514
1899
|
def _delete_checkpoints_for_thread(
|
|
@@ -1538,6 +1923,47 @@ def _delete_checkpoints_for_thread(
|
|
|
1538
1923
|
)
|
|
1539
1924
|
|
|
1540
1925
|
|
|
1926
|
+
def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -> bool:
|
|
1927
|
+
"""Check if metadata matches the filter conditions.
|
|
1928
|
+
|
|
1929
|
+
Args:
|
|
1930
|
+
metadata: The metadata to check
|
|
1931
|
+
filters: The filter conditions to apply
|
|
1932
|
+
|
|
1933
|
+
Returns:
|
|
1934
|
+
True if the metadata matches all filter conditions, False otherwise
|
|
1935
|
+
"""
|
|
1936
|
+
if not filters:
|
|
1937
|
+
return True
|
|
1938
|
+
|
|
1939
|
+
for key, value in filters.items():
|
|
1940
|
+
if isinstance(value, dict):
|
|
1941
|
+
op = next(iter(value))
|
|
1942
|
+
filter_value = value[op]
|
|
1943
|
+
|
|
1944
|
+
if op == "$eq":
|
|
1945
|
+
if key not in metadata or metadata[key] != filter_value:
|
|
1946
|
+
return False
|
|
1947
|
+
elif op == "$contains":
|
|
1948
|
+
if (
|
|
1949
|
+
key not in metadata
|
|
1950
|
+
or not isinstance(metadata[key], list)
|
|
1951
|
+
or filter_value not in metadata[key]
|
|
1952
|
+
):
|
|
1953
|
+
return False
|
|
1954
|
+
else:
|
|
1955
|
+
# Direct equality
|
|
1956
|
+
if key not in metadata or metadata[key] != value:
|
|
1957
|
+
return False
|
|
1958
|
+
|
|
1959
|
+
return True
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
async def _empty_generator():
|
|
1963
|
+
if False:
|
|
1964
|
+
yield
|
|
1965
|
+
|
|
1966
|
+
|
|
1541
1967
|
__all__ = [
|
|
1542
1968
|
"Assistants",
|
|
1543
1969
|
"Crons",
|