langgraph-runtime-inmem 0.6.4__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem/database.py +6 -2
- langgraph_runtime_inmem/inmem_stream.py +160 -36
- langgraph_runtime_inmem/lifespan.py +41 -2
- langgraph_runtime_inmem/metrics.py +1 -1
- langgraph_runtime_inmem/ops.py +695 -206
- langgraph_runtime_inmem/queue.py +8 -18
- {langgraph_runtime_inmem-0.6.4.dist-info → langgraph_runtime_inmem-0.18.1.dist-info}/METADATA +3 -3
- langgraph_runtime_inmem-0.18.1.dist-info/RECORD +13 -0
- langgraph_runtime_inmem-0.6.4.dist-info/RECORD +0 -13
- {langgraph_runtime_inmem-0.6.4.dist-info → langgraph_runtime_inmem-0.18.1.dist-info}/WHEEL +0 -0
langgraph_runtime_inmem/ops.py
CHANGED
|
@@ -11,7 +11,6 @@ import uuid
|
|
|
11
11
|
from collections import defaultdict
|
|
12
12
|
from collections.abc import AsyncIterator, Sequence
|
|
13
13
|
from contextlib import asynccontextmanager
|
|
14
|
-
from copy import deepcopy
|
|
15
14
|
from datetime import UTC, datetime, timedelta
|
|
16
15
|
from typing import Any, Literal, cast
|
|
17
16
|
from uuid import UUID, uuid4
|
|
@@ -27,17 +26,24 @@ from starlette.exceptions import HTTPException
|
|
|
27
26
|
|
|
28
27
|
from langgraph_runtime_inmem.checkpoint import Checkpointer
|
|
29
28
|
from langgraph_runtime_inmem.database import InMemConnectionProto, connect
|
|
30
|
-
from langgraph_runtime_inmem.inmem_stream import
|
|
29
|
+
from langgraph_runtime_inmem.inmem_stream import (
|
|
30
|
+
THREADLESS_KEY,
|
|
31
|
+
ContextQueue,
|
|
32
|
+
Message,
|
|
33
|
+
get_stream_manager,
|
|
34
|
+
)
|
|
31
35
|
|
|
32
36
|
if typing.TYPE_CHECKING:
|
|
33
37
|
from langgraph_api.asyncio import ValueEvent
|
|
34
38
|
from langgraph_api.config import ThreadTTLConfig
|
|
35
39
|
from langgraph_api.schema import (
|
|
36
40
|
Assistant,
|
|
41
|
+
AssistantSelectField,
|
|
37
42
|
Checkpoint,
|
|
38
43
|
Config,
|
|
39
44
|
Context,
|
|
40
45
|
Cron,
|
|
46
|
+
CronSelectField,
|
|
41
47
|
DeprecatedInterrupt,
|
|
42
48
|
IfNotExists,
|
|
43
49
|
MetadataInput,
|
|
@@ -46,15 +52,19 @@ if typing.TYPE_CHECKING:
|
|
|
46
52
|
OnConflictBehavior,
|
|
47
53
|
QueueStats,
|
|
48
54
|
Run,
|
|
55
|
+
RunSelectField,
|
|
49
56
|
RunStatus,
|
|
50
57
|
StreamMode,
|
|
51
58
|
Thread,
|
|
59
|
+
ThreadSelectField,
|
|
52
60
|
ThreadStatus,
|
|
61
|
+
ThreadStreamMode,
|
|
53
62
|
ThreadUpdateResponse,
|
|
54
63
|
)
|
|
55
64
|
from langgraph_api.schema import Interrupt as InterruptSchema
|
|
56
|
-
from langgraph_api.
|
|
65
|
+
from langgraph_api.utils import AsyncConnectionProto
|
|
57
66
|
|
|
67
|
+
StreamHandler = ContextQueue
|
|
58
68
|
|
|
59
69
|
logger = structlog.stdlib.get_logger(__name__)
|
|
60
70
|
|
|
@@ -136,6 +146,7 @@ class Assistants(Authenticated):
|
|
|
136
146
|
offset: int,
|
|
137
147
|
sort_by: str | None = None,
|
|
138
148
|
sort_order: str | None = None,
|
|
149
|
+
select: list[AssistantSelectField] | None = None,
|
|
139
150
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
140
151
|
) -> tuple[AsyncIterator[Assistant], int]:
|
|
141
152
|
metadata = metadata if metadata is not None else {}
|
|
@@ -157,9 +168,6 @@ class Assistants(Authenticated):
|
|
|
157
168
|
and (not filters or _check_filter_match(assistant["metadata"], filters))
|
|
158
169
|
]
|
|
159
170
|
|
|
160
|
-
# Get total count before sorting and pagination
|
|
161
|
-
total_count = len(filtered_assistants)
|
|
162
|
-
|
|
163
171
|
# Sort based on sort_by and sort_order
|
|
164
172
|
sort_by = sort_by.lower() if sort_by else None
|
|
165
173
|
if sort_by and sort_by in (
|
|
@@ -181,22 +189,31 @@ class Assistants(Authenticated):
|
|
|
181
189
|
else:
|
|
182
190
|
filtered_assistants.sort(key=lambda x: x.get(sort_by), reverse=reverse)
|
|
183
191
|
else:
|
|
192
|
+
sort_by = "created_at"
|
|
184
193
|
# Default sorting by created_at in descending order
|
|
185
194
|
filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
|
|
186
195
|
|
|
187
196
|
# Apply pagination
|
|
188
197
|
paginated_assistants = filtered_assistants[offset : offset + limit]
|
|
198
|
+
cur = offset + limit if len(filtered_assistants) > offset + limit else None
|
|
189
199
|
|
|
190
200
|
async def assistant_iterator() -> AsyncIterator[Assistant]:
|
|
191
201
|
for assistant in paginated_assistants:
|
|
192
|
-
|
|
202
|
+
if select:
|
|
203
|
+
# Filter to only selected fields
|
|
204
|
+
filtered_assistant = {
|
|
205
|
+
k: v for k, v in assistant.items() if k in select
|
|
206
|
+
}
|
|
207
|
+
yield filtered_assistant
|
|
208
|
+
else:
|
|
209
|
+
yield assistant
|
|
193
210
|
|
|
194
|
-
return assistant_iterator(),
|
|
211
|
+
return assistant_iterator(), cur
|
|
195
212
|
|
|
196
213
|
@staticmethod
|
|
197
214
|
async def get(
|
|
198
215
|
conn: InMemConnectionProto,
|
|
199
|
-
assistant_id: UUID,
|
|
216
|
+
assistant_id: UUID | str,
|
|
200
217
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
201
218
|
) -> AsyncIterator[Assistant]:
|
|
202
219
|
"""Get an assistant by ID."""
|
|
@@ -212,14 +229,14 @@ class Assistants(Authenticated):
|
|
|
212
229
|
if assistant["assistant_id"] == assistant_id and (
|
|
213
230
|
not filters or _check_filter_match(assistant["metadata"], filters)
|
|
214
231
|
):
|
|
215
|
-
yield assistant
|
|
232
|
+
yield copy.deepcopy(assistant)
|
|
216
233
|
|
|
217
234
|
return _yield_result()
|
|
218
235
|
|
|
219
236
|
@staticmethod
|
|
220
237
|
async def put(
|
|
221
238
|
conn: InMemConnectionProto,
|
|
222
|
-
assistant_id: UUID,
|
|
239
|
+
assistant_id: UUID | str,
|
|
223
240
|
*,
|
|
224
241
|
graph_id: str,
|
|
225
242
|
config: Config,
|
|
@@ -231,6 +248,8 @@ class Assistants(Authenticated):
|
|
|
231
248
|
description: str | None = None,
|
|
232
249
|
) -> AsyncIterator[Assistant]:
|
|
233
250
|
"""Insert an assistant."""
|
|
251
|
+
from langgraph_api.graph import GRAPHS
|
|
252
|
+
|
|
234
253
|
assistant_id = _ensure_uuid(assistant_id)
|
|
235
254
|
metadata = metadata if metadata is not None else {}
|
|
236
255
|
filters = await Assistants.handle_event(
|
|
@@ -245,6 +264,22 @@ class Assistants(Authenticated):
|
|
|
245
264
|
name=name,
|
|
246
265
|
),
|
|
247
266
|
)
|
|
267
|
+
|
|
268
|
+
if config.get("configurable") and context:
|
|
269
|
+
raise HTTPException(
|
|
270
|
+
status_code=400,
|
|
271
|
+
detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if graph_id not in GRAPHS:
|
|
275
|
+
raise HTTPException(status_code=404, detail=f"Graph {graph_id} not found")
|
|
276
|
+
|
|
277
|
+
# Keep config and context up to date with one another
|
|
278
|
+
if config.get("configurable"):
|
|
279
|
+
context = config["configurable"]
|
|
280
|
+
elif context:
|
|
281
|
+
config["configurable"] = context
|
|
282
|
+
|
|
248
283
|
existing_assistant = next(
|
|
249
284
|
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
250
285
|
None,
|
|
@@ -330,6 +365,7 @@ class Assistants(Authenticated):
|
|
|
330
365
|
"""
|
|
331
366
|
assistant_id = _ensure_uuid(assistant_id)
|
|
332
367
|
metadata = metadata if metadata is not None else {}
|
|
368
|
+
config = config if config is not None else {}
|
|
333
369
|
filters = await Assistants.handle_event(
|
|
334
370
|
ctx,
|
|
335
371
|
"update",
|
|
@@ -342,6 +378,19 @@ class Assistants(Authenticated):
|
|
|
342
378
|
name=name,
|
|
343
379
|
),
|
|
344
380
|
)
|
|
381
|
+
|
|
382
|
+
if config.get("configurable") and context:
|
|
383
|
+
raise HTTPException(
|
|
384
|
+
status_code=400,
|
|
385
|
+
detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Keep config and context up to date with one another
|
|
389
|
+
if config.get("configurable"):
|
|
390
|
+
context = config["configurable"]
|
|
391
|
+
elif context:
|
|
392
|
+
config["configurable"] = context
|
|
393
|
+
|
|
345
394
|
assistant = next(
|
|
346
395
|
(a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
|
|
347
396
|
None,
|
|
@@ -367,19 +416,17 @@ class Assistants(Authenticated):
|
|
|
367
416
|
else 1
|
|
368
417
|
)
|
|
369
418
|
|
|
370
|
-
# Update assistant_versions table
|
|
371
|
-
if metadata:
|
|
372
|
-
metadata = {
|
|
373
|
-
**assistant["metadata"],
|
|
374
|
-
**metadata,
|
|
375
|
-
}
|
|
376
419
|
new_version_entry = {
|
|
377
420
|
"assistant_id": assistant_id,
|
|
378
421
|
"version": new_version,
|
|
379
422
|
"graph_id": graph_id if graph_id is not None else assistant["graph_id"],
|
|
380
|
-
"config": config if config
|
|
381
|
-
"context": context if context is not None else assistant
|
|
382
|
-
"metadata":
|
|
423
|
+
"config": config if config else assistant["config"],
|
|
424
|
+
"context": context if context is not None else assistant.get("context", {}),
|
|
425
|
+
"metadata": (
|
|
426
|
+
{**assistant["metadata"], **metadata}
|
|
427
|
+
if metadata is not None
|
|
428
|
+
else assistant["metadata"]
|
|
429
|
+
),
|
|
383
430
|
"created_at": now,
|
|
384
431
|
"name": name if name is not None else assistant["name"],
|
|
385
432
|
"description": (
|
|
@@ -514,6 +561,8 @@ class Assistants(Authenticated):
|
|
|
514
561
|
"metadata": version_data["metadata"],
|
|
515
562
|
"version": version_data["version"],
|
|
516
563
|
"updated_at": datetime.now(UTC),
|
|
564
|
+
"name": version_data["name"],
|
|
565
|
+
"description": version_data["description"],
|
|
517
566
|
}
|
|
518
567
|
)
|
|
519
568
|
|
|
@@ -572,6 +621,37 @@ class Assistants(Authenticated):
|
|
|
572
621
|
|
|
573
622
|
return _yield_versions()
|
|
574
623
|
|
|
624
|
+
@staticmethod
|
|
625
|
+
async def count(
|
|
626
|
+
conn: InMemConnectionProto,
|
|
627
|
+
*,
|
|
628
|
+
graph_id: str | None = None,
|
|
629
|
+
metadata: MetadataInput = None,
|
|
630
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
631
|
+
) -> int:
|
|
632
|
+
"""Get count of assistants."""
|
|
633
|
+
metadata = metadata if metadata is not None else {}
|
|
634
|
+
filters = await Assistants.handle_event(
|
|
635
|
+
ctx,
|
|
636
|
+
"search",
|
|
637
|
+
Auth.types.AssistantsSearch(
|
|
638
|
+
graph_id=graph_id, metadata=metadata, limit=0, offset=0
|
|
639
|
+
),
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
count = 0
|
|
643
|
+
for assistant in conn.store["assistants"]:
|
|
644
|
+
if (
|
|
645
|
+
(not graph_id or assistant["graph_id"] == graph_id)
|
|
646
|
+
and (
|
|
647
|
+
not metadata or is_jsonb_contained(assistant["metadata"], metadata)
|
|
648
|
+
)
|
|
649
|
+
and (not filters or _check_filter_match(assistant["metadata"], filters))
|
|
650
|
+
):
|
|
651
|
+
count += 1
|
|
652
|
+
|
|
653
|
+
return count
|
|
654
|
+
|
|
575
655
|
|
|
576
656
|
def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
|
|
577
657
|
"""
|
|
@@ -649,13 +729,13 @@ def _patch_interrupt(
|
|
|
649
729
|
interrupt = Interrupt(**interrupt)
|
|
650
730
|
|
|
651
731
|
return {
|
|
652
|
-
"id":
|
|
653
|
-
|
|
654
|
-
|
|
732
|
+
"id": (
|
|
733
|
+
interrupt.interrupt_id if hasattr(interrupt, "interrupt_id") else None
|
|
734
|
+
),
|
|
655
735
|
"value": interrupt.value,
|
|
656
736
|
"resumable": interrupt.resumable,
|
|
657
737
|
"ns": interrupt.ns,
|
|
658
|
-
"when": interrupt.when,
|
|
738
|
+
"when": interrupt.when, # type: ignore[unresolved-attribute]
|
|
659
739
|
}
|
|
660
740
|
|
|
661
741
|
|
|
@@ -666,6 +746,7 @@ class Threads(Authenticated):
|
|
|
666
746
|
async def search(
|
|
667
747
|
conn: InMemConnectionProto,
|
|
668
748
|
*,
|
|
749
|
+
ids: list[str] | list[UUID] | None = None,
|
|
669
750
|
metadata: MetadataInput,
|
|
670
751
|
values: MetadataInput,
|
|
671
752
|
status: ThreadStatus | None,
|
|
@@ -673,6 +754,7 @@ class Threads(Authenticated):
|
|
|
673
754
|
offset: int,
|
|
674
755
|
sort_by: str | None = None,
|
|
675
756
|
sort_order: str | None = None,
|
|
757
|
+
select: list[ThreadSelectField] | None = None,
|
|
676
758
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
677
759
|
) -> tuple[AsyncIterator[Thread], int]:
|
|
678
760
|
threads = conn.store["threads"]
|
|
@@ -692,7 +774,19 @@ class Threads(Authenticated):
|
|
|
692
774
|
)
|
|
693
775
|
|
|
694
776
|
# Apply filters
|
|
777
|
+
id_set: set[UUID] | None = None
|
|
778
|
+
if ids:
|
|
779
|
+
id_set = set()
|
|
780
|
+
for i in ids:
|
|
781
|
+
try:
|
|
782
|
+
id_set.add(_ensure_uuid(i))
|
|
783
|
+
except Exception:
|
|
784
|
+
raise HTTPException(
|
|
785
|
+
status_code=400, detail="Invalid thread ID " + str(i)
|
|
786
|
+
) from None
|
|
695
787
|
for thread in threads:
|
|
788
|
+
if id_set is not None and thread.get("thread_id") not in id_set:
|
|
789
|
+
continue
|
|
696
790
|
if filters and not _check_filter_match(thread["metadata"], filters):
|
|
697
791
|
continue
|
|
698
792
|
|
|
@@ -710,8 +804,6 @@ class Threads(Authenticated):
|
|
|
710
804
|
continue
|
|
711
805
|
|
|
712
806
|
filtered_threads.append(thread)
|
|
713
|
-
# Get total count before pagination
|
|
714
|
-
total_count = len(filtered_threads)
|
|
715
807
|
|
|
716
808
|
if sort_by and sort_by in [
|
|
717
809
|
"thread_id",
|
|
@@ -724,20 +816,26 @@ class Threads(Authenticated):
|
|
|
724
816
|
filtered_threads, key=lambda x: x.get(sort_by), reverse=reverse
|
|
725
817
|
)
|
|
726
818
|
else:
|
|
819
|
+
sort_by = "created_at"
|
|
727
820
|
# Default sorting by created_at in descending order
|
|
728
821
|
sorted_threads = sorted(
|
|
729
|
-
filtered_threads, key=lambda x: x["
|
|
822
|
+
filtered_threads, key=lambda x: x["updated_at"], reverse=True
|
|
730
823
|
)
|
|
731
824
|
|
|
732
825
|
# Apply limit and offset
|
|
733
826
|
paginated_threads = sorted_threads[offset : offset + limit]
|
|
827
|
+
cursor = offset + limit if len(sorted_threads) > offset + limit else None
|
|
734
828
|
|
|
735
829
|
async def thread_iterator() -> AsyncIterator[Thread]:
|
|
736
830
|
for thread in paginated_threads:
|
|
737
|
-
|
|
831
|
+
if select:
|
|
832
|
+
# Filter to only selected fields
|
|
833
|
+
filtered_thread = {k: v for k, v in thread.items() if k in select}
|
|
834
|
+
yield filtered_thread
|
|
835
|
+
else:
|
|
836
|
+
yield thread
|
|
738
837
|
|
|
739
|
-
|
|
740
|
-
return thread_iterator(), total_count
|
|
838
|
+
return thread_iterator(), cursor
|
|
741
839
|
|
|
742
840
|
@staticmethod
|
|
743
841
|
async def _get_with_filters(
|
|
@@ -799,7 +897,7 @@ class Threads(Authenticated):
|
|
|
799
897
|
@staticmethod
|
|
800
898
|
async def put(
|
|
801
899
|
conn: InMemConnectionProto,
|
|
802
|
-
thread_id: UUID,
|
|
900
|
+
thread_id: UUID | str,
|
|
803
901
|
*,
|
|
804
902
|
metadata: MetadataInput,
|
|
805
903
|
if_exists: OnConflictBehavior,
|
|
@@ -866,6 +964,7 @@ class Threads(Authenticated):
|
|
|
866
964
|
thread_id: UUID,
|
|
867
965
|
*,
|
|
868
966
|
metadata: MetadataValue,
|
|
967
|
+
ttl: ThreadTTLConfig | None = None,
|
|
869
968
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
870
969
|
) -> AsyncIterator[Thread]:
|
|
871
970
|
"""Update a thread."""
|
|
@@ -978,6 +1077,7 @@ class Threads(Authenticated):
|
|
|
978
1077
|
thread_id: UUID,
|
|
979
1078
|
run_id: UUID,
|
|
980
1079
|
run_status: RunStatus | Literal["rollback"],
|
|
1080
|
+
graph_id: str,
|
|
981
1081
|
checkpoint: CheckpointPayload | None = None,
|
|
982
1082
|
exception: BaseException | None = None,
|
|
983
1083
|
) -> None:
|
|
@@ -1057,6 +1157,7 @@ class Threads(Authenticated):
|
|
|
1057
1157
|
final_thread_status = "busy"
|
|
1058
1158
|
else:
|
|
1059
1159
|
final_thread_status = base_thread_status
|
|
1160
|
+
thread["metadata"]["graph_id"] = graph_id
|
|
1060
1161
|
thread.update(
|
|
1061
1162
|
{
|
|
1062
1163
|
"updated_at": now,
|
|
@@ -1136,13 +1237,23 @@ class Threads(Authenticated):
|
|
|
1136
1237
|
"""Create a copy of an existing thread."""
|
|
1137
1238
|
thread_id = _ensure_uuid(thread_id)
|
|
1138
1239
|
new_thread_id = uuid4()
|
|
1139
|
-
|
|
1240
|
+
read_filters = await Threads.handle_event(
|
|
1140
1241
|
ctx,
|
|
1141
1242
|
"read",
|
|
1142
1243
|
Auth.types.ThreadsRead(
|
|
1244
|
+
thread_id=thread_id,
|
|
1245
|
+
),
|
|
1246
|
+
)
|
|
1247
|
+
# Assert that the user has permissions to create a new thread.
|
|
1248
|
+
# (We don't actually need the filters.)
|
|
1249
|
+
await Threads.handle_event(
|
|
1250
|
+
ctx,
|
|
1251
|
+
"create",
|
|
1252
|
+
Auth.types.ThreadsCreate(
|
|
1143
1253
|
thread_id=new_thread_id,
|
|
1144
1254
|
),
|
|
1145
1255
|
)
|
|
1256
|
+
|
|
1146
1257
|
async with conn.pipeline():
|
|
1147
1258
|
# Find the original thread in our store
|
|
1148
1259
|
original_thread = next(
|
|
@@ -1151,8 +1262,8 @@ class Threads(Authenticated):
|
|
|
1151
1262
|
|
|
1152
1263
|
if not original_thread:
|
|
1153
1264
|
return _empty_generator()
|
|
1154
|
-
if
|
|
1155
|
-
original_thread["metadata"],
|
|
1265
|
+
if read_filters and not _check_filter_match(
|
|
1266
|
+
original_thread["metadata"], read_filters
|
|
1156
1267
|
):
|
|
1157
1268
|
return _empty_generator()
|
|
1158
1269
|
|
|
@@ -1161,7 +1272,7 @@ class Threads(Authenticated):
|
|
|
1161
1272
|
"thread_id": new_thread_id,
|
|
1162
1273
|
"created_at": datetime.now(tz=UTC),
|
|
1163
1274
|
"updated_at": datetime.now(tz=UTC),
|
|
1164
|
-
"metadata": deepcopy(original_thread["metadata"]),
|
|
1275
|
+
"metadata": copy.deepcopy(original_thread["metadata"]),
|
|
1165
1276
|
"status": "idle",
|
|
1166
1277
|
"config": {},
|
|
1167
1278
|
}
|
|
@@ -1248,9 +1359,24 @@ class Threads(Authenticated):
|
|
|
1248
1359
|
)
|
|
1249
1360
|
|
|
1250
1361
|
metadata = thread.get("metadata", {})
|
|
1251
|
-
thread_config = thread.get("config", {})
|
|
1362
|
+
thread_config = cast(dict[str, Any], thread.get("config", {}))
|
|
1363
|
+
thread_config = {
|
|
1364
|
+
**thread_config,
|
|
1365
|
+
"configurable": {
|
|
1366
|
+
**thread_config.get("configurable", {}),
|
|
1367
|
+
**config.get("configurable", {}),
|
|
1368
|
+
},
|
|
1369
|
+
}
|
|
1252
1370
|
|
|
1253
|
-
|
|
1371
|
+
# Fallback to graph_id from run if not in thread metadata
|
|
1372
|
+
graph_id = metadata.get("graph_id")
|
|
1373
|
+
if not graph_id:
|
|
1374
|
+
for run in conn.store["runs"]:
|
|
1375
|
+
if run["thread_id"] == thread_id:
|
|
1376
|
+
graph_id = run["kwargs"]["config"]["configurable"]["graph_id"]
|
|
1377
|
+
break
|
|
1378
|
+
|
|
1379
|
+
if graph_id:
|
|
1254
1380
|
# format latest checkpoint for response
|
|
1255
1381
|
checkpointer.latest_iter = checkpoint
|
|
1256
1382
|
async with get_graph(
|
|
@@ -1290,6 +1416,7 @@ class Threads(Authenticated):
|
|
|
1290
1416
|
"""Add state to a thread."""
|
|
1291
1417
|
from langgraph_api.graph import get_graph
|
|
1292
1418
|
from langgraph_api.schema import ThreadUpdateResponse
|
|
1419
|
+
from langgraph_api.state import state_snapshot_to_thread_state
|
|
1293
1420
|
from langgraph_api.store import get_store
|
|
1294
1421
|
from langgraph_api.utils import fetchone
|
|
1295
1422
|
|
|
@@ -1327,8 +1454,23 @@ class Threads(Authenticated):
|
|
|
1327
1454
|
status_code=409,
|
|
1328
1455
|
detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
|
|
1329
1456
|
)
|
|
1457
|
+
thread_config = {
|
|
1458
|
+
**thread_config,
|
|
1459
|
+
"configurable": {
|
|
1460
|
+
**thread_config.get("configurable", {}),
|
|
1461
|
+
**config.get("configurable", {}),
|
|
1462
|
+
},
|
|
1463
|
+
}
|
|
1330
1464
|
|
|
1331
|
-
|
|
1465
|
+
# Fallback to graph_id from run if not in thread metadata
|
|
1466
|
+
graph_id = metadata.get("graph_id")
|
|
1467
|
+
if not graph_id:
|
|
1468
|
+
for run in conn.store["runs"]:
|
|
1469
|
+
if run["thread_id"] == thread_id:
|
|
1470
|
+
graph_id = run["kwargs"]["config"]["configurable"]["graph_id"]
|
|
1471
|
+
break
|
|
1472
|
+
|
|
1473
|
+
if graph_id:
|
|
1332
1474
|
config["configurable"].setdefault("graph_id", graph_id)
|
|
1333
1475
|
|
|
1334
1476
|
checkpointer.latest_iter = checkpoint
|
|
@@ -1359,6 +1501,19 @@ class Threads(Authenticated):
|
|
|
1359
1501
|
thread["values"] = state.values
|
|
1360
1502
|
break
|
|
1361
1503
|
|
|
1504
|
+
# Publish state update event
|
|
1505
|
+
from langgraph_api.serde import json_dumpb
|
|
1506
|
+
|
|
1507
|
+
event_data = {
|
|
1508
|
+
"state": state_snapshot_to_thread_state(state),
|
|
1509
|
+
"thread_id": str(thread_id),
|
|
1510
|
+
}
|
|
1511
|
+
await Threads.Stream.publish(
|
|
1512
|
+
thread_id,
|
|
1513
|
+
"state_update",
|
|
1514
|
+
json_dumpb(event_data),
|
|
1515
|
+
)
|
|
1516
|
+
|
|
1362
1517
|
return ThreadUpdateResponse(
|
|
1363
1518
|
checkpoint=next_config["configurable"],
|
|
1364
1519
|
# Including deprecated fields
|
|
@@ -1366,7 +1521,11 @@ class Threads(Authenticated):
|
|
|
1366
1521
|
checkpoint_id=next_config["configurable"]["checkpoint_id"],
|
|
1367
1522
|
)
|
|
1368
1523
|
else:
|
|
1369
|
-
raise HTTPException(
|
|
1524
|
+
raise HTTPException(
|
|
1525
|
+
status_code=400,
|
|
1526
|
+
detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
|
|
1527
|
+
" This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
|
|
1528
|
+
)
|
|
1370
1529
|
|
|
1371
1530
|
@staticmethod
|
|
1372
1531
|
async def bulk(
|
|
@@ -1397,7 +1556,14 @@ class Threads(Authenticated):
|
|
|
1397
1556
|
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
1398
1557
|
)
|
|
1399
1558
|
|
|
1400
|
-
thread_config = thread["config"]
|
|
1559
|
+
thread_config = cast(dict[str, Any], thread["config"])
|
|
1560
|
+
thread_config = {
|
|
1561
|
+
**thread_config,
|
|
1562
|
+
"configurable": {
|
|
1563
|
+
**thread_config.get("configurable", {}),
|
|
1564
|
+
**config.get("configurable", {}),
|
|
1565
|
+
},
|
|
1566
|
+
}
|
|
1401
1567
|
metadata = thread["metadata"]
|
|
1402
1568
|
|
|
1403
1569
|
if not thread:
|
|
@@ -1444,18 +1610,35 @@ class Threads(Authenticated):
|
|
|
1444
1610
|
thread["values"] = state.values
|
|
1445
1611
|
break
|
|
1446
1612
|
|
|
1613
|
+
# Publish state update event
|
|
1614
|
+
from langgraph_api.serde import json_dumpb
|
|
1615
|
+
|
|
1616
|
+
event_data = {
|
|
1617
|
+
"state": state,
|
|
1618
|
+
"thread_id": str(thread_id),
|
|
1619
|
+
}
|
|
1620
|
+
await Threads.Stream.publish(
|
|
1621
|
+
thread_id,
|
|
1622
|
+
"state_update",
|
|
1623
|
+
json_dumpb(event_data),
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1447
1626
|
return ThreadUpdateResponse(
|
|
1448
1627
|
checkpoint=next_config["configurable"],
|
|
1449
1628
|
)
|
|
1450
1629
|
else:
|
|
1451
|
-
raise HTTPException(
|
|
1630
|
+
raise HTTPException(
|
|
1631
|
+
status_code=400,
|
|
1632
|
+
detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
|
|
1633
|
+
" This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
|
|
1634
|
+
)
|
|
1452
1635
|
|
|
1453
1636
|
@staticmethod
|
|
1454
1637
|
async def list(
|
|
1455
1638
|
conn: InMemConnectionProto,
|
|
1456
1639
|
*,
|
|
1457
1640
|
config: Config,
|
|
1458
|
-
limit: int =
|
|
1641
|
+
limit: int = 1,
|
|
1459
1642
|
before: str | Checkpoint | None = None,
|
|
1460
1643
|
metadata: MetadataInput = None,
|
|
1461
1644
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
@@ -1481,7 +1664,14 @@ class Threads(Authenticated):
|
|
|
1481
1664
|
if not _check_filter_match(thread_metadata, filters):
|
|
1482
1665
|
return []
|
|
1483
1666
|
|
|
1484
|
-
thread_config = thread["config"]
|
|
1667
|
+
thread_config = cast(dict[str, Any], thread["config"])
|
|
1668
|
+
thread_config = {
|
|
1669
|
+
**thread_config,
|
|
1670
|
+
"configurable": {
|
|
1671
|
+
**thread_config.get("configurable", {}),
|
|
1672
|
+
**config.get("configurable", {}),
|
|
1673
|
+
},
|
|
1674
|
+
}
|
|
1485
1675
|
# If graph_id exists, get state history
|
|
1486
1676
|
if graph_id := thread_metadata.get("graph_id"):
|
|
1487
1677
|
async with get_graph(
|
|
@@ -1510,6 +1700,265 @@ class Threads(Authenticated):
|
|
|
1510
1700
|
|
|
1511
1701
|
return []
|
|
1512
1702
|
|
|
1703
|
+
class Stream(Authenticated):
|
|
1704
|
+
resource = "threads"
|
|
1705
|
+
|
|
1706
|
+
@staticmethod
|
|
1707
|
+
async def subscribe(
|
|
1708
|
+
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
1709
|
+
thread_id: UUID,
|
|
1710
|
+
seen_runs: set[UUID],
|
|
1711
|
+
) -> list[tuple[UUID, asyncio.Queue]]:
|
|
1712
|
+
"""Subscribe to the thread stream, creating queues for unseen runs."""
|
|
1713
|
+
stream_manager = get_stream_manager()
|
|
1714
|
+
queues = []
|
|
1715
|
+
|
|
1716
|
+
# Create new queues only for runs not yet seen
|
|
1717
|
+
thread_id = _ensure_uuid(thread_id)
|
|
1718
|
+
|
|
1719
|
+
# Add thread stream queue
|
|
1720
|
+
if thread_id not in seen_runs:
|
|
1721
|
+
queue = await stream_manager.add_thread_stream(thread_id)
|
|
1722
|
+
queues.append((thread_id, queue))
|
|
1723
|
+
seen_runs.add(thread_id)
|
|
1724
|
+
|
|
1725
|
+
for run in conn.store["runs"]:
|
|
1726
|
+
if run["thread_id"] == thread_id:
|
|
1727
|
+
run_id = run["run_id"]
|
|
1728
|
+
if run_id not in seen_runs:
|
|
1729
|
+
queue = await stream_manager.add_queue(run_id, thread_id)
|
|
1730
|
+
queues.append((run_id, queue))
|
|
1731
|
+
seen_runs.add(run_id)
|
|
1732
|
+
|
|
1733
|
+
return queues
|
|
1734
|
+
|
|
1735
|
+
@staticmethod
|
|
1736
|
+
async def join(
|
|
1737
|
+
thread_id: UUID,
|
|
1738
|
+
*,
|
|
1739
|
+
last_event_id: str | None = None,
|
|
1740
|
+
stream_modes: list[ThreadStreamMode],
|
|
1741
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1742
|
+
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
1743
|
+
"""Stream the thread output."""
|
|
1744
|
+
await Threads.Stream.check_thread_stream_auth(thread_id, ctx)
|
|
1745
|
+
|
|
1746
|
+
from langgraph_api.utils.stream_codec import (
|
|
1747
|
+
decode_stream_message,
|
|
1748
|
+
)
|
|
1749
|
+
|
|
1750
|
+
def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
|
|
1751
|
+
"""Check if an event should be filtered out based on stream_modes."""
|
|
1752
|
+
if "run_modes" in stream_modes and event_name != "state_update":
|
|
1753
|
+
return False
|
|
1754
|
+
if "state_update" in stream_modes and event_name == "state_update":
|
|
1755
|
+
return False
|
|
1756
|
+
if "lifecycle" in stream_modes and event_name == "metadata":
|
|
1757
|
+
try:
|
|
1758
|
+
message_data = orjson.loads(message_bytes)
|
|
1759
|
+
if message_data.get("status") == "run_done":
|
|
1760
|
+
return False
|
|
1761
|
+
if "attempt" in message_data and "run_id" in message_data:
|
|
1762
|
+
return False
|
|
1763
|
+
except (orjson.JSONDecodeError, TypeError):
|
|
1764
|
+
pass
|
|
1765
|
+
return True
|
|
1766
|
+
|
|
1767
|
+
stream_manager = get_stream_manager()
|
|
1768
|
+
seen_runs: set[UUID] = set()
|
|
1769
|
+
created_queues: list[tuple[UUID, asyncio.Queue]] = []
|
|
1770
|
+
|
|
1771
|
+
try:
|
|
1772
|
+
async with connect() as conn:
|
|
1773
|
+
await logger.ainfo(
|
|
1774
|
+
"Joined thread stream",
|
|
1775
|
+
thread_id=str(thread_id),
|
|
1776
|
+
)
|
|
1777
|
+
|
|
1778
|
+
# Restore messages if resuming from a specific event
|
|
1779
|
+
if last_event_id is not None:
|
|
1780
|
+
# Collect all events from all message stores for this thread
|
|
1781
|
+
all_events = []
|
|
1782
|
+
for run_id in stream_manager.message_stores.get(
|
|
1783
|
+
str(thread_id), []
|
|
1784
|
+
):
|
|
1785
|
+
for message in stream_manager.restore_messages(
|
|
1786
|
+
run_id, thread_id, last_event_id
|
|
1787
|
+
):
|
|
1788
|
+
all_events.append((message, run_id))
|
|
1789
|
+
|
|
1790
|
+
# Sort by message ID (which is ms-seq format)
|
|
1791
|
+
all_events.sort(key=lambda x: x[0].id.decode())
|
|
1792
|
+
|
|
1793
|
+
# Yield sorted events
|
|
1794
|
+
for message, run_id in all_events:
|
|
1795
|
+
decoded = decode_stream_message(
|
|
1796
|
+
message.data, channel=message.topic
|
|
1797
|
+
)
|
|
1798
|
+
event_bytes = decoded.event_bytes
|
|
1799
|
+
message_bytes = decoded.message_bytes
|
|
1800
|
+
|
|
1801
|
+
if event_bytes == b"control":
|
|
1802
|
+
if message_bytes == b"done":
|
|
1803
|
+
event_bytes = b"metadata"
|
|
1804
|
+
message_bytes = orjson.dumps(
|
|
1805
|
+
{"status": "run_done", "run_id": run_id}
|
|
1806
|
+
)
|
|
1807
|
+
if not should_filter_event(
|
|
1808
|
+
event_bytes.decode("utf-8"), message_bytes
|
|
1809
|
+
):
|
|
1810
|
+
yield (
|
|
1811
|
+
event_bytes,
|
|
1812
|
+
message_bytes,
|
|
1813
|
+
message.id,
|
|
1814
|
+
)
|
|
1815
|
+
|
|
1816
|
+
# Listen for live messages from all queues
|
|
1817
|
+
while True:
|
|
1818
|
+
# Refresh queues to pick up any new runs that joined this thread
|
|
1819
|
+
new_queue_tuples = await Threads.Stream.subscribe(
|
|
1820
|
+
conn, thread_id, seen_runs
|
|
1821
|
+
)
|
|
1822
|
+
# Track new queues for cleanup
|
|
1823
|
+
for run_id, queue in new_queue_tuples:
|
|
1824
|
+
created_queues.append((run_id, queue))
|
|
1825
|
+
|
|
1826
|
+
for run_id, queue in created_queues:
|
|
1827
|
+
try:
|
|
1828
|
+
message = await asyncio.wait_for(
|
|
1829
|
+
queue.get(), timeout=0.2
|
|
1830
|
+
)
|
|
1831
|
+
decoded = decode_stream_message(
|
|
1832
|
+
message.data, channel=message.topic
|
|
1833
|
+
)
|
|
1834
|
+
event = decoded.event_bytes
|
|
1835
|
+
event_name = event.decode("utf-8")
|
|
1836
|
+
payload = decoded.message_bytes
|
|
1837
|
+
|
|
1838
|
+
if event == b"control" and payload == b"done":
|
|
1839
|
+
topic = message.topic.decode()
|
|
1840
|
+
run_id = topic.split("run:")[1].split(":")[0]
|
|
1841
|
+
meta_event = b"metadata"
|
|
1842
|
+
meta_payload = orjson.dumps(
|
|
1843
|
+
{"status": "run_done", "run_id": run_id}
|
|
1844
|
+
)
|
|
1845
|
+
if not should_filter_event(
|
|
1846
|
+
"metadata", meta_payload
|
|
1847
|
+
):
|
|
1848
|
+
yield (meta_event, meta_payload, message.id)
|
|
1849
|
+
else:
|
|
1850
|
+
if not should_filter_event(event_name, payload):
|
|
1851
|
+
yield (event, payload, message.id)
|
|
1852
|
+
|
|
1853
|
+
except TimeoutError:
|
|
1854
|
+
continue
|
|
1855
|
+
except (ValueError, KeyError):
|
|
1856
|
+
continue
|
|
1857
|
+
|
|
1858
|
+
# Yield execution to other tasks to prevent event loop starvation
|
|
1859
|
+
await asyncio.sleep(0)
|
|
1860
|
+
|
|
1861
|
+
except WrappedHTTPException as e:
|
|
1862
|
+
raise e.http_exception from None
|
|
1863
|
+
except asyncio.CancelledError:
|
|
1864
|
+
await logger.awarning(
|
|
1865
|
+
"Thread stream client disconnected",
|
|
1866
|
+
thread_id=str(thread_id),
|
|
1867
|
+
)
|
|
1868
|
+
raise
|
|
1869
|
+
except:
|
|
1870
|
+
raise
|
|
1871
|
+
finally:
|
|
1872
|
+
# Clean up all created queues
|
|
1873
|
+
for run_id, queue in created_queues:
|
|
1874
|
+
try:
|
|
1875
|
+
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
1876
|
+
except Exception:
|
|
1877
|
+
# Ignore cleanup errors
|
|
1878
|
+
pass
|
|
1879
|
+
|
|
1880
|
+
@staticmethod
|
|
1881
|
+
async def publish(
|
|
1882
|
+
thread_id: UUID | str,
|
|
1883
|
+
event: str,
|
|
1884
|
+
message: bytes,
|
|
1885
|
+
) -> None:
|
|
1886
|
+
"""Publish a thread-level event to the thread stream."""
|
|
1887
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
1888
|
+
|
|
1889
|
+
topic = f"thread:{thread_id}:stream".encode()
|
|
1890
|
+
|
|
1891
|
+
stream_manager = get_stream_manager()
|
|
1892
|
+
payload = STREAM_CODEC.encode(event, message)
|
|
1893
|
+
await stream_manager.put_thread(
|
|
1894
|
+
str(thread_id), Message(topic=topic, data=payload)
|
|
1895
|
+
)
|
|
1896
|
+
|
|
1897
|
+
@staticmethod
|
|
1898
|
+
async def check_thread_stream_auth(
|
|
1899
|
+
thread_id: UUID,
|
|
1900
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1901
|
+
) -> None:
|
|
1902
|
+
async with connect() as conn:
|
|
1903
|
+
filters = await Threads.Stream.handle_event(
|
|
1904
|
+
ctx,
|
|
1905
|
+
"read",
|
|
1906
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1907
|
+
)
|
|
1908
|
+
if filters:
|
|
1909
|
+
thread = await Threads._get_with_filters(
|
|
1910
|
+
cast(InMemConnectionProto, conn), thread_id, filters
|
|
1911
|
+
)
|
|
1912
|
+
if not thread:
|
|
1913
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
1914
|
+
|
|
1915
|
+
@staticmethod
|
|
1916
|
+
async def count(
|
|
1917
|
+
conn: InMemConnectionProto,
|
|
1918
|
+
*,
|
|
1919
|
+
metadata: MetadataInput = None,
|
|
1920
|
+
values: MetadataInput = None,
|
|
1921
|
+
status: ThreadStatus | None = None,
|
|
1922
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1923
|
+
) -> int:
|
|
1924
|
+
"""Get count of threads."""
|
|
1925
|
+
threads = conn.store["threads"]
|
|
1926
|
+
metadata = metadata if metadata is not None else {}
|
|
1927
|
+
values = values if values is not None else {}
|
|
1928
|
+
filters = await Threads.handle_event(
|
|
1929
|
+
ctx,
|
|
1930
|
+
"search",
|
|
1931
|
+
Auth.types.ThreadsSearch(
|
|
1932
|
+
metadata=metadata,
|
|
1933
|
+
values=values,
|
|
1934
|
+
status=status,
|
|
1935
|
+
limit=0,
|
|
1936
|
+
offset=0,
|
|
1937
|
+
),
|
|
1938
|
+
)
|
|
1939
|
+
|
|
1940
|
+
count = 0
|
|
1941
|
+
for thread in threads:
|
|
1942
|
+
if filters and not _check_filter_match(thread["metadata"], filters):
|
|
1943
|
+
continue
|
|
1944
|
+
|
|
1945
|
+
if metadata and not is_jsonb_contained(thread["metadata"], metadata):
|
|
1946
|
+
continue
|
|
1947
|
+
|
|
1948
|
+
if (
|
|
1949
|
+
values
|
|
1950
|
+
and "values" in thread
|
|
1951
|
+
and not is_jsonb_contained(thread["values"], values)
|
|
1952
|
+
):
|
|
1953
|
+
continue
|
|
1954
|
+
|
|
1955
|
+
if status and thread.get("status") != status:
|
|
1956
|
+
continue
|
|
1957
|
+
|
|
1958
|
+
count += 1
|
|
1959
|
+
|
|
1960
|
+
return count
|
|
1961
|
+
|
|
1513
1962
|
|
|
1514
1963
|
RUN_LOCK = asyncio.Lock()
|
|
1515
1964
|
|
|
@@ -1526,38 +1975,37 @@ class Runs(Authenticated):
|
|
|
1526
1975
|
if not pending_runs and not running_runs:
|
|
1527
1976
|
return {
|
|
1528
1977
|
"n_pending": 0,
|
|
1529
|
-
"
|
|
1530
|
-
"
|
|
1978
|
+
"pending_runs_wait_time_max_secs": None,
|
|
1979
|
+
"pending_runs_wait_time_med_secs": None,
|
|
1531
1980
|
"n_running": 0,
|
|
1532
1981
|
}
|
|
1533
1982
|
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
median_time = sorted_times[median_idx]
|
|
1983
|
+
now = datetime.now(UTC)
|
|
1984
|
+
pending_waits: list[float] = []
|
|
1985
|
+
for run in pending_runs:
|
|
1986
|
+
created_at = run.get("created_at")
|
|
1987
|
+
if not isinstance(created_at, datetime):
|
|
1988
|
+
continue
|
|
1989
|
+
if created_at.tzinfo is None:
|
|
1990
|
+
created_at = created_at.replace(tzinfo=UTC)
|
|
1991
|
+
pending_waits.append((now - created_at).total_seconds())
|
|
1992
|
+
|
|
1993
|
+
max_pending_wait = max(pending_waits) if pending_waits else None
|
|
1994
|
+
if pending_waits:
|
|
1995
|
+
sorted_waits = sorted(pending_waits)
|
|
1996
|
+
half = len(sorted_waits) // 2
|
|
1997
|
+
if len(sorted_waits) % 2 == 1:
|
|
1998
|
+
med_pending_wait = sorted_waits[half]
|
|
1999
|
+
else:
|
|
2000
|
+
med_pending_wait = (sorted_waits[half - 1] + sorted_waits[half]) / 2
|
|
2001
|
+
else:
|
|
2002
|
+
med_pending_wait = None
|
|
1555
2003
|
|
|
1556
2004
|
return {
|
|
1557
2005
|
"n_pending": len(pending_runs),
|
|
1558
2006
|
"n_running": len(running_runs),
|
|
1559
|
-
"
|
|
1560
|
-
"
|
|
2007
|
+
"pending_runs_wait_time_max_secs": max_pending_wait,
|
|
2008
|
+
"pending_runs_wait_time_med_secs": med_pending_wait,
|
|
1561
2009
|
}
|
|
1562
2010
|
|
|
1563
2011
|
@staticmethod
|
|
@@ -1621,38 +2069,51 @@ class Runs(Authenticated):
|
|
|
1621
2069
|
@asynccontextmanager
|
|
1622
2070
|
@staticmethod
|
|
1623
2071
|
async def enter(
|
|
1624
|
-
run_id: UUID,
|
|
2072
|
+
run_id: UUID,
|
|
2073
|
+
thread_id: UUID | None,
|
|
2074
|
+
loop: asyncio.AbstractEventLoop,
|
|
2075
|
+
resumable: bool,
|
|
1625
2076
|
) -> AsyncIterator[ValueEvent]:
|
|
1626
2077
|
"""Enter a run, listen for cancellation while running, signal when done."
|
|
1627
2078
|
This method should be called as a context manager by a worker executing a run.
|
|
1628
2079
|
"""
|
|
1629
2080
|
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
|
|
2081
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
1630
2082
|
|
|
1631
2083
|
stream_manager = get_stream_manager()
|
|
1632
|
-
# Get queue for this run
|
|
1633
|
-
|
|
2084
|
+
# Get control queue for this run (normal queue is created during run creation)
|
|
2085
|
+
control_queue = await stream_manager.add_control_queue(run_id, thread_id)
|
|
1634
2086
|
|
|
1635
2087
|
async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
|
|
1636
2088
|
done = ValueEvent()
|
|
1637
|
-
tg.create_task(
|
|
2089
|
+
tg.create_task(
|
|
2090
|
+
listen_for_cancellation(control_queue, run_id, thread_id, done)
|
|
2091
|
+
)
|
|
1638
2092
|
|
|
1639
2093
|
# Give done event to caller
|
|
1640
2094
|
yield done
|
|
1641
|
-
#
|
|
2095
|
+
# Store the control message for late subscribers
|
|
1642
2096
|
control_message = Message(
|
|
1643
2097
|
topic=f"run:{run_id}:control".encode(), data=b"done"
|
|
1644
2098
|
)
|
|
2099
|
+
await stream_manager.put(run_id, thread_id, control_message)
|
|
1645
2100
|
|
|
1646
|
-
#
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
2101
|
+
# Signal done to all subscribers using stream codec
|
|
2102
|
+
stream_message = Message(
|
|
2103
|
+
topic=f"run:{run_id}:stream".encode(),
|
|
2104
|
+
data=STREAM_CODEC.encode("control", b"done"),
|
|
2105
|
+
)
|
|
2106
|
+
await stream_manager.put(
|
|
2107
|
+
run_id, thread_id, stream_message, resumable=resumable
|
|
2108
|
+
)
|
|
2109
|
+
|
|
2110
|
+
# Remove the control_queue (normal queue is cleaned up during run deletion)
|
|
2111
|
+
await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
|
|
1651
2112
|
|
|
1652
2113
|
@staticmethod
|
|
1653
|
-
async def sweep(
|
|
2114
|
+
async def sweep() -> None:
|
|
1654
2115
|
"""Sweep runs that are no longer running"""
|
|
1655
|
-
|
|
2116
|
+
pass
|
|
1656
2117
|
|
|
1657
2118
|
@staticmethod
|
|
1658
2119
|
def _merge_jsonb(*objects: dict) -> dict:
|
|
@@ -1670,7 +2131,7 @@ class Runs(Authenticated):
|
|
|
1670
2131
|
|
|
1671
2132
|
@staticmethod
|
|
1672
2133
|
async def put(
|
|
1673
|
-
conn: InMemConnectionProto,
|
|
2134
|
+
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
1674
2135
|
assistant_id: UUID,
|
|
1675
2136
|
kwargs: dict,
|
|
1676
2137
|
*,
|
|
@@ -1701,6 +2162,7 @@ class Runs(Authenticated):
|
|
|
1701
2162
|
run_id = _ensure_uuid(run_id) if run_id else None
|
|
1702
2163
|
metadata = metadata if metadata is not None else {}
|
|
1703
2164
|
config = kwargs.get("config", {})
|
|
2165
|
+
temporary = kwargs.get("temporary", False)
|
|
1704
2166
|
|
|
1705
2167
|
# Handle thread creation/update
|
|
1706
2168
|
existing_thread = next(
|
|
@@ -1710,7 +2172,7 @@ class Runs(Authenticated):
|
|
|
1710
2172
|
ctx,
|
|
1711
2173
|
"create_run",
|
|
1712
2174
|
Auth.types.RunsCreate(
|
|
1713
|
-
thread_id=thread_id,
|
|
2175
|
+
thread_id=None if temporary else thread_id,
|
|
1714
2176
|
assistant_id=assistant_id,
|
|
1715
2177
|
run_id=run_id,
|
|
1716
2178
|
status=status,
|
|
@@ -1731,6 +2193,7 @@ class Runs(Authenticated):
|
|
|
1731
2193
|
# Create new thread
|
|
1732
2194
|
if thread_id is None:
|
|
1733
2195
|
thread_id = uuid4()
|
|
2196
|
+
|
|
1734
2197
|
thread = Thread(
|
|
1735
2198
|
thread_id=thread_id,
|
|
1736
2199
|
status="busy",
|
|
@@ -1746,7 +2209,6 @@ class Runs(Authenticated):
|
|
|
1746
2209
|
{
|
|
1747
2210
|
"configurable": Runs._merge_jsonb(
|
|
1748
2211
|
Runs._get_configurable(assistant["config"]),
|
|
1749
|
-
Runs._get_configurable(config),
|
|
1750
2212
|
)
|
|
1751
2213
|
},
|
|
1752
2214
|
),
|
|
@@ -1754,6 +2216,7 @@ class Runs(Authenticated):
|
|
|
1754
2216
|
updated_at=datetime.now(UTC),
|
|
1755
2217
|
values=b"",
|
|
1756
2218
|
)
|
|
2219
|
+
|
|
1757
2220
|
await logger.ainfo("Creating thread", thread_id=thread_id)
|
|
1758
2221
|
conn.store["threads"].append(thread)
|
|
1759
2222
|
elif existing_thread:
|
|
@@ -1775,7 +2238,6 @@ class Runs(Authenticated):
|
|
|
1775
2238
|
"configurable": Runs._merge_jsonb(
|
|
1776
2239
|
Runs._get_configurable(assistant["config"]),
|
|
1777
2240
|
Runs._get_configurable(existing_thread["config"]),
|
|
1778
|
-
Runs._get_configurable(config),
|
|
1779
2241
|
)
|
|
1780
2242
|
},
|
|
1781
2243
|
)
|
|
@@ -1920,6 +2382,7 @@ class Runs(Authenticated):
|
|
|
1920
2382
|
if not thread:
|
|
1921
2383
|
return _empty_generator()
|
|
1922
2384
|
_delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
|
|
2385
|
+
|
|
1923
2386
|
found = False
|
|
1924
2387
|
for i, run in enumerate(conn.store["runs"]):
|
|
1925
2388
|
if run["run_id"] == run_id and run["thread_id"] == thread_id:
|
|
@@ -1935,54 +2398,10 @@ class Runs(Authenticated):
|
|
|
1935
2398
|
|
|
1936
2399
|
return _yield_deleted()
|
|
1937
2400
|
|
|
1938
|
-
@staticmethod
|
|
1939
|
-
async def join(
|
|
1940
|
-
run_id: UUID,
|
|
1941
|
-
*,
|
|
1942
|
-
thread_id: UUID,
|
|
1943
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1944
|
-
) -> Fragment:
|
|
1945
|
-
"""Wait for a run to complete. If already done, return immediately.
|
|
1946
|
-
|
|
1947
|
-
Returns:
|
|
1948
|
-
the final state of the run.
|
|
1949
|
-
"""
|
|
1950
|
-
from langgraph_api.serde import Fragment
|
|
1951
|
-
from langgraph_api.utils import fetchone
|
|
1952
|
-
|
|
1953
|
-
async with connect() as conn:
|
|
1954
|
-
# Validate ownership
|
|
1955
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1956
|
-
await fetchone(thread_iter)
|
|
1957
|
-
last_chunk: bytes | None = None
|
|
1958
|
-
# wait for the run to complete
|
|
1959
|
-
# Rely on this join's auth
|
|
1960
|
-
async for mode, chunk, _ in Runs.Stream.join(
|
|
1961
|
-
run_id, thread_id=thread_id, ctx=ctx, ignore_404=True
|
|
1962
|
-
):
|
|
1963
|
-
if mode == b"values":
|
|
1964
|
-
last_chunk = chunk
|
|
1965
|
-
elif mode == b"error":
|
|
1966
|
-
last_chunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
|
|
1967
|
-
# if we received a final chunk, return it
|
|
1968
|
-
if last_chunk is not None:
|
|
1969
|
-
# ie. if the run completed while we were waiting for it
|
|
1970
|
-
return Fragment(last_chunk)
|
|
1971
|
-
else:
|
|
1972
|
-
# otherwise, the run had already finished, so fetch the state from thread
|
|
1973
|
-
async with connect() as conn:
|
|
1974
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
1975
|
-
thread = await fetchone(thread_iter)
|
|
1976
|
-
if thread["status"] == "error":
|
|
1977
|
-
return Fragment(
|
|
1978
|
-
orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
|
|
1979
|
-
)
|
|
1980
|
-
return thread["values"]
|
|
1981
|
-
|
|
1982
2401
|
@staticmethod
|
|
1983
2402
|
async def cancel(
|
|
1984
|
-
conn: InMemConnectionProto,
|
|
1985
|
-
run_ids: Sequence[UUID] | None = None,
|
|
2403
|
+
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
2404
|
+
run_ids: Sequence[UUID | str] | None = None,
|
|
1986
2405
|
*,
|
|
1987
2406
|
action: Literal["interrupt", "rollback"] = "interrupt",
|
|
1988
2407
|
thread_id: UUID | None = None,
|
|
@@ -2086,9 +2505,9 @@ class Runs(Authenticated):
|
|
|
2086
2505
|
topic=f"run:{run_id}:control".encode(),
|
|
2087
2506
|
data=action.encode(),
|
|
2088
2507
|
)
|
|
2089
|
-
coros.append(stream_manager.put(run_id, control_message))
|
|
2508
|
+
coros.append(stream_manager.put(run_id, thread_id, control_message))
|
|
2090
2509
|
|
|
2091
|
-
queues = stream_manager.get_queues(run_id)
|
|
2510
|
+
queues = stream_manager.get_queues(run_id, thread_id)
|
|
2092
2511
|
|
|
2093
2512
|
if run["status"] in ("pending", "running"):
|
|
2094
2513
|
cancelable_runs.append(run)
|
|
@@ -2146,6 +2565,7 @@ class Runs(Authenticated):
|
|
|
2146
2565
|
limit: int = 10,
|
|
2147
2566
|
offset: int = 0,
|
|
2148
2567
|
status: RunStatus | None = None,
|
|
2568
|
+
select: list[RunSelectField] | None = None,
|
|
2149
2569
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2150
2570
|
) -> AsyncIterator[Run]:
|
|
2151
2571
|
"""List all runs by thread."""
|
|
@@ -2173,7 +2593,12 @@ class Runs(Authenticated):
|
|
|
2173
2593
|
|
|
2174
2594
|
async def _return():
|
|
2175
2595
|
for run in sliced_runs:
|
|
2176
|
-
|
|
2596
|
+
if select:
|
|
2597
|
+
# Filter to only selected fields
|
|
2598
|
+
filtered_run = {k: v for k, v in run.items() if k in select}
|
|
2599
|
+
yield filtered_run
|
|
2600
|
+
else:
|
|
2601
|
+
yield run
|
|
2177
2602
|
|
|
2178
2603
|
return _return()
|
|
2179
2604
|
|
|
@@ -2197,73 +2622,81 @@ class Runs(Authenticated):
|
|
|
2197
2622
|
@staticmethod
|
|
2198
2623
|
async def subscribe(
|
|
2199
2624
|
run_id: UUID,
|
|
2200
|
-
|
|
2201
|
-
|
|
2202
|
-
) -> asyncio.Queue:
|
|
2625
|
+
thread_id: UUID | None = None,
|
|
2626
|
+
) -> ContextQueue:
|
|
2203
2627
|
"""Subscribe to the run stream, returning a queue."""
|
|
2204
2628
|
stream_manager = get_stream_manager()
|
|
2205
|
-
queue = await stream_manager.add_queue(_ensure_uuid(run_id))
|
|
2629
|
+
queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
|
|
2206
2630
|
|
|
2207
2631
|
# If there's a control message already stored, send it to the new subscriber
|
|
2208
|
-
if
|
|
2209
|
-
|
|
2210
|
-
|
|
2632
|
+
if thread_id is None:
|
|
2633
|
+
thread_id = THREADLESS_KEY
|
|
2634
|
+
if control_queues := stream_manager.control_queues.get(thread_id, {}).get(
|
|
2635
|
+
run_id
|
|
2636
|
+
):
|
|
2637
|
+
for control_queue in control_queues:
|
|
2638
|
+
try:
|
|
2639
|
+
while True:
|
|
2640
|
+
control_msg = control_queue.get()
|
|
2641
|
+
await queue.put(control_msg)
|
|
2642
|
+
except asyncio.QueueEmpty:
|
|
2643
|
+
pass
|
|
2211
2644
|
return queue
|
|
2212
2645
|
|
|
2213
2646
|
@staticmethod
|
|
2214
2647
|
async def join(
|
|
2215
2648
|
run_id: UUID,
|
|
2216
2649
|
*,
|
|
2650
|
+
stream_channel: asyncio.Queue,
|
|
2217
2651
|
thread_id: UUID,
|
|
2218
2652
|
ignore_404: bool = False,
|
|
2219
2653
|
cancel_on_disconnect: bool = False,
|
|
2220
|
-
stream_mode: StreamMode |
|
|
2654
|
+
stream_mode: list[StreamMode] | StreamMode | None = None,
|
|
2221
2655
|
last_event_id: str | None = None,
|
|
2222
2656
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2223
2657
|
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
2224
2658
|
"""Stream the run output."""
|
|
2225
2659
|
from langgraph_api.asyncio import create_task
|
|
2660
|
+
from langgraph_api.serde import json_dumpb
|
|
2661
|
+
from langgraph_api.utils.stream_codec import decode_stream_message
|
|
2226
2662
|
|
|
2227
|
-
queue =
|
|
2228
|
-
stream_mode
|
|
2229
|
-
if isinstance(stream_mode, asyncio.Queue)
|
|
2230
|
-
else await Runs.Stream.subscribe(run_id, stream_mode=stream_mode)
|
|
2231
|
-
)
|
|
2232
|
-
|
|
2663
|
+
queue = stream_channel
|
|
2233
2664
|
try:
|
|
2234
2665
|
async with connect() as conn:
|
|
2235
|
-
|
|
2236
|
-
ctx
|
|
2237
|
-
|
|
2238
|
-
|
|
2239
|
-
)
|
|
2240
|
-
if filters:
|
|
2241
|
-
thread = await Threads._get_with_filters(
|
|
2242
|
-
cast(InMemConnectionProto, conn), thread_id, filters
|
|
2243
|
-
)
|
|
2244
|
-
if not thread:
|
|
2245
|
-
raise WrappedHTTPException(
|
|
2246
|
-
HTTPException(
|
|
2247
|
-
status_code=404, detail="Thread not found"
|
|
2248
|
-
)
|
|
2249
|
-
)
|
|
2250
|
-
channel_prefix = f"run:{run_id}:stream:"
|
|
2251
|
-
len_prefix = len(channel_prefix.encode())
|
|
2666
|
+
try:
|
|
2667
|
+
await Runs.Stream.check_run_stream_auth(run_id, thread_id, ctx)
|
|
2668
|
+
except HTTPException as e:
|
|
2669
|
+
raise WrappedHTTPException(e) from None
|
|
2670
|
+
run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
|
|
2252
2671
|
|
|
2253
2672
|
for message in get_stream_manager().restore_messages(
|
|
2254
|
-
run_id, last_event_id
|
|
2673
|
+
run_id, thread_id, last_event_id
|
|
2255
2674
|
):
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2675
|
+
data, id = message.data, message.id
|
|
2676
|
+
decoded = decode_stream_message(data, channel=message.topic)
|
|
2677
|
+
mode = decoded.event_bytes.decode("utf-8")
|
|
2678
|
+
payload = decoded.message_bytes
|
|
2679
|
+
|
|
2680
|
+
if mode == "control":
|
|
2681
|
+
if payload == b"done":
|
|
2259
2682
|
return
|
|
2260
|
-
|
|
2261
|
-
|
|
2683
|
+
elif (
|
|
2684
|
+
not stream_mode
|
|
2685
|
+
or mode in stream_mode
|
|
2686
|
+
or (
|
|
2687
|
+
(
|
|
2688
|
+
"messages" in stream_mode
|
|
2689
|
+
or "messages-tuple" in stream_mode
|
|
2690
|
+
)
|
|
2691
|
+
and mode.startswith("messages")
|
|
2692
|
+
)
|
|
2693
|
+
):
|
|
2694
|
+
yield mode.encode(), payload, id
|
|
2262
2695
|
logger.debug(
|
|
2263
2696
|
"Replayed run event",
|
|
2264
2697
|
run_id=str(run_id),
|
|
2265
2698
|
message_id=id,
|
|
2266
|
-
stream_mode=
|
|
2699
|
+
stream_mode=mode,
|
|
2267
2700
|
data=data,
|
|
2268
2701
|
)
|
|
2269
2702
|
|
|
@@ -2271,20 +2704,32 @@ class Runs(Authenticated):
|
|
|
2271
2704
|
try:
|
|
2272
2705
|
# Wait for messages with a timeout
|
|
2273
2706
|
message = await asyncio.wait_for(queue.get(), timeout=0.5)
|
|
2274
|
-
|
|
2707
|
+
data, id = message.data, message.id
|
|
2708
|
+
decoded = decode_stream_message(data, channel=message.topic)
|
|
2709
|
+
mode = decoded.event_bytes.decode("utf-8")
|
|
2710
|
+
payload = decoded.message_bytes
|
|
2275
2711
|
|
|
2276
|
-
if
|
|
2277
|
-
if
|
|
2712
|
+
if mode == "control":
|
|
2713
|
+
if payload == b"done":
|
|
2278
2714
|
break
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2715
|
+
elif (
|
|
2716
|
+
not stream_mode
|
|
2717
|
+
or mode in stream_mode
|
|
2718
|
+
or (
|
|
2719
|
+
(
|
|
2720
|
+
"messages" in stream_mode
|
|
2721
|
+
or "messages-tuple" in stream_mode
|
|
2722
|
+
)
|
|
2723
|
+
and mode.startswith("messages")
|
|
2724
|
+
)
|
|
2725
|
+
):
|
|
2726
|
+
yield mode.encode(), payload, id
|
|
2282
2727
|
logger.debug(
|
|
2283
2728
|
"Streamed run event",
|
|
2284
2729
|
run_id=str(run_id),
|
|
2285
|
-
stream_mode=
|
|
2730
|
+
stream_mode=mode,
|
|
2286
2731
|
message_id=id,
|
|
2287
|
-
data=
|
|
2732
|
+
data=payload,
|
|
2288
2733
|
)
|
|
2289
2734
|
except TimeoutError:
|
|
2290
2735
|
# Check if the run is still pending
|
|
@@ -2298,8 +2743,10 @@ class Runs(Authenticated):
|
|
|
2298
2743
|
elif run is None:
|
|
2299
2744
|
yield (
|
|
2300
2745
|
b"error",
|
|
2301
|
-
|
|
2302
|
-
|
|
2746
|
+
json_dumpb(
|
|
2747
|
+
HTTPException(
|
|
2748
|
+
status_code=404, detail="Run not found"
|
|
2749
|
+
)
|
|
2303
2750
|
),
|
|
2304
2751
|
None,
|
|
2305
2752
|
)
|
|
@@ -2314,45 +2761,68 @@ class Runs(Authenticated):
|
|
|
2314
2761
|
raise
|
|
2315
2762
|
finally:
|
|
2316
2763
|
stream_manager = get_stream_manager()
|
|
2317
|
-
await stream_manager.remove_queue(run_id, queue)
|
|
2764
|
+
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
2318
2765
|
|
|
2319
2766
|
@staticmethod
|
|
2320
|
-
async def
|
|
2767
|
+
async def check_run_stream_auth(
|
|
2321
2768
|
run_id: UUID,
|
|
2769
|
+
thread_id: UUID,
|
|
2770
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2771
|
+
) -> None:
|
|
2772
|
+
async with connect() as conn:
|
|
2773
|
+
filters = await Runs.handle_event(
|
|
2774
|
+
ctx,
|
|
2775
|
+
"read",
|
|
2776
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
2777
|
+
)
|
|
2778
|
+
if filters:
|
|
2779
|
+
thread = await Threads._get_with_filters(
|
|
2780
|
+
cast(InMemConnectionProto, conn), thread_id, filters
|
|
2781
|
+
)
|
|
2782
|
+
if not thread:
|
|
2783
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
2784
|
+
|
|
2785
|
+
@staticmethod
|
|
2786
|
+
async def publish(
|
|
2787
|
+
run_id: UUID | str,
|
|
2322
2788
|
event: str,
|
|
2323
2789
|
message: bytes,
|
|
2324
2790
|
*,
|
|
2791
|
+
thread_id: UUID | str | None = None,
|
|
2325
2792
|
resumable: bool = False,
|
|
2326
2793
|
) -> None:
|
|
2327
2794
|
"""Publish a message to all subscribers of the run stream."""
|
|
2328
|
-
|
|
2795
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
2796
|
+
|
|
2797
|
+
topic = f"run:{run_id}:stream".encode()
|
|
2329
2798
|
|
|
2330
2799
|
stream_manager = get_stream_manager()
|
|
2331
|
-
# Send to all queues subscribed to this run_id
|
|
2800
|
+
# Send to all queues subscribed to this run_id using protocol frame
|
|
2801
|
+
payload = STREAM_CODEC.encode(event, message)
|
|
2332
2802
|
await stream_manager.put(
|
|
2333
|
-
run_id, Message(topic=topic, data=
|
|
2803
|
+
run_id, thread_id, Message(topic=topic, data=payload), resumable
|
|
2334
2804
|
)
|
|
2335
2805
|
|
|
2336
2806
|
|
|
2337
|
-
async def listen_for_cancellation(
|
|
2807
|
+
async def listen_for_cancellation(
|
|
2808
|
+
queue: asyncio.Queue, run_id: UUID, thread_id: UUID | None, done: ValueEvent
|
|
2809
|
+
):
|
|
2338
2810
|
"""Listen for cancellation messages and set the done event accordingly."""
|
|
2339
2811
|
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
2340
2812
|
|
|
2341
2813
|
stream_manager = get_stream_manager()
|
|
2342
|
-
control_key = f"run:{run_id}:control"
|
|
2343
2814
|
|
|
2344
|
-
if
|
|
2345
|
-
|
|
2346
|
-
|
|
2347
|
-
|
|
2348
|
-
|
|
2349
|
-
|
|
2350
|
-
done.set(UserInterrupt())
|
|
2815
|
+
if control_key := stream_manager.get_control_key(run_id, thread_id):
|
|
2816
|
+
payload = control_key.data
|
|
2817
|
+
if payload == b"rollback":
|
|
2818
|
+
done.set(UserRollback())
|
|
2819
|
+
elif payload == b"interrupt":
|
|
2820
|
+
done.set(UserInterrupt())
|
|
2351
2821
|
|
|
2352
2822
|
while not done.is_set():
|
|
2353
2823
|
try:
|
|
2354
2824
|
# This task gets cancelled when Runs.enter exits anyway,
|
|
2355
|
-
# so we can have a pretty
|
|
2825
|
+
# so we can have a pretty lengthy timeout here
|
|
2356
2826
|
message = await asyncio.wait_for(queue.get(), timeout=240)
|
|
2357
2827
|
payload = message.data
|
|
2358
2828
|
if payload == b"rollback":
|
|
@@ -2362,10 +2832,6 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
|
|
|
2362
2832
|
elif payload == b"done":
|
|
2363
2833
|
done.set()
|
|
2364
2834
|
break
|
|
2365
|
-
|
|
2366
|
-
# Store control messages for late subscribers
|
|
2367
|
-
if message.topic.decode() == control_key:
|
|
2368
|
-
stream_manager.control_queues[run_id].append(message)
|
|
2369
2835
|
except TimeoutError:
|
|
2370
2836
|
break
|
|
2371
2837
|
|
|
@@ -2379,6 +2845,7 @@ class Crons:
|
|
|
2379
2845
|
schedule: str,
|
|
2380
2846
|
cron_id: UUID | None = None,
|
|
2381
2847
|
thread_id: UUID | None = None,
|
|
2848
|
+
on_run_completed: Literal["delete", "keep"] | None = None,
|
|
2382
2849
|
end_time: datetime | None = None,
|
|
2383
2850
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2384
2851
|
) -> AsyncIterator[Cron]:
|
|
@@ -2417,10 +2884,24 @@ class Crons:
|
|
|
2417
2884
|
thread_id: UUID | None,
|
|
2418
2885
|
limit: int,
|
|
2419
2886
|
offset: int,
|
|
2887
|
+
select: list[CronSelectField] | None = None,
|
|
2420
2888
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2421
|
-
|
|
2889
|
+
sort_by: str | None = None,
|
|
2890
|
+
sort_order: Literal["asc", "desc"] | None = None,
|
|
2891
|
+
) -> tuple[AsyncIterator[Cron], int]:
|
|
2422
2892
|
raise NotImplementedError
|
|
2423
2893
|
|
|
2894
|
+
@staticmethod
|
|
2895
|
+
async def count(
|
|
2896
|
+
conn: InMemConnectionProto,
|
|
2897
|
+
*,
|
|
2898
|
+
assistant_id: UUID | None = None,
|
|
2899
|
+
thread_id: UUID | None = None,
|
|
2900
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2901
|
+
) -> int:
|
|
2902
|
+
"""Get count of crons."""
|
|
2903
|
+
raise NotImplementedError("The in-mem server does not implement Crons.")
|
|
2904
|
+
|
|
2424
2905
|
|
|
2425
2906
|
async def cancel_run(
|
|
2426
2907
|
thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
|
|
@@ -2478,11 +2959,18 @@ def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -
|
|
|
2478
2959
|
if key not in metadata or metadata[key] != filter_value:
|
|
2479
2960
|
return False
|
|
2480
2961
|
elif op == "$contains":
|
|
2481
|
-
if (
|
|
2482
|
-
|
|
2483
|
-
|
|
2484
|
-
|
|
2485
|
-
|
|
2962
|
+
if key not in metadata or not isinstance(metadata[key], list):
|
|
2963
|
+
return False
|
|
2964
|
+
|
|
2965
|
+
if isinstance(filter_value, list):
|
|
2966
|
+
# Mimick Postgres containment operator behavior.
|
|
2967
|
+
# It would be more efficient to use set operations here,
|
|
2968
|
+
# but we can't assume that elements are hashable.
|
|
2969
|
+
# The Postgres algorithm is also O(n^2).
|
|
2970
|
+
for filter_element in filter_value:
|
|
2971
|
+
if filter_element not in metadata[key]:
|
|
2972
|
+
return False
|
|
2973
|
+
elif filter_value not in metadata[key]:
|
|
2486
2974
|
return False
|
|
2487
2975
|
else:
|
|
2488
2976
|
# Direct equality
|
|
@@ -2498,6 +2986,7 @@ async def _empty_generator():
|
|
|
2498
2986
|
|
|
2499
2987
|
|
|
2500
2988
|
__all__ = [
|
|
2989
|
+
"StreamHandler",
|
|
2501
2990
|
"Assistants",
|
|
2502
2991
|
"Crons",
|
|
2503
2992
|
"Runs",
|