langgraph-runtime-inmem 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.9.0"
12
+ __version__ = "0.11.0"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -58,6 +58,7 @@ class StreamManager:
58
58
  ) # Dict[str, List[asyncio.Queue]]
59
59
  self.control_keys = defaultdict(lambda: defaultdict())
60
60
  self.control_queues = defaultdict(lambda: defaultdict(list))
61
+ self.thread_streams = defaultdict(list)
61
62
 
62
63
  self.message_stores = defaultdict(
63
64
  lambda: defaultdict(list[Message])
@@ -95,7 +96,7 @@ class StreamManager:
95
96
 
96
97
  async def put(
97
98
  self,
98
- run_id: UUID | str,
99
+ run_id: UUID | str | None,
99
100
  thread_id: UUID | str | None,
100
101
  message: Message,
101
102
  resumable: bool = False,
@@ -121,6 +122,20 @@ class StreamManager:
121
122
  if isinstance(result, Exception):
122
123
  logger.exception(f"Failed to put message in queue: {result}")
123
124
 
125
+ async def put_thread(
126
+ self,
127
+ thread_id: UUID | str,
128
+ message: Message,
129
+ ) -> None:
130
+ thread_id = _ensure_uuid(thread_id)
131
+ message.id = _generate_ms_seq_id().encode()
132
+ queues = self.thread_streams[thread_id]
133
+ coros = [queue.put(message) for queue in queues]
134
+ results = await asyncio.gather(*coros, return_exceptions=True)
135
+ for result in results:
136
+ if isinstance(result, Exception):
137
+ logger.exception(f"Failed to put message in queue: {result}")
138
+
124
139
  async def add_queue(
125
140
  self, run_id: UUID | str, thread_id: UUID | str | None
126
141
  ) -> asyncio.Queue:
@@ -145,6 +160,12 @@ class StreamManager:
145
160
  self.control_queues[thread_id][run_id].append(queue)
146
161
  return queue
147
162
 
163
+ async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
164
+ thread_id = _ensure_uuid(thread_id)
165
+ queue = ContextQueue()
166
+ self.thread_streams[thread_id].append(queue)
167
+ return queue
168
+
148
169
  async def remove_queue(
149
170
  self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
150
171
  ):
@@ -29,6 +29,7 @@ from langgraph_runtime_inmem.checkpoint import Checkpointer
29
29
  from langgraph_runtime_inmem.database import InMemConnectionProto, connect
30
30
  from langgraph_runtime_inmem.inmem_stream import (
31
31
  THREADLESS_KEY,
32
+ ContextQueue,
32
33
  Message,
33
34
  get_stream_manager,
34
35
  )
@@ -58,6 +59,7 @@ if typing.TYPE_CHECKING:
58
59
  Thread,
59
60
  ThreadSelectField,
60
61
  ThreadStatus,
62
+ ThreadStreamMode,
61
63
  ThreadUpdateResponse,
62
64
  )
63
65
  from langgraph_api.schema import Interrupt as InterruptSchema
@@ -738,6 +740,7 @@ class Threads(Authenticated):
738
740
  async def search(
739
741
  conn: InMemConnectionProto,
740
742
  *,
743
+ ids: list[str] | list[UUID] | None = None,
741
744
  metadata: MetadataInput,
742
745
  values: MetadataInput,
743
746
  status: ThreadStatus | None,
@@ -765,7 +768,19 @@ class Threads(Authenticated):
765
768
  )
766
769
 
767
770
  # Apply filters
771
+ id_set: set[UUID] | None = None
772
+ if ids:
773
+ id_set = set()
774
+ for i in ids:
775
+ try:
776
+ id_set.add(_ensure_uuid(i))
777
+ except Exception:
778
+ raise HTTPException(
779
+ status_code=400, detail="Invalid thread ID " + str(i)
780
+ ) from None
768
781
  for thread in threads:
782
+ if id_set is not None and thread.get("thread_id") not in id_set:
783
+ continue
769
784
  if filters and not _check_filter_match(thread["metadata"], filters):
770
785
  continue
771
786
 
@@ -943,6 +958,7 @@ class Threads(Authenticated):
943
958
  thread_id: UUID,
944
959
  *,
945
960
  metadata: MetadataValue,
961
+ ttl: ThreadTTLConfig | None = None,
946
962
  ctx: Auth.types.BaseAuthContext | None = None,
947
963
  ) -> AsyncIterator[Thread]:
948
964
  """Update a thread."""
@@ -1327,7 +1343,14 @@ class Threads(Authenticated):
1327
1343
  )
1328
1344
 
1329
1345
  metadata = thread.get("metadata", {})
1330
- thread_config = thread.get("config", {})
1346
+ thread_config = cast(dict[str, Any], thread.get("config", {}))
1347
+ thread_config = {
1348
+ **thread_config,
1349
+ "configurable": {
1350
+ **thread_config.get("configurable", {}),
1351
+ **config.get("configurable", {}),
1352
+ },
1353
+ }
1331
1354
 
1332
1355
  # Fallback to graph_id from run if not in thread metadata
1333
1356
  graph_id = metadata.get("graph_id")
@@ -1414,6 +1437,13 @@ class Threads(Authenticated):
1414
1437
  status_code=409,
1415
1438
  detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
1416
1439
  )
1440
+ thread_config = {
1441
+ **thread_config,
1442
+ "configurable": {
1443
+ **thread_config.get("configurable", {}),
1444
+ **config.get("configurable", {}),
1445
+ },
1446
+ }
1417
1447
 
1418
1448
  # Fallback to graph_id from run if not in thread metadata
1419
1449
  graph_id = metadata.get("graph_id")
@@ -1454,6 +1484,19 @@ class Threads(Authenticated):
1454
1484
  thread["values"] = state.values
1455
1485
  break
1456
1486
 
1487
+ # Publish state update event
1488
+ from langgraph_api.serde import json_dumpb
1489
+
1490
+ event_data = {
1491
+ "state": state,
1492
+ "thread_id": str(thread_id),
1493
+ }
1494
+ await Threads.Stream.publish(
1495
+ thread_id,
1496
+ "state_update",
1497
+ json_dumpb(event_data),
1498
+ )
1499
+
1457
1500
  return ThreadUpdateResponse(
1458
1501
  checkpoint=next_config["configurable"],
1459
1502
  # Including deprecated fields
@@ -1496,7 +1539,14 @@ class Threads(Authenticated):
1496
1539
  thread_iter, not_found_detail=f"Thread {thread_id} not found."
1497
1540
  )
1498
1541
 
1499
- thread_config = thread["config"]
1542
+ thread_config = cast(dict[str, Any], thread["config"])
1543
+ thread_config = {
1544
+ **thread_config,
1545
+ "configurable": {
1546
+ **thread_config.get("configurable", {}),
1547
+ **config.get("configurable", {}),
1548
+ },
1549
+ }
1500
1550
  metadata = thread["metadata"]
1501
1551
 
1502
1552
  if not thread:
@@ -1543,6 +1593,19 @@ class Threads(Authenticated):
1543
1593
  thread["values"] = state.values
1544
1594
  break
1545
1595
 
1596
+ # Publish state update event
1597
+ from langgraph_api.serde import json_dumpb
1598
+
1599
+ event_data = {
1600
+ "state": state,
1601
+ "thread_id": str(thread_id),
1602
+ }
1603
+ await Threads.Stream.publish(
1604
+ thread_id,
1605
+ "state_update",
1606
+ json_dumpb(event_data),
1607
+ )
1608
+
1546
1609
  return ThreadUpdateResponse(
1547
1610
  checkpoint=next_config["configurable"],
1548
1611
  )
@@ -1584,7 +1647,14 @@ class Threads(Authenticated):
1584
1647
  if not _check_filter_match(thread_metadata, filters):
1585
1648
  return []
1586
1649
 
1587
- thread_config = thread["config"]
1650
+ thread_config = cast(dict[str, Any], thread["config"])
1651
+ thread_config = {
1652
+ **thread_config,
1653
+ "configurable": {
1654
+ **thread_config.get("configurable", {}),
1655
+ **config.get("configurable", {}),
1656
+ },
1657
+ }
1588
1658
  # If graph_id exists, get state history
1589
1659
  if graph_id := thread_metadata.get("graph_id"):
