langgraph-runtime-inmem 0.9.0__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem/database.py +3 -1
- langgraph_runtime_inmem/inmem_stream.py +24 -2
- langgraph_runtime_inmem/lifespan.py +41 -2
- langgraph_runtime_inmem/metrics.py +1 -1
- langgraph_runtime_inmem/ops.py +322 -229
- langgraph_runtime_inmem/queue.py +7 -16
- {langgraph_runtime_inmem-0.9.0.dist-info → langgraph_runtime_inmem-0.18.1.dist-info}/METADATA +3 -3
- langgraph_runtime_inmem-0.18.1.dist-info/RECORD +13 -0
- langgraph_runtime_inmem-0.9.0.dist-info/RECORD +0 -13
- {langgraph_runtime_inmem-0.9.0.dist-info → langgraph_runtime_inmem-0.18.1.dist-info}/WHEEL +0 -0
|
@@ -142,7 +142,9 @@ class InMemConnectionProto:
|
|
|
142
142
|
|
|
143
143
|
|
|
144
144
|
@asynccontextmanager
|
|
145
|
-
async def connect(
|
|
145
|
+
async def connect(
|
|
146
|
+
*, supports_core_api: bool = False, __test__: bool = False
|
|
147
|
+
) -> AsyncIterator["AsyncConnectionProto"]:
|
|
146
148
|
yield InMemConnectionProto()
|
|
147
149
|
|
|
148
150
|
|
|
@@ -58,6 +58,7 @@ class StreamManager:
|
|
|
58
58
|
) # Dict[str, List[asyncio.Queue]]
|
|
59
59
|
self.control_keys = defaultdict(lambda: defaultdict())
|
|
60
60
|
self.control_queues = defaultdict(lambda: defaultdict(list))
|
|
61
|
+
self.thread_streams = defaultdict(list)
|
|
61
62
|
|
|
62
63
|
self.message_stores = defaultdict(
|
|
63
64
|
lambda: defaultdict(list[Message])
|
|
@@ -95,7 +96,7 @@ class StreamManager:
|
|
|
95
96
|
|
|
96
97
|
async def put(
|
|
97
98
|
self,
|
|
98
|
-
run_id: UUID | str,
|
|
99
|
+
run_id: UUID | str | None,
|
|
99
100
|
thread_id: UUID | str | None,
|
|
100
101
|
message: Message,
|
|
101
102
|
resumable: bool = False,
|
|
@@ -107,9 +108,10 @@ class StreamManager:
|
|
|
107
108
|
thread_id = _ensure_uuid(thread_id)
|
|
108
109
|
|
|
109
110
|
message.id = _generate_ms_seq_id().encode()
|
|
111
|
+
# For resumable run streams, embed the generated message ID into the frame
|
|
112
|
+
topic = message.topic.decode()
|
|
110
113
|
if resumable:
|
|
111
114
|
self.message_stores[thread_id][run_id].append(message)
|
|
112
|
-
topic = message.topic.decode()
|
|
113
115
|
if "control" in topic:
|
|
114
116
|
self.control_keys[thread_id][run_id] = message
|
|
115
117
|
queues = self.control_queues[thread_id][run_id]
|
|
@@ -121,6 +123,20 @@ class StreamManager:
|
|
|
121
123
|
if isinstance(result, Exception):
|
|
122
124
|
logger.exception(f"Failed to put message in queue: {result}")
|
|
123
125
|
|
|
126
|
+
async def put_thread(
|
|
127
|
+
self,
|
|
128
|
+
thread_id: UUID | str,
|
|
129
|
+
message: Message,
|
|
130
|
+
) -> None:
|
|
131
|
+
thread_id = _ensure_uuid(thread_id)
|
|
132
|
+
message.id = _generate_ms_seq_id().encode()
|
|
133
|
+
queues = self.thread_streams[thread_id]
|
|
134
|
+
coros = [queue.put(message) for queue in queues]
|
|
135
|
+
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
136
|
+
for result in results:
|
|
137
|
+
if isinstance(result, Exception):
|
|
138
|
+
logger.exception(f"Failed to put message in queue: {result}")
|
|
139
|
+
|
|
124
140
|
async def add_queue(
|
|
125
141
|
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
126
142
|
) -> asyncio.Queue:
|
|
@@ -145,6 +161,12 @@ class StreamManager:
|
|
|
145
161
|
self.control_queues[thread_id][run_id].append(queue)
|
|
146
162
|
return queue
|
|
147
163
|
|
|
164
|
+
async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
|
|
165
|
+
thread_id = _ensure_uuid(thread_id)
|
|
166
|
+
queue = ContextQueue()
|
|
167
|
+
self.thread_streams[thread_id].append(queue)
|
|
168
|
+
return queue
|
|
169
|
+
|
|
148
170
|
async def remove_queue(
|
|
149
171
|
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
150
172
|
):
|
|
@@ -14,9 +14,17 @@ from langgraph_runtime_inmem.database import start_pool, stop_pool
|
|
|
14
14
|
logger = structlog.stdlib.get_logger(__name__)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
_LAST_LIFESPAN_ERROR: BaseException | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_last_error() -> BaseException | None:
|
|
21
|
+
return _LAST_LIFESPAN_ERROR
|
|
22
|
+
|
|
23
|
+
|
|
17
24
|
@asynccontextmanager
|
|
18
25
|
async def lifespan(
|
|
19
26
|
app: Starlette | None = None,
|
|
27
|
+
cancel_event: asyncio.Event | None = None,
|
|
20
28
|
taskset: set[asyncio.Task] | None = None,
|
|
21
29
|
**kwargs: Any,
|
|
22
30
|
):
|
|
@@ -41,13 +49,31 @@ async def lifespan(
|
|
|
41
49
|
except RuntimeError:
|
|
42
50
|
await logger.aerror("Failed to set loop")
|
|
43
51
|
|
|
52
|
+
global _LAST_LIFESPAN_ERROR
|
|
53
|
+
_LAST_LIFESPAN_ERROR = None
|
|
54
|
+
|
|
44
55
|
await start_http_client()
|
|
45
56
|
await start_pool()
|
|
46
57
|
await start_ui_bundler()
|
|
58
|
+
|
|
59
|
+
async def _log_graph_load_failure(err: graph.GraphLoadError) -> None:
|
|
60
|
+
cause = err.__cause__ or err.cause
|
|
61
|
+
log_fields = err.log_fields()
|
|
62
|
+
log_fields["action"] = "fix_user_graph"
|
|
63
|
+
await logger.aerror(
|
|
64
|
+
f"Graph '{err.spec.id}' failed to load: {err.cause_message}",
|
|
65
|
+
**log_fields,
|
|
66
|
+
)
|
|
67
|
+
await logger.adebug(
|
|
68
|
+
"Full graph load failure traceback (internal)",
|
|
69
|
+
**{k: v for k, v in log_fields.items() if k != "user_traceback"},
|
|
70
|
+
exc_info=cause,
|
|
71
|
+
)
|
|
72
|
+
|
|
47
73
|
try:
|
|
48
74
|
async with SimpleTaskGroup(
|
|
49
75
|
cancel=True,
|
|
50
|
-
|
|
76
|
+
cancel_event=cancel_event,
|
|
51
77
|
taskgroup_name="Lifespan",
|
|
52
78
|
) as tg:
|
|
53
79
|
tg.create_task(metadata_loop())
|
|
@@ -76,11 +102,21 @@ async def lifespan(
|
|
|
76
102
|
var_child_runnable_config.set(langgraph_config)
|
|
77
103
|
|
|
78
104
|
# Keep after the setter above so users can access the store from within the factory function
|
|
79
|
-
|
|
105
|
+
try:
|
|
106
|
+
await graph.collect_graphs_from_env(True)
|
|
107
|
+
except graph.GraphLoadError as exc:
|
|
108
|
+
_LAST_LIFESPAN_ERROR = exc
|
|
109
|
+
await _log_graph_load_failure(exc)
|
|
110
|
+
raise
|
|
80
111
|
if config.N_JOBS_PER_WORKER > 0:
|
|
81
112
|
tg.create_task(queue_with_signal())
|
|
82
113
|
|
|
83
114
|
yield
|
|
115
|
+
except graph.GraphLoadError as exc:
|
|
116
|
+
_LAST_LIFESPAN_ERROR = exc
|
|
117
|
+
raise
|
|
118
|
+
except asyncio.CancelledError:
|
|
119
|
+
pass
|
|
84
120
|
finally:
|
|
85
121
|
await api_store.exit_store()
|
|
86
122
|
await stop_ui_bundler()
|
|
@@ -97,3 +133,6 @@ async def queue_with_signal():
|
|
|
97
133
|
except Exception as exc:
|
|
98
134
|
logger.exception("Queue failed. Signaling shutdown", exc_info=exc)
|
|
99
135
|
signal.raise_signal(signal.SIGINT)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
lifespan.get_last_error = get_last_error # type: ignore[attr-defined]
|
langgraph_runtime_inmem/ops.py
CHANGED
|
@@ -11,7 +11,6 @@ import uuid
|
|
|
11
11
|
from collections import defaultdict
|
|
12
12
|
from collections.abc import AsyncIterator, Sequence
|
|
13
13
|
from contextlib import asynccontextmanager
|
|
14
|
-
from copy import deepcopy
|
|
15
14
|
from datetime import UTC, datetime, timedelta
|
|
16
15
|
from typing import Any, Literal, cast
|
|
17
16
|
from uuid import UUID, uuid4
|
|
@@ -29,6 +28,7 @@ from langgraph_runtime_inmem.checkpoint import Checkpointer
|
|
|
29
28
|
from langgraph_runtime_inmem.database import InMemConnectionProto, connect
|
|
30
29
|
from langgraph_runtime_inmem.inmem_stream import (
|
|
31
30
|
THREADLESS_KEY,
|
|
31
|
+
ContextQueue,
|
|
32
32
|
Message,
|
|
33
33
|
get_stream_manager,
|
|
34
34
|
)
|
|
@@ -58,12 +58,13 @@ if typing.TYPE_CHECKING:
|
|
|
58
58
|
Thread,
|
|
59
59
|
ThreadSelectField,
|
|
60
60
|
ThreadStatus,
|
|
61
|
+
ThreadStreamMode,
|
|
61
62
|
ThreadUpdateResponse,
|
|
62
63
|
)
|
|
63
64
|
from langgraph_api.schema import Interrupt as InterruptSchema
|
|
64
|
-
from langgraph_api.serde import Fragment
|
|
65
65
|
from langgraph_api.utils import AsyncConnectionProto
|
|
66
66
|
|
|
67
|
+
StreamHandler = ContextQueue
|
|
67
68
|
|
|
68
69
|
logger = structlog.stdlib.get_logger(__name__)
|
|
69
70
|
|
|
@@ -228,7 +229,7 @@ class Assistants(Authenticated):
|
|
|
228
229
|
if assistant["assistant_id"] == assistant_id and (
|
|
229
230
|
not filters or _check_filter_match(assistant["metadata"], filters)
|
|
230
231
|
):
|
|
231
|
-
yield assistant
|
|
232
|
+
yield copy.deepcopy(assistant)
|
|
232
233
|
|
|
233
234
|
return _yield_result()
|
|
234
235
|
|
|
@@ -247,6 +248,8 @@ class Assistants(Authenticated):
|
|
|
247
248
|
description: str | None = None,
|
|
248
249
|
) -> AsyncIterator[Assistant]:
|
|
249
250
|
"""Insert an assistant."""
|
|
251
|
+
from langgraph_api.graph import GRAPHS
|
|
252
|
+
|
|
250
253
|
assistant_id = _ensure_uuid(assistant_id)
|
|
251
254
|
metadata = metadata if metadata is not None else {}
|
|
252
255
|
filters = await Assistants.handle_event(
|
|
@@ -268,6 +271,9 @@ class Assistants(Authenticated):
|
|
|
268
271
|
detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
|
|
269
272
|
)
|
|
270
273
|
|
|
274
|
+
if graph_id not in GRAPHS:
|
|
275
|
+
raise HTTPException(status_code=404, detail=f"Graph {graph_id} not found")
|
|
276
|
+
|
|
271
277
|
# Keep config and context up to date with one another
|
|
272
278
|
if config.get("configurable"):
|
|
273
279
|
context = config["configurable"]
|
|
@@ -555,6 +561,8 @@ class Assistants(Authenticated):
|
|
|
555
561
|
"metadata": version_data["metadata"],
|
|
556
562
|
"version": version_data["version"],
|
|
557
563
|
"updated_at": datetime.now(UTC),
|
|
564
|
+
"name": version_data["name"],
|
|
565
|
+
"description": version_data["description"],
|
|
558
566
|
}
|
|
559
567
|
)
|
|
560
568
|
|
|
@@ -738,6 +746,7 @@ class Threads(Authenticated):
|
|
|
738
746
|
async def search(
|
|
739
747
|
conn: InMemConnectionProto,
|
|
740
748
|
*,
|
|
749
|
+
ids: list[str] | list[UUID] | None = None,
|
|
741
750
|
metadata: MetadataInput,
|
|
742
751
|
values: MetadataInput,
|
|
743
752
|
status: ThreadStatus | None,
|
|
@@ -765,7 +774,19 @@ class Threads(Authenticated):
|
|
|
765
774
|
)
|
|
766
775
|
|
|
767
776
|
# Apply filters
|
|
777
|
+
id_set: set[UUID] | None = None
|
|
778
|
+
if ids:
|
|
779
|
+
id_set = set()
|
|
780
|
+
for i in ids:
|
|
781
|
+
try:
|
|
782
|
+
id_set.add(_ensure_uuid(i))
|
|
783
|
+
except Exception:
|
|
784
|
+
raise HTTPException(
|
|
785
|
+
status_code=400, detail="Invalid thread ID " + str(i)
|
|
786
|
+
) from None
|
|
768
787
|
for thread in threads:
|
|
788
|
+
if id_set is not None and thread.get("thread_id") not in id_set:
|
|
789
|
+
continue
|
|
769
790
|
if filters and not _check_filter_match(thread["metadata"], filters):
|
|
770
791
|
continue
|
|
771
792
|
|
|
@@ -943,6 +964,7 @@ class Threads(Authenticated):
|
|
|
943
964
|
thread_id: UUID,
|
|
944
965
|
*,
|
|
945
966
|
metadata: MetadataValue,
|
|
967
|
+
ttl: ThreadTTLConfig | None = None,
|
|
946
968
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
947
969
|
) -> AsyncIterator[Thread]:
|
|
948
970
|
"""Update a thread."""
|
|
@@ -1215,13 +1237,23 @@ class Threads(Authenticated):
|
|
|
1215
1237
|
"""Create a copy of an existing thread."""
|
|
1216
1238
|
thread_id = _ensure_uuid(thread_id)
|
|
1217
1239
|
new_thread_id = uuid4()
|
|
1218
|
-
|
|
1240
|
+
read_filters = await Threads.handle_event(
|
|
1219
1241
|
ctx,
|
|
1220
1242
|
"read",
|
|
1221
1243
|
Auth.types.ThreadsRead(
|
|
1244
|
+
thread_id=thread_id,
|
|
1245
|
+
),
|
|
1246
|
+
)
|
|
1247
|
+
# Assert that the user has permissions to create a new thread.
|
|
1248
|
+
# (We don't actually need the filters.)
|
|
1249
|
+
await Threads.handle_event(
|
|
1250
|
+
ctx,
|
|
1251
|
+
"create",
|
|
1252
|
+
Auth.types.ThreadsCreate(
|
|
1222
1253
|
thread_id=new_thread_id,
|
|
1223
1254
|
),
|
|
1224
1255
|
)
|
|
1256
|
+
|
|
1225
1257
|
async with conn.pipeline():
|
|
1226
1258
|
# Find the original thread in our store
|
|
1227
1259
|
original_thread = next(
|
|
@@ -1230,8 +1262,8 @@ class Threads(Authenticated):
|
|
|
1230
1262
|
|
|
1231
1263
|
if not original_thread:
|
|
1232
1264
|
return _empty_generator()
|
|
1233
|
-
if
|
|
1234
|
-
original_thread["metadata"],
|
|
1265
|
+
if read_filters and not _check_filter_match(
|
|
1266
|
+
original_thread["metadata"], read_filters
|
|
1235
1267
|
):
|
|
1236
1268
|
return _empty_generator()
|
|
1237
1269
|
|
|
@@ -1240,7 +1272,7 @@ class Threads(Authenticated):
|
|
|
1240
1272
|
"thread_id": new_thread_id,
|
|
1241
1273
|
"created_at": datetime.now(tz=UTC),
|
|
1242
1274
|
"updated_at": datetime.now(tz=UTC),
|
|
1243
|
-
"metadata": deepcopy(original_thread["metadata"]),
|
|
1275
|
+
"metadata": copy.deepcopy(original_thread["metadata"]),
|
|
1244
1276
|
"status": "idle",
|
|
1245
1277
|
"config": {},
|
|
1246
1278
|
}
|
|
@@ -1327,7 +1359,14 @@ class Threads(Authenticated):
|
|
|
1327
1359
|
)
|
|
1328
1360
|
|
|
1329
1361
|
metadata = thread.get("metadata", {})
|
|
1330
|
-
thread_config = thread.get("config", {})
|
|
1362
|
+
thread_config = cast(dict[str, Any], thread.get("config", {}))
|
|
1363
|
+
thread_config = {
|
|
1364
|
+
**thread_config,
|
|
1365
|
+
"configurable": {
|
|
1366
|
+
**thread_config.get("configurable", {}),
|
|
1367
|
+
**config.get("configurable", {}),
|
|
1368
|
+
},
|
|
1369
|
+
}
|
|
1331
1370
|
|
|
1332
1371
|
# Fallback to graph_id from run if not in thread metadata
|
|
1333
1372
|
graph_id = metadata.get("graph_id")
|
|
@@ -1377,6 +1416,7 @@ class Threads(Authenticated):
|
|
|
1377
1416
|
"""Add state to a thread."""
|
|
1378
1417
|
from langgraph_api.graph import get_graph
|
|
1379
1418
|
from langgraph_api.schema import ThreadUpdateResponse
|
|
1419
|
+
from langgraph_api.state import state_snapshot_to_thread_state
|
|
1380
1420
|
from langgraph_api.store import get_store
|
|
1381
1421
|
from langgraph_api.utils import fetchone
|
|
1382
1422
|
|
|
@@ -1414,6 +1454,13 @@ class Threads(Authenticated):
|
|
|
1414
1454
|
status_code=409,
|
|
1415
1455
|
detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
|
|
1416
1456
|
)
|
|
1457
|
+
thread_config = {
|
|
1458
|
+
**thread_config,
|
|
1459
|
+
"configurable": {
|
|
1460
|
+
**thread_config.get("configurable", {}),
|
|
1461
|
+
**config.get("configurable", {}),
|
|
1462
|
+
},
|
|
1463
|
+
}
|
|
1417
1464
|
|
|
1418
1465
|
# Fallback to graph_id from run if not in thread metadata
|
|
1419
1466
|
graph_id = metadata.get("graph_id")
|
|
@@ -1454,6 +1501,19 @@ class Threads(Authenticated):
|
|
|
1454
1501
|
thread["values"] = state.values
|
|
1455
1502
|
break
|
|
1456
1503
|
|
|
1504
|
+
# Publish state update event
|
|
1505
|
+
from langgraph_api.serde import json_dumpb
|
|
1506
|
+
|
|
1507
|
+
event_data = {
|
|
1508
|
+
"state": state_snapshot_to_thread_state(state),
|
|
1509
|
+
"thread_id": str(thread_id),
|
|
1510
|
+
}
|
|
1511
|
+
await Threads.Stream.publish(
|
|
1512
|
+
thread_id,
|
|
1513
|
+
"state_update",
|
|
1514
|
+
json_dumpb(event_data),
|
|
1515
|
+
)
|
|
1516
|
+
|
|
1457
1517
|
return ThreadUpdateResponse(
|
|
1458
1518
|
checkpoint=next_config["configurable"],
|
|
1459
1519
|
# Including deprecated fields
|
|
@@ -1496,7 +1556,14 @@ class Threads(Authenticated):
|
|
|
1496
1556
|
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
1497
1557
|
)
|
|
1498
1558
|
|
|
1499
|
-
thread_config = thread["config"]
|
|
1559
|
+
thread_config = cast(dict[str, Any], thread["config"])
|
|
1560
|
+
thread_config = {
|
|
1561
|
+
**thread_config,
|
|
1562
|
+
"configurable": {
|
|
1563
|
+
**thread_config.get("configurable", {}),
|
|
1564
|
+
**config.get("configurable", {}),
|
|
1565
|
+
},
|
|
1566
|
+
}
|
|
1500
1567
|
metadata = thread["metadata"]
|
|
1501
1568
|
|
|
1502
1569
|
if not thread:
|
|
@@ -1543,6 +1610,19 @@ class Threads(Authenticated):
|
|
|
1543
1610
|
thread["values"] = state.values
|
|
1544
1611
|
break
|
|
1545
1612
|
|
|
1613
|
+
# Publish state update event
|
|
1614
|
+
from langgraph_api.serde import json_dumpb
|
|
1615
|
+
|
|
1616
|
+
event_data = {
|
|
1617
|
+
"state": state,
|
|
1618
|
+
"thread_id": str(thread_id),
|
|
1619
|
+
}
|
|
1620
|
+
await Threads.Stream.publish(
|
|
1621
|
+
thread_id,
|
|
1622
|
+
"state_update",
|
|
1623
|
+
json_dumpb(event_data),
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1546
1626
|
return ThreadUpdateResponse(
|
|
1547
1627
|
checkpoint=next_config["configurable"],
|
|
1548
1628
|
)
|
|
@@ -1584,7 +1664,14 @@ class Threads(Authenticated):
|
|
|
1584
1664
|
if not _check_filter_match(thread_metadata, filters):
|
|
1585
1665
|
return []
|
|
1586
1666
|
|
|
1587
|
-
thread_config = thread["config"]
|
|
1667
|
+
thread_config = cast(dict[str, Any], thread["config"])
|
|
1668
|
+
thread_config = {
|
|
1669
|
+
**thread_config,
|
|
1670
|
+
"configurable": {
|
|
1671
|
+
**thread_config.get("configurable", {}),
|
|
1672
|
+
**config.get("configurable", {}),
|
|
1673
|
+
},
|
|
1674
|
+
}
|
|
1588
1675
|
# If graph_id exists, get state history
|
|
1589
1676
|
if graph_id := thread_metadata.get("graph_id"):
|
|
1590
1677
|
async with get_graph(
|
|
@@ -1613,7 +1700,9 @@ class Threads(Authenticated):
|
|
|
1613
1700
|
|
|
1614
1701
|
return []
|
|
1615
1702
|
|
|
1616
|
-
class Stream:
|
|
1703
|
+
class Stream(Authenticated):
|
|
1704
|
+
resource = "threads"
|
|
1705
|
+
|
|
1617
1706
|
@staticmethod
|
|
1618
1707
|
async def subscribe(
|
|
1619
1708
|
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
@@ -1626,6 +1715,13 @@ class Threads(Authenticated):
|
|
|
1626
1715
|
|
|
1627
1716
|
# Create new queues only for runs not yet seen
|
|
1628
1717
|
thread_id = _ensure_uuid(thread_id)
|
|
1718
|
+
|
|
1719
|
+
# Add thread stream queue
|
|
1720
|
+
if thread_id not in seen_runs:
|
|
1721
|
+
queue = await stream_manager.add_thread_stream(thread_id)
|
|
1722
|
+
queues.append((thread_id, queue))
|
|
1723
|
+
seen_runs.add(thread_id)
|
|
1724
|
+
|
|
1629
1725
|
for run in conn.store["runs"]:
|
|
1630
1726
|
if run["thread_id"] == thread_id:
|
|
1631
1727
|
run_id = run["run_id"]
|
|
@@ -1641,9 +1737,32 @@ class Threads(Authenticated):
|
|
|
1641
1737
|
thread_id: UUID,
|
|
1642
1738
|
*,
|
|
1643
1739
|
last_event_id: str | None = None,
|
|
1740
|
+
stream_modes: list[ThreadStreamMode],
|
|
1741
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1644
1742
|
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
1645
1743
|
"""Stream the thread output."""
|
|
1646
|
-
|
|
1744
|
+
await Threads.Stream.check_thread_stream_auth(thread_id, ctx)
|
|
1745
|
+
|
|
1746
|
+
from langgraph_api.utils.stream_codec import (
|
|
1747
|
+
decode_stream_message,
|
|
1748
|
+
)
|
|
1749
|
+
|
|
1750
|
+
def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
|
|
1751
|
+
"""Check if an event should be filtered out based on stream_modes."""
|
|
1752
|
+
if "run_modes" in stream_modes and event_name != "state_update":
|
|
1753
|
+
return False
|
|
1754
|
+
if "state_update" in stream_modes and event_name == "state_update":
|
|
1755
|
+
return False
|
|
1756
|
+
if "lifecycle" in stream_modes and event_name == "metadata":
|
|
1757
|
+
try:
|
|
1758
|
+
message_data = orjson.loads(message_bytes)
|
|
1759
|
+
if message_data.get("status") == "run_done":
|
|
1760
|
+
return False
|
|
1761
|
+
if "attempt" in message_data and "run_id" in message_data:
|
|
1762
|
+
return False
|
|
1763
|
+
except (orjson.JSONDecodeError, TypeError):
|
|
1764
|
+
pass
|
|
1765
|
+
return True
|
|
1647
1766
|
|
|
1648
1767
|
stream_manager = get_stream_manager()
|
|
1649
1768
|
seen_runs: set[UUID] = set()
|
|
@@ -1673,23 +1792,24 @@ class Threads(Authenticated):
|
|
|
1673
1792
|
|
|
1674
1793
|
# Yield sorted events
|
|
1675
1794
|
for message, run_id in all_events:
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
message.id,
|
|
1795
|
+
decoded = decode_stream_message(
|
|
1796
|
+
message.data, channel=message.topic
|
|
1797
|
+
)
|
|
1798
|
+
event_bytes = decoded.event_bytes
|
|
1799
|
+
message_bytes = decoded.message_bytes
|
|
1800
|
+
|
|
1801
|
+
if event_bytes == b"control":
|
|
1802
|
+
if message_bytes == b"done":
|
|
1803
|
+
event_bytes = b"metadata"
|
|
1804
|
+
message_bytes = orjson.dumps(
|
|
1805
|
+
{"status": "run_done", "run_id": run_id}
|
|
1688
1806
|
)
|
|
1689
|
-
|
|
1807
|
+
if not should_filter_event(
|
|
1808
|
+
event_bytes.decode("utf-8"), message_bytes
|
|
1809
|
+
):
|
|
1690
1810
|
yield (
|
|
1691
|
-
|
|
1692
|
-
|
|
1811
|
+
event_bytes,
|
|
1812
|
+
message_bytes,
|
|
1693
1813
|
message.id,
|
|
1694
1814
|
)
|
|
1695
1815
|
|
|
@@ -1708,28 +1828,27 @@ class Threads(Authenticated):
|
|
|
1708
1828
|
message = await asyncio.wait_for(
|
|
1709
1829
|
queue.get(), timeout=0.2
|
|
1710
1830
|
)
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
),
|
|
1725
|
-
message.id,
|
|
1726
|
-
)
|
|
1727
|
-
else:
|
|
1728
|
-
yield (
|
|
1729
|
-
event_name.encode(),
|
|
1730
|
-
base64.b64decode(message_content),
|
|
1731
|
-
message.id,
|
|
1831
|
+
decoded = decode_stream_message(
|
|
1832
|
+
message.data, channel=message.topic
|
|
1833
|
+
)
|
|
1834
|
+
event = decoded.event_bytes
|
|
1835
|
+
event_name = event.decode("utf-8")
|
|
1836
|
+
payload = decoded.message_bytes
|
|
1837
|
+
|
|
1838
|
+
if event == b"control" and payload == b"done":
|
|
1839
|
+
topic = message.topic.decode()
|
|
1840
|
+
run_id = topic.split("run:")[1].split(":")[0]
|
|
1841
|
+
meta_event = b"metadata"
|
|
1842
|
+
meta_payload = orjson.dumps(
|
|
1843
|
+
{"status": "run_done", "run_id": run_id}
|
|
1732
1844
|
)
|
|
1845
|
+
if not should_filter_event(
|
|
1846
|
+
"metadata", meta_payload
|
|
1847
|
+
):
|
|
1848
|
+
yield (meta_event, meta_payload, message.id)
|
|
1849
|
+
else:
|
|
1850
|
+
if not should_filter_event(event_name, payload):
|
|
1851
|
+
yield (event, payload, message.id)
|
|
1733
1852
|
|
|
1734
1853
|
except TimeoutError:
|
|
1735
1854
|
continue
|
|
@@ -1758,6 +1877,41 @@ class Threads(Authenticated):
|
|
|
1758
1877
|
# Ignore cleanup errors
|
|
1759
1878
|
pass
|
|
1760
1879
|
|
|
1880
|
+
@staticmethod
|
|
1881
|
+
async def publish(
|
|
1882
|
+
thread_id: UUID | str,
|
|
1883
|
+
event: str,
|
|
1884
|
+
message: bytes,
|
|
1885
|
+
) -> None:
|
|
1886
|
+
"""Publish a thread-level event to the thread stream."""
|
|
1887
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
1888
|
+
|
|
1889
|
+
topic = f"thread:{thread_id}:stream".encode()
|
|
1890
|
+
|
|
1891
|
+
stream_manager = get_stream_manager()
|
|
1892
|
+
payload = STREAM_CODEC.encode(event, message)
|
|
1893
|
+
await stream_manager.put_thread(
|
|
1894
|
+
str(thread_id), Message(topic=topic, data=payload)
|
|
1895
|
+
)
|
|
1896
|
+
|
|
1897
|
+
@staticmethod
|
|
1898
|
+
async def check_thread_stream_auth(
|
|
1899
|
+
thread_id: UUID,
|
|
1900
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1901
|
+
) -> None:
|
|
1902
|
+
async with connect() as conn:
|
|
1903
|
+
filters = await Threads.Stream.handle_event(
|
|
1904
|
+
ctx,
|
|
1905
|
+
"read",
|
|
1906
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
1907
|
+
)
|
|
1908
|
+
if filters:
|
|
1909
|
+
thread = await Threads._get_with_filters(
|
|
1910
|
+
cast(InMemConnectionProto, conn), thread_id, filters
|
|
1911
|
+
)
|
|
1912
|
+
if not thread:
|
|
1913
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
1914
|
+
|
|
1761
1915
|
@staticmethod
|
|
1762
1916
|
async def count(
|
|
1763
1917
|
conn: InMemConnectionProto,
|
|
@@ -1821,38 +1975,37 @@ class Runs(Authenticated):
|
|
|
1821
1975
|
if not pending_runs and not running_runs:
|
|
1822
1976
|
return {
|
|
1823
1977
|
"n_pending": 0,
|
|
1824
|
-
"
|
|
1825
|
-
"
|
|
1978
|
+
"pending_runs_wait_time_max_secs": None,
|
|
1979
|
+
"pending_runs_wait_time_med_secs": None,
|
|
1826
1980
|
"n_running": 0,
|
|
1827
1981
|
}
|
|
1828
1982
|
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1837
|
-
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
median_time = sorted_times[median_idx]
|
|
1983
|
+
now = datetime.now(UTC)
|
|
1984
|
+
pending_waits: list[float] = []
|
|
1985
|
+
for run in pending_runs:
|
|
1986
|
+
created_at = run.get("created_at")
|
|
1987
|
+
if not isinstance(created_at, datetime):
|
|
1988
|
+
continue
|
|
1989
|
+
if created_at.tzinfo is None:
|
|
1990
|
+
created_at = created_at.replace(tzinfo=UTC)
|
|
1991
|
+
pending_waits.append((now - created_at).total_seconds())
|
|
1992
|
+
|
|
1993
|
+
max_pending_wait = max(pending_waits) if pending_waits else None
|
|
1994
|
+
if pending_waits:
|
|
1995
|
+
sorted_waits = sorted(pending_waits)
|
|
1996
|
+
half = len(sorted_waits) // 2
|
|
1997
|
+
if len(sorted_waits) % 2 == 1:
|
|
1998
|
+
med_pending_wait = sorted_waits[half]
|
|
1999
|
+
else:
|
|
2000
|
+
med_pending_wait = (sorted_waits[half - 1] + sorted_waits[half]) / 2
|
|
2001
|
+
else:
|
|
2002
|
+
med_pending_wait = None
|
|
1850
2003
|
|
|
1851
2004
|
return {
|
|
1852
2005
|
"n_pending": len(pending_runs),
|
|
1853
2006
|
"n_running": len(running_runs),
|
|
1854
|
-
"
|
|
1855
|
-
"
|
|
2007
|
+
"pending_runs_wait_time_max_secs": max_pending_wait,
|
|
2008
|
+
"pending_runs_wait_time_med_secs": med_pending_wait,
|
|
1856
2009
|
}
|
|
1857
2010
|
|
|
1858
2011
|
@staticmethod
|
|
@@ -1916,12 +2069,16 @@ class Runs(Authenticated):
|
|
|
1916
2069
|
@asynccontextmanager
|
|
1917
2070
|
@staticmethod
|
|
1918
2071
|
async def enter(
|
|
1919
|
-
run_id: UUID,
|
|
2072
|
+
run_id: UUID,
|
|
2073
|
+
thread_id: UUID | None,
|
|
2074
|
+
loop: asyncio.AbstractEventLoop,
|
|
2075
|
+
resumable: bool,
|
|
1920
2076
|
) -> AsyncIterator[ValueEvent]:
|
|
1921
2077
|
"""Enter a run, listen for cancellation while running, signal when done."
|
|
1922
2078
|
This method should be called as a context manager by a worker executing a run.
|
|
1923
2079
|
"""
|
|
1924
2080
|
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
|
|
2081
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
1925
2082
|
|
|
1926
2083
|
stream_manager = get_stream_manager()
|
|
1927
2084
|
# Get control queue for this run (normal queue is created during run creation)
|
|
@@ -1941,12 +2098,14 @@ class Runs(Authenticated):
|
|
|
1941
2098
|
)
|
|
1942
2099
|
await stream_manager.put(run_id, thread_id, control_message)
|
|
1943
2100
|
|
|
1944
|
-
# Signal done to all subscribers
|
|
2101
|
+
# Signal done to all subscribers using stream codec
|
|
1945
2102
|
stream_message = Message(
|
|
1946
2103
|
topic=f"run:{run_id}:stream".encode(),
|
|
1947
|
-
data=
|
|
2104
|
+
data=STREAM_CODEC.encode("control", b"done"),
|
|
2105
|
+
)
|
|
2106
|
+
await stream_manager.put(
|
|
2107
|
+
run_id, thread_id, stream_message, resumable=resumable
|
|
1948
2108
|
)
|
|
1949
|
-
await stream_manager.put(run_id, thread_id, stream_message)
|
|
1950
2109
|
|
|
1951
2110
|
# Remove the control_queue (normal queue is cleaned up during run deletion)
|
|
1952
2111
|
await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
|
|
@@ -1988,7 +2147,6 @@ class Runs(Authenticated):
|
|
|
1988
2147
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
1989
2148
|
) -> AsyncIterator[Run]:
|
|
1990
2149
|
"""Create a run."""
|
|
1991
|
-
from langgraph_api.config import FF_RICH_THREADS
|
|
1992
2150
|
from langgraph_api.schema import Run, Thread
|
|
1993
2151
|
|
|
1994
2152
|
assistant_id = _ensure_uuid(assistant_id)
|
|
@@ -2004,6 +2162,7 @@ class Runs(Authenticated):
|
|
|
2004
2162
|
run_id = _ensure_uuid(run_id) if run_id else None
|
|
2005
2163
|
metadata = metadata if metadata is not None else {}
|
|
2006
2164
|
config = kwargs.get("config", {})
|
|
2165
|
+
temporary = kwargs.get("temporary", False)
|
|
2007
2166
|
|
|
2008
2167
|
# Handle thread creation/update
|
|
2009
2168
|
existing_thread = next(
|
|
@@ -2013,7 +2172,7 @@ class Runs(Authenticated):
|
|
|
2013
2172
|
ctx,
|
|
2014
2173
|
"create_run",
|
|
2015
2174
|
Auth.types.RunsCreate(
|
|
2016
|
-
thread_id=thread_id,
|
|
2175
|
+
thread_id=None if temporary else thread_id,
|
|
2017
2176
|
assistant_id=assistant_id,
|
|
2018
2177
|
run_id=run_id,
|
|
2019
2178
|
status=status,
|
|
@@ -2034,49 +2193,35 @@ class Runs(Authenticated):
|
|
|
2034
2193
|
# Create new thread
|
|
2035
2194
|
if thread_id is None:
|
|
2036
2195
|
thread_id = uuid4()
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
},
|
|
2055
|
-
),
|
|
2056
|
-
created_at=datetime.now(UTC),
|
|
2057
|
-
updated_at=datetime.now(UTC),
|
|
2058
|
-
values=b"",
|
|
2059
|
-
)
|
|
2060
|
-
else:
|
|
2061
|
-
thread = Thread(
|
|
2062
|
-
thread_id=thread_id,
|
|
2063
|
-
status="idle",
|
|
2064
|
-
metadata={
|
|
2065
|
-
"graph_id": assistant["graph_id"],
|
|
2066
|
-
"assistant_id": str(assistant_id),
|
|
2067
|
-
**(config.get("metadata") or {}),
|
|
2068
|
-
**metadata,
|
|
2196
|
+
|
|
2197
|
+
thread = Thread(
|
|
2198
|
+
thread_id=thread_id,
|
|
2199
|
+
status="busy",
|
|
2200
|
+
metadata={
|
|
2201
|
+
"graph_id": assistant["graph_id"],
|
|
2202
|
+
"assistant_id": str(assistant_id),
|
|
2203
|
+
**(config.get("metadata") or {}),
|
|
2204
|
+
**metadata,
|
|
2205
|
+
},
|
|
2206
|
+
config=Runs._merge_jsonb(
|
|
2207
|
+
assistant["config"],
|
|
2208
|
+
config,
|
|
2209
|
+
{
|
|
2210
|
+
"configurable": Runs._merge_jsonb(
|
|
2211
|
+
Runs._get_configurable(assistant["config"]),
|
|
2212
|
+
)
|
|
2069
2213
|
},
|
|
2070
|
-
|
|
2071
|
-
|
|
2072
|
-
|
|
2073
|
-
|
|
2074
|
-
|
|
2214
|
+
),
|
|
2215
|
+
created_at=datetime.now(UTC),
|
|
2216
|
+
updated_at=datetime.now(UTC),
|
|
2217
|
+
values=b"",
|
|
2218
|
+
)
|
|
2219
|
+
|
|
2075
2220
|
await logger.ainfo("Creating thread", thread_id=thread_id)
|
|
2076
2221
|
conn.store["threads"].append(thread)
|
|
2077
2222
|
elif existing_thread:
|
|
2078
2223
|
# Update existing thread
|
|
2079
|
-
if
|
|
2224
|
+
if existing_thread["status"] != "busy":
|
|
2080
2225
|
existing_thread["status"] = "busy"
|
|
2081
2226
|
existing_thread["metadata"] = Runs._merge_jsonb(
|
|
2082
2227
|
existing_thread["metadata"],
|
|
@@ -2253,66 +2398,6 @@ class Runs(Authenticated):
|
|
|
2253
2398
|
|
|
2254
2399
|
return _yield_deleted()
|
|
2255
2400
|
|
|
2256
|
-
@staticmethod
|
|
2257
|
-
async def join(
|
|
2258
|
-
run_id: UUID,
|
|
2259
|
-
*,
|
|
2260
|
-
thread_id: UUID,
|
|
2261
|
-
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2262
|
-
) -> Fragment:
|
|
2263
|
-
"""Wait for a run to complete. If already done, return immediately.
|
|
2264
|
-
|
|
2265
|
-
Returns:
|
|
2266
|
-
the final state of the run.
|
|
2267
|
-
"""
|
|
2268
|
-
from langgraph_api.serde import Fragment
|
|
2269
|
-
from langgraph_api.utils import fetchone
|
|
2270
|
-
|
|
2271
|
-
async with connect() as conn:
|
|
2272
|
-
# Validate ownership
|
|
2273
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
2274
|
-
await fetchone(thread_iter)
|
|
2275
|
-
last_chunk: bytes | None = None
|
|
2276
|
-
# wait for the run to complete
|
|
2277
|
-
# Rely on this join's auth
|
|
2278
|
-
async for mode, chunk, _ in Runs.Stream.join(
|
|
2279
|
-
run_id,
|
|
2280
|
-
thread_id=thread_id,
|
|
2281
|
-
ctx=ctx,
|
|
2282
|
-
ignore_404=True,
|
|
2283
|
-
stream_mode=["values", "updates", "error"],
|
|
2284
|
-
):
|
|
2285
|
-
if mode == b"values":
|
|
2286
|
-
last_chunk = chunk
|
|
2287
|
-
elif mode == b"updates" and b"__interrupt__" in chunk:
|
|
2288
|
-
last_chunk = chunk
|
|
2289
|
-
elif mode == b"error":
|
|
2290
|
-
last_chunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
|
|
2291
|
-
# if we received a final chunk, return it
|
|
2292
|
-
if last_chunk is not None:
|
|
2293
|
-
# ie. if the run completed while we were waiting for it
|
|
2294
|
-
return Fragment(last_chunk)
|
|
2295
|
-
else:
|
|
2296
|
-
# otherwise, the run had already finished, so fetch the state from thread
|
|
2297
|
-
async with connect() as conn:
|
|
2298
|
-
thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
|
|
2299
|
-
thread = await fetchone(thread_iter)
|
|
2300
|
-
if thread["status"] == "error":
|
|
2301
|
-
return Fragment(
|
|
2302
|
-
orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
|
|
2303
|
-
)
|
|
2304
|
-
if thread["status"] == "interrupted":
|
|
2305
|
-
# Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
|
|
2306
|
-
# interrupt, but we'll always show one. Long term we should show all of them.
|
|
2307
|
-
try:
|
|
2308
|
-
interrupt_map = thread["interrupts"]
|
|
2309
|
-
interrupt = [next(iter(interrupt_map.values()))[0]]
|
|
2310
|
-
return Fragment(orjson.dumps({"__interrupt__": interrupt}))
|
|
2311
|
-
except Exception:
|
|
2312
|
-
# No interrupt, but status is interrupted from a before/after block. Default back to values.
|
|
2313
|
-
pass
|
|
2314
|
-
return thread["values"]
|
|
2315
|
-
|
|
2316
2401
|
@staticmethod
|
|
2317
2402
|
async def cancel(
|
|
2318
2403
|
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
@@ -2538,7 +2623,7 @@ class Runs(Authenticated):
|
|
|
2538
2623
|
async def subscribe(
|
|
2539
2624
|
run_id: UUID,
|
|
2540
2625
|
thread_id: UUID | None = None,
|
|
2541
|
-
) ->
|
|
2626
|
+
) -> ContextQueue:
|
|
2542
2627
|
"""Subscribe to the run stream, returning a queue."""
|
|
2543
2628
|
stream_manager = get_stream_manager()
|
|
2544
2629
|
queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
|
|
@@ -2562,54 +2647,38 @@ class Runs(Authenticated):
|
|
|
2562
2647
|
async def join(
|
|
2563
2648
|
run_id: UUID,
|
|
2564
2649
|
*,
|
|
2650
|
+
stream_channel: asyncio.Queue,
|
|
2565
2651
|
thread_id: UUID,
|
|
2566
2652
|
ignore_404: bool = False,
|
|
2567
2653
|
cancel_on_disconnect: bool = False,
|
|
2568
|
-
stream_channel: asyncio.Queue | None = None,
|
|
2569
2654
|
stream_mode: list[StreamMode] | StreamMode | None = None,
|
|
2570
2655
|
last_event_id: str | None = None,
|
|
2571
2656
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2572
2657
|
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
2573
2658
|
"""Stream the run output."""
|
|
2574
2659
|
from langgraph_api.asyncio import create_task
|
|
2575
|
-
from langgraph_api.serde import
|
|
2576
|
-
|
|
2577
|
-
queue = (
|
|
2578
|
-
stream_channel
|
|
2579
|
-
if stream_channel
|
|
2580
|
-
else await Runs.Stream.subscribe(run_id, thread_id)
|
|
2581
|
-
)
|
|
2660
|
+
from langgraph_api.serde import json_dumpb
|
|
2661
|
+
from langgraph_api.utils.stream_codec import decode_stream_message
|
|
2582
2662
|
|
|
2663
|
+
queue = stream_channel
|
|
2583
2664
|
try:
|
|
2584
2665
|
async with connect() as conn:
|
|
2585
|
-
|
|
2586
|
-
ctx
|
|
2587
|
-
|
|
2588
|
-
|
|
2589
|
-
)
|
|
2590
|
-
if filters:
|
|
2591
|
-
thread = await Threads._get_with_filters(
|
|
2592
|
-
cast(InMemConnectionProto, conn), thread_id, filters
|
|
2593
|
-
)
|
|
2594
|
-
if not thread:
|
|
2595
|
-
raise WrappedHTTPException(
|
|
2596
|
-
HTTPException(
|
|
2597
|
-
status_code=404, detail="Thread not found"
|
|
2598
|
-
)
|
|
2599
|
-
)
|
|
2666
|
+
try:
|
|
2667
|
+
await Runs.Stream.check_run_stream_auth(run_id, thread_id, ctx)
|
|
2668
|
+
except HTTPException as e:
|
|
2669
|
+
raise WrappedHTTPException(e) from None
|
|
2600
2670
|
run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
|
|
2601
2671
|
|
|
2602
2672
|
for message in get_stream_manager().restore_messages(
|
|
2603
2673
|
run_id, thread_id, last_event_id
|
|
2604
2674
|
):
|
|
2605
2675
|
data, id = message.data, message.id
|
|
2606
|
-
|
|
2607
|
-
|
|
2608
|
-
|
|
2609
|
-
message = data["message"]
|
|
2676
|
+
decoded = decode_stream_message(data, channel=message.topic)
|
|
2677
|
+
mode = decoded.event_bytes.decode("utf-8")
|
|
2678
|
+
payload = decoded.message_bytes
|
|
2610
2679
|
|
|
2611
2680
|
if mode == "control":
|
|
2612
|
-
if
|
|
2681
|
+
if payload == b"done":
|
|
2613
2682
|
return
|
|
2614
2683
|
elif (
|
|
2615
2684
|
not stream_mode
|
|
@@ -2622,7 +2691,7 @@ class Runs(Authenticated):
|
|
|
2622
2691
|
and mode.startswith("messages")
|
|
2623
2692
|
)
|
|
2624
2693
|
):
|
|
2625
|
-
yield mode.encode(),
|
|
2694
|
+
yield mode.encode(), payload, id
|
|
2626
2695
|
logger.debug(
|
|
2627
2696
|
"Replayed run event",
|
|
2628
2697
|
run_id=str(run_id),
|
|
@@ -2636,13 +2705,12 @@ class Runs(Authenticated):
|
|
|
2636
2705
|
# Wait for messages with a timeout
|
|
2637
2706
|
message = await asyncio.wait_for(queue.get(), timeout=0.5)
|
|
2638
2707
|
data, id = message.data, message.id
|
|
2639
|
-
|
|
2640
|
-
|
|
2641
|
-
|
|
2642
|
-
message = data["message"]
|
|
2708
|
+
decoded = decode_stream_message(data, channel=message.topic)
|
|
2709
|
+
mode = decoded.event_bytes.decode("utf-8")
|
|
2710
|
+
payload = decoded.message_bytes
|
|
2643
2711
|
|
|
2644
2712
|
if mode == "control":
|
|
2645
|
-
if
|
|
2713
|
+
if payload == b"done":
|
|
2646
2714
|
break
|
|
2647
2715
|
elif (
|
|
2648
2716
|
not stream_mode
|
|
@@ -2655,13 +2723,13 @@ class Runs(Authenticated):
|
|
|
2655
2723
|
and mode.startswith("messages")
|
|
2656
2724
|
)
|
|
2657
2725
|
):
|
|
2658
|
-
yield mode.encode(),
|
|
2726
|
+
yield mode.encode(), payload, id
|
|
2659
2727
|
logger.debug(
|
|
2660
2728
|
"Streamed run event",
|
|
2661
2729
|
run_id=str(run_id),
|
|
2662
2730
|
stream_mode=mode,
|
|
2663
2731
|
message_id=id,
|
|
2664
|
-
data=
|
|
2732
|
+
data=payload,
|
|
2665
2733
|
)
|
|
2666
2734
|
except TimeoutError:
|
|
2667
2735
|
# Check if the run is still pending
|
|
@@ -2675,8 +2743,10 @@ class Runs(Authenticated):
|
|
|
2675
2743
|
elif run is None:
|
|
2676
2744
|
yield (
|
|
2677
2745
|
b"error",
|
|
2678
|
-
|
|
2679
|
-
|
|
2746
|
+
json_dumpb(
|
|
2747
|
+
HTTPException(
|
|
2748
|
+
status_code=404, detail="Run not found"
|
|
2749
|
+
)
|
|
2680
2750
|
),
|
|
2681
2751
|
None,
|
|
2682
2752
|
)
|
|
@@ -2693,6 +2763,25 @@ class Runs(Authenticated):
|
|
|
2693
2763
|
stream_manager = get_stream_manager()
|
|
2694
2764
|
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
2695
2765
|
|
|
2766
|
+
@staticmethod
|
|
2767
|
+
async def check_run_stream_auth(
|
|
2768
|
+
run_id: UUID,
|
|
2769
|
+
thread_id: UUID,
|
|
2770
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2771
|
+
) -> None:
|
|
2772
|
+
async with connect() as conn:
|
|
2773
|
+
filters = await Runs.handle_event(
|
|
2774
|
+
ctx,
|
|
2775
|
+
"read",
|
|
2776
|
+
Auth.types.ThreadsRead(thread_id=thread_id),
|
|
2777
|
+
)
|
|
2778
|
+
if filters:
|
|
2779
|
+
thread = await Threads._get_with_filters(
|
|
2780
|
+
cast(InMemConnectionProto, conn), thread_id, filters
|
|
2781
|
+
)
|
|
2782
|
+
if not thread:
|
|
2783
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
2784
|
+
|
|
2696
2785
|
@staticmethod
|
|
2697
2786
|
async def publish(
|
|
2698
2787
|
run_id: UUID | str,
|
|
@@ -2703,18 +2792,13 @@ class Runs(Authenticated):
|
|
|
2703
2792
|
resumable: bool = False,
|
|
2704
2793
|
) -> None:
|
|
2705
2794
|
"""Publish a message to all subscribers of the run stream."""
|
|
2706
|
-
from langgraph_api.
|
|
2795
|
+
from langgraph_api.utils.stream_codec import STREAM_CODEC
|
|
2707
2796
|
|
|
2708
2797
|
topic = f"run:{run_id}:stream".encode()
|
|
2709
2798
|
|
|
2710
2799
|
stream_manager = get_stream_manager()
|
|
2711
|
-
# Send to all queues subscribed to this run_id
|
|
2712
|
-
payload =
|
|
2713
|
-
{
|
|
2714
|
-
"event": event,
|
|
2715
|
-
"message": message,
|
|
2716
|
-
}
|
|
2717
|
-
)
|
|
2800
|
+
# Send to all queues subscribed to this run_id using protocol frame
|
|
2801
|
+
payload = STREAM_CODEC.encode(event, message)
|
|
2718
2802
|
await stream_manager.put(
|
|
2719
2803
|
run_id, thread_id, Message(topic=topic, data=payload), resumable
|
|
2720
2804
|
)
|
|
@@ -2761,6 +2845,7 @@ class Crons:
|
|
|
2761
2845
|
schedule: str,
|
|
2762
2846
|
cron_id: UUID | None = None,
|
|
2763
2847
|
thread_id: UUID | None = None,
|
|
2848
|
+
on_run_completed: Literal["delete", "keep"] | None = None,
|
|
2764
2849
|
end_time: datetime | None = None,
|
|
2765
2850
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2766
2851
|
) -> AsyncIterator[Cron]:
|
|
@@ -2874,11 +2959,18 @@ def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -
|
|
|
2874
2959
|
if key not in metadata or metadata[key] != filter_value:
|
|
2875
2960
|
return False
|
|
2876
2961
|
elif op == "$contains":
|
|
2877
|
-
if (
|
|
2878
|
-
|
|
2879
|
-
|
|
2880
|
-
|
|
2881
|
-
|
|
2962
|
+
if key not in metadata or not isinstance(metadata[key], list):
|
|
2963
|
+
return False
|
|
2964
|
+
|
|
2965
|
+
if isinstance(filter_value, list):
|
|
2966
|
+
# Mimick Postgres containment operator behavior.
|
|
2967
|
+
# It would be more efficient to use set operations here,
|
|
2968
|
+
# but we can't assume that elements are hashable.
|
|
2969
|
+
# The Postgres algorithm is also O(n^2).
|
|
2970
|
+
for filter_element in filter_value:
|
|
2971
|
+
if filter_element not in metadata[key]:
|
|
2972
|
+
return False
|
|
2973
|
+
elif filter_value not in metadata[key]:
|
|
2882
2974
|
return False
|
|
2883
2975
|
else:
|
|
2884
2976
|
# Direct equality
|
|
@@ -2894,6 +2986,7 @@ async def _empty_generator():
|
|
|
2894
2986
|
|
|
2895
2987
|
|
|
2896
2988
|
__all__ = [
|
|
2989
|
+
"StreamHandler",
|
|
2897
2990
|
"Assistants",
|
|
2898
2991
|
"Crons",
|
|
2899
2992
|
"Runs",
|
langgraph_runtime_inmem/queue.py
CHANGED
|
@@ -154,17 +154,6 @@ def _enable_blockbuster():
|
|
|
154
154
|
|
|
155
155
|
ls_env.get_runtime_environment() # this gets cached
|
|
156
156
|
bb = BlockBuster(excluded_modules=[])
|
|
157
|
-
for module, func in (
|
|
158
|
-
# Note, we've cached this call in langsmith==0.3.21 so it shouldn't raise anyway
|
|
159
|
-
# but we don't want to raise teh minbound just for that.
|
|
160
|
-
("langsmith/client.py", "_default_retry_config"),
|
|
161
|
-
# Only triggers in python 3.11 for getting subgraphs
|
|
162
|
-
# Will be unnecessary once we cache the assistant schemas
|
|
163
|
-
("langgraph/pregel/utils.py", "get_function_nonlocals"),
|
|
164
|
-
("importlib/metadata/__init__.py", "metadata"),
|
|
165
|
-
("importlib/metadata/__init__.py", "read_text"),
|
|
166
|
-
):
|
|
167
|
-
bb.functions["io.TextIOWrapper.read"].can_block_in(module, func)
|
|
168
157
|
|
|
169
158
|
bb.functions["os.path.abspath"].can_block_in("inspect.py", "getmodule")
|
|
170
159
|
|
|
@@ -172,12 +161,8 @@ def _enable_blockbuster():
|
|
|
172
161
|
("memory/__init__.py", "sync"),
|
|
173
162
|
("memory/__init__.py", "load"),
|
|
174
163
|
("memory/__init__.py", "dump"),
|
|
164
|
+
("pydantic/main.py", "__init__"),
|
|
175
165
|
):
|
|
176
|
-
bb.functions["io.TextIOWrapper.read"].can_block_in(module, func)
|
|
177
|
-
bb.functions["io.TextIOWrapper.write"].can_block_in(module, func)
|
|
178
|
-
bb.functions["io.BufferedWriter.write"].can_block_in(module, func)
|
|
179
|
-
bb.functions["io.BufferedReader.read"].can_block_in(module, func)
|
|
180
|
-
|
|
181
166
|
bb.functions["os.remove"].can_block_in(module, func)
|
|
182
167
|
bb.functions["os.rename"].can_block_in(module, func)
|
|
183
168
|
|
|
@@ -199,6 +184,12 @@ def _enable_blockbuster():
|
|
|
199
184
|
# as well as importlib.metadata.
|
|
200
185
|
"os.listdir",
|
|
201
186
|
"os.remove",
|
|
187
|
+
# We used to block the IO things but people use them so often that
|
|
188
|
+
# we've decided to just let people make bad decisions for themselves.
|
|
189
|
+
"io.BufferedReader.read",
|
|
190
|
+
"io.BufferedWriter.write",
|
|
191
|
+
"io.TextIOWrapper.read",
|
|
192
|
+
"io.TextIOWrapper.write",
|
|
202
193
|
# If people are using threadpoolexecutor, etc. they'd be using this.
|
|
203
194
|
"threading.Lock.acquire",
|
|
204
195
|
]
|
{langgraph_runtime_inmem-0.9.0.dist-info → langgraph_runtime_inmem-0.18.1.dist-info}/METADATA
RENAMED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: langgraph-runtime-inmem
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.18.1
|
|
4
4
|
Summary: Inmem implementation for the LangGraph API server.
|
|
5
5
|
Author-email: Will Fu-Hinthorn <will@langchain.dev>
|
|
6
6
|
License: Elastic-2.0
|
|
7
7
|
Requires-Python: >=3.11.0
|
|
8
8
|
Requires-Dist: blockbuster<2.0.0,>=1.5.24
|
|
9
|
-
Requires-Dist: langgraph-checkpoint
|
|
10
|
-
Requires-Dist: langgraph
|
|
9
|
+
Requires-Dist: langgraph-checkpoint<4,>=3
|
|
10
|
+
Requires-Dist: langgraph<2,>=0.4.10
|
|
11
11
|
Requires-Dist: sse-starlette>=2
|
|
12
12
|
Requires-Dist: starlette>=0.37
|
|
13
13
|
Requires-Dist: structlog>23
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
langgraph_runtime_inmem/__init__.py,sha256=8LwgexYJfUTj5uFimXddpKAdiLBMFWKrf71glUqQkTc,311
|
|
2
|
+
langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
|
|
3
|
+
langgraph_runtime_inmem/database.py,sha256=g2XYa5KN-T8MbDeFH9sfUApDG62Wp4BACumVnDtxYhI,6403
|
|
4
|
+
langgraph_runtime_inmem/inmem_stream.py,sha256=PFLWbsxU8RqbT5mYJgNk6v5q6TWJRIY1hkZWhJF8nkI,9094
|
|
5
|
+
langgraph_runtime_inmem/lifespan.py,sha256=fCoYcN_h0cxmj6-muC-f0csPdSpyepZuGRD1yBrq4XM,4755
|
|
6
|
+
langgraph_runtime_inmem/metrics.py,sha256=_YiSkLnhQvHpMktk38SZo0abyL-5GihfVAtBo0-lFIc,403
|
|
7
|
+
langgraph_runtime_inmem/ops.py,sha256=s_3MN5f4uecR7FaSo4WTjeeUqD0fNgB0QhokiV6y8Hg,109178
|
|
8
|
+
langgraph_runtime_inmem/queue.py,sha256=17HBZrYaxJg_k4NoabToYD_J6cqVzyHpWIz3VzGg_14,9363
|
|
9
|
+
langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
|
|
10
|
+
langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
|
|
11
|
+
langgraph_runtime_inmem-0.18.1.dist-info/METADATA,sha256=JJWTv1Yhr5Fx83aOApdJOXkKMSJ3fomwb00xqfK_cnA,570
|
|
12
|
+
langgraph_runtime_inmem-0.18.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
13
|
+
langgraph_runtime_inmem-0.18.1.dist-info/RECORD,,
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
langgraph_runtime_inmem/__init__.py,sha256=f-VPPHH1-hKFwEreffg7dNATe9IdcYwQedcSx2MiZog,310
|
|
2
|
-
langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
|
|
3
|
-
langgraph_runtime_inmem/database.py,sha256=QgaA_WQo1IY6QioYd8r-e6-0B0rnC5anS0muIEJWby0,6364
|
|
4
|
-
langgraph_runtime_inmem/inmem_stream.py,sha256=pUEiHW-1uXQrVTcwEYPwO8YXaYm5qZbpRWawt67y6Lw,8187
|
|
5
|
-
langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
|
|
6
|
-
langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
|
|
7
|
-
langgraph_runtime_inmem/ops.py,sha256=0Jx65S3PCvvHlIpA0XYpl-UnDEo_AiGWXRE2QiFSocY,105165
|
|
8
|
-
langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
|
|
9
|
-
langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
|
|
10
|
-
langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
|
|
11
|
-
langgraph_runtime_inmem-0.9.0.dist-info/METADATA,sha256=ptwW1Ei-Xln53P81eJK1aPcFozU8D192OCZBuC_y5EQ,565
|
|
12
|
-
langgraph_runtime_inmem-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
13
|
-
langgraph_runtime_inmem-0.9.0.dist-info/RECORD,,
|
|
File without changes
|