langgraph-runtime-inmem 0.6.8__py3-none-any.whl → 0.6.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.6.8"
12
+ __version__ = "0.6.10"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -42,6 +42,7 @@ class ContextQueue(asyncio.Queue):
42
42
  class StreamManager:
43
43
  def __init__(self):
44
44
  self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
45
+ self.control_keys = defaultdict()
45
46
  self.control_queues = defaultdict(list)
46
47
 
47
48
  self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
@@ -51,6 +52,14 @@ class StreamManager:
51
52
  run_id = _ensure_uuid(run_id)
52
53
  return self.queues[run_id]
53
54
 
55
+ def get_control_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
56
+ run_id = _ensure_uuid(run_id)
57
+ return self.control_queues[run_id]
58
+
59
+ def get_control_key(self, run_id: UUID | str) -> Message | None:
60
+ run_id = _ensure_uuid(run_id)
61
+ return self.control_keys.get(run_id)
62
+
54
63
  async def put(
55
64
  self, run_id: UUID | str, message: Message, resumable: bool = False
56
65
  ) -> None:
@@ -61,8 +70,10 @@ class StreamManager:
61
70
  self.message_stores[run_id].append(message)
62
71
  topic = message.topic.decode()
63
72
  if "control" in topic:
64
- self.control_queues[run_id].append(message)
65
- queues = self.queues.get(run_id, [])
73
+ self.control_keys[run_id] = message
74
+ queues = self.control_queues[run_id]
75
+ else:
76
+ queues = self.queues[run_id]
66
77
  coros = [queue.put(message) for queue in queues]
67
78
  results = await asyncio.gather(*coros, return_exceptions=True)
68
79
  for result in results:
@@ -73,14 +84,12 @@ class StreamManager:
73
84
  run_id = _ensure_uuid(run_id)
74
85
  queue = ContextQueue()
75
86
  self.queues[run_id].append(queue)
76
- for control_msg in self.control_queues[run_id]:
77
- try:
78
- await queue.put(control_msg)
79
- except Exception:
80
- logger.exception(
81
- f"Failed to put control message in queue: {control_msg}"
82
- )
87
+ return queue
83
88
 
89
+ async def add_control_queue(self, run_id: UUID | str) -> asyncio.Queue:
90
+ run_id = _ensure_uuid(run_id)
91
+ queue = ContextQueue()
92
+ self.control_queues[run_id].append(queue)
84
93
  return queue
85
94
 
86
95
  async def remove_queue(self, run_id: UUID | str, queue: asyncio.Queue):
@@ -89,8 +98,13 @@ class StreamManager:
89
98
  self.queues[run_id].remove(queue)
90
99
  if not self.queues[run_id]:
91
100
  del self.queues[run_id]
92
- if run_id in self.message_stores:
93
- del self.message_stores[run_id]
101
+
102
+ async def remove_control_queue(self, run_id: UUID | str, queue: asyncio.Queue):
103
+ run_id = _ensure_uuid(run_id)
104
+ if run_id in self.control_queues:
105
+ self.control_queues[run_id].remove(queue)
106
+ if not self.control_queues[run_id]:
107
+ del self.control_queues[run_id]
94
108
 