1590
1660
  async with get_graph(
@@ -1626,6 +1696,13 @@ class Threads(Authenticated):
1626
1696
 
1627
1697
  # Create new queues only for runs not yet seen
1628
1698
  thread_id = _ensure_uuid(thread_id)
1699
+
1700
+ # Add thread stream queue
1701
+ if thread_id not in seen_runs:
1702
+ queue = await stream_manager.add_thread_stream(thread_id)
1703
+ queues.append((thread_id, queue))
1704
+ seen_runs.add(thread_id)
1705
+
1629
1706
  for run in conn.store["runs"]:
1630
1707
  if run["thread_id"] == thread_id:
1631
1708
  run_id = run["run_id"]
@@ -1641,8 +1718,27 @@ class Threads(Authenticated):
1641
1718
  thread_id: UUID,
1642
1719
  *,
1643
1720
  last_event_id: str | None = None,
1721
+ stream_modes: list[ThreadStreamMode],
1644
1722
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
1645
1723
  """Stream the thread output."""
1724
+
1725
+ def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
1726
+ """Check if an event should be filtered out based on stream_modes."""
1727
+ if "run_modes" in stream_modes and event_name != "state_update":
1728
+ return False
1729
+ if "state_update" in stream_modes and event_name == "state_update":
1730
+ return False
1731
+ if "lifecycle" in stream_modes and event_name == "metadata":
1732
+ try:
1733
+ message_data = orjson.loads(message_bytes)
1734
+ if message_data.get("status") == "run_done":
1735
+ return False
1736
+ if "attempt" in message_data and "run_id" in message_data:
1737
+ return False
1738
+ except (orjson.JSONDecodeError, TypeError):
1739
+ pass
1740
+ return True
1741
+
1646
1742
  from langgraph_api.serde import json_loads
1647
1743
 
1648
1744
  stream_manager = get_stream_manager()
@@ -1679,19 +1775,29 @@ class Threads(Authenticated):
1679
1775
 
1680
1776
  if event_name == "control":
1681
1777
  if message_content == b"done":
1778
+ event_bytes = b"metadata"
1779
+ message_bytes = orjson.dumps(
1780
+ {"status": "run_done", "run_id": run_id}
1781
+ )
1782
+ # Filter events based on stream_modes
1783
+ if not should_filter_event(
1784
+ "metadata", message_bytes
1785
+ ):
1786
+ yield (
1787
+ event_bytes,
1788
+ message_bytes,
1789
+ message.id,
1790
+ )
1791
+ else:
1792
+ event_bytes = event_name.encode()
1793
+ message_bytes = base64.b64decode(message_content)
1794
+ # Filter events based on stream_modes
1795
+ if not should_filter_event(event_name, message_bytes):
1682
1796
  yield (
1683
- b"metadata",
1684
- orjson.dumps(
1685
- {"status": "run_done", "run_id": run_id}
1686
- ),
1797
+ event_bytes,
1798
+ message_bytes,
1687
1799
  message.id,
1688
1800
  )
1689
- else:
1690
- yield (
1691
- event_name.encode(),
1692
- base64.b64decode(message_content),
1693
- message.id,
1694
- )
1695
1801
 
1696
1802
  # Listen for live messages from all queues
1697
1803
  while True:
@@ -1717,19 +1823,31 @@ class Threads(Authenticated):
1717
1823
  # Extract run_id from topic
1718
1824
  topic = message.topic.decode()
1719
1825
  run_id = topic.split("run:")[1].split(":")[0]
1826
+ event_bytes = b"metadata"
1827
+ message_bytes = orjson.dumps(
1828
+ {"status": "run_done", "run_id": run_id}
1829
+ )
1830
+ # Filter events based on stream_modes
1831
+ if not should_filter_event(
1832
+ "metadata", message_bytes
1833
+ ):
1834
+ yield (
1835
+ event_bytes,
1836
+ message_bytes,
1837
+ message.id,
1838
+ )
1839
+ else:
1840
+ event_bytes = event_name.encode()
1841
+ message_bytes = base64.b64decode(message_content)
1842
+ # Filter events based on stream_modes
1843
+ if not should_filter_event(
1844
+ event_name, message_bytes
1845
+ ):
1720
1846
  yield (
1721
- b"metadata",
1722
- orjson.dumps(
1723
- {"status": "run_done", "run_id": run_id}
1724
- ),
1847
+ event_bytes,
1848
+ message_bytes,
1725
1849
  message.id,
1726
1850
  )
1727
- else:
1728
- yield (
1729
- event_name.encode(),
1730
- base64.b64decode(message_content),
1731
- message.id,
1732
- )
1733
1851
 
1734
1852
  except TimeoutError:
1735
1853
  continue
@@ -1758,6 +1876,29 @@ class Threads(Authenticated):
1758
1876
  # Ignore cleanup errors
1759
1877
  pass
1760
1878
 
1879
+ @staticmethod
1880
+ async def publish(
1881
+ thread_id: UUID | str,
1882
+ event: str,
1883
+ message: bytes,
1884
+ ) -> None:
1885
+ """Publish a thread-level event to the thread stream."""
1886
+ from langgraph_api.serde import json_dumpb
1887
+
1888
+ topic = f"thread:{thread_id}:stream".encode()
1889
+
1890
+ stream_manager = get_stream_manager()
1891
+ # Send to thread stream topic
1892
+ payload = json_dumpb(
1893
+ {
1894
+ "event": event,
1895
+ "message": message,
1896
+ }
1897
+ )
1898
+ await stream_manager.put_thread(
1899
+ str(thread_id), Message(topic=topic, data=payload)
1900
+ )
1901
+
1761
1902
  @staticmethod
1762
1903
  async def count(
1763
1904
  conn: InMemConnectionProto,
@@ -2004,6 +2145,7 @@ class Runs(Authenticated):
2004
2145
  run_id = _ensure_uuid(run_id) if run_id else None
2005
2146
  metadata = metadata if metadata is not None else {}
2006
2147
  config = kwargs.get("config", {})
2148
+ temporary = kwargs.get("temporary", False)
2007
2149
 
2008
2150
  # Handle thread creation/update
2009
2151
  existing_thread = next(
@@ -2013,7 +2155,7 @@ class Runs(Authenticated):
2013
2155
  ctx,
2014
2156
  "create_run",
2015
2157
  Auth.types.RunsCreate(
2016
- thread_id=thread_id,
2158
+ thread_id=None if temporary else thread_id,
2017
2159
  assistant_id=assistant_id,
2018
2160
  run_id=run_id,
2019
2161
  status=status,
@@ -2538,7 +2680,7 @@ class Runs(Authenticated):
2538
2680
  async def subscribe(
2539
2681
  run_id: UUID,
2540
2682
  thread_id: UUID | None = None,
2541
- ) -> asyncio.Queue:
2683
+ ) -> ContextQueue:
2542
2684
  """Subscribe to the run stream, returning a queue."""
2543
2685
  stream_manager = get_stream_manager()
2544
2686
  queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.9.0
3
+ Version: 0.11.0
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -1,13 +1,13 @@
1
- langgraph_runtime_inmem/__init__.py,sha256=f-VPPHH1-hKFwEreffg7dNATe9IdcYwQedcSx2MiZog,310
1
+ langgraph_runtime_inmem/__init__.py,sha256=fsghjd8RgYa23k1SGZS_DHaq_7X0NDBhUIAT0ud9uy4,311
2
2
  langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
3
3
  langgraph_runtime_inmem/database.py,sha256=QgaA_WQo1IY6QioYd8r-e6-0B0rnC5anS0muIEJWby0,6364
4
- langgraph_runtime_inmem/inmem_stream.py,sha256=pUEiHW-1uXQrVTcwEYPwO8YXaYm5qZbpRWawt67y6Lw,8187
4
+ langgraph_runtime_inmem/inmem_stream.py,sha256=utL1OlOJsy6VDkSGAA6eX9nETreZlM6K6nhfNoubmRQ,9011
5
5
  langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
6
6
  langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
7
- langgraph_runtime_inmem/ops.py,sha256=0Jx65S3PCvvHlIpA0XYpl-UnDEo_AiGWXRE2QiFSocY,105165
7
+ langgraph_runtime_inmem/ops.py,sha256=4SBZ3LVQgKB53WeZMien4jb5WzNkMfZ48C4aS8rJX8k,111227
8
8
  langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
9
9
  langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
10
10
  langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
11
- langgraph_runtime_inmem-0.9.0.dist-info/METADATA,sha256=ptwW1Ei-Xln53P81eJK1aPcFozU8D192OCZBuC_y5EQ,565
12
- langgraph_runtime_inmem-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- langgraph_runtime_inmem-0.9.0.dist-info/RECORD,,
11
+ langgraph_runtime_inmem-0.11.0.dist-info/METADATA,sha256=xtMG9LLstWO11EQLaOQIhqSvBoA-Nprs2m9upNA7TbE,566
12
+ langgraph_runtime_inmem-0.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ langgraph_runtime_inmem-0.11.0.dist-info/RECORD,,