langgraph-runtime-inmem 0.6.7__py3-none-any.whl → 0.6.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.6.7"
12
+ __version__ = "0.6.10"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -42,6 +42,7 @@ class ContextQueue(asyncio.Queue):
42
42
  class StreamManager:
43
43
  def __init__(self):
44
44
  self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
45
+ self.control_keys = defaultdict()
45
46
  self.control_queues = defaultdict(list)
46
47
 
47
48
  self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
@@ -51,6 +52,14 @@ class StreamManager:
51
52
  run_id = _ensure_uuid(run_id)
52
53
  return self.queues[run_id]
53
54
 
55
+ def get_control_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
56
+ run_id = _ensure_uuid(run_id)
57
+ return self.control_queues[run_id]
58
+
59
+ def get_control_key(self, run_id: UUID | str) -> Message | None:
60
+ run_id = _ensure_uuid(run_id)
61
+ return self.control_keys.get(run_id)
62
+
54
63
  async def put(
55
64
  self, run_id: UUID | str, message: Message, resumable: bool = False
56
65
  ) -> None:
@@ -61,8 +70,10 @@ class StreamManager:
61
70
  self.message_stores[run_id].append(message)
62
71
  topic = message.topic.decode()
63
72
  if "control" in topic:
64
- self.control_queues[run_id].append(message)
65
- queues = self.queues.get(run_id, [])
73
+ self.control_keys[run_id] = message
74
+ queues = self.control_queues[run_id]
75
+ else:
76
+ queues = self.queues[run_id]
66
77
  coros = [queue.put(message) for queue in queues]
67
78
  results = await asyncio.gather(*coros, return_exceptions=True)
68
79
  for result in results:
@@ -73,14 +84,12 @@ class StreamManager:
73
84
  run_id = _ensure_uuid(run_id)
74
85
  queue = ContextQueue()
75
86
  self.queues[run_id].append(queue)
76
- for control_msg in self.control_queues[run_id]:
77
- try:
78
- await queue.put(control_msg)
79
- except Exception:
80
- logger.exception(
81
- f"Failed to put control message in queue: {control_msg}"
82
- )
87
+ return queue
83
88
 
89
+ async def add_control_queue(self, run_id: UUID | str) -> asyncio.Queue:
90
+ run_id = _ensure_uuid(run_id)
91
+ queue = ContextQueue()
92
+ self.control_queues[run_id].append(queue)
84
93
  return queue
85
94
 
86
95
  async def remove_queue(self, run_id: UUID | str, queue: asyncio.Queue):
@@ -89,8 +98,13 @@ class StreamManager:
89
98
  self.queues[run_id].remove(queue)
90
99
  if not self.queues[run_id]:
91
100
  del self.queues[run_id]
92
- if run_id in self.message_stores:
93
- del self.message_stores[run_id]
101
+
102
+ async def remove_control_queue(self, run_id: UUID | str, queue: asyncio.Queue):
103
+ run_id = _ensure_uuid(run_id)
104
+ if run_id in self.control_queues:
105
+ self.control_queues[run_id].remove(queue)
106
+ if not self.control_queues[run_id]:
107
+ del self.control_queues[run_id]
94
108
 
95
109
  def restore_messages(
96
110
  self, run_id: UUID | str, message_id: str | None
@@ -157,9 +157,6 @@ class Assistants(Authenticated):
157
157
  and (not filters or _check_filter_match(assistant["metadata"], filters))
158
158
  ]
159
159
 
160
- # Get total count before sorting and pagination
161
- total_count = len(filtered_assistants)
162
-
163
160
  # Sort based on sort_by and sort_order
164
161
  sort_by = sort_by.lower() if sort_by else None