95
109
  def restore_messages(
96
110
  self, run_id: UUID | str, message_id: str | None
@@ -722,7 +722,7 @@ class Threads(Authenticated):
722
722
  else:
723
723
  # Default sorting by created_at in descending order
724
724
  sorted_threads = sorted(
725
- filtered_threads, key=lambda x: x["created_at"], reverse=True
725
+ filtered_threads, key=lambda x: x["updated_at"], reverse=True
726
726
  )
727
727
 
728
728
  # Apply limit and offset
@@ -1451,7 +1451,7 @@ class Threads(Authenticated):
1451
1451
  conn: InMemConnectionProto,
1452
1452
  *,
1453
1453
  config: Config,
1454
- limit: int = 10,
1454
+ limit: int = 1,
1455
1455
  before: str | Checkpoint | None = None,
1456
1456
  metadata: MetadataInput = None,
1457
1457
  ctx: Auth.types.BaseAuthContext | None = None,
@@ -1626,7 +1626,7 @@ class Runs(Authenticated):
1626
1626
 
1627
1627
  stream_manager = get_stream_manager()
1628
1628
  # Get queue for this run
1629
- queue = await Runs.Stream.subscribe(run_id)
1629
+ queue = await stream_manager.add_control_queue(run_id)
1630
1630
 
1631
1631
  async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
1632
1632
  done = ValueEvent()
@@ -1634,16 +1634,21 @@ class Runs(Authenticated):
1634
1634
 
1635
1635
  # Give done event to caller
1636
1636
  yield done
1637
- # Signal done to all subscribers
1637
+ # Store the control message for late subscribers
1638
1638
  control_message = Message(
1639
1639
  topic=f"run:{run_id}:control".encode(), data=b"done"
1640
1640
  )
1641
-
1642
- # Store the control message for late subscribers
1643
1641
  await stream_manager.put(run_id, control_message)
1644
- stream_manager.control_queues[run_id].append(control_message)
1645
- # Clean up this queue
1646
- await stream_manager.remove_queue(run_id, queue)
1642
+
1643
+ # Signal done to all subscribers
1644
+ stream_message = Message(
1645
+ topic=f"run:{run_id}:stream".encode(),
1646
+ data={"event": "control", "message": b"done"},
1647
+ )
1648
+ await stream_manager.put(run_id, stream_message)
1649
+
1650
+ # Remove the queue
1651
+ await stream_manager.remove_control_queue(run_id, queue)
1647
1652
 
1648
1653
  @staticmethod
1649
1654
  async def sweep(conn: InMemConnectionProto) -> list[UUID]:
@@ -1979,6 +1984,16 @@ class Runs(Authenticated):
1979
1984
  return Fragment(
1980
1985
  orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
1981
1986
  )
1987
+ if thread["status"] == "interrupted":
1988
+ # Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
1989
+ # interrupt, but we'll always show one. Long term we should show all of them.
1990
+ try:
1991
+ interrupt_map = thread["interrupts"]
1992
+ interrupt = [next(iter(interrupt_map.values()))[0]]
1993
+ return Fragment(orjson.dumps({"__interrupt__": interrupt}))
1994
+ except Exception:
1995
+ # No interrupt, but status is interrupted from a before/after block. Default back to values.
1996
+ pass
1982
1997
  return thread["values"]
1983
1998
 
1984
1999
  @staticmethod
@@ -2199,8 +2214,6 @@ class Runs(Authenticated):
2199
2214
  @staticmethod
2200
2215
  async def subscribe(
2201
2216
  run_id: UUID,
2202
- *,
2203
- stream_mode: StreamMode | None = None,
2204
2217
  ) -> asyncio.Queue:
2205
2218
  """Subscribe to the run stream, returning a queue."""
2206
2219
  stream_manager = get_stream_manager()
@@ -2220,20 +2233,18 @@ class Runs(Authenticated):
2220
2233
  ignore_404: bool = False,
2221
2234
  cancel_on_disconnect: bool = False,
2222
2235
  stream_channel: asyncio.Queue | None = None,
2223
- stream_mode: list[StreamMode] | StreamMode,
2236
+ stream_mode: list[StreamMode] | StreamMode | None = None,
2224
2237
  last_event_id: str | None = None,
2225
2238
  ctx: Auth.types.BaseAuthContext | None = None,
2226
2239
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
2227
2240
  """Stream the run output."""
2228
2241
  from langgraph_api.asyncio import create_task
2229
-
2230
- if stream_mode and not isinstance(stream_mode, list):
2231
- stream_mode = [stream_mode]
2242
+ from langgraph_api.serde import json_loads
2232
2243
 
2233
2244
  queue = (
2234
2245
  stream_channel
2235
2246
  if stream_channel
2236
- else await Runs.Stream.subscribe(run_id, stream_mode=stream_mode)
2247
+ else await Runs.Stream.subscribe(run_id)
2237
2248
  )
2238
2249
 
2239
2250
  try:
@@ -2254,53 +2265,71 @@ class Runs(Authenticated):
2254
2265
  )
2255
2266
  )
2256
2267
  run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
