langgraph-runtime-inmem 0.12.1__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.12.1"
12
+ __version__ = "0.14.0"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -108,9 +108,10 @@ class StreamManager:
108
108
  thread_id = _ensure_uuid(thread_id)
109
109
 
110
110
  message.id = _generate_ms_seq_id().encode()
111
+ # For resumable run streams, embed the generated message ID into the frame
112
+ topic = message.topic.decode()
111
113
  if resumable:
112
114
  self.message_stores[thread_id][run_id].append(message)
113
- topic = message.topic.decode()
114
115
  if "control" in topic:
115
116
  self.control_keys[thread_id][run_id] = message
116
117
  queues = self.control_queues[thread_id][run_id]
@@ -62,9 +62,9 @@ if typing.TYPE_CHECKING:
62
62
  ThreadUpdateResponse,
63
63
  )
64
64
  from langgraph_api.schema import Interrupt as InterruptSchema
65
- from langgraph_api.serde import Fragment
66
65
  from langgraph_api.utils import AsyncConnectionProto
67
66
 
67
+ StreamHandler = ContextQueue
68
68
 
69
69
  logger = structlog.stdlib.get_logger(__name__)
70
70
 
@@ -1720,6 +1720,9 @@ class Threads(Authenticated):
1720
1720
  stream_modes: list[ThreadStreamMode],
1721
1721
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
1722
1722
  """Stream the thread output."""
1723
+ from langgraph_api.utils.stream_codec import (
1724
+ decode_stream_message,
1725
+ )
1723
1726
 
1724
1727
  def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
1725
1728
  """Check if an event should be filtered out based on stream_modes."""
@@ -1738,8 +1741,6 @@ class Threads(Authenticated):
1738
1741
  pass
1739
1742
  return True
1740
1743
 
1741
- from langgraph_api.serde import json_loads
1742
-
1743
1744
  stream_manager = get_stream_manager()
1744
1745
  seen_runs: set[UUID] = set()
1745
1746
  created_queues: list[tuple[UUID, asyncio.Queue]] = []
@@ -1768,35 +1769,26 @@ class Threads(Authenticated):
1768
1769
 
1769
1770
  # Yield sorted events
1770
1771
  for message, run_id in all_events:
1771
- data = json_loads(message.data)
1772
- event_name = data["event"]
1773
- message_content = data["message"]
1772
+ decoded = decode_stream_message(
1773
+ message.data, channel=message.topic
1774
+ )
1775
+ event_bytes = decoded.event_bytes
1776
+ message_bytes = decoded.message_bytes
1774
1777
 
1775
- if event_name == "control":
1776
- if message_content == b"done":
1778
+ if event_bytes == b"control":
1779
+ if message_bytes == b"done":
1777
1780
  event_bytes = b"metadata"
1778
1781
  message_bytes = orjson.dumps(
1779
1782
  {"status": "run_done", "run_id": run_id}
1780
1783
  )
1781
- # Filter events based on stream_modes
1782
- if not should_filter_event(
1783
- "metadata", message_bytes
1784
- ):
1785
- yield (
1786
- event_bytes,
1787
- message_bytes,
1788
- message.id,
1789
- )
1790
- else:
1791
- event_bytes = event_name.encode()
1792
- message_bytes = base64.b64decode(message_content)
1793
- # Filter events based on stream_modes
1794
- if not should_filter_event(event_name, message_bytes):
1795
- yield (
1796
- event_bytes,
1797
- message_bytes,
1798
- message.id,
1799
- )
1784
+ if not should_filter_event(
1785
+ event_bytes.decode("utf-8"), message_bytes
1786
+ ):
1787
+ yield (
1788
+ event_bytes,
1789
+ message_bytes,
1790
+ message.id,
1791
+ )
1800
1792
 
1801
1793
  # Listen for live messages from all queues
1802
1794
  while True:
@@ -1813,40 +1805,27 @@ class Threads(Authenticated):
1813
1805
  message = await asyncio.wait_for(
1814
1806
  queue.get(), timeout=0.2
1815
1807
  )
1816
- data = json_loads(message.data)
1817
- event_name = data["event"]
1818
- message_content = data["message"]
1819
-
1820
- if event_name == "control":
1821
- if message_content == b"done":
1822
- # Extract run_id from topic
1823
- topic = message.topic.decode()
1824
- run_id = topic.split("run:")[1].split(":")[0]
1825
- event_bytes = b"metadata"
1826
- message_bytes = orjson.dumps(
1827
- {"status": "run_done", "run_id": run_id}
1828
- )
1829
- # Filter events based on stream_modes
1830
- if not should_filter_event(
1831
- "metadata", message_bytes
1832
- ):
1833
- yield (
1834
- event_bytes,
1835
- message_bytes,
1836
- message.id,
1837
- )
1838
- else:
1839
- event_bytes = event_name.encode()
1840
- message_bytes = base64.b64decode(message_content)
1841
- # Filter events based on stream_modes
1808
+ decoded = decode_stream_message(
1809
+ message.data, channel=message.topic
1810
+ )
1811
+ event = decoded.event_bytes
1812
+ event_name = event.decode("utf-8")
1813
+ payload = decoded.message_bytes
1814
+
1815
+ if event == b"control" and payload == b"done":
1816
+ topic = message.topic.decode()
1817
+ run_id = topic.split("run:")[1].split(":")[0]
1818
+ meta_event = b"metadata"
1819
+ meta_payload = orjson.dumps(
1820
+ {"status": "run_done", "run_id": run_id}
1821
+ )
1842
1822
  if not should_filter_event(
1843
- event_name, message_bytes
1823
+ "metadata", meta_payload
1844
1824
  ):
1845
- yield (
1846
- event_bytes,
1847
- message_bytes,
1848
- message.id,
1849
- )
1825
+ yield (meta_event, meta_payload, message.id)
1826
+ else:
1827
+ if not should_filter_event(event_name, payload):
1828
+ yield (event, payload, message.id)
1850
1829
 
1851
1830
  except TimeoutError:
1852
1831
  continue
@@ -1882,18 +1861,12 @@ class Threads(Authenticated):
1882
1861
  message: bytes,
1883
1862
  ) -> None:
1884
1863
  """Publish a thread-level event to the thread stream."""
1885
- from langgraph_api.serde import json_dumpb
1864
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
1886
1865
 
1887
1866
  topic = f"thread:{thread_id}:stream".encode()
1888
1867
 
1889
1868
  stream_manager = get_stream_manager()
1890
- # Send to thread stream topic
1891
- payload = json_dumpb(
1892
- {
1893
- "event": event,
1894
- "message": message,
1895
- }
1896
- )
1869
+ payload = STREAM_CODEC.encode(event, message)
1897
1870
  await stream_manager.put_thread(
1898
1871
  str(thread_id), Message(topic=topic, data=payload)
1899
1872
  )
@@ -2065,6 +2038,7 @@ class Runs(Authenticated):
2065
2038
  This method should be called as a context manager by a worker executing a run.