165
162
  if sort_by and sort_by in (
@@ -186,12 +183,13 @@ class Assistants(Authenticated):
186
183
 
187
184
  # Apply pagination
188
185
  paginated_assistants = filtered_assistants[offset : offset + limit]
186
+ cur = offset + limit if len(filtered_assistants) > offset + limit else None
189
187
 
190
188
  async def assistant_iterator() -> AsyncIterator[Assistant]:
191
189
  for assistant in paginated_assistants:
192
190
  yield assistant
193
191
 
194
- return assistant_iterator(), total_count
192
+ return assistant_iterator(), cur
195
193
 
196
194
  @staticmethod
197
195
  async def get(
@@ -724,7 +722,7 @@ class Threads(Authenticated):
724
722
  else:
725
723
  # Default sorting by created_at in descending order
726
724
  sorted_threads = sorted(
727
- filtered_threads, key=lambda x: x["created_at"], reverse=True
725
+ filtered_threads, key=lambda x: x["updated_at"], reverse=True
728
726
  )
729
727
 
730
728
  # Apply limit and offset
@@ -1453,7 +1451,7 @@ class Threads(Authenticated):
1453
1451
  conn: InMemConnectionProto,
1454
1452
  *,
1455
1453
  config: Config,
1456
- limit: int = 10,
1454
+ limit: int = 1,
1457
1455
  before: str | Checkpoint | None = None,
1458
1456
  metadata: MetadataInput = None,
1459
1457
  ctx: Auth.types.BaseAuthContext | None = None,
@@ -1628,7 +1626,7 @@ class Runs(Authenticated):
1628
1626
 
1629
1627
  stream_manager = get_stream_manager()
1630
1628
  # Get queue for this run
1631
- queue = await Runs.Stream.subscribe(run_id)
1629
+ queue = await stream_manager.add_control_queue(run_id)
1632
1630
 
1633
1631
  async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
1634
1632
  done = ValueEvent()
@@ -1636,16 +1634,21 @@ class Runs(Authenticated):
1636
1634
 
1637
1635
  # Give done event to caller
1638
1636
  yield done
1639
- # Signal done to all subscribers
1637
+ # Store the control message for late subscribers
1640
1638
  control_message = Message(
1641
1639
  topic=f"run:{run_id}:control".encode(), data=b"done"
1642
1640
  )
1643
-
1644
- # Store the control message for late subscribers
1645
1641
  await stream_manager.put(run_id, control_message)
1646
- stream_manager.control_queues[run_id].append(control_message)
1647
- # Clean up this queue
1648
- await stream_manager.remove_queue(run_id, queue)
1642
+
1643
+ # Signal done to all subscribers
1644
+ stream_message = Message(
1645
+ topic=f"run:{run_id}:stream".encode(),
1646
+ data={"event": "control", "message": b"done"},
1647
+ )
1648
+ await stream_manager.put(run_id, stream_message)
1649
+
1650
+ # Remove the queue
1651
+ await stream_manager.remove_control_queue(run_id, queue)
1649
1652
 
1650
1653
  @staticmethod
1651
1654
  async def sweep(conn: InMemConnectionProto) -> list[UUID]:
@@ -1981,6 +1984,16 @@ class Runs(Authenticated):
1981
1984
  return Fragment(
1982
1985
  orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
1983
1986
  )
1987
+ if thread["status"] == "interrupted":
1988
+ # Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
1989
+ # interrupt, but we'll always show one. Long term we should show all of them.
1990
+ try:
1991
+ interrupt_map = thread["interrupts"]
1992
+ interrupt = [next(iter(interrupt_map.values()))[0]]
1993
+ return Fragment(orjson.dumps({"__interrupt__": interrupt}))
1994
+ except Exception:
1995
+ # No interrupt, but status is interrupted from a before/after block. Default back to values.
1996
+ pass
1984
1997
  return thread["values"]
1985
1998
 
1986
1999
  @staticmethod
@@ -2201,8 +2214,6 @@ class Runs(Authenticated):
2201
2214
  @staticmethod
2202
2215
  async def subscribe(
2203
2216
  run_id: UUID,
2204
- *,
2205
- stream_mode: StreamMode | None = None,
2206
2217
  ) -> asyncio.Queue:
2207
2218
  """Subscribe to the run stream, returning a queue."""
2208
2219
  stream_manager = get_stream_manager()
@@ -2222,20 +2233,18 @@ class Runs(Authenticated):
2222
2233
  ignore_404: bool = False,
2223
2234
  cancel_on_disconnect: bool = False,
2224
2235
  stream_channel: asyncio.Queue | None = None,
2225
- stream_mode: list[StreamMode] | StreamMode,
2236
+ stream_mode: list[StreamMode] | StreamMode | None = None,
2226
2237
  last_event_id: str | None = None,
2227
2238
  ctx: Auth.types.BaseAuthContext | None = None,
2228
2239
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
2229
2240
  """Stream the run output."""
2230
2241
  from langgraph_api.asyncio import create_task
2231
-
2232
- if stream_mode and not isinstance(stream_mode, list):
2233
- stream_mode = [stream_mode]
2242
+ from langgraph_api.serde import json_loads
2234
2243
 
2235
2244
  queue = (
2236
2245
  stream_channel
2237
2246
  if stream_channel
2238
- else await Runs.Stream.subscribe(run_id, stream_mode=stream_mode)
2247
+ else await Runs.Stream.subscribe(run_id)
2239
2248
  )
2240
2249
 
2241
2250
  try:
@@ -2256,53 +2265,71 @@ class Runs(Authenticated):
2256
2265
  )
2257
2266
  )