2257
- channel_prefix = f"run:{run_id}:stream:"
2258
- len_prefix = len(channel_prefix.encode())
2259
2268
 
2260
2269
  for message in get_stream_manager().restore_messages(
2261
2270
  run_id, last_event_id
2262
2271
  ):
2263
- topic, data, id = message.topic, message.data, message.id
2264
- if topic.decode() == f"run:{run_id}:control":
2265
- if data == b"done":
2272
+ data, id = message.data, message.id
2273
+
2274
+ data = json_loads(data)
2275
+ mode = data["event"]
2276
+ message = data["message"]
2277
+
2278
+ if mode == "control":
2279
+ if message == b"done":
2266
2280
  return
2267
- else:
2268
- mode = topic[len_prefix:]
2269
- if mode == b"updates" and "updates" not in stream_mode:
2270
- continue
2271
- else:
2272
- yield mode, data, id
2273
- logger.debug(
2274
- "Replayed run event",
2275
- run_id=str(run_id),
2276
- message_id=id,
2277
- stream_mode=mode,
2278
- data=data,
2281
+ elif (
2282
+ not stream_mode
2283
+ or mode in stream_mode
2284
+ or (
2285
+ (
2286
+ "messages" in stream_mode
2287
+ or "messages-tuple" in stream_mode
2279
2288
  )
2289
+ and mode.startswith("messages")
2290
+ )
2291
+ ):
2292
+ yield mode.encode(), base64.b64decode(message), id
2293
+ logger.debug(
2294
+ "Replayed run event",
2295
+ run_id=str(run_id),
2296
+ message_id=id,
2297
+ stream_mode=mode,
2298
+ data=data,
2299
+ )
2280
2300
 
2281
2301
  while True:
2282
2302
  try:
2283
2303
  # Wait for messages with a timeout
2284
2304
  message = await asyncio.wait_for(queue.get(), timeout=0.5)
2285
- topic, data, id = message.topic, message.data, message.id
2305
+ data, id = message.data, message.id
2286
2306
 
2287
- if topic.decode() == f"run:{run_id}:control":
2288
- if data == b"done":
2307
+ data = json_loads(data)
2308
+ mode = data["event"]
2309
+ message = data["message"]
2310
+
2311
+ if mode == "control":
2312
+ if message == b"done":
2289
2313
  break
2290
- else:
2291
- # Extract mode from topic
2292
- mode = topic[len_prefix:]
2293
- if mode == b"updates" and "updates" not in stream_mode:
2294
- continue
2295
- else:
2296
- yield mode, data, id
2297
- logger.debug(
2298
- "Streamed run event",
2299
- run_id=str(run_id),
2300
- stream_mode=mode,
2301
- message_id=id,
2302
- data=data,
2314
+ elif (
2315
+ not stream_mode
2316
+ or mode in stream_mode
2317
+ or (
2318
+ (
2319
+ "messages" in stream_mode
2320
+ or "messages-tuple" in stream_mode
2303
2321
  )
2322
+ and mode.startswith("messages")
2323
+ )
2324
+ ):
2325
+ yield mode.encode(), base64.b64decode(message), id
2326
+ logger.debug(
2327
+ "Streamed run event",
2328
+ run_id=str(run_id),
2329
+ stream_mode=mode,
2330
+ message_id=id,
2331
+ data=message,
2332
+ )
2304
2333
  except TimeoutError:
2305
2334
  # Check if the run is still pending
2306
2335
  run_iter = await Runs.get(
@@ -2340,12 +2369,20 @@ class Runs(Authenticated):
2340
2369
  resumable: bool = False,
2341
2370
  ) -> None:
2342
2371
  """Publish a message to all subscribers of the run stream."""
2343
- topic = f"run:{run_id}:stream:{event}".encode()
2372
+ from langgraph_api.serde import json_dumpb
2373
+
2374
+ topic = f"run:{run_id}:stream".encode()
2344
2375
 
2345
2376
  stream_manager = get_stream_manager()
2346
2377
  # Send to all queues subscribed to this run_id
2378
+ payload = json_dumpb(
2379
+ {
2380
+ "event": event,
2381
+ "message": message,
2382
+ }
2383
+ )
2347
2384
  await stream_manager.put(
2348
- run_id, Message(topic=topic, data=message), resumable
2385
+ run_id, Message(topic=topic, data=payload), resumable
2349
2386
  )
2350
2387
 
2351
2388
 
@@ -2354,20 +2391,18 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
2354
2391
  from langgraph_api.errors import UserInterrupt, UserRollback
2355
2392
 
2356
2393
  stream_manager = get_stream_manager()
2357
- control_key = f"run:{run_id}:control"
2358
2394
 
2359
- if existing_queue := stream_manager.control_queues.get(run_id):
2360
- for message in existing_queue:
2361
- payload = message.data
2362
- if payload == b"rollback":
2363
- done.set(UserRollback())
2364
- elif payload == b"interrupt":
2365
- done.set(UserInterrupt())
2395
+ if control_key := stream_manager.get_control_key(run_id):
2396
+ payload = control_key.data
2397
+ if payload == b"rollback":
2398
+ done.set(UserRollback())
2399
+ elif payload == b"interrupt":
2400
+ done.set(UserInterrupt())
2366
2401
 
2367
2402
  while not done.is_set():
2368
2403
  try:
2369
2404
  # This task gets cancelled when Runs.enter exits anyway,
2370
- # so we can have a pretty length timeout here
2405
+ # so we can have a pretty lengthy timeout here
2371
2406
  message = await asyncio.wait_for(queue.get(), timeout=240)
2372
2407
  payload = message.data
2373
2408
  if payload == b"rollback":
@@ -2377,10 +2412,6 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
2377
2412
  elif payload == b"done":
2378
2413
  done.set()
2379
2414
  break
2380
-
2381
- # Store control messages for late subscribers
2382
- if message.topic.decode() == control_key:
2383
- stream_manager.control_queues[run_id].append(message)
2384
2415
  except TimeoutError:
2385
2416
  break
2386
2417
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.6.8
3
+ Version: 0.6.10
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -1,13 +1,13 @@
1
- langgraph_runtime_inmem/__init__.py,sha256=Zy5RCTPPvryu_HMij5RWuFUvkeY5gEeOV66pJgigHDM,310
1
+ langgraph_runtime_inmem/__init__.py,sha256=3R6DOc_wOQRDAv1aNNMMGjRofpJi6B_qIBkZ373AG44,311
2
2
  langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
3
3
  langgraph_runtime_inmem/database.py,sha256=G_6L2khpRDSpS2Vs_SujzHayODcwG5V2IhFP7LLBXgw,6349
4
- langgraph_runtime_inmem/inmem_stream.py,sha256=65z_2mBNJ0-yJsXWnlYwRc71039_y6Sa0MN8fL_U3Ko,4581
4
+ langgraph_runtime_inmem/inmem_stream.py,sha256=UWk1srLF44HZPPbRdArGGhsy0MY0UOJKSIxBSO7Hosc,5138
5
5
  langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
6
6
  langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
7
- langgraph_runtime_inmem/ops.py,sha256=CpicJwlu55cGm8WMWtfyse1Sy1rj8vLZqbLWVF45mB0,89326
7
+ langgraph_runtime_inmem/ops.py,sha256=CSH5vi7AsaeaWSngZ_DCtsPY-M7ah3cz8acTVi_UbUw,90559
8
8
  langgraph_runtime_inmem/queue.py,sha256=nqfgz7j_Jkh5Ek5-RsHB2Uvwbxguu9IUPkGXIxvFPns,10037
9
9
  langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
10
10
  langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
11
- langgraph_runtime_inmem-0.6.8.dist-info/METADATA,sha256=v2xaQ-PTil64hCubv5PMmqlc-k1xQSExh4jisUNbTCk,565
12
- langgraph_runtime_inmem-0.6.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- langgraph_runtime_inmem-0.6.8.dist-info/RECORD,,
11
+ langgraph_runtime_inmem-0.6.10.dist-info/METADATA,sha256=5YlbIvGy2fpU9gydQgbaSfu34loSdodD2PUtPlcn_1c,566
12
+ langgraph_runtime_inmem-0.6.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ langgraph_runtime_inmem-0.6.10.dist-info/RECORD,,