2066
2039
  """
2067
2040
  from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
2041
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
2068
2042
 
2069
2043
  stream_manager = get_stream_manager()
2070
2044
  # Get control queue for this run (normal queue is created during run creation)
@@ -2084,10 +2058,10 @@ class Runs(Authenticated):
2084
2058
  )
2085
2059
  await stream_manager.put(run_id, thread_id, control_message)
2086
2060
 
2087
- # Signal done to all subscribers
2061
+ # Signal done to all subscribers using stream codec
2088
2062
  stream_message = Message(
2089
2063
  topic=f"run:{run_id}:stream".encode(),
2090
- data={"event": "control", "message": b"done"},
2064
+ data=STREAM_CODEC.encode("control", b"done"),
2091
2065
  )
2092
2066
  await stream_manager.put(
2093
2067
  run_id, thread_id, stream_message, resumable=resumable
@@ -2399,68 +2373,6 @@ class Runs(Authenticated):
2399
2373
 
2400
2374
  return _yield_deleted()
2401
2375
 
2402
- @staticmethod
2403
- async def join(
2404
- run_id: UUID,
2405
- *,
2406
- thread_id: UUID,
2407
- ctx: Auth.types.BaseAuthContext | None = None,
2408
- ) -> Fragment:
2409
- """Wait for a run to complete. If already done, return immediately.
2410
-
2411
- Returns:
2412
- the final state of the run.
2413
- """
2414
- from langgraph_api.serde import Fragment
2415
- from langgraph_api.utils import fetchone
2416
-
2417
- async with connect() as conn:
2418
- # Validate ownership
2419
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
2420
- await fetchone(thread_iter)
2421
- last_chunk: bytes | None = None
2422
- # wait for the run to complete
2423
- # Rely on this join's auth
2424
- async with await Runs.Stream.subscribe(run_id, thread_id) as sub:
2425
- async for mode, chunk, _ in Runs.Stream.join(
2426
- run_id,
2427
- thread_id=thread_id,
2428
- ctx=ctx,
2429
- ignore_404=True,
2430
- stream_channel=sub,
2431
- stream_mode=["values", "updates", "error"],
2432
- ):
2433
- if mode == b"values":
2434
- last_chunk = chunk
2435
- elif mode == b"updates" and b"__interrupt__" in chunk:
2436
- last_chunk = chunk
2437
- elif mode == b"error":
2438
- last_chunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
2439
- # if we received a final chunk, return it
2440
- if last_chunk is not None:
2441
- # ie. if the run completed while we were waiting for it
2442
- return Fragment(last_chunk)
2443
- else:
2444
- # otherwise, the run had already finished, so fetch the state from thread
2445
- async with connect() as conn:
2446
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
2447
- thread = await fetchone(thread_iter)
2448
- if thread["status"] == "error":
2449
- return Fragment(
2450
- orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
2451
- )
2452
- if thread["status"] == "interrupted":
2453
- # Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
2454
- # interrupt, but we'll always show one. Long term we should show all of them.
2455
- try:
2456
- interrupt_map = thread["interrupts"]
2457
- interrupt = [next(iter(interrupt_map.values()))[0]]
2458
- return Fragment(orjson.dumps({"__interrupt__": interrupt}))
2459
- except Exception:
2460
- # No interrupt, but status is interrupted from a before/after block. Default back to values.
2461
- pass
2462
- return thread["values"]
2463
-
2464
2376
  @staticmethod
2465
2377
  async def cancel(
2466
2378
  conn: InMemConnectionProto | AsyncConnectionProto,
@@ -2720,39 +2632,28 @@ class Runs(Authenticated):
2720
2632
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
2721
2633
  """Stream the run output."""
2722
2634
  from langgraph_api.asyncio import create_task
2723
- from langgraph_api.serde import json_loads
2635
+ from langgraph_api.serde import json_dumpb
2636
+ from langgraph_api.utils.stream_codec import decode_stream_message
2724
2637
 
2725
2638
  queue = stream_channel
2726
2639
  try:
2727
2640
  async with connect() as conn:
2728
- filters = await Runs.handle_event(
2729
- ctx,
2730
- "read",
2731
- Auth.types.ThreadsRead(thread_id=thread_id),
2732
- )
2733
- if filters:
2734
- thread = await Threads._get_with_filters(
2735
- cast(InMemConnectionProto, conn), thread_id, filters
2736
- )
2737
- if not thread:
2738
- raise WrappedHTTPException(
2739
- HTTPException(
2740
- status_code=404, detail="Thread not found"
2741
- )
2742
- )
2641
+ try:
2642
+ await Runs.Stream.check_run_stream_auth(run_id, thread_id, ctx)
2643
+ except HTTPException as e:
2644
+ raise WrappedHTTPException(e) from None
2743
2645
  run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
2744
2646
 
2745
2647
  for message in get_stream_manager().restore_messages(
2746
2648
  run_id, thread_id, last_event_id
2747
2649
  ):
2748
2650
  data, id = message.data, message.id
2749
-
2750
- data = json_loads(data)
2751
- mode = data["event"]
2752
- message = data["message"]
2651
+ decoded = decode_stream_message(data, channel=message.topic)
2652
+ mode = decoded.event_bytes.decode("utf-8")
2653
+ payload = decoded.message_bytes
2753
2654
 
2754
2655
  if mode == "control":
2755
- if message == b"done":
2656
+ if payload == b"done":
2756
2657
  return
2757
2658
  elif (
2758
2659
  not stream_mode
@@ -2765,7 +2666,7 @@ class Runs(Authenticated):
2765
2666
  and mode.startswith("messages")
2766
2667
  )
2767
2668
  ):
2768
- yield mode.encode(), base64.b64decode(message), id
2669
+ yield mode.encode(), payload, id
2769
2670
  logger.debug(
2770
2671
  "Replayed run event",
2771
2672
  run_id=str(run_id),
@@ -2779,13 +2680,12 @@ class Runs(Authenticated):
2779
2680
  # Wait for messages with a timeout
2780
2681
  message = await asyncio.wait_for(queue.get(), timeout=0.5)
2781
2682
  data, id = message.data, message.id
2782
-
2783
- data = json_loads(data)
2784
- mode = data["event"]
2785
- message = data["message"]
2683
+ decoded = decode_stream_message(data, channel=message.topic)
2684
+ mode = decoded.event_bytes.decode("utf-8")
2685
+ payload = decoded.message_bytes
2786
2686
 
2787
2687
  if mode == "control":
2788
- if message == b"done":
2688
+ if payload == b"done":
2789
2689
  break
2790
2690
  elif (
2791
2691
  not stream_mode
@@ -2798,13 +2698,13 @@ class Runs(Authenticated):
2798
2698
  and mode.startswith("messages")
2799
2699
  )
2800
2700
  ):
2801
- yield mode.encode(), base64.b64decode(message), id
2701
+ yield mode.encode(), payload, id
2802
2702
  logger.debug(
2803
2703
  "Streamed run event",
2804
2704
  run_id=str(run_id),
2805
2705
  stream_mode=mode,
2806
2706
  message_id=id,
2807
- data=message,
2707
+ data=payload,
2808
2708
  )
2809
2709
  except TimeoutError:
2810
2710
  # Check if the run is still pending
@@ -2818,8 +2718,10 @@ class Runs(Authenticated):
2818
2718
  elif run is None:
2819
2719
  yield (
2820
2720
  b"error",
2821
- HTTPException(
2822
- status_code=404, detail="Run not found"
2721
+ json_dumpb(
2722
+ HTTPException(
2723
+ status_code=404, detail="Run not found"
2724
+ )
2823
2725
  ),
2824
2726
  None,
2825
2727
  )
@@ -2836,6 +2738,25 @@ class Runs(Authenticated):
2836
2738
  stream_manager = get_stream_manager()
2837
2739
  await stream_manager.remove_queue(run_id, thread_id, queue)
2838
2740
 
2741
+ @staticmethod
2742
+ async def check_run_stream_auth(
2743
+ run_id: UUID,
2744
+ thread_id: UUID,
2745
+ ctx: Auth.types.BaseAuthContext | None = None,
2746
+ ) -> None:
2747
+ async with connect() as conn:
2748
+ filters = await Runs.handle_event(
2749
+ ctx,
2750
+ "read",
2751
+ Auth.types.ThreadsRead(thread_id=thread_id),
2752
+ )
2753
+ if filters:
2754
+ thread = await Threads._get_with_filters(
2755
+ cast(InMemConnectionProto, conn), thread_id, filters
2756
+ )
2757
+ if not thread:
2758
+ raise HTTPException(status_code=404, detail="Thread not found")
2759
+
2839
2760
  @staticmethod
2840
2761
  async def publish(
2841
2762
  run_id: UUID | str,
@@ -2846,18 +2767,13 @@ class Runs(Authenticated):
2846
2767
  resumable: bool = False,
2847
2768
  ) -> None:
2848
2769
  """Publish a message to all subscribers of the run stream."""
2849
- from langgraph_api.serde import json_dumpb
2770
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
2850
2771
 
2851
2772
  topic = f"run:{run_id}:stream".encode()
2852
2773
 
2853
2774
  stream_manager = get_stream_manager()
2854
- # Send to all queues subscribed to this run_id
2855
- payload = json_dumpb(
2856
- {
2857
- "event": event,
2858
- "message": message,
2859
- }
2860
- )
2775
+ # Send to all queues subscribed to this run_id using protocol frame
2776
+ payload = STREAM_CODEC.encode(event, message)
2861
2777
  await stream_manager.put(
2862
2778
  run_id, thread_id, Message(topic=topic, data=payload), resumable
2863
2779
  )
@@ -3037,6 +2953,7 @@ async def _empty_generator():
3037
2953
 
3038
2954
 
3039
2955
  __all__ = [
2956
+ "StreamHandler",
3040
2957
  "Assistants",
3041
2958
  "Crons",
3042
2959
  "Runs",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.12.1
3
+ Version: 0.14.0
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -1,13 +1,13 @@
1
- langgraph_runtime_inmem/__init__.py,sha256=tv9_9neTzC-cY8ATa9lnqpF4WJ-YZW3fYhQ9rSL4Yu8,311
1
+ langgraph_runtime_inmem/__init__.py,sha256=csu7K0Iyy69kpS21MCa9q3MkfeJLSBXmsT02eK_hGXc,311
2
2
  langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
3
3
  langgraph_runtime_inmem/database.py,sha256=QgaA_WQo1IY6QioYd8r-e6-0B0rnC5anS0muIEJWby0,6364
4
- langgraph_runtime_inmem/inmem_stream.py,sha256=utL1OlOJsy6VDkSGAA6eX9nETreZlM6K6nhfNoubmRQ,9011
4
+ langgraph_runtime_inmem/inmem_stream.py,sha256=PFLWbsxU8RqbT5mYJgNk6v5q6TWJRIY1hkZWhJF8nkI,9094
5
5
  langgraph_runtime_inmem/lifespan.py,sha256=tngIYHMhDwTFd2zgpq9CZOxcBLONYYnkhwv2d2T5WWQ,3614
6
6
  langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
7
- langgraph_runtime_inmem/ops.py,sha256=593xx2A5E7y2TY6nLpbkFSsODH6guwm1y9z-ars-seU,111327
7
+ langgraph_runtime_inmem/ops.py,sha256=63uV88PijGnNxzgWGL_SljeXIeHd8dAwowBrWi9X4Xo,107645
8
8
  langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
9
9
  langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
10
10
  langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
11
- langgraph_runtime_inmem-0.12.1.dist-info/METADATA,sha256=faLaXWGpAJnAK6Z5XS0TmGBtafg7Cef16_nc_Viw8yg,566
12
- langgraph_runtime_inmem-0.12.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- langgraph_runtime_inmem-0.12.1.dist-info/RECORD,,
11
+ langgraph_runtime_inmem-0.14.0.dist-info/METADATA,sha256=jegaYI5exlmydXtt4oxMbgFBCIrKaV7HawwHKNr2MrU,566
12
+ langgraph_runtime_inmem-0.14.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ langgraph_runtime_inmem-0.14.0.dist-info/RECORD,,