langgraph-runtime-inmem 0.6.7__py3-none-any.whl → 0.6.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem/inmem_stream.py +25 -11
- langgraph_runtime_inmem/ops.py +97 -68
- {langgraph_runtime_inmem-0.6.7.dist-info → langgraph_runtime_inmem-0.6.10.dist-info}/METADATA +1 -1
- {langgraph_runtime_inmem-0.6.7.dist-info → langgraph_runtime_inmem-0.6.10.dist-info}/RECORD +6 -6
- {langgraph_runtime_inmem-0.6.7.dist-info → langgraph_runtime_inmem-0.6.10.dist-info}/WHEEL +0 -0
|
@@ -42,6 +42,7 @@ class ContextQueue(asyncio.Queue):
|
|
|
42
42
|
class StreamManager:
|
|
43
43
|
def __init__(self):
|
|
44
44
|
self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
|
|
45
|
+
self.control_keys = defaultdict()
|
|
45
46
|
self.control_queues = defaultdict(list)
|
|
46
47
|
|
|
47
48
|
self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
|
|
@@ -51,6 +52,14 @@ class StreamManager:
|
|
|
51
52
|
run_id = _ensure_uuid(run_id)
|
|
52
53
|
return self.queues[run_id]
|
|
53
54
|
|
|
55
|
+
def get_control_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
|
|
56
|
+
run_id = _ensure_uuid(run_id)
|
|
57
|
+
return self.control_queues[run_id]
|
|
58
|
+
|
|
59
|
+
def get_control_key(self, run_id: UUID | str) -> Message | None:
|
|
60
|
+
run_id = _ensure_uuid(run_id)
|
|
61
|
+
return self.control_keys.get(run_id)
|
|
62
|
+
|
|
54
63
|
async def put(
|
|
55
64
|
self, run_id: UUID | str, message: Message, resumable: bool = False
|
|
56
65
|
) -> None:
|
|
@@ -61,8 +70,10 @@ class StreamManager:
|
|
|
61
70
|
self.message_stores[run_id].append(message)
|
|
62
71
|
topic = message.topic.decode()
|
|
63
72
|
if "control" in topic:
|
|
64
|
-
self.
|
|
65
|
-
|
|
73
|
+
self.control_keys[run_id] = message
|
|
74
|
+
queues = self.control_queues[run_id]
|
|
75
|
+
else:
|
|
76
|
+
queues = self.queues[run_id]
|
|
66
77
|
coros = [queue.put(message) for queue in queues]
|
|
67
78
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
68
79
|
for result in results:
|
|
@@ -73,14 +84,12 @@ class StreamManager:
|
|
|
73
84
|
run_id = _ensure_uuid(run_id)
|
|
74
85
|
queue = ContextQueue()
|
|
75
86
|
self.queues[run_id].append(queue)
|
|
76
|
-
|
|
77
|
-
try:
|
|
78
|
-
await queue.put(control_msg)
|
|
79
|
-
except Exception:
|
|
80
|
-
logger.exception(
|
|
81
|
-
f"Failed to put control message in queue: {control_msg}"
|
|
82
|
-
)
|
|
87
|
+
return queue
|
|
83
88
|
|
|
89
|
+
async def add_control_queue(self, run_id: UUID | str) -> asyncio.Queue:
|
|
90
|
+
run_id = _ensure_uuid(run_id)
|
|
91
|
+
queue = ContextQueue()
|
|
92
|
+
self.control_queues[run_id].append(queue)
|
|
84
93
|
return queue
|
|
85
94
|
|
|
86
95
|
async def remove_queue(self, run_id: UUID | str, queue: asyncio.Queue):
|
|
@@ -89,8 +98,13 @@ class StreamManager:
|
|
|
89
98
|
self.queues[run_id].remove(queue)
|
|
90
99
|
if not self.queues[run_id]:
|
|
91
100
|
del self.queues[run_id]
|
|
92
|
-
|
|
93
|
-
|
|
101
|
+
|
|
102
|
+
async def remove_control_queue(self, run_id: UUID | str, queue: asyncio.Queue):
|
|
103
|
+
run_id = _ensure_uuid(run_id)
|
|
104
|
+
if run_id in self.control_queues:
|
|
105
|
+
self.control_queues[run_id].remove(queue)
|
|
106
|
+
if not self.control_queues[run_id]:
|
|
107
|
+
del self.control_queues[run_id]
|
|
94
108
|
|
|
95
109
|
def restore_messages(
|
|
96
110
|
self, run_id: UUID | str, message_id: str | None
|
langgraph_runtime_inmem/ops.py
CHANGED
|
@@ -157,9 +157,6 @@ class Assistants(Authenticated):
|
|
|
157
157
|
and (not filters or _check_filter_match(assistant["metadata"], filters))
|
|
158
158
|
]
|
|
159
159
|
|
|
160
|
-
# Get total count before sorting and pagination
|
|
161
|
-
total_count = len(filtered_assistants)
|
|
162
|
-
|
|
163
160
|
# Sort based on sort_by and sort_order
|
|
164
161
|
sort_by = sort_by.lower() if sort_by else None
|
|
165
162
|
if sort_by and sort_by in (
|
|
@@ -186,12 +183,13 @@ class Assistants(Authenticated):
|
|
|
186
183
|
|
|
187
184
|
# Apply pagination
|
|
188
185
|
paginated_assistants = filtered_assistants[offset : offset + limit]
|
|
186
|
+
cur = offset + limit if len(filtered_assistants) > offset + limit else None
|
|
189
187
|
|
|
190
188
|
async def assistant_iterator() -> AsyncIterator[Assistant]:
|
|
191
189
|
for assistant in paginated_assistants:
|
|
192
190
|
yield assistant
|
|
193
191
|
|
|
194
|
-
return assistant_iterator(),
|
|
192
|
+
return assistant_iterator(), cur
|
|
195
193
|
|
|
196
194
|
@staticmethod
|
|
197
195
|
async def get(
|
|
@@ -724,7 +722,7 @@ class Threads(Authenticated):
|
|
|
724
722
|
else:
|
|
725
723
|
# Default sorting by created_at in descending order
|
|
726
724
|
sorted_threads = sorted(
|
|
727
|
-
filtered_threads, key=lambda x: x["
|
|
725
|
+
filtered_threads, key=lambda x: x["updated_at"], reverse=True
|
|
728
726
|
)
|
|
729
727
|
|
|
730
728
|
# Apply limit and offset
|
|
@@ -1453,7 +1451,7 @@ class Threads(Authenticated):
|
|
|
1453
1451
|
conn: InMemConnectionProto,
|
|
1454
1452
|
*,
|
|
1455
1453
|
config: Config,
|
|
1456
|
-
limit: int =
|
|
1454
|
+
limit: int = 1,
|
|
1457
1455
|
before: str | Checkpoint | None = None,
|
|
1458
1456
|
metadata: MetadataInput = None,
|
|
1459
1457
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
@@ -1628,7 +1626,7 @@ class Runs(Authenticated):
|
|
|
1628
1626
|
|
|
1629
1627
|
stream_manager = get_stream_manager()
|
|
1630
1628
|
# Get queue for this run
|
|
1631
|
-
queue = await
|
|
1629
|
+
queue = await stream_manager.add_control_queue(run_id)
|
|
1632
1630
|
|
|
1633
1631
|
async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
|
|
1634
1632
|
done = ValueEvent()
|
|
@@ -1636,16 +1634,21 @@ class Runs(Authenticated):
|
|
|
1636
1634
|
|
|
1637
1635
|
# Give done event to caller
|
|
1638
1636
|
yield done
|
|
1639
|
-
#
|
|
1637
|
+
# Store the control message for late subscribers
|
|
1640
1638
|
control_message = Message(
|
|
1641
1639
|
topic=f"run:{run_id}:control".encode(), data=b"done"
|
|
1642
1640
|
)
|
|
1643
|
-
|
|
1644
|
-
# Store the control message for late subscribers
|
|
1645
1641
|
await stream_manager.put(run_id, control_message)
|
|
1646
|
-
|
|
1647
|
-
#
|
|
1648
|
-
|
|
1642
|
+
|
|
1643
|
+
# Signal done to all subscribers
|
|
1644
|
+
stream_message = Message(
|
|
1645
|
+
topic=f"run:{run_id}:stream".encode(),
|
|
1646
|
+
data={"event": "control", "message": b"done"},
|
|
1647
|
+
)
|
|
1648
|
+
await stream_manager.put(run_id, stream_message)
|
|
1649
|
+
|
|
1650
|
+
# Remove the queue
|
|
1651
|
+
await stream_manager.remove_control_queue(run_id, queue)
|
|
1649
1652
|
|
|
1650
1653
|
@staticmethod
|
|
1651
1654
|
async def sweep(conn: InMemConnectionProto) -> list[UUID]:
|
|
@@ -1981,6 +1984,16 @@ class Runs(Authenticated):
|
|
|
1981
1984
|
return Fragment(
|
|
1982
1985
|
orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
|
|
1983
1986
|
)
|
|
1987
|
+
if thread["status"] == "interrupted":
|
|
1988
|
+
# Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
|
|
1989
|
+
# interrupt, but we'll always show one. Long term we should show all of them.
|
|
1990
|
+
try:
|
|
1991
|
+
interrupt_map = thread["interrupts"]
|
|
1992
|
+
interrupt = [next(iter(interrupt_map.values()))[0]]
|
|
1993
|
+
return Fragment(orjson.dumps({"__interrupt__": interrupt}))
|
|
1994
|
+
except Exception:
|
|
1995
|
+
# No interrupt, but status is interrupted from a before/after block. Default back to values.
|
|
1996
|
+
pass
|
|
1984
1997
|
return thread["values"]
|
|
1985
1998
|
|
|
1986
1999
|
@staticmethod
|
|
@@ -2201,8 +2214,6 @@ class Runs(Authenticated):
|
|
|
2201
2214
|
@staticmethod
|
|
2202
2215
|
async def subscribe(
|
|
2203
2216
|
run_id: UUID,
|
|
2204
|
-
*,
|
|
2205
|
-
stream_mode: StreamMode | None = None,
|
|
2206
2217
|
) -> asyncio.Queue:
|
|
2207
2218
|
"""Subscribe to the run stream, returning a queue."""
|
|
2208
2219
|
stream_manager = get_stream_manager()
|
|
@@ -2222,20 +2233,18 @@ class Runs(Authenticated):
|
|
|
2222
2233
|
ignore_404: bool = False,
|
|
2223
2234
|
cancel_on_disconnect: bool = False,
|
|
2224
2235
|
stream_channel: asyncio.Queue | None = None,
|
|
2225
|
-
stream_mode: list[StreamMode] | StreamMode,
|
|
2236
|
+
stream_mode: list[StreamMode] | StreamMode | None = None,
|
|
2226
2237
|
last_event_id: str | None = None,
|
|
2227
2238
|
ctx: Auth.types.BaseAuthContext | None = None,
|
|
2228
2239
|
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
2229
2240
|
"""Stream the run output."""
|
|
2230
2241
|
from langgraph_api.asyncio import create_task
|
|
2231
|
-
|
|
2232
|
-
if stream_mode and not isinstance(stream_mode, list):
|
|
2233
|
-
stream_mode = [stream_mode]
|
|
2242
|
+
from langgraph_api.serde import json_loads
|
|
2234
2243
|
|
|
2235
2244
|
queue = (
|
|
2236
2245
|
stream_channel
|
|
2237
2246
|
if stream_channel
|
|
2238
|
-
else await Runs.Stream.subscribe(run_id
|
|
2247
|
+
else await Runs.Stream.subscribe(run_id)
|
|
2239
2248
|
)
|
|
2240
2249
|
|
|
2241
2250
|
try:
|
|
@@ -2256,53 +2265,71 @@ class Runs(Authenticated):
|
|
|
2256
2265
|
)
|
|
2257
2266
|
)
|
|
2258
2267
|
run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
|
|
2259
|
-
channel_prefix = f"run:{run_id}:stream:"
|
|
2260
|
-
len_prefix = len(channel_prefix.encode())
|
|
2261
2268
|
|
|
2262
2269
|
for message in get_stream_manager().restore_messages(
|
|
2263
2270
|
run_id, last_event_id
|
|
2264
2271
|
):
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
|
|
2272
|
+
data, id = message.data, message.id
|
|
2273
|
+
|
|
2274
|
+
data = json_loads(data)
|
|
2275
|
+
mode = data["event"]
|
|
2276
|
+
message = data["message"]
|
|
2277
|
+
|
|
2278
|
+
if mode == "control":
|
|
2279
|
+
if message == b"done":
|
|
2268
2280
|
return
|
|
2269
|
-
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
"Replayed run event",
|
|
2277
|
-
run_id=str(run_id),
|
|
2278
|
-
message_id=id,
|
|
2279
|
-
stream_mode=mode,
|
|
2280
|
-
data=data,
|
|
2281
|
+
elif (
|
|
2282
|
+
not stream_mode
|
|
2283
|
+
or mode in stream_mode
|
|
2284
|
+
or (
|
|
2285
|
+
(
|
|
2286
|
+
"messages" in stream_mode
|
|
2287
|
+
or "messages-tuple" in stream_mode
|
|
2281
2288
|
)
|
|
2289
|
+
and mode.startswith("messages")
|
|
2290
|
+
)
|
|
2291
|
+
):
|
|
2292
|
+
yield mode.encode(), base64.b64decode(message), id
|
|
2293
|
+
logger.debug(
|
|
2294
|
+
"Replayed run event",
|
|
2295
|
+
run_id=str(run_id),
|
|
2296
|
+
message_id=id,
|
|
2297
|
+
stream_mode=mode,
|
|
2298
|
+
data=data,
|
|
2299
|
+
)
|
|
2282
2300
|
|
|
2283
2301
|
while True:
|
|
2284
2302
|
try:
|
|
2285
2303
|
# Wait for messages with a timeout
|
|
2286
2304
|
message = await asyncio.wait_for(queue.get(), timeout=0.5)
|
|
2287
|
-
|
|
2305
|
+
data, id = message.data, message.id
|
|
2288
2306
|
|
|
2289
|
-
|
|
2290
|
-
|
|
2307
|
+
data = json_loads(data)
|
|
2308
|
+
mode = data["event"]
|
|
2309
|
+
message = data["message"]
|
|
2310
|
+
|
|
2311
|
+
if mode == "control":
|
|
2312
|
+
if message == b"done":
|
|
2291
2313
|
break
|
|
2292
|
-
|
|
2293
|
-
|
|
2294
|
-
mode
|
|
2295
|
-
|
|
2296
|
-
|
|
2297
|
-
|
|
2298
|
-
|
|
2299
|
-
logger.debug(
|
|
2300
|
-
"Streamed run event",
|
|
2301
|
-
run_id=str(run_id),
|
|
2302
|
-
stream_mode=mode,
|
|
2303
|
-
message_id=id,
|
|
2304
|
-
data=data,
|
|
2314
|
+
elif (
|
|
2315
|
+
not stream_mode
|
|
2316
|
+
or mode in stream_mode
|
|
2317
|
+
or (
|
|
2318
|
+
(
|
|
2319
|
+
"messages" in stream_mode
|
|
2320
|
+
or "messages-tuple" in stream_mode
|
|
2305
2321
|
)
|
|
2322
|
+
and mode.startswith("messages")
|
|
2323
|
+
)
|
|
2324
|
+
):
|
|
2325
|
+
yield mode.encode(), base64.b64decode(message), id
|
|
2326
|
+
logger.debug(
|
|
2327
|
+
"Streamed run event",
|
|
2328
|
+
run_id=str(run_id),
|
|
2329
|
+
stream_mode=mode,
|
|
2330
|
+
message_id=id,
|
|
2331
|
+
data=message,
|
|
2332
|
+
)
|
|
2306
2333
|
except TimeoutError:
|
|
2307
2334
|
# Check if the run is still pending
|
|
2308
2335
|
run_iter = await Runs.get(
|
|
@@ -2342,12 +2369,20 @@ class Runs(Authenticated):
|
|
|
2342
2369
|
resumable: bool = False,
|
|
2343
2370
|
) -> None:
|
|
2344
2371
|
"""Publish a message to all subscribers of the run stream."""
|
|
2345
|
-
|
|
2372
|
+
from langgraph_api.serde import json_dumpb
|
|
2373
|
+
|
|
2374
|
+
topic = f"run:{run_id}:stream".encode()
|
|
2346
2375
|
|
|
2347
2376
|
stream_manager = get_stream_manager()
|
|
2348
2377
|
# Send to all queues subscribed to this run_id
|
|
2378
|
+
payload = json_dumpb(
|
|
2379
|
+
{
|
|
2380
|
+
"event": event,
|
|
2381
|
+
"message": message,
|
|
2382
|
+
}
|
|
2383
|
+
)
|
|
2349
2384
|
await stream_manager.put(
|
|
2350
|
-
run_id, Message(topic=topic, data=
|
|
2385
|
+
run_id, Message(topic=topic, data=payload), resumable
|
|
2351
2386
|
)
|
|
2352
2387
|
|
|
2353
2388
|
|
|
@@ -2356,20 +2391,18 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
|
|
|
2356
2391
|
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
2357
2392
|
|
|
2358
2393
|
stream_manager = get_stream_manager()
|
|
2359
|
-
control_key = f"run:{run_id}:control"
|
|
2360
2394
|
|
|
2361
|
-
if
|
|
2362
|
-
|
|
2363
|
-
|
|
2364
|
-
|
|
2365
|
-
|
|
2366
|
-
|
|
2367
|
-
done.set(UserInterrupt())
|
|
2395
|
+
if control_key := stream_manager.get_control_key(run_id):
|
|
2396
|
+
payload = control_key.data
|
|
2397
|
+
if payload == b"rollback":
|
|
2398
|
+
done.set(UserRollback())
|
|
2399
|
+
elif payload == b"interrupt":
|
|
2400
|
+
done.set(UserInterrupt())
|
|
2368
2401
|
|
|
2369
2402
|
while not done.is_set():
|
|
2370
2403
|
try:
|
|
2371
2404
|
# This task gets cancelled when Runs.enter exits anyway,
|
|
2372
|
-
# so we can have a pretty
|
|
2405
|
+
# so we can have a pretty lengthy timeout here
|
|
2373
2406
|
message = await asyncio.wait_for(queue.get(), timeout=240)
|
|
2374
2407
|
payload = message.data
|
|
2375
2408
|
if payload == b"rollback":
|
|
@@ -2379,10 +2412,6 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
|
|
|
2379
2412
|
elif payload == b"done":
|
|
2380
2413
|
done.set()
|
|
2381
2414
|
break
|
|
2382
|
-
|
|
2383
|
-
# Store control messages for late subscribers
|
|
2384
|
-
if message.topic.decode() == control_key:
|
|
2385
|
-
stream_manager.control_queues[run_id].append(message)
|
|
2386
2415
|
except TimeoutError:
|
|
2387
2416
|
break
|
|
2388
2417
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
langgraph_runtime_inmem/__init__.py,sha256=
|
|
1
|
+
langgraph_runtime_inmem/__init__.py,sha256=3R6DOc_wOQRDAv1aNNMMGjRofpJi6B_qIBkZ373AG44,311
|
|
2
2
|
langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
|
|
3
3
|
langgraph_runtime_inmem/database.py,sha256=G_6L2khpRDSpS2Vs_SujzHayODcwG5V2IhFP7LLBXgw,6349
|
|
4
|
-
langgraph_runtime_inmem/inmem_stream.py,sha256=
|
|
4
|
+
langgraph_runtime_inmem/inmem_stream.py,sha256=UWk1srLF44HZPPbRdArGGhsy0MY0UOJKSIxBSO7Hosc,5138
|
|
5
5
|
langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
|
|
6
6
|
langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
|
|
7
|
-
langgraph_runtime_inmem/ops.py,sha256=
|
|
7
|
+
langgraph_runtime_inmem/ops.py,sha256=CSH5vi7AsaeaWSngZ_DCtsPY-M7ah3cz8acTVi_UbUw,90559
|
|
8
8
|
langgraph_runtime_inmem/queue.py,sha256=nqfgz7j_Jkh5Ek5-RsHB2Uvwbxguu9IUPkGXIxvFPns,10037
|
|
9
9
|
langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
|
|
10
10
|
langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
|
|
11
|
-
langgraph_runtime_inmem-0.6.
|
|
12
|
-
langgraph_runtime_inmem-0.6.
|
|
13
|
-
langgraph_runtime_inmem-0.6.
|
|
11
|
+
langgraph_runtime_inmem-0.6.10.dist-info/METADATA,sha256=5YlbIvGy2fpU9gydQgbaSfu34loSdodD2PUtPlcn_1c,566
|
|
12
|
+
langgraph_runtime_inmem-0.6.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
13
|
+
langgraph_runtime_inmem-0.6.10.dist-info/RECORD,,
|
|
File without changes
|