langgraph-runtime-inmem 0.9.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem/database.py +3 -1
- langgraph_runtime_inmem/inmem_stream.py +24 -2
- langgraph_runtime_inmem/lifespan.py +41 -2
- langgraph_runtime_inmem/metrics.py +1 -1
- langgraph_runtime_inmem/ops.py +463 -230
- langgraph_runtime_inmem/queue.py +16 -16
- {langgraph_runtime_inmem-0.9.0.dist-info → langgraph_runtime_inmem-0.20.1.dist-info}/METADATA +3 -3
- langgraph_runtime_inmem-0.20.1.dist-info/RECORD +13 -0
- {langgraph_runtime_inmem-0.9.0.dist-info → langgraph_runtime_inmem-0.20.1.dist-info}/WHEEL +1 -1
- langgraph_runtime_inmem-0.9.0.dist-info/RECORD +0 -13
langgraph_runtime_inmem/ops.py
CHANGED
|
@@ -11,7 +11,6 @@ import uuid
|
|
|
11
11
|
from collections import defaultdict
|
|
12
12
|
from collections.abc import AsyncIterator, Sequence
|
|
13
13
|
from contextlib import asynccontextmanager
|
|
14
|
-
from copy import deepcopy
|
|
15
14
|
from datetime import UTC, datetime, timedelta
|
|
16
15
|
from typing import Any, Literal, cast
|
|
17
16
|
from uuid import UUID, uuid4
|
|
@@ -29,6 +28,7 @@ from langgraph_runtime_inmem.checkpoint import Checkpointer
|
|
|
29
28
|
from langgraph_runtime_inmem.database import InMemConnectionProto, connect
|
|
30
29
|
from langgraph_runtime_inmem.inmem_stream import (
|
|
31
30
|
THREADLESS_KEY,
|
|
31
|
+
ContextQueue,
|
|
32
32
|
Message,
|
|
33
33
|
get_stream_manager,
|
|
34
34
|
)
|
|
@@ -58,12 +58,13 @@ if typing.TYPE_CHECKING:
|
|
|
58
58
|
Thread,
|
|
59
59
|
ThreadSelectField,
|
|
60
60
|
ThreadStatus,
|
|
61
|
+
ThreadStreamMode,
|
|
61
62
|
ThreadUpdateResponse,
|
|
62
63
|
)
|
|
63
64
|
from langgraph_api.schema import Interrupt as InterruptSchema
|
|
64
|
-
from langgraph_api.serde import Fragment
|
|
65
65
|
from langgraph_api.utils import AsyncConnectionProto
|
|
66
66
|
|
|
67
|
+
StreamHandler = ContextQueue
|
|
67
68
|
|
|
68
69
|
logger = structlog.stdlib.get_logger(__name__)
|
|
69
70
|
|
|
@@ -140,6 +141,7 @@ class Assistants(Authenticated):
|
|
|
140
141
|
conn: InMemConnectionProto,
|
|
141
142
|
*,
|
|
142
143
|
graph_id: str | None,
|
|
144
|
+
name: str | None,
|
|
143
145
|
metadata: MetadataInput,
|
|
144
146
|
limit: int,
|
|
145
147
|
offset: int,
|
|
@@ -163,6 +165,7 @@ class Assistants(Authenticated):
|
|
|
163
165
|
assistant
|
|
164
166
|
for assistant in assistants
|
|
165
167
|
if (not graph_id or assistant["graph_id"] == graph_id)
|
|
168
|
+
and (not name or name.lower() in assistant["name"].lower())
|
|
166
169
|
and (not metadata or is_jsonb_contained(assistant["metadata"], metadata))
|
|
167
170
|
and (not filters or _check_filter_match(assistant["metadata"], filters))
|
|
168
171
|
]
|
|
@@ -228,7 +231,7 @@ class Assistants(Authenticated):
|
|
|
228
231
|
if assistant["assistant_id"] == assistant_id and (
|
|
229
232
|
not filters or _check_filter_match(assistant["metadata"], filters)
|
|
230
233
|
):
|
|
231
|
-
yield assistant
|
|
234
|
+
yield copy.deepcopy(assistant)
|
|
232
235
|
|
|
233
236
|
return _yield_result()
|
|
234
237
|
|
|
@@ -247,6 +250,8 @@ class Assistants(Authenticated):
|
|
|
247
250
|
description: str | None = None,
|
|
248
251
|
) -> AsyncIterator[Assistant]:
|
|
249
252
|
"""Insert an assistant."""
|
|
253
|
+
from langgraph_api.graph import GRAPHS
|
|
254
|
+
|
|
250
255
|
assistant_id = _ensure_uuid(assistant_id)
|
|
251
256
|
metadata = metadata if metadata is not None else {}
|
|
252
257
|
filters = await Assistants.handle_event(
|
|
@@ -268,6 +273,9 @@ class Assistants(Authenticated):
|
|
|
268
273
|
detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
|
|
269
274
|
)
|
|
270
275
|
|
|
276
|
+
if graph_id not in GRAPHS:
|
|
277
|
+
raise HTTPException(status_code=404, detail=f"Graph {graph_id} not found")
|
|
278
|
+
|
|
271
279
|
# Keep config and context up to date with one another
|
|
272
280
|
if config.get("configurable"):
|
|
273
281
|
context = config["configurable"]
|
|
@@ -555,6 +563,8 @@ class Assistants(Authenticated):
|
|
|
555
563
|
"metadata": version_data["metadata"],
|
|
556
564
|
"version": version_data["version"],
|
|
557
565
|
"updated_at": datetime.now(UTC),
|
|
566
|
+
"name": version_data["name"],
|
|
567
|
+
"description": version_data["description"],
|
|
558
568
|
}
|
|
559
569
|
)
|
|
560
570
|
|
|
@@ -618,6 +628,7 @@ class Assistants(Authenticated):
|
|
|
618
628
|
conn: InMemConnectionProto,
|
|
619
629
|
*,
|
|
620
630
|
graph_id: str | None = None,
|
|
631
|
+
name: str | None = None,
|
|
621
632
|
metadata: MetadataInput = None,
|
|
622
633
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
623
634
|
) -> int:
|
|
@@ -635,6 +646,7 @@ class Assistants(Authenticated):
|
|
|
635
646
|
for assistant in conn.store["assistants"]:
|
|
636
647
|
if (
|
|
637
648
|
(not graph_id or assistant["graph_id"] == graph_id)
|
|
649
|
+
and (not name or name.lower() in assistant["name"].lower())
|
|
638
650
|
and (
|
|
639
651
|
not metadata or is_jsonb_contained(assistant["metadata"], metadata)
|
|
640
652
|
)
|
|
@@ -738,6 +750,7 @@ class Threads(Authenticated):
|
|
|
738
750
|
async def search(
|
|
739
751
|
conn: InMemConnectionProto,
|
|
740
752
|
*,
|
|
753
|
+
ids: list[str] | list[UUID] | None = None,
|
|
741
754
|
metadata: MetadataInput,
|
|
742
755
|
values: MetadataInput,
|
|
743
756
|
status: ThreadStatus | None,
|
|
@@ -765,7 +778,19 @@ class Threads(Authenticated):
|
|
|
765
778
|
)
|
|
766
779
|
|
|
767
780
|
# Apply filters
|
|
781
|
+
id_set: set[UUID] | None = None
|
|
782
|
+
if ids:
|
|
783
|
+
id_set = set()
|
|
784
|
+
for i in ids:
|
|
785
|
+
try:
|
|
786
|
+
id_set.add(_ensure_uuid(i))
|
|
787
|
+
except Exception:
|
|
788
|
+
raise HTTPException(
|
|
789
|
+
status_code=400, detail="Invalid thread ID " + str(i)
|
|
790
|
+
) from None
|
|
768
791
|
for thread in threads:
|
|
792
|
+
if id_set is not None and thread.get("thread_id") not in id_set:
|
|
793
|
+
continue
|
|
769
794
|
if filters and not _check_filter_match(thread["metadata"], filters):
|
|
770
795
|
continue
|
|
771
796
|
|
|
@@ -943,6 +968,7 @@ class Threads(Authenticated):
|
|
|
943
968
|
thread_id: UUID,
|
|
944
969
|
*,
|
|
945
970
|
metadata: MetadataValue,
|
|
971
|
+
ttl: ThreadTTLConfig | None = None,
|
|
946
972
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
947
973
|
) -> AsyncIterator[Thread]:
|
|
948
974
|
"""Update a thread."""
|
|
@@ -1215,13 +1241,23 @@ class Threads(Authenticated):
|
|
|
1215
1241
|
"""Create a copy of an existing thread."""
|
|
1216
1242
|
thread_id = _ensure_uuid(thread_id)
|
|
1217
1243
|
new_thread_id = uuid4()
|
|
1218
|
-
|
|
1244
|
+
read_filters = await Threads.handle_event(
|
|
1219
1245
|
ctx,
|
|
1220
1246
|
"read",
|
|
1221
1247
|
Auth.types.ThreadsRead(
|
|
1248
|
+
thread_id=thread_id,
|
|
1249
|
+
),
|
|
1250
|
+
)
|
|
1251
|
+
# Assert that the user has permissions to create a new thread.
|
|
1252
|
+
# (We don't actually need the filters.)
|
|
1253
|
+
await Threads.handle_event(
|
|
1254
|
+
ctx,
|
|
1255
|
+
"create",
|
|
1256
|
+
Auth.types.ThreadsCreate(
|
|
1222
1257
|
thread_id=new_thread_id,
|
|
1223
1258
|
),
|
|
1224
1259
|
)
|
|
1260
|
+
|
|
1225
1261
|
async with conn.pipeline():
|
|
1226
1262
|
# Find the original thread in our store
|
|
1227
1263
|
original_thread = next(
|
|
@@ -1230,8 +1266,8 @@ class Threads(Authenticated):
|
|
|
1230
1266
|
|
|
1231
1267
|
if not original_thread:
|
|
1232
1268
|
return _empty_generator()
|
|
1233
|
-
if
|
|
1234
|
-
original_thread["metadata"],
|
|
1269
|
+
if read_filters and not _check_filter_match(
|
|
1270
|
+
original_thread["metadata"], read_filters
|
|
1235
1271
|
):
|
|
1236
1272
|
return _empty_generator()
|
|
1237
1273
|
|
|
@@ -1240,7 +1276,7 @@ class Threads(Authenticated):
|
|
|
1240
1276
|
"thread_id": new_thread_id,
|
|
1241
1277
|
"created_at": datetime.now(tz=UTC),
|
|
1242
1278
|
"updated_at": datetime.now(tz=UTC),
|
|
1243
|
-
"metadata": deepcopy(original_thread["metadata"]),
|
|
1279
|
+
"metadata": copy.deepcopy(original_thread["metadata"]),
|
|
1244
1280
|
"status": "idle",
|
|
1245
1281
|
"config": {},
|
|
1246
1282
|
}
|
|
@@ -1327,7 +1363,14 @@ class Threads(Authenticated):
|
|
|
1327
1363
|
)
|
|
1328
1364
|
|
|
1329
1365
|
metadata = thread.get("metadata", {})
|
|
1330
|
-
thread_config = thread.get("config", {})
|
|
1366
|
+
thread_config = cast(dict[str, Any], thread.get("config", {}))
|
|
1367
|
+
thread_config = {
|
|
1368
|
+
**thread_config,
|
|
1369
|
+
"configurable": {
|
|
1370
|
+
**thread_config.get("configurable", {}),
|
|
1371
|
+
**config.get("configurable", {}),
|
|
1372
|
+
},
|
|
1373
|
+
}
|
|
1331
1374
|
|
|
1332
1375
|
# Fallback to graph_id from run if not in thread metadata
|
|
1333
1376
|
graph_id = metadata.get("graph_id")
|
|
@@ -1377,6 +1420,7 @@ class Threads(Authenticated):
|
|
|
1377
1420
|
"""Add state to a thread."""
|
|
1378
1421
|
from langgraph_api.graph import get_graph
|
|
1379
1422
|
from langgraph_api.schema import ThreadUpdateResponse
|
|
1423
|
+
from langgraph_api.state import state_snapshot_to_thread_state
|
|
1380
1424
|
from langgraph_api.store import get_store
|
|
1381
1425
|
from langgraph_api.utils import fetchone
|
|
1382
1426
|
|
|
@@ -1414,6 +1458,13 @@ class Threads(Authenticated):
|
|
|
1414
1458
|
status_code=409,
|
|
1415
1459
|
detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
|
|
1416
1460
|
)
|
|
1461
|
+
thread_config = {
|
|
1462
|
+
**thread_config,
|
|
1463
|
+
"configurable": {
|
|
1464
|
+
**thread_config.get("configurable", {}),
|
|
1465
|
+
**config.get("configurable", {}),
|
|
1466
|
+
},
|
|
1467
|
+
}
|
|
1417
1468
|
|
|
1418
1469
|
# Fallback to graph_id from run if not in thread metadata
|
|
1419
1470
|
graph_id = metadata.get("graph_id")
|
|
@@ -1454,6 +1505,19 @@ class Threads(Authenticated):
|
|
|
1454
1505
|
thread["values"] = state.values
|
|
1455
1506
|
break
|
|
1456
1507
|
|
|
1508
|
+
# Publish state update event
|
|
1509
|
+
from langgraph_api.serde import json_dumpb
|
|
1510
|
+
|
|
1511
|
+
event_data = {
|
|
1512
|
+
"state": state_snapshot_to_thread_state(state),
|
|
1513
|
+
"thread_id": str(thread_id),
|
|
1514
|
+
}
|
|
1515
|
+
await Threads.Stream.publish(
|
|
1516
|
+
thread_id,
|
|
1517
|
+
"state_update",
|
|
1518
|
+
json_dumpb(event_data),
|
|
1519
|
+
)
|
|
1520
|
+
|
|
1457
1521
|
return ThreadUpdateResponse(
|
|
1458
1522
|
checkpoint=next_config["configurable"],
|
|
1459
1523
|
# Including deprecated fields
|
|
@@ -1496,7 +1560,14 @@ class Threads(Authenticated):
|
|
|
1496
1560
|
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
1497
1561
|
)
|
|
1498
1562
|
|
|
1499
|
-
thread_config = thread["config"]
|
|
1563
|
+
thread_config = cast(dict[str, Any], thread["config"])
|
|
1564
|
+
thread_config = {
|
|
1565
|
+
**thread_config,
|
|
1566
|
+
"configurable": {
|
|
1567
|
+
**thread_config.get("configurable", {}),
|
|
1568
|
+
**config.get("configurable", {}),
|
|
1569
|
+
},
|
|
1570
|
+
}
|
|
1500
1571
|
metadata = thread["metadata"]
|
|
1501
1572
|
|
|
1502
1573
|
if not thread:
|
|
@@ -1543,6 +1614,19 @@ class Threads(Authenticated):
|
|
|
1543
1614
|
thread["values"] = state.values
|
|
1544
1615
|
break
|
|
1545
1616
|
|
|
1617
|
+
# Publish state update event
|
|
1618
|
+
from langgraph_api.serde import json_dumpb
|
|
1619
|
+
|
|
1620
|
+
event_data = {
|
|
1621
|
+
"state": state,
|
|
1622
|
+
"thread_id": str(thread_id),
|
|
1623
|
+
}
|
|
1624
|
+
await Threads.Stream.publish(
|
|
1625
|
+
thread_id,
|
|
1626
|
+
"state_update",
|
|
1627
|
+
json_dumpb(event_data),
|
|
1628
|
+
)
|
|
1629
|
+
|
|
1546
1630
|
return ThreadUpdateResponse(
|
|
1547
1631
|
checkpoint=next_config["configurable"],
|
|
1548
1632
|
)
|
|
@@ -1584,7 +1668,14 @@ class Threads(Authenticated):
|
|
|
1584
1668
|
if not _check_filter_match(thread_metadata, filters):
|
|
1585
1669
|
return []
|
|
1586
1670
|
|
|
1587
|
-
thread_config = thread["config"]
|
|
1671
|
+
thread_config = cast(dict[str, Any], thread["config"])
|
|
1672
|
+
thread_config = {
|
|
1673
|
+
**thread_config,
|
|
1674
|
+
"configurable": {
|
|
1675
|
+
**thread_config.get("configurable", {}),
|
|
1676
|
+
**config.get("configurable", {}),
|
|
1677
|
+
},
|
|
1678
|
+
}
|
|
1588
1679
|
# If graph_id exists, get state history
|
|
1589
1680
|
if graph_id := thread_metadata.get("graph_id"):
|
|
1590
1681
|
async with get_graph(
|
|
@@ -1613,7 +1704,9 @@ class Threads(Authenticated):
|
|
|
1613
1704
|
|
|
1614
1705
|
return []
|
|
1615
1706
|
|
|
1616
|
-
class Stream:
|
|
1707
|
+
class Stream(Authenticated):
|
|
1708
|
+
resource = "threads"
|
|
1709
|
+
|
|
1617
1710
|
@staticmethod
|
|
1618
1711
|
async def subscribe(
|
|
1619
1712
|
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
@@ -1626,6 +1719,13 @@ class Threads(Authenticated):
|
|
|
1626
1719
|
|
|
1627
1720
|
# Create new queues only for runs not yet seen
|
|
1628
1721
|
thread_id = _ensure_uuid(thread_id)
|
|
1722
|
+
|
|
1723
|
+
# Add thread stream queue
|
|
1724
|
+
if thread_id not in seen_runs:
|
|
1725
|
+
queue = await stream_manager.add_thread_stream(thread_id)
|
|
1726
|
+
queues.append((thread_id, queue))
|
|
1727
|
+
seen_runs.add(thread_id)
|
|
1728
|
+
|
|
1629
1729
|
for run in conn.store["runs"]:
|
|
1630
1730
|
if run["thread_id"] == thread_id:
|
|
1631
1731
|
run_id = run["run_id"]
|
|
@@ -1641,9 +1741,32 @@ class Threads(Authenticated):
|
|
|
1641
1741
|
thread_id: UUID,
|
|
1642
1742
|
*,
|
|
1643
1743
|
last_event_id: str | None = None,
|
|
1744
|
+
stream_modes: list[ThreadStreamMode],
|
|
1745
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1644
1746
|
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
1645
1747
|
"""Stream the thread output."""
|
|
1646
|
-
|
|
1748
|
+
await Threads.Stream.check_thread_stream_auth(thread_id, ctx)
|
|
1749
|
+
|
|
1750
|
+
from langgraph_api.utils.stream_codec import (
|
|
1751
|
+
decode_stream_message,
|
|
1752
|
+
)
|
|
1753
|
+
|
|
1754
|
+
def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
|
|
1755
|
+
"""Check if an event should be filtered out based on stream_modes."""
|
|
1756
|
+
if "run_modes" in stream_modes and event_name != "state_update":
|
|
1757
|
+
return False
|
|
1758
|
+
if "state_update" in stream_modes and event_name == "state_update":
|
|
1759
|
+
return False
|
|
1760
|
+
if "lifecycle" in stream_modes and event_name == "metadata":
|
|
1761
|
+
try:
|
|
1762
|
+
message_data = orjson.loads(message_bytes)
|
|
1763
|
+
if message_data.get("status") == "run_done":
|
|
1764
|
+
return False
|
|
1765
|
+
if "attempt" in message_data and "run_id" in message_data:
|
|
1766
|
+
return False
|
|
1767
|
+
except (orjson.JSONDecodeError, TypeError):
|
|
1768
|
+
pass
|
|
1769
|
+
return True
|
|
1647
1770
|
|
|
1648
1771
|
stream_manager = get_stream_manager()
|
|
1649
1772
|
seen_runs: set[UUID] = set()
|
|
@@ -1673,23 +1796,24 @@ class Threads(Authenticated):
|
|
|
1673
1796
|
|
|
1674
1797
|
# Yield sorted events
|
|
1675
1798
|
for message, run_id in all_events:
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
message.id,
|
|
1799
|
+
decoded = decode_stream_message(
|
|
1800
|
+
message.data, channel=message.topic
|
|
1801
|
+
)
|
|
1802
|
+
event_bytes = decoded.event_bytes
|
|
1803
|
+
message_bytes = decoded.message_bytes
|
|
1804
|
+
|
|
1805
|
+
if event_bytes == b"control":
|
|
1806
|
+
if message_bytes == b"done":
|
|
1807
|
+
event_bytes = b"metadata"
|
|
1808
|
+
message_bytes = orjson.dumps(
|
|
1809
|
+
{"status": "run_done", "run_id": run_id}
|
|
1688
1810
|
)
|
|
1689
|
-
|
|
1811
|
+
if not should_filter_event(
|
|
1812
|
+
event_bytes.decode("utf-8"), message_bytes
|
|
1813
|
+
):
|
|
1690
1814
|
yield (
|
|
1691
|
-
|
|
1692
|
-
|
|
1815
|
+
event_bytes,
|
|
1816
|
+
message_bytes,
|
|
1693
1817
|
message.id,
|
|
1694
1818
|
)
|
|
1695
1819
|
|
|
@@ -1708,28 +1832,27 @@ class Threads(Authenticated):
|
|
|
1708
1832
|
message = await asyncio.wait_for(
|
|
1709
1833
|
queue.get(), timeout=0.2
|
|
1710
1834
|
)
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
),
|
|
1725
|
-
message.id,
|
|
1726
|
-
)
|
|
1727
|
-
else:
|
|
1728
|
-
yield (
|
|
1729
|
-
event_name.encode(),
|
|
1730
|
-
base64.b64decode(message_content),
|
|
1731
|
-
message.id,
|
|
1835
|
+
decoded = decode_stream_message(
|
|
1836
|
+
message.data, channel=message.topic
|
|
1837
|
+
)
|
|
1838
|
+
event = decoded.event_bytes
|
|
1839
|
+
event_name = event.decode("utf-8")
|
|
1840
|
+
payload = decoded.message_bytes
|
|
1841
|
+
|
|
1842
|
+
if event == b"control" and payload == b"done":
|
|
1843
|
+
topic = message.topic.decode()
|
|
1844
|
+
run_id = topic.split("run:")[1].split(":")[0]
|
|
1845
|
+
meta_event = b"metadata"
|
|
1846
|
+
meta_payload = orjson.dumps(
|
|
1847
|
+
{"status": "run_done", "run_id": run_id}
|
|
1732
1848
|
)
|
|
1849
|
+
if not should_filter_event(
|
|
1850
|
+
"metadata", meta_payload
|
|
1851
|
+
):
|
|
1852
|
+
yield (meta_event, meta_payload, message.id)
|
|
1853
|
+
else:
|
|
1854
|
+
if not should_filter_event(event_name, payload):
|
|
1855
|
+
yield (event, payload, message.id)
|
|
1733
1856
|
|
|
1734
1857
|
except TimeoutError:
|
|
1735
1858
|
continue
|
|
@@ -1758,6 +1881,41 @@ class Threads(Authenticated):
|
|
|
1758
1881
|
# Ignore cleanup errors
|
|
1759
1882
|
pass
|
|
1760
1883
|
|
|
1884
|
+
@staticmethod
|
|
1885
|
+
async def publish(
|
|
1886
|
+
thread_id: UUID | str,
|
|
1887
|
+
event: str,
|
|
1888
|
+
message: bytes,
|
|
1889
|
+
) -> None:
|
|
1890
|
+
"""Publish a thread-level event to the thread stream."""
|
|
1891
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
1892
|
+
|
|
1893
|
+
topic = f"thread:{thread_id}:stream".encode()
|
|
1894
|
+
|
|
1895
|
+
stream_manager = get_stream_manager()
|
|
1896
|
+
payload = STREAM_CODEC.encode(event, message)
|
|
1897
|
+
await stream_manager.put_thread(
|
|
1898
|
+
str(thread_id), Message(topic=topic, data=payload)
|
|
1899
|
+
)
|
|
1900
|
+
|
|
1901
|
+
@staticmethod
|
|
1902
|
+
async def check_thread_stream_auth(
|
|
1903
|
+
thread_id: UUID,
|
|
1904
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1905
|
+
) -> None:
|
|
1906
|
+
async with connect() as conn:
|
|
1907
|
+
filters = await Threads.Stream.handle_event(
|
|
1908
|
+
ctx,
|
|
1909
|
+
"read",
|
|
1910
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1911
|
+
)
|
|
1912
|
+
if filters:
|
|
1913
|
+
thread = await Threads._get_with_filters(
|
|
1914
|
+
cast(InMemConnectionProto, conn), thread_id, filters
|
|
1915
|
+
)
|
|
1916
|
+
if not thread:
|
|
1917
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
1918
|
+
|
|
1761
1919
|
@staticmethod
|
|
1762
1920
|
async def count(
|
|
1763
1921
|
conn: InMemConnectionProto,
|
|
@@ -1821,38 +1979,37 @@ class Runs(Authenticated):
|
|
|
1821
1979
|
if not pending_runs and not running_runs:
|
|
1822
1980
|
return {
|
|
1823
1981
|
"n_pending": 0,
|
|
1824
|
-
"
|
|
1825
|
-
"
|
|
1982
|
+
"pending_runs_wait_time_max_secs": None,
|
|
1983
|
+
"pending_runs_wait_time_med_secs": None,
|
|
1826
1984
|
"n_running": 0,
|
|
1827
1985
|
}
|
|
1828
1986
|
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1837
|
-
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
median_time = sorted_times[median_idx]
|
|
1987
|
+
now = datetime.now(UTC)
|
|
1988
|
+
pending_waits: list[float] = []
|
|
1989
|
+
for run in pending_runs:
|
|
1990
|
+
created_at = run.get("created_at")
|
|
1991
|
+
if not isinstance(created_at, datetime):
|
|
1992
|
+
continue
|
|
1993
|
+
if created_at.tzinfo is None:
|
|
1994
|
+
created_at = created_at.replace(tzinfo=UTC)
|
|
1995
|
+
pending_waits.append((now - created_at).total_seconds())
|
|
1996
|
+
|
|
1997
|
+
max_pending_wait = max(pending_waits) if pending_waits else None
|
|
1998
|
+
if pending_waits:
|
|
1999
|
+
sorted_waits = sorted(pending_waits)
|
|
2000
|
+
half = len(sorted_waits) // 2
|
|
2001
|
+
if len(sorted_waits) % 2 == 1:
|
|
2002
|
+
med_pending_wait = sorted_waits[half]
|
|
2003
|
+
else:
|
|
2004
|
+
med_pending_wait = (sorted_waits[half - 1] + sorted_waits[half]) / 2
|
|
2005
|
+
else:
|
|
2006
|
+
med_pending_wait = None
|
|
1850
2007
|
|
|
1851
2008
|
return {
|
|
1852
2009
|
"n_pending": len(pending_runs),
|
|
1853
2010
|
"n_running": len(running_runs),
|
|
1854
|
-
"
|
|
1855
|
-
"
|
|
2011
|
+
"pending_runs_wait_time_max_secs": max_pending_wait,
|
|
2012
|
+
"pending_runs_wait_time_med_secs": med_pending_wait,
|
|
1856
2013
|
}
|
|
1857
2014
|
|
|
1858
2015
|
@staticmethod
|
|
@@ -1916,12 +2073,16 @@ class Runs(Authenticated):
|
|
|
1916
2073
|
@asynccontextmanager
|
|
1917
2074
|
@staticmethod
|
|
1918
2075
|
async def enter(
|
|
1919
|
-
run_id: UUID,
|
|
2076
|
+
run_id: UUID,
|
|
2077
|
+
thread_id: UUID | None,
|
|
2078
|
+
loop: asyncio.AbstractEventLoop,
|
|
2079
|
+
resumable: bool,
|
|
1920
2080
|
) -> AsyncIterator[ValueEvent]:
|
|
1921
2081
|
"""Enter a run, listen for cancellation while running, signal when done."
|
|
1922
2082
|
This method should be called as a context manager by a worker executing a run.
|
|
1923
2083
|
"""
|
|
1924
2084
|
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
|
|
2085
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
1925
2086
|
|
|
1926
2087
|
stream_manager = get_stream_manager()
|
|
1927
2088
|
# Get control queue for this run (normal queue is created during run creation)
|
|
@@ -1941,12 +2102,14 @@ class Runs(Authenticated):
|
|
|
1941
2102
|
)
|
|
1942
2103
|
await stream_manager.put(run_id, thread_id, control_message)
|
|
1943
2104
|
|
|
1944
|
-
# Signal done to all subscribers
|
|
2105
|
+
# Signal done to all subscribers using stream codec
|
|
1945
2106
|
stream_message = Message(
|
|
1946
2107
|
topic=f"run:{run_id}:stream".encode(),
|
|
1947
|
-
data=
|
|
2108
|
+
data=STREAM_CODEC.encode("control", b"done"),
|
|
2109
|
+
)
|
|
2110
|
+
await stream_manager.put(
|
|
2111
|
+
run_id, thread_id, stream_message, resumable=resumable
|
|
1948
2112
|
)
|
|
1949
|
-
await stream_manager.put(run_id, thread_id, stream_message)
|
|
1950
2113
|
|
|
1951
2114
|
# Remove the control_queue (normal queue is cleaned up during run deletion)
|
|
1952
2115
|
await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
|
|
@@ -1988,7 +2151,6 @@ class Runs(Authenticated):
|
|
|
1988
2151
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1989
2152
|
) -> AsyncIterator[Run]:
|
|
1990
2153
|
"""Create a run."""
|
|
1991
|
-
from langgraph_api.config import FF_RICH_THREADS
|
|
1992
2154
|
from langgraph_api.schema import Run, Thread
|
|
1993
2155
|
|
|
1994
2156
|
assistant_id = _ensure_uuid(assistant_id)
|
|
@@ -2004,6 +2166,7 @@ class Runs(Authenticated):
|
|
|
2004
2166
|
run_id = _ensure_uuid(run_id) if run_id else None
|
|
2005
2167
|
metadata = metadata if metadata is not None else {}
|
|
2006
2168
|
config = kwargs.get("config", {})
|
|
2169
|
+
temporary = kwargs.get("temporary", False)
|
|
2007
2170
|
|
|
2008
2171
|
# Handle thread creation/update
|
|
2009
2172
|
existing_thread = next(
|
|
@@ -2013,7 +2176,7 @@ class Runs(Authenticated):
|
|
|
2013
2176
|
ctx,
|
|
2014
2177
|
"create_run",
|
|
2015
2178
|
Auth.types.RunsCreate(
|
|
2016
|
-
thread_id=thread_id,
|
|
2179
|
+
thread_id=None if temporary else thread_id,
|
|
2017
2180
|
assistant_id=assistant_id,
|
|
2018
2181
|
run_id=run_id,
|
|
2019
2182
|
status=status,
|
|
@@ -2034,49 +2197,35 @@ class Runs(Authenticated):
|
|
|
2034
2197
|
# Create new thread
|
|
2035
2198
|
if thread_id is None:
|
|
2036
2199
|
thread_id = uuid4()
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
},
|
|
2055
|
-
),
|
|
2056
|
-
created_at=datetime.now(UTC),
|
|
2057
|
-
updated_at=datetime.now(UTC),
|
|
2058
|
-
values=b"",
|
|
2059
|
-
)
|
|
2060
|
-
else:
|
|
2061
|
-
thread = Thread(
|
|
2062
|
-
thread_id=thread_id,
|
|
2063
|
-
status="idle",
|
|
2064
|
-
metadata={
|
|
2065
|
-
"graph_id": assistant["graph_id"],
|
|
2066
|
-
"assistant_id": str(assistant_id),
|
|
2067
|
-
**(config.get("metadata") or {}),
|
|
2068
|
-
**metadata,
|
|
2200
|
+
|
|
2201
|
+
thread = Thread(
|
|
2202
|
+
thread_id=thread_id,
|
|
2203
|
+
status="busy",
|
|
2204
|
+
metadata={
|
|
2205
|
+
"graph_id": assistant["graph_id"],
|
|
2206
|
+
"assistant_id": str(assistant_id),
|
|
2207
|
+
**(config.get("metadata") or {}),
|
|
2208
|
+
**metadata,
|
|
2209
|
+
},
|
|
2210
|
+
config=Runs._merge_jsonb(
|
|
2211
|
+
assistant["config"],
|
|
2212
|
+
config,
|
|
2213
|
+
{
|
|
2214
|
+
"configurable": Runs._merge_jsonb(
|
|
2215
|
+
Runs._get_configurable(assistant["config"]),
|
|
2216
|
+
)
|
|
2069
2217
|
},
|
|
2070
|
-
|
|
2071
|
-
|
|
2072
|
-
|
|
2073
|
-
|
|
2074
|
-
|
|
2218
|
+
),
|
|
2219
|
+
created_at=datetime.now(UTC),
|
|
2220
|
+
updated_at=datetime.now(UTC),
|
|
2221
|
+
values=b"",
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2075
2224
|
await logger.ainfo("Creating thread", thread_id=thread_id)
|
|
2076
2225
|
conn.store["threads"].append(thread)
|
|
2077
2226
|
elif existing_thread:
|
|
2078
2227
|
# Update existing thread
|
|
2079
|
-
if
|
|
2228
|
+
if existing_thread["status"] != "busy":
|
|
2080
2229
|
existing_thread["status"] = "busy"
|
|
2081
2230
|
existing_thread["metadata"] = Runs._merge_jsonb(
|
|
2082
2231
|
existing_thread["metadata"],
|
|
@@ -2253,66 +2402,6 @@ class Runs(Authenticated):
|
|
|
2253
2402
|
|
|
2254
2403
|
return _yield_deleted()
|
|
2255
2404
|
|
|
2256
|
-
@staticmethod
|
|
2257
|
-
async def join(
|
|
2258
|
-
run_id: UUID,
|
|
2259
|
-
*,
|
|
2260
|
-
thread_id: UUID,
|
|
2261
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2262
|
-
) -> Fragment:
|
|
2263
|
-
"""Wait for a run to complete. If already done, return immediately.
|
|
2264
|
-
|
|
2265
|
-
Returns:
|
|
2266
|
-
the final state of the run.
|
|
2267
|
-
"""
|
|
2268
|
-
from langgraph_api.serde import Fragment
|
|
2269
|
-
from langgraph_api.utils import fetchone
|
|
2270
|
-
|
|
2271
|
-
async with connect() as conn:
|
|
2272
|
-
# Validate ownership
|
|
2273
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
2274
|
-
await fetchone(thread_iter)
|
|
2275
|
-
last_chunk: bytes | None = None
|
|
2276
|
-
# wait for the run to complete
|
|
2277
|
-
# Rely on this join's auth
|
|
2278
|
-
async for mode, chunk, _ in Runs.Stream.join(
|
|
2279
|
-
run_id,
|
|
2280
|
-
thread_id=thread_id,
|
|
2281
|
-
ctx=ctx,
|
|
2282
|
-
ignore_404=True,
|
|
2283
|
-
stream_mode=["values", "updates", "error"],
|
|
2284
|
-
):
|
|
2285
|
-
if mode == b"values":
|
|
2286
|
-
last_chunk = chunk
|
|
2287
|
-
elif mode == b"updates" and b"__interrupt__" in chunk:
|
|
2288
|
-
last_chunk = chunk
|
|
2289
|
-
elif mode == b"error":
|
|
2290
|
-
last_chunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
|
|
2291
|
-
# if we received a final chunk, return it
|
|
2292
|
-
if last_chunk is not None:
|
|
2293
|
-
# ie. if the run completed while we were waiting for it
|
|
2294
|
-
return Fragment(last_chunk)
|
|
2295
|
-
else:
|
|
2296
|
-
# otherwise, the run had already finished, so fetch the state from thread
|
|
2297
|
-
async with connect() as conn:
|
|
2298
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
2299
|
-
thread = await fetchone(thread_iter)
|
|
2300
|
-
if thread["status"] == "error":
|
|
2301
|
-
return Fragment(
|
|
2302
|
-
orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
|
|
2303
|
-
)
|
|
2304
|
-
if thread["status"] == "interrupted":
|
|
2305
|
-
# Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
|
|
2306
|
-
# interrupt, but we'll always show one. Long term we should show all of them.
|
|
2307
|
-
try:
|
|
2308
|
-
interrupt_map = thread["interrupts"]
|
|
2309
|
-
interrupt = [next(iter(interrupt_map.values()))[0]]
|
|
2310
|
-
return Fragment(orjson.dumps({"__interrupt__": interrupt}))
|
|
2311
|
-
except Exception:
|
|
2312
|
-
# No interrupt, but status is interrupted from a before/after block. Default back to values.
|
|
2313
|
-
pass
|
|
2314
|
-
return thread["values"]
|
|
2315
|
-
|
|
2316
2405
|
@staticmethod
|
|
2317
2406
|
async def cancel(
|
|
2318
2407
|
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
@@ -2538,7 +2627,7 @@ class Runs(Authenticated):
|
|
|
2538
2627
|
async def subscribe(
|
|
2539
2628
|
run_id: UUID,
|
|
2540
2629
|
thread_id: UUID | None = None,
|
|
2541
|
-
) ->
|
|
2630
|
+
) -> ContextQueue:
|
|
2542
2631
|
"""Subscribe to the run stream, returning a queue."""
|
|
2543
2632
|
stream_manager = get_stream_manager()
|
|
2544
2633
|
queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
|
|
@@ -2562,54 +2651,38 @@ class Runs(Authenticated):
|
|
|
2562
2651
|
async def join(
|
|
2563
2652
|
run_id: UUID,
|
|
2564
2653
|
*,
|
|
2654
|
+
stream_channel: asyncio.Queue,
|
|
2565
2655
|
thread_id: UUID,
|
|
2566
2656
|
ignore_404: bool = False,
|
|
2567
2657
|
cancel_on_disconnect: bool = False,
|
|
2568
|
-
stream_channel: asyncio.Queue | None = None,
|
|
2569
2658
|
stream_mode: list[StreamMode] | StreamMode | None = None,
|
|
2570
2659
|
last_event_id: str | None = None,
|
|
2571
2660
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2572
2661
|
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
2573
2662
|
"""Stream the run output."""
|
|
2574
2663
|
from langgraph_api.asyncio import create_task
|
|
2575
|
-
from langgraph_api.serde import
|
|
2576
|
-
|
|
2577
|
-
queue = (
|
|
2578
|
-
stream_channel
|
|
2579
|
-
if stream_channel
|
|
2580
|
-
else await Runs.Stream.subscribe(run_id, thread_id)
|
|
2581
|
-
)
|
|
2664
|
+
from langgraph_api.serde import json_dumpb
|
|
2665
|
+
from langgraph_api.utils.stream_codec import decode_stream_message
|
|
2582
2666
|
|
|
2667
|
+
queue = stream_channel
|
|
2583
2668
|
try:
|
|
2584
2669
|
async with connect() as conn:
|
|
2585
|
-
|
|
2586
|
-
ctx
|
|
2587
|
-
|
|
2588
|
-
|
|
2589
|
-
)
|
|
2590
|
-
if filters:
|
|
2591
|
-
thread = await Threads._get_with_filters(
|
|
2592
|
-
cast(InMemConnectionProto, conn), thread_id, filters
|
|
2593
|
-
)
|
|
2594
|
-
if not thread:
|
|
2595
|
-
raise WrappedHTTPException(
|
|
2596
|
-
HTTPException(
|
|
2597
|
-
status_code=404, detail="Thread not found"
|
|
2598
|
-
)
|
|
2599
|
-
)
|
|
2670
|
+
try:
|
|
2671
|
+
await Runs.Stream.check_run_stream_auth(run_id, thread_id, ctx)
|
|
2672
|
+
except HTTPException as e:
|
|
2673
|
+
raise WrappedHTTPException(e) from None
|
|
2600
2674
|
run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
|
|
2601
2675
|
|
|
2602
2676
|
for message in get_stream_manager().restore_messages(
|
|
2603
2677
|
run_id, thread_id, last_event_id
|
|
2604
2678
|
):
|
|
2605
2679
|
data, id = message.data, message.id
|
|
2606
|
-
|
|
2607
|
-
|
|
2608
|
-
|
|
2609
|
-
message = data["message"]
|
|
2680
|
+
decoded = decode_stream_message(data, channel=message.topic)
|
|
2681
|
+
mode = decoded.event_bytes.decode("utf-8")
|
|
2682
|
+
payload = decoded.message_bytes
|
|
2610
2683
|
|
|
2611
2684
|
if mode == "control":
|
|
2612
|
-
if
|
|
2685
|
+
if payload == b"done":
|
|
2613
2686
|
return
|
|
2614
2687
|
elif (
|
|
2615
2688
|
not stream_mode
|
|
@@ -2622,7 +2695,7 @@ class Runs(Authenticated):
|
|
|
2622
2695
|
and mode.startswith("messages")
|
|
2623
2696
|
)
|
|
2624
2697
|
):
|
|
2625
|
-
yield mode.encode(),
|
|
2698
|
+
yield mode.encode(), payload, id
|
|
2626
2699
|
logger.debug(
|
|
2627
2700
|
"Replayed run event",
|
|
2628
2701
|
run_id=str(run_id),
|
|
@@ -2636,13 +2709,12 @@ class Runs(Authenticated):
|
|
|
2636
2709
|
# Wait for messages with a timeout
|
|
2637
2710
|
message = await asyncio.wait_for(queue.get(), timeout=0.5)
|
|
2638
2711
|
data, id = message.data, message.id
|
|
2639
|
-
|
|
2640
|
-
|
|
2641
|
-
|
|
2642
|
-
message = data["message"]
|
|
2712
|
+
decoded = decode_stream_message(data, channel=message.topic)
|
|
2713
|
+
mode = decoded.event_bytes.decode("utf-8")
|
|
2714
|
+
payload = decoded.message_bytes
|
|
2643
2715
|
|
|
2644
2716
|
if mode == "control":
|
|
2645
|
-
if
|
|
2717
|
+
if payload == b"done":
|
|
2646
2718
|
break
|
|
2647
2719
|
elif (
|
|
2648
2720
|
not stream_mode
|
|
@@ -2655,13 +2727,13 @@ class Runs(Authenticated):
|
|
|
2655
2727
|
and mode.startswith("messages")
|
|
2656
2728
|
)
|
|
2657
2729
|
):
|
|
2658
|
-
yield mode.encode(),
|
|
2730
|
+
yield mode.encode(), payload, id
|
|
2659
2731
|
logger.debug(
|
|
2660
2732
|
"Streamed run event",
|
|
2661
2733
|
run_id=str(run_id),
|
|
2662
2734
|
stream_mode=mode,
|
|
2663
2735
|
message_id=id,
|
|
2664
|
-
data=
|
|
2736
|
+
data=payload,
|
|
2665
2737
|
)
|
|
2666
2738
|
except TimeoutError:
|
|
2667
2739
|
# Check if the run is still pending
|
|
@@ -2675,8 +2747,10 @@ class Runs(Authenticated):
|
|
|
2675
2747
|
elif run is None:
|
|
2676
2748
|
yield (
|
|
2677
2749
|
b"error",
|
|
2678
|
-
|
|
2679
|
-
|
|
2750
|
+
json_dumpb(
|
|
2751
|
+
HTTPException(
|
|
2752
|
+
status_code=404, detail="Run not found"
|
|
2753
|
+
)
|
|
2680
2754
|
),
|
|
2681
2755
|
None,
|
|
2682
2756
|
)
|
|
@@ -2693,6 +2767,25 @@ class Runs(Authenticated):
|
|
|
2693
2767
|
stream_manager = get_stream_manager()
|
|
2694
2768
|
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
2695
2769
|
|
|
2770
|
+
@staticmethod
|
|
2771
|
+
async def check_run_stream_auth(
|
|
2772
|
+
run_id: UUID,
|
|
2773
|
+
thread_id: UUID,
|
|
2774
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2775
|
+
) -> None:
|
|
2776
|
+
async with connect() as conn:
|
|
2777
|
+
filters = await Runs.handle_event(
|
|
2778
|
+
ctx,
|
|
2779
|
+
"read",
|
|
2780
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
2781
|
+
)
|
|
2782
|
+
if filters:
|
|
2783
|
+
thread = await Threads._get_with_filters(
|
|
2784
|
+
cast(InMemConnectionProto, conn), thread_id, filters
|
|
2785
|
+
)
|
|
2786
|
+
if not thread:
|
|
2787
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
2788
|
+
|
|
2696
2789
|
@staticmethod
|
|
2697
2790
|
async def publish(
|
|
2698
2791
|
run_id: UUID | str,
|
|
@@ -2703,18 +2796,13 @@ class Runs(Authenticated):
|
|
|
2703
2796
|
resumable: bool = False,
|
|
2704
2797
|
) -> None:
|
|
2705
2798
|
"""Publish a message to all subscribers of the run stream."""
|
|
2706
|
-
from langgraph_api.
|
|
2799
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
2707
2800
|
|
|
2708
2801
|
topic = f"run:{run_id}:stream".encode()
|
|
2709
2802
|
|
|
2710
2803
|
stream_manager = get_stream_manager()
|
|
2711
|
-
# Send to all queues subscribed to this run_id
|
|
2712
|
-
payload =
|
|
2713
|
-
{
|
|
2714
|
-
"event": event,
|
|
2715
|
-
"message": message,
|
|
2716
|
-
}
|
|
2717
|
-
)
|
|
2804
|
+
# Send to all queues subscribed to this run_id using protocol frame
|
|
2805
|
+
payload = STREAM_CODEC.encode(event, message)
|
|
2718
2806
|
await stream_manager.put(
|
|
2719
2807
|
run_id, thread_id, Message(topic=topic, data=payload), resumable
|
|
2720
2808
|
)
|
|
@@ -2761,7 +2849,9 @@ class Crons:
|
|
|
2761
2849
|
schedule: str,
|
|
2762
2850
|
cron_id: UUID | None = None,
|
|
2763
2851
|
thread_id: UUID | None = None,
|
|
2852
|
+
on_run_completed: Literal["delete", "keep"] | None = None,
|
|
2764
2853
|
end_time: datetime | None = None,
|
|
2854
|
+
metadata: dict | None = None,
|
|
2765
2855
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2766
2856
|
) -> AsyncIterator[Cron]:
|
|
2767
2857
|
raise NotImplementedError
|
|
@@ -2852,19 +2942,154 @@ def _delete_checkpoints_for_thread(
|
|
|
2852
2942
|
)
|
|
2853
2943
|
|
|
2854
2944
|
|
|
2855
|
-
def
|
|
2945
|
+
def _validate_filter_structure(
|
|
2946
|
+
filters: Auth.types.FilterType | None,
|
|
2947
|
+
nesting_level: int = 0,
|
|
2948
|
+
) -> None:
|
|
2949
|
+
"""Validate the structure of filter conditions without checking matches.
|
|
2950
|
+
|
|
2951
|
+
Args:
|
|
2952
|
+
filters: The filter conditions to validate
|
|
2953
|
+
nesting_level: Current depth of nested operators (max 2)
|
|
2954
|
+
|
|
2955
|
+
Raises:
|
|
2956
|
+
HTTPException: If the filter structure is invalid
|
|
2957
|
+
"""
|
|
2958
|
+
if nesting_level > 2:
|
|
2959
|
+
raise HTTPException(
|
|
2960
|
+
status_code=500,
|
|
2961
|
+
detail="Your auth handler returned a filter with too many nested operators. The maximum depth for nested operators is 2. Please simplify your filter.",
|
|
2962
|
+
)
|
|
2963
|
+
|
|
2964
|
+
if not filters:
|
|
2965
|
+
return
|
|
2966
|
+
|
|
2967
|
+
# Handle $or operator
|
|
2968
|
+
if "$or" in filters:
|
|
2969
|
+
or_groups = filters["$or"]
|
|
2970
|
+
if not isinstance(or_groups, list) or not len(or_groups) >= 2:
|
|
2971
|
+
raise HTTPException(
|
|
2972
|
+
status_code=500,
|
|
2973
|
+
detail="Your auth handler returned a filter with an invalid $or operator. The $or operator must be a list of at least 2 filter objects. Check the filter returned by your auth handler.",
|
|
2974
|
+
)
|
|
2975
|
+
|
|
2976
|
+
# Recursively validate all groups
|
|
2977
|
+
for group in or_groups:
|
|
2978
|
+
_validate_filter_structure(group, nesting_level=nesting_level + 1)
|
|
2979
|
+
|
|
2980
|
+
# Validate remaining filters (implicit AND with the $or)
|
|
2981
|
+
remaining_filters = {k: v for k, v in filters.items() if k != "$or"}
|
|
2982
|
+
if remaining_filters:
|
|
2983
|
+
_validate_filter_structure(
|
|
2984
|
+
remaining_filters, nesting_level=nesting_level + 1
|
|
2985
|
+
)
|
|
2986
|
+
|
|
2987
|
+
# Handle $and operator
|
|
2988
|
+
if "$and" in filters:
|
|
2989
|
+
and_groups = filters["$and"]
|
|
2990
|
+
if not isinstance(and_groups, list) or not len(and_groups) >= 2:
|
|
2991
|
+
raise HTTPException(
|
|
2992
|
+
status_code=500,
|
|
2993
|
+
detail="Your auth handler returned a filter with an invalid $and operator. The $and operator must be a list of at least 2 filter objects. Check the filter returned by your auth handler.",
|
|
2994
|
+
)
|
|
2995
|
+
|
|
2996
|
+
# Recursively validate all groups
|
|
2997
|
+
for group in and_groups:
|
|
2998
|
+
_validate_filter_structure(group, nesting_level=nesting_level + 1)
|
|
2999
|
+
|
|
3000
|
+
# Validate remaining filters (implicit AND with the $and)
|
|
3001
|
+
remaining_filters = {k: v for k, v in filters.items() if k != "$and"}
|
|
3002
|
+
if remaining_filters:
|
|
3003
|
+
_validate_filter_structure(
|
|
3004
|
+
remaining_filters, nesting_level=nesting_level + 1
|
|
3005
|
+
)
|
|
3006
|
+
|
|
3007
|
+
|
|
3008
|
+
def _check_filter_match(
|
|
3009
|
+
metadata: dict,
|
|
3010
|
+
filters: Auth.types.FilterType | None,
|
|
3011
|
+
nesting_level: int = 0,
|
|
3012
|
+
) -> bool:
|
|
2856
3013
|
"""Check if metadata matches the filter conditions.
|
|
2857
3014
|
|
|
2858
3015
|
Args:
|
|
2859
3016
|
metadata: The metadata to check
|
|
2860
3017
|
filters: The filter conditions to apply
|
|
3018
|
+
nesting_level: Current depth of nested operators (max 2)
|
|
2861
3019
|
|
|
2862
3020
|
Returns:
|
|
2863
3021
|
True if the metadata matches all filter conditions, False otherwise
|
|
2864
3022
|
"""
|
|
3023
|
+
if nesting_level > 2:
|
|
3024
|
+
raise HTTPException(
|
|
3025
|
+
status_code=500,
|
|
3026
|
+
detail="Your auth handler returned a filter with too many nested operators. The maximum depth for nested operators is 2. Please simplify your filter.",
|
|
3027
|
+
)
|
|
3028
|
+
|
|
2865
3029
|
if not filters:
|
|
2866
3030
|
return True
|
|
2867
3031
|
|
|
3032
|
+
# Handle $or operator
|
|
3033
|
+
if "$or" in filters:
|
|
3034
|
+
or_groups = filters["$or"]
|
|
3035
|
+
if not isinstance(or_groups, list) or not len(or_groups) >= 2:
|
|
3036
|
+
raise HTTPException(
|
|
3037
|
+
status_code=500,
|
|
3038
|
+
detail="Your auth handler returned a filter with an invalid $or operator. The $or operator must be a list of at least 2 filter objects. Check the filter returned by your auth handler.",
|
|
3039
|
+
)
|
|
3040
|
+
|
|
3041
|
+
# Validate all groups first to ensure nesting limits are respected
|
|
3042
|
+
# (even if we short-circuit during matching)
|
|
3043
|
+
for group in or_groups:
|
|
3044
|
+
_validate_filter_structure(group, nesting_level=nesting_level + 1)
|
|
3045
|
+
|
|
3046
|
+
# At least one group must match
|
|
3047
|
+
or_match = False
|
|
3048
|
+
for group in or_groups:
|
|
3049
|
+
if _check_filter_match(metadata, group, nesting_level=nesting_level + 1):
|
|
3050
|
+
or_match = True
|
|
3051
|
+
break
|
|
3052
|
+
|
|
3053
|
+
if not or_match:
|
|
3054
|
+
return False
|
|
3055
|
+
|
|
3056
|
+
# Check remaining filters (implicit AND with the $or)
|
|
3057
|
+
remaining_filters = {k: v for k, v in filters.items() if k != "$or"}
|
|
3058
|
+
if remaining_filters:
|
|
3059
|
+
return _check_filter_match(
|
|
3060
|
+
metadata, remaining_filters, nesting_level=nesting_level + 1
|
|
3061
|
+
)
|
|
3062
|
+
return True
|
|
3063
|
+
|
|
3064
|
+
# Handle $and operator
|
|
3065
|
+
if "$and" in filters:
|
|
3066
|
+
and_groups = filters["$and"]
|
|
3067
|
+
if not isinstance(and_groups, list) or not len(and_groups) >= 2:
|
|
3068
|
+
raise HTTPException(
|
|
3069
|
+
status_code=500,
|
|
3070
|
+
detail="Your auth handler returned a filter with an invalid $and operator. The $and operator must be a list of at least 2 filter objects. Check the filter returned by your auth handler.",
|
|
3071
|
+
)
|
|
3072
|
+
|
|
3073
|
+
# Validate all groups first to ensure nesting limits are respected
|
|
3074
|
+
for group in and_groups:
|
|
3075
|
+
_validate_filter_structure(group, nesting_level=nesting_level + 1)
|
|
3076
|
+
|
|
3077
|
+
# All groups must match
|
|
3078
|
+
for group in and_groups:
|
|
3079
|
+
if not _check_filter_match(
|
|
3080
|
+
metadata, group, nesting_level=nesting_level + 1
|
|
3081
|
+
):
|
|
3082
|
+
return False
|
|
3083
|
+
|
|
3084
|
+
# Check remaining filters (implicit AND with the $and)
|
|
3085
|
+
remaining_filters = {k: v for k, v in filters.items() if k != "$and"}
|
|
3086
|
+
if remaining_filters:
|
|
3087
|
+
return _check_filter_match(
|
|
3088
|
+
metadata, remaining_filters, nesting_level=nesting_level + 1
|
|
3089
|
+
)
|
|
3090
|
+
return True
|
|
3091
|
+
|
|
3092
|
+
# Regular filter logic (implicit AND)
|
|
2868
3093
|
for key, value in filters.items():
|
|
2869
3094
|
if isinstance(value, dict):
|
|
2870
3095
|
op = next(iter(value))
|
|
@@ -2874,11 +3099,18 @@ def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -
|
|
|
2874
3099
|
if key not in metadata or metadata[key] != filter_value:
|
|
2875
3100
|
return False
|
|
2876
3101
|
elif op == "$contains":
|
|
2877
|
-
if (
|
|
2878
|
-
|
|
2879
|
-
|
|
2880
|
-
|
|
2881
|
-
|
|
3102
|
+
if key not in metadata or not isinstance(metadata[key], list):
|
|
3103
|
+
return False
|
|
3104
|
+
|
|
3105
|
+
if isinstance(filter_value, list):
|
|
3106
|
+
# Mimick Postgres containment operator behavior.
|
|
3107
|
+
# It would be more efficient to use set operations here,
|
|
3108
|
+
# but we can't assume that elements are hashable.
|
|
3109
|
+
# The Postgres algorithm is also O(n^2).
|
|
3110
|
+
for filter_element in filter_value:
|
|
3111
|
+
if filter_element not in metadata[key]:
|
|
3112
|
+
return False
|
|
3113
|
+
elif filter_value not in metadata[key]:
|
|
2882
3114
|
return False
|
|
2883
3115
|
else:
|
|
2884
3116
|
# Direct equality
|
|
@@ -2894,6 +3126,7 @@ async def _empty_generator():
|
|
|
2894
3126
|
|
|
2895
3127
|
|
|
2896
3128
|
__all__ = [
|
|
3129
|
+
"StreamHandler",
|
|
2897
3130
|
"Assistants",
|
|
2898
3131
|
"Crons",
|
|
2899
3132
|
"Runs",
|