2258
2267
  run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
2259
- channel_prefix = f"run:{run_id}:stream:"
2260
- len_prefix = len(channel_prefix.encode())
2261
2268
 
2262
2269
  for message in get_stream_manager().restore_messages(
2263
2270
  run_id, last_event_id
2264
2271
  ):
2265
- topic, data, id = message.topic, message.data, message.id
2266
- if topic.decode() == f"run:{run_id}:control":
2267
- if data == b"done":
2272
+ data, id = message.data, message.id
2273
+
2274
+ data = json_loads(data)
2275
+ mode = data["event"]
2276
+ message = data["message"]
2277
+
2278
+ if mode == "control":
2279
+ if message == b"done":
2268
2280
  return
2269
- else:
2270
- mode = topic[len_prefix:]
2271
- if mode == b"updates" and "updates" not in stream_mode:
2272
- continue
2273
- else:
2274
- yield mode, data, id
2275
- logger.debug(
2276
- "Replayed run event",
2277
- run_id=str(run_id),
2278
- message_id=id,
2279
- stream_mode=mode,
2280
- data=data,
2281
+ elif (
2282
+ not stream_mode
2283
+ or mode in stream_mode
2284
+ or (
2285
+ (
2286
+ "messages" in stream_mode
2287
+ or "messages-tuple" in stream_mode
2281
2288
  )
2289
+ and mode.startswith("messages")
2290
+ )
2291
+ ):
2292
+ yield mode.encode(), base64.b64decode(message), id
2293
+ logger.debug(
2294
+ "Replayed run event",
2295
+ run_id=str(run_id),
2296
+ message_id=id,
2297
+ stream_mode=mode,
2298
+ data=data,
2299
+ )
2282
2300
 
2283
2301
  while True:
2284
2302
  try:
2285
2303
  # Wait for messages with a timeout
2286
2304
  message = await asyncio.wait_for(queue.get(), timeout=0.5)
2287
- topic, data, id = message.topic, message.data, message.id
2305
+ data, id = message.data, message.id
2288
2306
 
2289
- if topic.decode() == f"run:{run_id}:control":
2290
- if data == b"done":
2307
+ data = json_loads(data)
2308
+ mode = data["event"]
2309
+ message = data["message"]
2310
+
2311
+ if mode == "control":
2312
+ if message == b"done":
2291
2313
  break
2292
- else:
2293
- # Extract mode from topic
2294
- mode = topic[len_prefix:]
2295
- if mode == b"updates" and "updates" not in stream_mode:
2296
- continue
2297
- else:
2298
- yield mode, data, id
2299
- logger.debug(
2300
- "Streamed run event",
2301
- run_id=str(run_id),
2302
- stream_mode=mode,
2303
- message_id=id,
2304
- data=data,
2314
+ elif (
2315
+ not stream_mode
2316
+ or mode in stream_mode
2317
+ or (
2318
+ (
2319
+ "messages" in stream_mode
2320
+ or "messages-tuple" in stream_mode
2305
2321
  )
2322
+ and mode.startswith("messages")
2323
+ )
2324
+ ):
2325
+ yield mode.encode(), base64.b64decode(message), id
2326
+ logger.debug(
2327
+ "Streamed run event",
2328
+ run_id=str(run_id),
2329
+ stream_mode=mode,
2330
+ message_id=id,
2331
+ data=message,
2332
+ )
2306
2333
  except TimeoutError:
2307
2334
  # Check if the run is still pending
2308
2335
  run_iter = await Runs.get(
@@ -2342,12 +2369,20 @@ class Runs(Authenticated):
2342
2369
  resumable: bool = False,
2343
2370
  ) -> None:
2344
2371
  """Publish a message to all subscribers of the run stream."""
2345
- topic = f"run:{run_id}:stream:{event}".encode()
2372
+ from langgraph_api.serde import json_dumpb
2373
+
2374
+ topic = f"run:{run_id}:stream".encode()
2346
2375
 
2347
2376
  stream_manager = get_stream_manager()
2348
2377
  # Send to all queues subscribed to this run_id
2378
+ payload = json_dumpb(
2379
+ {
2380
+ "event": event,
2381
+ "message": message,
2382
+ }
2383
+ )
2349
2384
  await stream_manager.put(
2350
- run_id, Message(topic=topic, data=message), resumable
2385
+ run_id, Message(topic=topic, data=payload), resumable
2351
2386
  )
2352
2387
 
2353
2388
 
@@ -2356,20 +2391,18 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
2356
2391
  from langgraph_api.errors import UserInterrupt, UserRollback
2357
2392
 
2358
2393
  stream_manager = get_stream_manager()
2359
- control_key = f"run:{run_id}:control"
2360
2394
 
2361
- if existing_queue := stream_manager.control_queues.get(run_id):
2362
- for message in existing_queue:
2363
- payload = message.data
2364
- if payload == b"rollback":
2365
- done.set(UserRollback())
2366
- elif payload == b"interrupt":
2367
- done.set(UserInterrupt())
2395
+ if control_key := stream_manager.get_control_key(run_id):
2396
+ payload = control_key.data
2397
+ if payload == b"rollback":
2398
+ done.set(UserRollback())
2399
+ elif payload == b"interrupt":
2400
+ done.set(UserInterrupt())
2368
2401
 
2369
2402
  while not done.is_set():
2370
2403
  try:
2371
2404
  # This task gets cancelled when Runs.enter exits anyway,
2372
- # so we can have a pretty length timeout here
2405
+ # so we can have a pretty lengthy timeout here
2373
2406
  message = await asyncio.wait_for(queue.get(), timeout=240)
2374
2407
  payload = message.data
2375
2408
  if payload == b"rollback":
@@ -2379,10 +2412,6 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
2379
2412
  elif payload == b"done":
2380
2413
  done.set()
2381
2414
  break
2382
-
2383
- # Store control messages for late subscribers
2384
- if message.topic.decode() == control_key:
2385
- stream_manager.control_queues[run_id].append(message)
2386
2415
  except TimeoutError:
2387
2416
  break
2388
2417
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.6.7
3
+ Version: 0.6.10
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -1,13 +1,13 @@
1
- langgraph_runtime_inmem/__init__.py,sha256=IO8abGzSnjJCCIfUCYcUfMMQUyqIIQ6oHu2-2yJJCg0,310
1
+ langgraph_runtime_inmem/__init__.py,sha256=3R6DOc_wOQRDAv1aNNMMGjRofpJi6B_qIBkZ373AG44,311
2
2
  langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
3
3
  langgraph_runtime_inmem/database.py,sha256=G_6L2khpRDSpS2Vs_SujzHayODcwG5V2IhFP7LLBXgw,6349
4
- langgraph_runtime_inmem/inmem_stream.py,sha256=65z_2mBNJ0-yJsXWnlYwRc71039_y6Sa0MN8fL_U3Ko,4581
4
+ langgraph_runtime_inmem/inmem_stream.py,sha256=UWk1srLF44HZPPbRdArGGhsy0MY0UOJKSIxBSO7Hosc,5138
5
5
  langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
6
6
  langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
7
- langgraph_runtime_inmem/ops.py,sha256=kx0f6pKqYTypbcoMByM6L0YBF4O0b1TEYpOqdBnJMao,89354
7
+ langgraph_runtime_inmem/ops.py,sha256=CSH5vi7AsaeaWSngZ_DCtsPY-M7ah3cz8acTVi_UbUw,90559
8
8
  langgraph_runtime_inmem/queue.py,sha256=nqfgz7j_Jkh5Ek5-RsHB2Uvwbxguu9IUPkGXIxvFPns,10037
9
9
  langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
10
10
  langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
11
- langgraph_runtime_inmem-0.6.7.dist-info/METADATA,sha256=83cPJjL678QXBM5qzhYS5CRQZUObUqyGjQRLMI-1TBA,565
12
- langgraph_runtime_inmem-0.6.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- langgraph_runtime_inmem-0.6.7.dist-info/RECORD,,
11
+ langgraph_runtime_inmem-0.6.10.dist-info/METADATA,sha256=5YlbIvGy2fpU9gydQgbaSfu34loSdodD2PUtPlcn_1c,566
12
+ langgraph_runtime_inmem-0.6.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ langgraph_runtime_inmem-0.6.10.dist-info/RECORD,,