langgraph-api 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- langgraph_api/api/openapi.py +38 -1
- langgraph_api/api/runs.py +70 -36
- langgraph_api/auth/custom.py +535 -0
- langgraph_api/auth/middleware.py +12 -3
- langgraph_api/cli.py +47 -1
- langgraph_api/config.py +24 -0
- langgraph_api/cron_scheduler.py +32 -27
- langgraph_api/graph.py +0 -11
- langgraph_api/models/run.py +77 -19
- langgraph_api/route.py +2 -0
- langgraph_api/utils.py +32 -0
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.10.dist-info}/METADATA +2 -2
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.10.dist-info}/RECORD +19 -18
- langgraph_storage/checkpoint.py +16 -0
- langgraph_storage/database.py +17 -1
- langgraph_storage/ops.py +494 -69
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.10.dist-info}/LICENSE +0 -0
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.10.dist-info}/WHEEL +0 -0
- {langgraph_api-0.0.8.dist-info → langgraph_api-0.0.10.dist-info}/entry_points.txt +0 -0
langgraph_storage/ops.py
CHANGED
|
@@ -11,15 +11,17 @@ from collections.abc import AsyncIterator, Sequence
|
|
|
11
11
|
from contextlib import asynccontextmanager
|
|
12
12
|
from copy import deepcopy
|
|
13
13
|
from datetime import UTC, datetime, timedelta
|
|
14
|
-
from typing import Any, Literal
|
|
14
|
+
from typing import Any, Literal, cast
|
|
15
15
|
from uuid import UUID, uuid4
|
|
16
16
|
|
|
17
17
|
import structlog
|
|
18
18
|
from langgraph.pregel.debug import CheckpointPayload
|
|
19
19
|
from langgraph.pregel.types import StateSnapshot
|
|
20
|
+
from langgraph_sdk import Auth
|
|
20
21
|
from starlette.exceptions import HTTPException
|
|
21
22
|
|
|
22
23
|
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
|
|
24
|
+
from langgraph_api.auth.custom import handle_event
|
|
23
25
|
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
24
26
|
from langgraph_api.graph import get_graph
|
|
25
27
|
from langgraph_api.schema import (
|
|
@@ -41,7 +43,7 @@ from langgraph_api.schema import (
|
|
|
41
43
|
ThreadUpdateResponse,
|
|
42
44
|
)
|
|
43
45
|
from langgraph_api.serde import Fragment
|
|
44
|
-
from langgraph_api.utils import fetchone
|
|
46
|
+
from langgraph_api.utils import fetchone, get_auth_ctx
|
|
45
47
|
from langgraph_storage.checkpoint import Checkpointer
|
|
46
48
|
from langgraph_storage.database import InMemConnectionProto, connect
|
|
47
49
|
from langgraph_storage.queue import Message, get_stream_manager
|
|
@@ -57,13 +59,51 @@ def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
|
|
|
57
59
|
return id_
|
|
58
60
|
|
|
59
61
|
|
|
62
|
+
class WrappedHTTPException(Exception):
|
|
63
|
+
def __init__(self, http_exception: HTTPException):
|
|
64
|
+
self.http_exception = http_exception
|
|
65
|
+
|
|
66
|
+
|
|
60
67
|
# Right now the whole API types as UUID but frequently passes a str
|
|
61
68
|
# We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
|
|
62
69
|
# which we leave as strings. This is because I'm too lazy to subclass fully
|
|
63
70
|
# and we use non-UUID examples in the OSS version
|
|
64
71
|
|
|
65
72
|
|
|
66
|
-
class
|
|
73
|
+
class Authenticated:
|
|
74
|
+
resource: Literal["threads", "crons", "assistants"]
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def _context(
|
|
78
|
+
cls,
|
|
79
|
+
ctx: Auth.types.BaseAuthContext | None,
|
|
80
|
+
action: Literal["create", "read", "update", "delete", "create_run"],
|
|
81
|
+
) -> Auth.types.AuthContext | None:
|
|
82
|
+
if not ctx:
|
|
83
|
+
return
|
|
84
|
+
return Auth.types.AuthContext(
|
|
85
|
+
user=ctx.user,
|
|
86
|
+
scopes=ctx.scopes,
|
|
87
|
+
resource=cls.resource,
|
|
88
|
+
action=action,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
async def handle_event(
|
|
93
|
+
cls,
|
|
94
|
+
ctx: Auth.types.BaseAuthContext | None,
|
|
95
|
+
action: Literal["create", "read", "update", "delete", "search"],
|
|
96
|
+
value: Any,
|
|
97
|
+
) -> Auth.types.FilterType | None:
|
|
98
|
+
ctx = ctx or get_auth_ctx()
|
|
99
|
+
if not ctx:
|
|
100
|
+
return
|
|
101
|
+
return await handle_event(cls._context(ctx, action), value)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Assistants(Authenticated):
|
|
105
|
+
resource = "assistants"
|
|
106
|
+
|
|
67
107
|
@staticmethod
|
|
68
108
|
async def search(
|
|
69
109
|
conn: InMemConnectionProto,
|
|
@@ -72,7 +112,17 @@ class Assistants:
|
|
|
72
112
|
metadata: MetadataInput,
|
|
73
113
|
limit: int,
|
|
74
114
|
offset: int,
|
|
115
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
75
116
|
) -> AsyncIterator[Assistant]:
|
|
117
|
+
metadata = metadata if metadata is not None else {}
|
|
118
|
+
filters = await Assistants.handle_event(
|
|
119
|
+
ctx,
|
|
120
|
+
"search",
|
|
121
|
+
Auth.types.AssistantsSearch(
|
|
122
|
+
graph_id=graph_id, metadata=metadata, limit=limit, offset=offset
|
|
123
|
+
),
|
|
124
|
+
)
|
|
125
|
+
|
|
76
126
|
async def filter_and_yield() -> AsyncIterator[Assistant]:
|
|
77
127
|
assistants = conn.store["assistants"]
|
|
78
128
|
filtered_assistants = [
|
|
@@ -82,6 +132,7 @@ class Assistants:
|
|
|
82
132
|
and (
|
|
83
133
|
not metadata or is_jsonb_contained(assistant["metadata"], metadata)
|
|
84
134
|
)
|
|
135
|
+
and (not filters or _check_filter_match(assistant["metadata"], filters))
|
|
85
136
|
]
|
|
86
137
|
filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
|
|
87
138
|
for assistant in filtered_assistants[offset : offset + limit]:
|
|
@@ -91,14 +142,23 @@ class Assistants:
|
|
|
91
142
|
|
|
92
143
|
@staticmethod
|
|
93
144
|
async def get(
|
|
94
|
-
conn: InMemConnectionProto,
|
|
145
|
+
conn: InMemConnectionProto,
|
|
146
|
+
assistant_id: UUID,
|
|
147
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
95
148
|
) -> AsyncIterator[Assistant]:
|
|
96
149
|
"""Get an assistant by ID."""
|
|
97
150
|
assistant_id = _ensure_uuid(assistant_id)
|
|
151
|
+
filters = await Assistants.handle_event(
|
|
152
|
+
ctx,
|
|
153
|
+
"read",
|
|
154
|
+
Auth.types.AssistantsRead(assistant_id=assistant_id),
|
|
155
|
+
)
|
|
98
156
|
|
|
99
157
|
async def _yield_result():
|
|
100
158
|
for assistant in conn.store["assistants"]:
|
|
101
|
-
if assistant["assistant_id"] == assistant_id
|
|
159
|
+
if assistant["assistant_id"] == assistant_id and (
|
|
160
|
+
not filters or _check_filter_match(assistant["metadata"], filters)
|
|
161
|
+
):
|
|
102
162
|
yield assistant
|
|
103
163
|
|
|
104
164
|
return _yield_result()
|
|
@@ -113,14 +173,33 @@ class Assistants:
|
|
|
113
173
|
metadata: MetadataInput,
|
|
114
174
|
if_exists: OnConflictBehavior,
|
|
115
175
|
name: str,
|
|
176
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
116
177
|
) -> AsyncIterator[Assistant]:
|
|
117
178
|
"""Insert an assistant."""
|
|
118
179
|
assistant_id = _ensure_uuid(assistant_id)
|
|
180
|
+
metadata = metadata if metadata is not None else {}
|
|
181
|
+
filters = await Assistants.handle_event(
|
|
182
|
+
ctx,
|
|
183
|
+
"create",
|
|
184
|
+
Auth.types.AssistantsCreate(
|
|
185
|
+
assistant_id=assistant_id,
|
|
186
|
+
graph_id=graph_id,
|
|
187
|
+
config=config,
|
|
188
|
+
metadata=metadata,
|
|
189
|
+
name=name,
|
|
190
|
+
),
|
|
191
|
+
)
|
|
119
192
|
existing_assistant = next(
|
|
120
193
|
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
121
194
|
None,
|
|
122
195
|
)
|
|
123
196
|
if existing_assistant:
|
|
197
|
+
if filters and not _check_filter_match(
|
|
198
|
+
existing_assistant["metadata"], filters
|
|
199
|
+
):
|
|
200
|
+
raise HTTPException(
|
|
201
|
+
status_code=409, detail=f"Assistant {assistant_id} already exists"
|
|
202
|
+
)
|
|
124
203
|
if if_exists == "raise":
|
|
125
204
|
raise HTTPException(
|
|
126
205
|
status_code=409, detail=f"Assistant {assistant_id} already exists"
|
|
@@ -168,6 +247,7 @@ class Assistants:
|
|
|
168
247
|
graph_id: str | None = None,
|
|
169
248
|
metadata: MetadataInput | None = None,
|
|
170
249
|
name: str | None = None,
|
|
250
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
171
251
|
) -> AsyncIterator[Assistant]:
|
|
172
252
|
"""Update an assistant.
|
|
173
253
|
|
|
@@ -182,6 +262,18 @@ class Assistants:
|
|
|
182
262
|
return the updated assistant model.
|
|
183
263
|
"""
|
|
184
264
|
assistant_id = _ensure_uuid(assistant_id)
|
|
265
|
+
metadata = metadata if metadata is not None else {}
|
|
266
|
+
filters = await Assistants.handle_event(
|
|
267
|
+
ctx,
|
|
268
|
+
"update",
|
|
269
|
+
Auth.types.AssistantsUpdate(
|
|
270
|
+
assistant_id=assistant_id,
|
|
271
|
+
graph_id=graph_id,
|
|
272
|
+
config=config,
|
|
273
|
+
metadata=metadata,
|
|
274
|
+
name=name,
|
|
275
|
+
),
|
|
276
|
+
)
|
|
185
277
|
assistant = next(
|
|
186
278
|
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
187
279
|
None,
|
|
@@ -190,6 +282,10 @@ class Assistants:
|
|
|
190
282
|
raise HTTPException(
|
|
191
283
|
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
192
284
|
)
|
|
285
|
+
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
286
|
+
raise HTTPException(
|
|
287
|
+
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
288
|
+
)
|
|
193
289
|
|
|
194
290
|
now = datetime.now(UTC)
|
|
195
291
|
new_version = (
|
|
@@ -204,6 +300,11 @@ class Assistants:
|
|
|
204
300
|
)
|
|
205
301
|
|
|
206
302
|
# Update assistant_versions table
|
|
303
|
+
if metadata:
|
|
304
|
+
metadata = {
|
|
305
|
+
**assistant["metadata"],
|
|
306
|
+
**metadata,
|
|
307
|
+
}
|
|
207
308
|
new_version_entry = {
|
|
208
309
|
"assistant_id": assistant_id,
|
|
209
310
|
"version": new_version,
|
|
@@ -233,10 +334,33 @@ class Assistants:
|
|
|
233
334
|
|
|
234
335
|
@staticmethod
|
|
235
336
|
async def delete(
|
|
236
|
-
conn: InMemConnectionProto,
|
|
337
|
+
conn: InMemConnectionProto,
|
|
338
|
+
assistant_id: UUID,
|
|
339
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
237
340
|
) -> AsyncIterator[UUID]:
|
|
238
341
|
"""Delete an assistant by ID."""
|
|
239
342
|
assistant_id = _ensure_uuid(assistant_id)
|
|
343
|
+
filters = await Assistants.handle_event(
|
|
344
|
+
ctx,
|
|
345
|
+
"delete",
|
|
346
|
+
Auth.types.AssistantsDelete(
|
|
347
|
+
assistant_id=assistant_id,
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
assistant = next(
|
|
351
|
+
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
352
|
+
None,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
if not assistant:
|
|
356
|
+
raise HTTPException(
|
|
357
|
+
status_code=404, detail=f"Assistant with ID {assistant_id} not found"
|
|
358
|
+
)
|
|
359
|
+
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
360
|
+
raise HTTPException(
|
|
361
|
+
status_code=404, detail=f"Assistant with ID {assistant_id} not found"
|
|
362
|
+
)
|
|
363
|
+
|
|
240
364
|
conn.store["assistants"] = [
|
|
241
365
|
a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
|
|
242
366
|
]
|
|
@@ -249,9 +373,10 @@ class Assistants:
|
|
|
249
373
|
retained = []
|
|
250
374
|
for run in conn.store["runs"]:
|
|
251
375
|
if run["assistant_id"] == assistant_id:
|
|
252
|
-
res = await Runs.delete(
|
|
376
|
+
res = await Runs.delete(
|
|
377
|
+
conn, run["run_id"], thread_id=run["thread_id"], ctx=ctx
|
|
378
|
+
)
|
|
253
379
|
await anext(res)
|
|
254
|
-
|
|
255
380
|
else:
|
|
256
381
|
retained.append(run)
|
|
257
382
|
|
|
@@ -262,10 +387,21 @@ class Assistants:
|
|
|
262
387
|
|
|
263
388
|
@staticmethod
|
|
264
389
|
async def set_latest(
|
|
265
|
-
conn: InMemConnectionProto,
|
|
390
|
+
conn: InMemConnectionProto,
|
|
391
|
+
assistant_id: UUID,
|
|
392
|
+
version: int,
|
|
393
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
266
394
|
) -> AsyncIterator[Assistant]:
|
|
267
395
|
"""Change the version of an assistant."""
|
|
268
396
|
assistant_id = _ensure_uuid(assistant_id)
|
|
397
|
+
filters = await Assistants.handle_event(
|
|
398
|
+
ctx,
|
|
399
|
+
"update",
|
|
400
|
+
Auth.types.AssistantsUpdate(
|
|
401
|
+
assistant_id=assistant_id,
|
|
402
|
+
version=version,
|
|
403
|
+
),
|
|
404
|
+
)
|
|
269
405
|
assistant = next(
|
|
270
406
|
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
271
407
|
None,
|
|
@@ -274,6 +410,10 @@ class Assistants:
|
|
|
274
410
|
raise HTTPException(
|
|
275
411
|
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
276
412
|
)
|
|
413
|
+
elif filters and not _check_filter_match(assistant["metadata"], filters):
|
|
414
|
+
raise HTTPException(
|
|
415
|
+
status_code=404, detail=f"Assistant {assistant_id} not found"
|
|
416
|
+
)
|
|
277
417
|
|
|
278
418
|
version_data = next(
|
|
279
419
|
(
|
|
@@ -310,14 +450,21 @@ class Assistants:
|
|
|
310
450
|
metadata: MetadataInput,
|
|
311
451
|
limit: int,
|
|
312
452
|
offset: int,
|
|
453
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
313
454
|
) -> AsyncIterator[Assistant]:
|
|
314
455
|
"""Get all versions of an assistant."""
|
|
315
456
|
assistant_id = _ensure_uuid(assistant_id)
|
|
457
|
+
filters = await Assistants.handle_event(
|
|
458
|
+
ctx,
|
|
459
|
+
"read",
|
|
460
|
+
Auth.types.AssistantsRead(assistant_id=assistant_id),
|
|
461
|
+
)
|
|
316
462
|
versions = [
|
|
317
463
|
v
|
|
318
464
|
for v in conn.store["assistant_versions"]
|
|
319
465
|
if v["assistant_id"] == assistant_id
|
|
320
466
|
and (not metadata or is_jsonb_contained(v["metadata"], metadata))
|
|
467
|
+
and (not filters or _check_filter_match(v["metadata"], filters))
|
|
321
468
|
]
|
|
322
469
|
versions.sort(key=lambda x: x["version"], reverse=True)
|
|
323
470
|
|
|
@@ -379,7 +526,9 @@ def _replace_thread_id(data, new_thread_id, thread_id):
|
|
|
379
526
|
return d
|
|
380
527
|
|
|
381
528
|
|
|
382
|
-
class Threads:
|
|
529
|
+
class Threads(Authenticated):
|
|
530
|
+
resource = "threads"
|
|
531
|
+
|
|
383
532
|
@staticmethod
|
|
384
533
|
async def search(
|
|
385
534
|
conn: InMemConnectionProto,
|
|
@@ -389,29 +538,43 @@ class Threads:
|
|
|
389
538
|
status: ThreadStatus | None,
|
|
390
539
|
limit: int,
|
|
391
540
|
offset: int,
|
|
541
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
392
542
|
) -> AsyncIterator[Thread]:
|
|
393
543
|
threads = conn.store["threads"]
|
|
394
544
|
filtered_threads: list[Thread] = []
|
|
545
|
+
metadata = metadata if metadata is not None else {}
|
|
546
|
+
values = values if values is not None else {}
|
|
547
|
+
filters = await Threads.handle_event(
|
|
548
|
+
ctx,
|
|
549
|
+
"search",
|
|
550
|
+
Auth.types.ThreadsSearch(
|
|
551
|
+
metadata=metadata,
|
|
552
|
+
values=values,
|
|
553
|
+
status=status,
|
|
554
|
+
limit=limit,
|
|
555
|
+
offset=offset,
|
|
556
|
+
),
|
|
557
|
+
)
|
|
395
558
|
|
|
396
559
|
# Apply filters
|
|
397
560
|
for thread in threads:
|
|
398
|
-
|
|
561
|
+
if filters and not _check_filter_match(thread["metadata"], filters):
|
|
562
|
+
continue
|
|
399
563
|
|
|
400
564
|
if metadata and not is_jsonb_contained(thread["metadata"], metadata):
|
|
401
|
-
|
|
565
|
+
continue
|
|
402
566
|
|
|
403
567
|
if (
|
|
404
568
|
values
|
|
405
569
|
and "values" in thread
|
|
406
570
|
and not is_jsonb_contained(thread["values"], values)
|
|
407
571
|
):
|
|
408
|
-
|
|
572
|
+
continue
|
|
409
573
|
|
|
410
574
|
if status and thread.get("status") != status:
|
|
411
|
-
|
|
575
|
+
continue
|
|
412
576
|
|
|
413
|
-
|
|
414
|
-
filtered_threads.append(thread)
|
|
577
|
+
filtered_threads.append(thread)
|
|
415
578
|
|
|
416
579
|
# Sort by created_at in descending order
|
|
417
580
|
sorted_threads = sorted(
|
|
@@ -428,8 +591,11 @@ class Threads:
|
|
|
428
591
|
return thread_iterator()
|
|
429
592
|
|
|
430
593
|
@staticmethod
|
|
431
|
-
async def
|
|
432
|
-
|
|
594
|
+
async def _get_with_filters(
|
|
595
|
+
conn: InMemConnectionProto,
|
|
596
|
+
thread_id: UUID,
|
|
597
|
+
filters: Auth.types.FilterType | None,
|
|
598
|
+
) -> Thread | None:
|
|
433
599
|
thread_id = _ensure_uuid(thread_id)
|
|
434
600
|
matching_thread = next(
|
|
435
601
|
(
|
|
@@ -439,6 +605,37 @@ class Threads:
|
|
|
439
605
|
),
|
|
440
606
|
None,
|
|
441
607
|
)
|
|
608
|
+
if not matching_thread or (
|
|
609
|
+
filters and not _check_filter_match(matching_thread["metadata"], filters)
|
|
610
|
+
):
|
|
611
|
+
return
|
|
612
|
+
|
|
613
|
+
return matching_thread
|
|
614
|
+
|
|
615
|
+
@staticmethod
|
|
616
|
+
async def _get(
|
|
617
|
+
conn: InMemConnectionProto,
|
|
618
|
+
thread_id: UUID,
|
|
619
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
620
|
+
) -> Thread | None:
|
|
621
|
+
"""Get a thread by ID."""
|
|
622
|
+
thread_id = _ensure_uuid(thread_id)
|
|
623
|
+
filters = await Threads.handle_event(
|
|
624
|
+
ctx,
|
|
625
|
+
"read",
|
|
626
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
627
|
+
)
|
|
628
|
+
return await Threads._get_with_filters(conn, thread_id, filters)
|
|
629
|
+
|
|
630
|
+
@staticmethod
|
|
631
|
+
async def get(
|
|
632
|
+
conn: InMemConnectionProto,
|
|
633
|
+
thread_id: UUID,
|
|
634
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
635
|
+
) -> AsyncIterator[Thread]:
|
|
636
|
+
"""Get a thread by ID."""
|
|
637
|
+
matching_thread = await Threads._get(conn, thread_id, ctx)
|
|
638
|
+
|
|
442
639
|
if not matching_thread:
|
|
443
640
|
raise HTTPException(
|
|
444
641
|
status_code=404, detail=f"Thread with ID {thread_id} not found"
|
|
@@ -457,6 +654,7 @@ class Threads:
|
|
|
457
654
|
*,
|
|
458
655
|
metadata: MetadataInput,
|
|
459
656
|
if_exists: OnConflictBehavior,
|
|
657
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
460
658
|
) -> AsyncIterator[Thread]:
|
|
461
659
|
"""Insert or update a thread."""
|
|
462
660
|
thread_id = _ensure_uuid(thread_id)
|
|
@@ -467,8 +665,22 @@ class Threads:
|
|
|
467
665
|
existing_thread = next(
|
|
468
666
|
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
469
667
|
)
|
|
668
|
+
filters = await Threads.handle_event(
|
|
669
|
+
ctx,
|
|
670
|
+
"create",
|
|
671
|
+
Auth.types.ThreadsCreate(
|
|
672
|
+
thread_id=thread_id, metadata=metadata, if_exists=if_exists
|
|
673
|
+
),
|
|
674
|
+
)
|
|
470
675
|
|
|
471
676
|
if existing_thread:
|
|
677
|
+
if filters and not _check_filter_match(
|
|
678
|
+
existing_thread["metadata"], filters
|
|
679
|
+
):
|
|
680
|
+
# Should we use a different status code here?
|
|
681
|
+
raise HTTPException(
|
|
682
|
+
status_code=409, detail=f"Thread with ID {thread_id} already exists"
|
|
683
|
+
)
|
|
472
684
|
if if_exists == "raise":
|
|
473
685
|
raise HTTPException(
|
|
474
686
|
status_code=409, detail=f"Thread with ID {thread_id} already exists"
|
|
@@ -479,7 +691,6 @@ class Threads:
|
|
|
479
691
|
yield existing_thread
|
|
480
692
|
|
|
481
693
|
return _yield_existing()
|
|
482
|
-
|
|
483
694
|
# Create new thread
|
|
484
695
|
new_thread: Thread = {
|
|
485
696
|
"thread_id": thread_id,
|
|
@@ -501,7 +712,11 @@ class Threads:
|
|
|
501
712
|
|
|
502
713
|
@staticmethod
|
|
503
714
|
async def patch(
|
|
504
|
-
conn: InMemConnectionProto,
|
|
715
|
+
conn: InMemConnectionProto,
|
|
716
|
+
thread_id: UUID,
|
|
717
|
+
*,
|
|
718
|
+
metadata: MetadataValue,
|
|
719
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
505
720
|
) -> AsyncIterator[Thread]:
|
|
506
721
|
"""Update a thread."""
|
|
507
722
|
thread_list = conn.store["threads"]
|
|
@@ -514,15 +729,23 @@ class Threads:
|
|
|
514
729
|
break
|
|
515
730
|
|
|
516
731
|
if thread_idx is not None:
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
732
|
+
filters = await Threads.handle_event(
|
|
733
|
+
ctx,
|
|
734
|
+
"update",
|
|
735
|
+
Auth.types.ThreadsUpdate(thread_id=thread_id, metadata=metadata),
|
|
736
|
+
)
|
|
737
|
+
if not filters or _check_filter_match(
|
|
738
|
+
thread_list[thread_idx]["metadata"], filters
|
|
739
|
+
):
|
|
740
|
+
thread = copy.deepcopy(thread_list[thread_idx])
|
|
741
|
+
thread["metadata"] = {**thread["metadata"], **metadata}
|
|
742
|
+
thread["updated_at"] = datetime.now(UTC)
|
|
743
|
+
thread_list[thread_idx] = thread
|
|
521
744
|
|
|
522
|
-
|
|
523
|
-
|
|
745
|
+
async def thread_iterator() -> AsyncIterator[Thread]:
|
|
746
|
+
yield thread
|
|
524
747
|
|
|
525
|
-
|
|
748
|
+
return thread_iterator()
|
|
526
749
|
|
|
527
750
|
async def empty_iterator() -> AsyncIterator[Thread]:
|
|
528
751
|
if False: # This ensures the iterator is empty
|
|
@@ -536,6 +759,7 @@ class Threads:
|
|
|
536
759
|
thread_id: UUID,
|
|
537
760
|
checkpoint: CheckpointPayload | None,
|
|
538
761
|
exception: BaseException | None,
|
|
762
|
+
# This does not accept the auth context since it's only used internally
|
|
539
763
|
) -> None:
|
|
540
764
|
"""Set the status of a thread."""
|
|
541
765
|
thread_id = _ensure_uuid(thread_id)
|
|
@@ -597,19 +821,33 @@ class Threads:
|
|
|
597
821
|
|
|
598
822
|
@staticmethod
|
|
599
823
|
async def delete(
|
|
600
|
-
conn: InMemConnectionProto,
|
|
824
|
+
conn: InMemConnectionProto,
|
|
825
|
+
thread_id: UUID,
|
|
826
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
601
827
|
) -> AsyncIterator[UUID]:
|
|
602
828
|
"""Delete a thread by ID and cascade delete all associated runs."""
|
|
603
829
|
thread_list = conn.store["threads"]
|
|
604
830
|
thread_idx = None
|
|
605
831
|
thread_id = _ensure_uuid(thread_id)
|
|
606
|
-
conn.locks.pop(thread_id, None)
|
|
607
832
|
|
|
608
833
|
# Find the thread to delete
|
|
609
834
|
for idx, thread in enumerate(thread_list):
|
|
610
835
|
if thread["thread_id"] == thread_id:
|
|
611
836
|
thread_idx = idx
|
|
612
837
|
break
|
|
838
|
+
filters = await Threads.handle_event(
|
|
839
|
+
ctx,
|
|
840
|
+
"delete",
|
|
841
|
+
Auth.types.ThreadsDelete(thread_id=thread_id),
|
|
842
|
+
)
|
|
843
|
+
if (filters and not _check_filter_match(thread["metadata"], filters)) or (
|
|
844
|
+
thread_idx is None
|
|
845
|
+
):
|
|
846
|
+
raise HTTPException(
|
|
847
|
+
status_code=404, detail=f"Thread with ID {thread_id} not found"
|
|
848
|
+
)
|
|
849
|
+
# Delete the thread
|
|
850
|
+
conn.locks.pop(thread_id, None)
|
|
613
851
|
# Cascade delete all runs associated with this thread
|
|
614
852
|
conn.store["runs"] = [
|
|
615
853
|
run for run in conn.store["runs"] if run["thread_id"] != thread_id
|
|
@@ -635,12 +873,20 @@ class Threads:
|
|
|
635
873
|
|
|
636
874
|
@staticmethod
|
|
637
875
|
async def copy(
|
|
638
|
-
conn: InMemConnectionProto,
|
|
876
|
+
conn: InMemConnectionProto,
|
|
877
|
+
thread_id: UUID,
|
|
878
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
639
879
|
) -> AsyncIterator[Thread]:
|
|
640
880
|
"""Create a copy of an existing thread."""
|
|
641
881
|
thread_id = _ensure_uuid(thread_id)
|
|
642
882
|
new_thread_id = uuid4()
|
|
643
|
-
|
|
883
|
+
filters = await Threads.handle_event(
|
|
884
|
+
ctx,
|
|
885
|
+
"read",
|
|
886
|
+
Auth.types.ThreadsRead(
|
|
887
|
+
thread_id=new_thread_id,
|
|
888
|
+
),
|
|
889
|
+
)
|
|
644
890
|
async with conn.pipeline():
|
|
645
891
|
# Find the original thread in our store
|
|
646
892
|
original_thread = next(
|
|
@@ -648,7 +894,11 @@ class Threads:
|
|
|
648
894
|
)
|
|
649
895
|
|
|
650
896
|
if not original_thread:
|
|
651
|
-
return
|
|
897
|
+
return _empty_generator()
|
|
898
|
+
if filters and not _check_filter_match(
|
|
899
|
+
original_thread["metadata"], filters
|
|
900
|
+
):
|
|
901
|
+
return _empty_generator()
|
|
652
902
|
|
|
653
903
|
# Create new thread with copied metadata
|
|
654
904
|
new_thread: Thread = {
|
|
@@ -690,15 +940,22 @@ class Threads:
|
|
|
690
940
|
|
|
691
941
|
return row_generator()
|
|
692
942
|
|
|
693
|
-
class State:
|
|
943
|
+
class State(Authenticated):
|
|
944
|
+
# We will treat this like a runs resource for now.
|
|
945
|
+
resource = "threads"
|
|
946
|
+
|
|
694
947
|
@staticmethod
|
|
695
948
|
async def get(
|
|
696
|
-
conn: InMemConnectionProto,
|
|
949
|
+
conn: InMemConnectionProto,
|
|
950
|
+
config: Config,
|
|
951
|
+
subgraphs: bool = False,
|
|
952
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
697
953
|
) -> StateSnapshot:
|
|
698
954
|
"""Get state for a thread."""
|
|
699
955
|
checkpointer = Checkpointer(conn)
|
|
700
956
|
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
701
|
-
|
|
957
|
+
# Auth will be applied here so no need to use filters downstream
|
|
958
|
+
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
702
959
|
thread = await anext(thread_iter)
|
|
703
960
|
checkpoint = await checkpointer.aget(config)
|
|
704
961
|
|
|
@@ -747,12 +1004,19 @@ class Threads:
|
|
|
747
1004
|
config: Config,
|
|
748
1005
|
values: Sequence[dict] | dict[str, Any] | None,
|
|
749
1006
|
as_node: str | None = None,
|
|
1007
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
750
1008
|
) -> ThreadUpdateResponse:
|
|
751
1009
|
"""Add state to a thread."""
|
|
1010
|
+
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
1011
|
+
filters = await Threads.handle_event(
|
|
1012
|
+
ctx,
|
|
1013
|
+
"update",
|
|
1014
|
+
Auth.types.ThreadsUpdate(thread_id=thread_id),
|
|
1015
|
+
)
|
|
752
1016
|
|
|
753
1017
|
checkpointer = Checkpointer(conn)
|
|
754
|
-
|
|
755
|
-
thread_iter = await Threads.get(conn, thread_id)
|
|
1018
|
+
|
|
1019
|
+
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
756
1020
|
thread = await fetchone(
|
|
757
1021
|
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
758
1022
|
)
|
|
@@ -760,6 +1024,8 @@ class Threads:
|
|
|
760
1024
|
|
|
761
1025
|
if not thread:
|
|
762
1026
|
raise HTTPException(status_code=404, detail="Thread not found")
|
|
1027
|
+
if not _check_filter_match(thread["metadata"], filters):
|
|
1028
|
+
raise HTTPException(status_code=403, detail="Forbidden")
|
|
763
1029
|
|
|
764
1030
|
metadata = thread["metadata"]
|
|
765
1031
|
thread_config = thread["config"]
|
|
@@ -781,7 +1047,7 @@ class Threads:
|
|
|
781
1047
|
)
|
|
782
1048
|
|
|
783
1049
|
# Get current state
|
|
784
|
-
state = await Threads.State.get(conn, config, subgraphs=False)
|
|
1050
|
+
state = await Threads.State.get(conn, config, subgraphs=False, ctx=ctx)
|
|
785
1051
|
# Update thread values
|
|
786
1052
|
for thread in conn.store["threads"]:
|
|
787
1053
|
if thread["thread_id"] == thread_id:
|
|
@@ -805,22 +1071,26 @@ class Threads:
|
|
|
805
1071
|
limit: int = 10,
|
|
806
1072
|
before: str | Checkpoint | None = None,
|
|
807
1073
|
metadata: MetadataInput = None,
|
|
1074
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
808
1075
|
) -> list[StateSnapshot]:
|
|
809
1076
|
"""Get the history of a thread."""
|
|
810
1077
|
|
|
811
1078
|
thread_id = _ensure_uuid(config["configurable"]["thread_id"])
|
|
812
1079
|
thread = None
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
1080
|
+
filters = await Threads.handle_event(
|
|
1081
|
+
ctx,
|
|
1082
|
+
"read",
|
|
1083
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1084
|
+
)
|
|
1085
|
+
thread = await fetchone(
|
|
1086
|
+
await Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx)
|
|
1087
|
+
)
|
|
821
1088
|
|
|
822
1089
|
# Parse thread metadata and config
|
|
823
1090
|
thread_metadata = thread["metadata"]
|
|
1091
|
+
if not _check_filter_match(thread_metadata, filters):
|
|
1092
|
+
return []
|
|
1093
|
+
|
|
824
1094
|
thread_config = thread["config"]
|
|
825
1095
|
# If graph_id exists, get state history
|
|
826
1096
|
if graph_id := thread_metadata.get("graph_id"):
|
|
@@ -847,7 +1117,9 @@ class Threads:
|
|
|
847
1117
|
return []
|
|
848
1118
|
|
|
849
1119
|
|
|
850
|
-
class Runs:
|
|
1120
|
+
class Runs(Authenticated):
|
|
1121
|
+
resource = "threads"
|
|
1122
|
+
|
|
851
1123
|
@staticmethod
|
|
852
1124
|
async def stats(conn: InMemConnectionProto) -> QueueStats:
|
|
853
1125
|
"""Get stats about the queue."""
|
|
@@ -1001,6 +1273,7 @@ class Runs:
|
|
|
1001
1273
|
multitask_strategy: MultitaskStrategy = "reject",
|
|
1002
1274
|
if_not_exists: IfNotExists = "reject",
|
|
1003
1275
|
after_seconds: int = 0,
|
|
1276
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1004
1277
|
) -> AsyncIterator[Run]:
|
|
1005
1278
|
"""Create a run."""
|
|
1006
1279
|
assistant_id = _ensure_uuid(assistant_id)
|
|
@@ -1009,22 +1282,38 @@ class Runs:
|
|
|
1009
1282
|
None,
|
|
1010
1283
|
)
|
|
1011
1284
|
|
|
1012
|
-
async def empty_generator():
|
|
1013
|
-
if False:
|
|
1014
|
-
yield
|
|
1015
|
-
|
|
1016
1285
|
if not assistant:
|
|
1017
|
-
return
|
|
1286
|
+
return _empty_generator()
|
|
1018
1287
|
|
|
1019
1288
|
thread_id = _ensure_uuid(thread_id) if thread_id else None
|
|
1020
1289
|
run_id = _ensure_uuid(run_id) if run_id else None
|
|
1021
|
-
metadata = metadata
|
|
1290
|
+
metadata = metadata if metadata is not None else {}
|
|
1022
1291
|
config = kwargs.get("config", {})
|
|
1023
1292
|
|
|
1024
1293
|
# Handle thread creation/update
|
|
1025
1294
|
existing_thread = next(
|
|
1026
1295
|
(t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
|
|
1027
1296
|
)
|
|
1297
|
+
filters = await Runs.handle_event(
|
|
1298
|
+
ctx,
|
|
1299
|
+
"create",
|
|
1300
|
+
Auth.types.RunsCreate(
|
|
1301
|
+
thread_id=thread_id,
|
|
1302
|
+
assistant_id=assistant_id,
|
|
1303
|
+
run_id=run_id,
|
|
1304
|
+
status=status,
|
|
1305
|
+
metadata=metadata,
|
|
1306
|
+
prevent_insert_if_inflight=prevent_insert_if_inflight,
|
|
1307
|
+
multitask_strategy=multitask_strategy,
|
|
1308
|
+
if_not_exists=if_not_exists,
|
|
1309
|
+
after_seconds=after_seconds,
|
|
1310
|
+
kwargs=kwargs,
|
|
1311
|
+
),
|
|
1312
|
+
)
|
|
1313
|
+
if existing_thread and filters:
|
|
1314
|
+
# Reject if the user doesn't own the thread
|
|
1315
|
+
if not _check_filter_match(existing_thread["metadata"], filters):
|
|
1316
|
+
return _empty_generator()
|
|
1028
1317
|
|
|
1029
1318
|
if not existing_thread and (thread_id is None or if_not_exists == "create"):
|
|
1030
1319
|
# Create new thread
|
|
@@ -1069,7 +1358,7 @@ class Runs:
|
|
|
1069
1358
|
)
|
|
1070
1359
|
existing_thread["updated_at"] = datetime.now(UTC)
|
|
1071
1360
|
else:
|
|
1072
|
-
return
|
|
1361
|
+
return _empty_generator()
|
|
1073
1362
|
|
|
1074
1363
|
# Check for inflight runs if needed
|
|
1075
1364
|
inflight_runs = [
|
|
@@ -1089,9 +1378,11 @@ class Runs:
|
|
|
1089
1378
|
# Create new run
|
|
1090
1379
|
configurable = Runs._merge_jsonb(
|
|
1091
1380
|
Runs._get_configurable(assistant["config"]),
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1381
|
+
(
|
|
1382
|
+
Runs._get_configurable(existing_thread["config"])
|
|
1383
|
+
if existing_thread
|
|
1384
|
+
else {}
|
|
1385
|
+
),
|
|
1095
1386
|
Runs._get_configurable(config),
|
|
1096
1387
|
{
|
|
1097
1388
|
"run_id": str(run_id),
|
|
@@ -1149,11 +1440,20 @@ class Runs:
|
|
|
1149
1440
|
|
|
1150
1441
|
@staticmethod
|
|
1151
1442
|
async def get(
|
|
1152
|
-
conn: InMemConnectionProto,
|
|
1443
|
+
conn: InMemConnectionProto,
|
|
1444
|
+
run_id: UUID,
|
|
1445
|
+
*,
|
|
1446
|
+
thread_id: UUID,
|
|
1447
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1153
1448
|
) -> AsyncIterator[Run]:
|
|
1154
1449
|
"""Get a run by ID."""
|
|
1155
1450
|
|
|
1156
1451
|
run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
|
|
1452
|
+
filters = await Runs.handle_event(
|
|
1453
|
+
ctx,
|
|
1454
|
+
"read",
|
|
1455
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1456
|
+
)
|
|
1157
1457
|
|
|
1158
1458
|
async def _yield_result():
|
|
1159
1459
|
matching_run = None
|
|
@@ -1162,16 +1462,36 @@ class Runs:
|
|
|
1162
1462
|
matching_run = run
|
|
1163
1463
|
break
|
|
1164
1464
|
if matching_run:
|
|
1465
|
+
if filters:
|
|
1466
|
+
thread = await Threads._get_with_filters(
|
|
1467
|
+
conn, matching_run["thread_id"], filters
|
|
1468
|
+
)
|
|
1469
|
+
if not thread:
|
|
1470
|
+
return
|
|
1165
1471
|
yield matching_run
|
|
1166
1472
|
|
|
1167
1473
|
return _yield_result()
|
|
1168
1474
|
|
|
1169
1475
|
@staticmethod
|
|
1170
1476
|
async def delete(
|
|
1171
|
-
conn: InMemConnectionProto,
|
|
1477
|
+
conn: InMemConnectionProto,
|
|
1478
|
+
run_id: UUID,
|
|
1479
|
+
*,
|
|
1480
|
+
thread_id: UUID,
|
|
1481
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1172
1482
|
) -> AsyncIterator[UUID]:
|
|
1173
1483
|
"""Delete a run by ID."""
|
|
1174
1484
|
run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
|
|
1485
|
+
filters = await Runs.handle_event(
|
|
1486
|
+
ctx,
|
|
1487
|
+
"delete",
|
|
1488
|
+
Auth.types.ThreadsDelete(run_id=run_id, thread_id=thread_id),
|
|
1489
|
+
)
|
|
1490
|
+
|
|
1491
|
+
if filters:
|
|
1492
|
+
thread = await Threads._get_with_filters(conn, thread_id, filters)
|
|
1493
|
+
if not thread:
|
|
1494
|
+
return _empty_generator()
|
|
1175
1495
|
_delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
|
|
1176
1496
|
found = False
|
|
1177
1497
|
for i, run in enumerate(conn.store["runs"]):
|
|
@@ -1192,16 +1512,22 @@ class Runs:
|
|
|
1192
1512
|
run_id: UUID,
|
|
1193
1513
|
*,
|
|
1194
1514
|
thread_id: UUID,
|
|
1515
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1195
1516
|
) -> Fragment:
|
|
1196
1517
|
"""Wait for a run to complete. If already done, return immediately.
|
|
1197
1518
|
|
|
1198
1519
|
Returns:
|
|
1199
1520
|
the final state of the run.
|
|
1200
1521
|
"""
|
|
1522
|
+
async with connect() as conn:
|
|
1523
|
+
# Validate ownership
|
|
1524
|
+
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1525
|
+
await fetchone(thread_iter)
|
|
1201
1526
|
last_chunk: bytes | None = None
|
|
1202
1527
|
# wait for the run to complete
|
|
1528
|
+
# Rely on this join's auth
|
|
1203
1529
|
async for mode, chunk in Runs.Stream.join(
|
|
1204
|
-
run_id, thread_id=thread_id, stream_mode="values"
|
|
1530
|
+
run_id, thread_id=thread_id, stream_mode="values", ctx=ctx
|
|
1205
1531
|
):
|
|
1206
1532
|
if mode == b"values":
|
|
1207
1533
|
last_chunk = chunk
|
|
@@ -1212,7 +1538,7 @@ class Runs:
|
|
|
1212
1538
|
else:
|
|
1213
1539
|
# otherwise, the run had already finished, so fetch the state from thread
|
|
1214
1540
|
async with connect() as conn:
|
|
1215
|
-
thread_iter = await Threads.get(conn, thread_id)
|
|
1541
|
+
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1216
1542
|
thread = await fetchone(thread_iter)
|
|
1217
1543
|
return thread["values"]
|
|
1218
1544
|
|
|
@@ -1223,8 +1549,10 @@ class Runs:
|
|
|
1223
1549
|
*,
|
|
1224
1550
|
action: Literal["interrupt", "rollback"] = "interrupt",
|
|
1225
1551
|
thread_id: UUID,
|
|
1552
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1226
1553
|
) -> None:
|
|
1227
1554
|
"""Cancel a run."""
|
|
1555
|
+
# Authwise, this invokes the runs.update handler
|
|
1228
1556
|
# Cancellation tries to take two actions, to cover runs in different states:
|
|
1229
1557
|
# - For any run, send a cancellation message through the stream manager
|
|
1230
1558
|
# - For queued runs not yet picked up by a worker, update their status if interrupt,
|
|
@@ -1233,6 +1561,15 @@ class Runs:
|
|
|
1233
1561
|
# - For runs in any other state, we raise a 404
|
|
1234
1562
|
run_ids = [_ensure_uuid(run_id) for run_id in run_ids]
|
|
1235
1563
|
thread_id = _ensure_uuid(thread_id)
|
|
1564
|
+
filters = await Runs.handle_event(
|
|
1565
|
+
ctx,
|
|
1566
|
+
"update",
|
|
1567
|
+
Auth.types.ThreadsUpdate(
|
|
1568
|
+
thread_id=thread_id,
|
|
1569
|
+
action=action,
|
|
1570
|
+
metadata={"run_ids": run_ids},
|
|
1571
|
+
),
|
|
1572
|
+
)
|
|
1236
1573
|
|
|
1237
1574
|
stream_manager = get_stream_manager()
|
|
1238
1575
|
found_runs = []
|
|
@@ -1247,6 +1584,10 @@ class Runs:
|
|
|
1247
1584
|
None,
|
|
1248
1585
|
)
|
|
1249
1586
|
if run:
|
|
1587
|
+
if filters:
|
|
1588
|
+
thread = await Threads._get_with_filters(conn, thread_id, filters)
|
|
1589
|
+
if not thread:
|
|
1590
|
+
continue
|
|
1250
1591
|
found_runs.append(run)
|
|
1251
1592
|
# Send cancellation message through stream manager
|
|
1252
1593
|
control_message = Message(
|
|
@@ -1296,16 +1637,26 @@ class Runs:
|
|
|
1296
1637
|
offset: int = 0,
|
|
1297
1638
|
metadata: MetadataInput,
|
|
1298
1639
|
status: RunStatus | None = None,
|
|
1640
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1299
1641
|
) -> AsyncIterator[Run]:
|
|
1300
1642
|
"""List all runs by thread."""
|
|
1301
1643
|
runs = conn.store["runs"]
|
|
1302
|
-
metadata = metadata
|
|
1644
|
+
metadata = metadata if metadata is not None else {}
|
|
1303
1645
|
thread_id = _ensure_uuid(thread_id)
|
|
1646
|
+
filters = await Runs.handle_event(
|
|
1647
|
+
ctx,
|
|
1648
|
+
"search",
|
|
1649
|
+
Auth.types.ThreadsSearch(thread_id=thread_id, metadata=metadata),
|
|
1650
|
+
)
|
|
1304
1651
|
filtered_runs = [
|
|
1305
1652
|
run
|
|
1306
1653
|
for run in runs
|
|
1307
1654
|
if run["thread_id"] == thread_id
|
|
1308
1655
|
and is_jsonb_contained(run["metadata"], metadata)
|
|
1656
|
+
and (
|
|
1657
|
+
not filters
|
|
1658
|
+
or (await Threads._get_with_filters(conn, thread_id, filters))
|
|
1659
|
+
)
|
|
1309
1660
|
and (status is None or run["status"] == status)
|
|
1310
1661
|
]
|
|
1311
1662
|
sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
|
|
@@ -1358,6 +1709,7 @@ class Runs:
|
|
|
1358
1709
|
ignore_404: bool = False,
|
|
1359
1710
|
cancel_on_disconnect: bool = False,
|
|
1360
1711
|
stream_mode: "StreamMode | asyncio.Queue | None" = None,
|
|
1712
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1361
1713
|
) -> AsyncIterator[tuple[bytes, bytes]]:
|
|
1362
1714
|
"""Stream the run output."""
|
|
1363
1715
|
log = logger.isEnabledFor(logging.DEBUG)
|
|
@@ -1369,6 +1721,21 @@ class Runs:
|
|
|
1369
1721
|
|
|
1370
1722
|
try:
|
|
1371
1723
|
async with connect() as conn:
|
|
1724
|
+
filters = await Runs.handle_event(
|
|
1725
|
+
ctx,
|
|
1726
|
+
"read",
|
|
1727
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1728
|
+
)
|
|
1729
|
+
if filters:
|
|
1730
|
+
thread = await Threads._get_with_filters(
|
|
1731
|
+
cast(InMemConnectionProto, conn), thread_id, filters
|
|
1732
|
+
)
|
|
1733
|
+
if not thread:
|
|
1734
|
+
raise WrappedHTTPException(
|
|
1735
|
+
HTTPException(
|
|
1736
|
+
status_code=404, detail="Thread not found"
|
|
1737
|
+
)
|
|
1738
|
+
)
|
|
1372
1739
|
channel_prefix = f"run:{run_id}:stream:"
|
|
1373
1740
|
len_prefix = len(channel_prefix.encode())
|
|
1374
1741
|
|
|
@@ -1393,7 +1760,9 @@ class Runs:
|
|
|
1393
1760
|
)
|
|
1394
1761
|
except TimeoutError:
|
|
1395
1762
|
# Check if the run is still pending
|
|
1396
|
-
run_iter = await Runs.get(
|
|
1763
|
+
run_iter = await Runs.get(
|
|
1764
|
+
conn, run_id, thread_id=thread_id, ctx=ctx
|
|
1765
|
+
)
|
|
1397
1766
|
run = await anext(run_iter, None)
|
|
1398
1767
|
|
|
1399
1768
|
if ignore_404 and run is None:
|
|
@@ -1408,6 +1777,8 @@ class Runs:
|
|
|
1408
1777
|
break
|
|
1409
1778
|
elif run["status"] != "pending":
|
|
1410
1779
|
break
|
|
1780
|
+
except WrappedHTTPException as e:
|
|
1781
|
+
raise e.http_exception from None
|
|
1411
1782
|
except:
|
|
1412
1783
|
if cancel_on_disconnect:
|
|
1413
1784
|
create_task(cancel_run(thread_id, run_id))
|
|
@@ -1475,22 +1846,32 @@ class Crons:
|
|
|
1475
1846
|
schedule: str,
|
|
1476
1847
|
cron_id: UUID | None = None,
|
|
1477
1848
|
thread_id: UUID | None = None,
|
|
1478
|
-
user_id: str | None = None,
|
|
1479
1849
|
end_time: datetime | None = None,
|
|
1850
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1480
1851
|
) -> AsyncIterator[Cron]:
|
|
1481
1852
|
raise NotImplementedError
|
|
1482
1853
|
|
|
1483
1854
|
@staticmethod
|
|
1484
|
-
async def delete(
|
|
1855
|
+
async def delete(
|
|
1856
|
+
conn: InMemConnectionProto,
|
|
1857
|
+
cron_id: UUID,
|
|
1858
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1859
|
+
) -> AsyncIterator[UUID]:
|
|
1485
1860
|
raise NotImplementedError
|
|
1486
1861
|
|
|
1487
1862
|
@staticmethod
|
|
1488
|
-
async def next(
|
|
1863
|
+
async def next(
|
|
1864
|
+
conn: InMemConnectionProto,
|
|
1865
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1866
|
+
) -> AsyncIterator[Cron]:
|
|
1489
1867
|
raise NotImplementedError
|
|
1490
1868
|
|
|
1491
1869
|
@staticmethod
|
|
1492
1870
|
async def set_next_run_date(
|
|
1493
|
-
conn: InMemConnectionProto,
|
|
1871
|
+
conn: InMemConnectionProto,
|
|
1872
|
+
cron_id: UUID,
|
|
1873
|
+
next_run_date: datetime,
|
|
1874
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1494
1875
|
) -> None:
|
|
1495
1876
|
raise NotImplementedError
|
|
1496
1877
|
|
|
@@ -1502,13 +1883,16 @@ class Crons:
|
|
|
1502
1883
|
thread_id: UUID | None,
|
|
1503
1884
|
limit: int,
|
|
1504
1885
|
offset: int,
|
|
1886
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1505
1887
|
) -> AsyncIterator[Cron]:
|
|
1506
1888
|
raise NotImplementedError
|
|
1507
1889
|
|
|
1508
1890
|
|
|
1509
|
-
async def cancel_run(
|
|
1891
|
+
async def cancel_run(
|
|
1892
|
+
thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
|
|
1893
|
+
) -> None:
|
|
1510
1894
|
async with connect() as conn:
|
|
1511
|
-
await Runs.cancel(conn, [run_id], thread_id=thread_id)
|
|
1895
|
+
await Runs.cancel(conn, [run_id], thread_id=thread_id, ctx=ctx)
|
|
1512
1896
|
|
|
1513
1897
|
|
|
1514
1898
|
def _delete_checkpoints_for_thread(
|
|
@@ -1538,6 +1922,47 @@ def _delete_checkpoints_for_thread(
|
|
|
1538
1922
|
)
|
|
1539
1923
|
|
|
1540
1924
|
|
|
1925
|
+
def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -> bool:
|
|
1926
|
+
"""Check if metadata matches the filter conditions.
|
|
1927
|
+
|
|
1928
|
+
Args:
|
|
1929
|
+
metadata: The metadata to check
|
|
1930
|
+
filters: The filter conditions to apply
|
|
1931
|
+
|
|
1932
|
+
Returns:
|
|
1933
|
+
True if the metadata matches all filter conditions, False otherwise
|
|
1934
|
+
"""
|
|
1935
|
+
if not filters:
|
|
1936
|
+
return True
|
|
1937
|
+
|
|
1938
|
+
for key, value in filters.items():
|
|
1939
|
+
if isinstance(value, dict):
|
|
1940
|
+
op = next(iter(value))
|
|
1941
|
+
filter_value = value[op]
|
|
1942
|
+
|
|
1943
|
+
if op == "$eq":
|
|
1944
|
+
if key not in metadata or metadata[key] != filter_value:
|
|
1945
|
+
return False
|
|
1946
|
+
elif op == "$contains":
|
|
1947
|
+
if (
|
|
1948
|
+
key not in metadata
|
|
1949
|
+
or not isinstance(metadata[key], list)
|
|
1950
|
+
or filter_value not in metadata[key]
|
|
1951
|
+
):
|
|
1952
|
+
return False
|
|
1953
|
+
else:
|
|
1954
|
+
# Direct equality
|
|
1955
|
+
if key not in metadata or metadata[key] != value:
|
|
1956
|
+
return False
|
|
1957
|
+
|
|
1958
|
+
return True
|
|
1959
|
+
|
|
1960
|
+
|
|
1961
|
+
async def _empty_generator():
|
|
1962
|
+
if False:
|
|
1963
|
+
yield
|
|
1964
|
+
|
|
1965
|
+
|
|
1541
1966
|
__all__ = [
|
|
1542
1967
|
"Assistants",
|
|
1543
1968
|
"Crons",
|