langgraph-runtime-inmem 0.6.8__py3-none-any.whl → 0.6.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem/inmem_stream.py +25 -11
- langgraph_runtime_inmem/ops.py +95 -64
- {langgraph_runtime_inmem-0.6.8.dist-info → langgraph_runtime_inmem-0.6.10.dist-info}/METADATA +1 -1
- {langgraph_runtime_inmem-0.6.8.dist-info → langgraph_runtime_inmem-0.6.10.dist-info}/RECORD +6 -6
- {langgraph_runtime_inmem-0.6.8.dist-info → langgraph_runtime_inmem-0.6.10.dist-info}/WHEEL +0 -0
|
@@ -42,6 +42,7 @@ class ContextQueue(asyncio.Queue):
|
|
|
42
42
|
class StreamManager:
|
|
43
43
|
def __init__(self):
|
|
44
44
|
self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
|
|
45
|
+
self.control_keys = defaultdict()
|
|
45
46
|
self.control_queues = defaultdict(list)
|
|
46
47
|
|
|
47
48
|
self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
|
|
@@ -51,6 +52,14 @@ class StreamManager:
|
|
|
51
52
|
run_id = _ensure_uuid(run_id)
|
|
52
53
|
return self.queues[run_id]
|
|
53
54
|
|
|
55
|
+
def get_control_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
|
|
56
|
+
run_id = _ensure_uuid(run_id)
|
|
57
|
+
return self.control_queues[run_id]
|
|
58
|
+
|
|
59
|
+
def get_control_key(self, run_id: UUID | str) -> Message | None:
|
|
60
|
+
run_id = _ensure_uuid(run_id)
|
|
61
|
+
return self.control_keys.get(run_id)
|
|
62
|
+
|
|
54
63
|
async def put(
|
|
55
64
|
self, run_id: UUID | str, message: Message, resumable: bool = False
|
|
56
65
|
) -> None:
|
|
@@ -61,8 +70,10 @@ class StreamManager:
|
|
|
61
70
|
self.message_stores[run_id].append(message)
|
|
62
71
|
topic = message.topic.decode()
|
|
63
72
|
if "control" in topic:
|
|
64
|
-
self.
|
|
65
|
-
|
|
73
|
+
self.control_keys[run_id] = message
|
|
74
|
+
queues = self.control_queues[run_id]
|
|
75
|
+
else:
|
|
76
|
+
queues = self.queues[run_id]
|
|
66
77
|
coros = [queue.put(message) for queue in queues]
|
|
67
78
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
68
79
|
for result in results:
|
|
@@ -73,14 +84,12 @@ class StreamManager:
|
|
|
73
84
|
run_id = _ensure_uuid(run_id)
|
|
74
85
|
queue = ContextQueue()
|
|
75
86
|
self.queues[run_id].append(queue)
|
|
76
|
-
|
|
77
|
-
try:
|
|
78
|
-
await queue.put(control_msg)
|
|
79
|
-
except Exception:
|
|
80
|
-
logger.exception(
|
|
81
|
-
f"Failed to put control message in queue: {control_msg}"
|
|
82
|
-
)
|
|
87
|
+
return queue
|
|
83
88
|
|
|
89
|
+
async def add_control_queue(self, run_id: UUID | str) -> asyncio.Queue:
|
|
90
|
+
run_id = _ensure_uuid(run_id)
|
|
91
|
+
queue = ContextQueue()
|
|
92
|
+
self.control_queues[run_id].append(queue)
|
|
84
93
|
return queue
|
|
85
94
|
|
|
86
95
|
async def remove_queue(self, run_id: UUID | str, queue: asyncio.Queue):
|
|
@@ -89,8 +98,13 @@ class StreamManager:
|
|
|
89
98
|
self.queues[run_id].remove(queue)
|
|
90
99
|
if not self.queues[run_id]:
|
|
91
100
|
del self.queues[run_id]
|
|
92
|
-
|
|
93
|
-
|
|
101
|
+
|
|
102
|
+
async def remove_control_queue(self, run_id: UUID | str, queue: asyncio.Queue):
|
|
103
|
+
run_id = _ensure_uuid(run_id)
|
|
104
|
+
if run_id in self.control_queues:
|
|
105
|
+
self.control_queues[run_id].remove(queue)
|
|
106
|
+
if not self.control_queues[run_id]:
|
|
107
|
+
del self.control_queues[run_id]
|
|
94
108
|
|
|
95
109
|
def restore_messages(
|
|
96
110
|
self, run_id: UUID | str, message_id: str | None
|
langgraph_runtime_inmem/ops.py
CHANGED
|
@@ -722,7 +722,7 @@ class Threads(Authenticated):
|
|
|
722
722
|
else:
|
|
723
723
|
# Default sorting by created_at in descending order
|
|
724
724
|
sorted_threads = sorted(
|
|
725
|
-
filtered_threads, key=lambda x: x["
|
|
725
|
+
filtered_threads, key=lambda x: x["updated_at"], reverse=True
|
|
726
726
|
)
|
|
727
727
|
|
|
728
728
|
# Apply limit and offset
|
|
@@ -1451,7 +1451,7 @@ class Threads(Authenticated):
|
|
|
1451
1451
|
conn: InMemConnectionProto,
|
|
1452
1452
|
*,
|
|
1453
1453
|
config: Config,
|
|
1454
|
-
limit: int =
|
|
1454
|
+
limit: int = 1,
|
|
1455
1455
|
before: str | Checkpoint | None = None,
|
|
1456
1456
|
metadata: MetadataInput = None,
|
|
1457
1457
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
@@ -1626,7 +1626,7 @@ class Runs(Authenticated):
|
|
|
1626
1626
|
|
|
1627
1627
|
stream_manager = get_stream_manager()
|
|
1628
1628
|
# Get queue for this run
|
|
1629
|
-
queue = await
|
|
1629
|
+
queue = await stream_manager.add_control_queue(run_id)
|
|
1630
1630
|
|
|
1631
1631
|
async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
|
|
1632
1632
|
done = ValueEvent()
|
|
@@ -1634,16 +1634,21 @@ class Runs(Authenticated):
|
|
|
1634
1634
|
|
|
1635
1635
|
# Give done event to caller
|
|
1636
1636
|
yield done
|
|
1637
|
-
#
|
|
1637
|
+
# Store the control message for late subscribers
|
|
1638
1638
|
control_message = Message(
|
|
1639
1639
|
topic=f"run:{run_id}:control".encode(), data=b"done"
|
|
1640
1640
|
)
|
|
1641
|
-
|
|
1642
|
-
# Store the control message for late subscribers
|
|
1643
1641
|
await stream_manager.put(run_id, control_message)
|
|
1644
|
-
|
|
1645
|
-
#
|
|
1646
|
-
|
|
1642
|
+
|
|
1643
|
+
# Signal done to all subscribers
|
|
1644
|
+
stream_message = Message(
|
|
1645
|
+
topic=f"run:{run_id}:stream".encode(),
|
|
1646
|
+
data={"event": "control", "message": b"done"},
|
|
1647
|
+
)
|
|
1648
|
+
await stream_manager.put(run_id, stream_message)
|
|
1649
|
+
|
|
1650
|
+
# Remove the queue
|
|
1651
|
+
await stream_manager.remove_control_queue(run_id, queue)
|
|
1647
1652
|
|
|
1648
1653
|
@staticmethod
|
|
1649
1654
|
async def sweep(conn: InMemConnectionProto) -> list[UUID]:
|
|
@@ -1979,6 +1984,16 @@ class Runs(Authenticated):
|
|
|
1979
1984
|
return Fragment(
|
|
1980
1985
|
orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
|
|
1981
1986
|
)
|
|
1987
|
+
if thread["status"] == "interrupted":
|
|
1988
|
+
# Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
|
|
1989
|
+
# interrupt, but we'll always show one. Long term we should show all of them.
|
|
1990
|
+
try:
|
|
1991
|
+
interrupt_map = thread["interrupts"]
|
|
1992
|
+
interrupt = [next(iter(interrupt_map.values()))[0]]
|
|
1993
|
+
return Fragment(orjson.dumps({"__interrupt__": interrupt}))
|
|
1994
|
+
except Exception:
|
|
1995
|
+
# No interrupt, but status is interrupted from a before/after block. Default back to values.
|
|
1996
|
+
pass
|
|
1982
1997
|
return thread["values"]
|
|
1983
1998
|
|
|
1984
1999
|
@staticmethod
|
|
@@ -2199,8 +2214,6 @@ class Runs(Authenticated):
|
|
|
2199
2214
|
@staticmethod
|
|
2200
2215
|
async def subscribe(
|
|
2201
2216
|
run_id: UUID,
|
|
2202
|
-
*,
|
|
2203
|
-
stream_mode: StreamMode | None = None,
|
|
2204
2217
|
) -> asyncio.Queue:
|
|
2205
2218
|
"""Subscribe to the run stream, returning a queue."""
|
|
2206
2219
|
stream_manager = get_stream_manager()
|
|
@@ -2220,20 +2233,18 @@ class Runs(Authenticated):
|
|
|
2220
2233
|
ignore_404: bool = False,
|
|
2221
2234
|
cancel_on_disconnect: bool = False,
|
|
2222
2235
|
stream_channel: asyncio.Queue | None = None,
|
|
2223
|
-
stream_mode: list[StreamMode] | StreamMode,
|
|
2236
|
+
stream_mode: list[StreamMode] | StreamMode | None = None,
|
|
2224
2237
|
last_event_id: str | None = None,
|
|
2225
2238
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2226
2239
|
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
2227
2240
|
"""Stream the run output."""
|
|
2228
2241
|
from langgraph_api.asyncio import create_task
|
|
2229
|
-
|
|
2230
|
-
if stream_mode and not isinstance(stream_mode, list):
|
|
2231
|
-
stream_mode = [stream_mode]
|
|
2242
|
+
from langgraph_api.serde import json_loads
|
|
2232
2243
|
|
|
2233
2244
|
queue = (
|
|
2234
2245
|
stream_channel
|
|
2235
2246
|
if stream_channel
|
|
2236
|
-
else await Runs.Stream.subscribe(run_id
|
|
2247
|
+
else await Runs.Stream.subscribe(run_id)
|
|
2237
2248
|
)
|
|
2238
2249
|
|
|
2239
2250
|
try:
|
|
@@ -2254,53 +2265,71 @@ class Runs(Authenticated):
|
|
|
2254
2265
|
)
|
|
2255
2266
|
)
|
|
2256
2267
|
run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
|
|
2257
|
-
channel_prefix = f"run:{run_id}:stream:"
|
|
2258
|
-
len_prefix = len(channel_prefix.encode())
|
|
2259
2268
|
|
|
2260
2269
|
for message in get_stream_manager().restore_messages(
|
|
2261
2270
|
run_id, last_event_id
|
|
2262
2271
|
):
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2272
|
+
data, id = message.data, message.id
|
|
2273
|
+
|
|
2274
|
+
data = json_loads(data)
|
|
2275
|
+
mode = data["event"]
|
|
2276
|
+
message = data["message"]
|
|
2277
|
+
|
|
2278
|
+
if mode == "control":
|
|
2279
|
+
if message == b"done":
|
|
2266
2280
|
return
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
"Replayed run event",
|
|
2275
|
-
run_id=str(run_id),
|
|
2276
|
-
message_id=id,
|
|
2277
|
-
stream_mode=mode,
|
|
2278
|
-
data=data,
|
|
2281
|
+
elif (
|
|
2282
|
+
not stream_mode
|
|
2283
|
+
or mode in stream_mode
|
|
2284
|
+
or (
|
|
2285
|
+
(
|
|
2286
|
+
"messages" in stream_mode
|
|
2287
|
+
or "messages-tuple" in stream_mode
|
|
2279
2288
|
)
|
|
2289
|
+
and mode.startswith("messages")
|
|
2290
|
+
)
|
|
2291
|
+
):
|
|
2292
|
+
yield mode.encode(), base64.b64decode(message), id
|
|
2293
|
+
logger.debug(
|
|
2294
|
+
"Replayed run event",
|
|
2295
|
+
run_id=str(run_id),
|
|
2296
|
+
message_id=id,
|
|
2297
|
+
stream_mode=mode,
|
|
2298
|
+
data=data,
|
|
2299
|
+
)
|
|
2280
2300
|
|
|
2281
2301
|
while True:
|
|
2282
2302
|
try:
|
|
2283
2303
|
# Wait for messages with a timeout
|
|
2284
2304
|
message = await asyncio.wait_for(queue.get(), timeout=0.5)
|
|
2285
|
-
|
|
2305
|
+
data, id = message.data, message.id
|
|
2286
2306
|
|
|
2287
|
-
|
|
2288
|
-
|
|
2307
|
+
data = json_loads(data)
|
|
2308
|
+
mode = data["event"]
|
|
2309
|
+
message = data["message"]
|
|
2310
|
+
|
|
2311
|
+
if mode == "control":
|
|
2312
|
+
if message == b"done":
|
|
2289
2313
|
break
|
|
2290
|
-
|
|
2291
|
-
|
|
2292
|
-
mode
|
|
2293
|
-
|
|
2294
|
-
|
|
2295
|
-
|
|
2296
|
-
|
|
2297
|
-
logger.debug(
|
|
2298
|
-
"Streamed run event",
|
|
2299
|
-
run_id=str(run_id),
|
|
2300
|
-
stream_mode=mode,
|
|
2301
|
-
message_id=id,
|
|
2302
|
-
data=data,
|
|
2314
|
+
elif (
|
|
2315
|
+
not stream_mode
|
|
2316
|
+
or mode in stream_mode
|
|
2317
|
+
or (
|
|
2318
|
+
(
|
|
2319
|
+
"messages" in stream_mode
|
|
2320
|
+
or "messages-tuple" in stream_mode
|
|
2303
2321
|
)
|
|
2322
|
+
and mode.startswith("messages")
|
|
2323
|
+
)
|
|
2324
|
+
):
|
|
2325
|
+
yield mode.encode(), base64.b64decode(message), id
|
|
2326
|
+
logger.debug(
|
|
2327
|
+
"Streamed run event",
|
|
2328
|
+
run_id=str(run_id),
|
|
2329
|
+
stream_mode=mode,
|
|
2330
|
+
message_id=id,
|
|
2331
|
+
data=message,
|
|
2332
|
+
)
|
|
2304
2333
|
except TimeoutError:
|
|
2305
2334
|
# Check if the run is still pending
|
|
2306
2335
|
run_iter = await Runs.get(
|
|
@@ -2340,12 +2369,20 @@ class Runs(Authenticated):
|
|
|
2340
2369
|
resumable: bool = False,
|
|
2341
2370
|
) -> None:
|
|
2342
2371
|
"""Publish a message to all subscribers of the run stream."""
|
|
2343
|
-
|
|
2372
|
+
from langgraph_api.serde import json_dumpb
|
|
2373
|
+
|
|
2374
|
+
topic = f"run:{run_id}:stream".encode()
|
|
2344
2375
|
|
|
2345
2376
|
stream_manager = get_stream_manager()
|
|
2346
2377
|
# Send to all queues subscribed to this run_id
|
|
2378
|
+
payload = json_dumpb(
|
|
2379
|
+
{
|
|
2380
|
+
"event": event,
|
|
2381
|
+
"message": message,
|
|
2382
|
+
}
|
|
2383
|
+
)
|
|
2347
2384
|
await stream_manager.put(
|
|
2348
|
-
run_id, Message(topic=topic, data=
|
|
2385
|
+
run_id, Message(topic=topic, data=payload), resumable
|
|
2349
2386
|
)
|
|
2350
2387
|
|
|
2351
2388
|
|
|
@@ -2354,20 +2391,18 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
|
|
|
2354
2391
|
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
2355
2392
|
|
|
2356
2393
|
stream_manager = get_stream_manager()
|
|
2357
|
-
control_key = f"run:{run_id}:control"
|
|
2358
2394
|
|
|
2359
|
-
if
|
|
2360
|
-
|
|
2361
|
-
|
|
2362
|
-
|
|
2363
|
-
|
|
2364
|
-
|
|
2365
|
-
done.set(UserInterrupt())
|
|
2395
|
+
if control_key := stream_manager.get_control_key(run_id):
|
|
2396
|
+
payload = control_key.data
|
|
2397
|
+
if payload == b"rollback":
|
|
2398
|
+
done.set(UserRollback())
|
|
2399
|
+
elif payload == b"interrupt":
|
|
2400
|
+
done.set(UserInterrupt())
|
|
2366
2401
|
|
|
2367
2402
|
while not done.is_set():
|
|
2368
2403
|
try:
|
|
2369
2404
|
# This task gets cancelled when Runs.enter exits anyway,
|
|
2370
|
-
# so we can have a pretty
|
|
2405
|
+
# so we can have a pretty lengthy timeout here
|
|
2371
2406
|
message = await asyncio.wait_for(queue.get(), timeout=240)
|
|
2372
2407
|
payload = message.data
|
|
2373
2408
|
if payload == b"rollback":
|
|
@@ -2377,10 +2412,6 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
|
|
|
2377
2412
|
elif payload == b"done":
|
|
2378
2413
|
done.set()
|
|
2379
2414
|
break
|
|
2380
|
-
|
|
2381
|
-
# Store control messages for late subscribers
|
|
2382
|
-
if message.topic.decode() == control_key:
|
|
2383
|
-
stream_manager.control_queues[run_id].append(message)
|
|
2384
2415
|
except TimeoutError:
|
|
2385
2416
|
break
|
|
2386
2417
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
langgraph_runtime_inmem/__init__.py,sha256=
|
|
1
|
+
langgraph_runtime_inmem/__init__.py,sha256=3R6DOc_wOQRDAv1aNNMMGjRofpJi6B_qIBkZ373AG44,311
|
|
2
2
|
langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
|
|
3
3
|
langgraph_runtime_inmem/database.py,sha256=G_6L2khpRDSpS2Vs_SujzHayODcwG5V2IhFP7LLBXgw,6349
|
|
4
|
-
langgraph_runtime_inmem/inmem_stream.py,sha256=
|
|
4
|
+
langgraph_runtime_inmem/inmem_stream.py,sha256=UWk1srLF44HZPPbRdArGGhsy0MY0UOJKSIxBSO7Hosc,5138
|
|
5
5
|
langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
|
|
6
6
|
langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
|
|
7
|
-
langgraph_runtime_inmem/ops.py,sha256=
|
|
7
|
+
langgraph_runtime_inmem/ops.py,sha256=CSH5vi7AsaeaWSngZ_DCtsPY-M7ah3cz8acTVi_UbUw,90559
|
|
8
8
|
langgraph_runtime_inmem/queue.py,sha256=nqfgz7j_Jkh5Ek5-RsHB2Uvwbxguu9IUPkGXIxvFPns,10037
|
|
9
9
|
langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
|
|
10
10
|
langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
|
|
11
|
-
langgraph_runtime_inmem-0.6.
|
|
12
|
-
langgraph_runtime_inmem-0.6.
|
|
13
|
-
langgraph_runtime_inmem-0.6.
|
|
11
|
+
langgraph_runtime_inmem-0.6.10.dist-info/METADATA,sha256=5YlbIvGy2fpU9gydQgbaSfu34loSdodD2PUtPlcn_1c,566
|
|
12
|
+
langgraph_runtime_inmem-0.6.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
13
|
+
langgraph_runtime_inmem-0.6.10.dist-info/RECORD,,
|
|
File without changes
|