langgraph-runtime-inmem 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.9.0"
12
+ __version__ = "0.10.0"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -58,6 +58,7 @@ class StreamManager:
58
58
  ) # Dict[str, List[asyncio.Queue]]
59
59
  self.control_keys = defaultdict(lambda: defaultdict())
60
60
  self.control_queues = defaultdict(lambda: defaultdict(list))
61
+ self.thread_streams = defaultdict(list)
61
62
 
62
63
  self.message_stores = defaultdict(
63
64
  lambda: defaultdict(list[Message])
@@ -95,7 +96,7 @@ class StreamManager:
95
96
 
96
97
  async def put(
97
98
  self,
98
- run_id: UUID | str,
99
+ run_id: UUID | str | None,
99
100
  thread_id: UUID | str | None,
100
101
  message: Message,
101
102
  resumable: bool = False,
@@ -121,6 +122,20 @@ class StreamManager:
121
122
  if isinstance(result, Exception):
122
123
  logger.exception(f"Failed to put message in queue: {result}")
123
124
 
125
+ async def put_thread(
126
+ self,
127
+ thread_id: UUID | str,
128
+ message: Message,
129
+ ) -> None:
130
+ thread_id = _ensure_uuid(thread_id)
131
+ message.id = _generate_ms_seq_id().encode()
132
+ queues = self.thread_streams[thread_id]
133
+ coros = [queue.put(message) for queue in queues]
134
+ results = await asyncio.gather(*coros, return_exceptions=True)
135
+ for result in results:
136
+ if isinstance(result, Exception):
137
+ logger.exception(f"Failed to put message in queue: {result}")
138
+
124
139
  async def add_queue(
125
140
  self, run_id: UUID | str, thread_id: UUID | str | None
126
141
  ) -> asyncio.Queue:
@@ -145,6 +160,12 @@ class StreamManager:
145
160
  self.control_queues[thread_id][run_id].append(queue)
146
161
  return queue
147
162
 
163
+ async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
164
+ thread_id = _ensure_uuid(thread_id)
165
+ queue = ContextQueue()
166
+ self.thread_streams[thread_id].append(queue)
167
+ return queue
168
+
148
169
  async def remove_queue(
149
170
  self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
150
171
  ):
@@ -29,6 +29,7 @@ from langgraph_runtime_inmem.checkpoint import Checkpointer
29
29
  from langgraph_runtime_inmem.database import InMemConnectionProto, connect
30
30
  from langgraph_runtime_inmem.inmem_stream import (
31
31
  THREADLESS_KEY,
32
+ ContextQueue,
32
33
  Message,
33
34
  get_stream_manager,
34
35
  )
@@ -58,6 +59,7 @@ if typing.TYPE_CHECKING:
58
59
  Thread,
59
60
  ThreadSelectField,
60
61
  ThreadStatus,
62
+ ThreadStreamMode,
61
63
  ThreadUpdateResponse,
62
64
  )
63
65
  from langgraph_api.schema import Interrupt as InterruptSchema
@@ -738,6 +740,7 @@ class Threads(Authenticated):
738
740
  async def search(
739
741
  conn: InMemConnectionProto,
740
742
  *,
743
+ ids: list[str] | list[UUID] | None = None,
741
744
  metadata: MetadataInput,
742
745
  values: MetadataInput,
743
746
  status: ThreadStatus | None,
@@ -765,7 +768,19 @@ class Threads(Authenticated):
765
768
  )
766
769
 
767
770
  # Apply filters
771
+ id_set: set[UUID] | None = None
772
+ if ids:
773
+ id_set = set()
774
+ for i in ids:
775
+ try:
776
+ id_set.add(_ensure_uuid(i))
777
+ except Exception:
778
+ raise HTTPException(
779
+ status_code=400, detail="Invalid thread ID " + str(i)
780
+ ) from None
768
781
  for thread in threads:
782
+ if id_set is not None and thread.get("thread_id") not in id_set:
783
+ continue
769
784
  if filters and not _check_filter_match(thread["metadata"], filters):
770
785
  continue
771
786
 
@@ -1327,7 +1342,14 @@ class Threads(Authenticated):
1327
1342
  )
1328
1343
 
1329
1344
  metadata = thread.get("metadata", {})
1330
- thread_config = thread.get("config", {})
1345
+ thread_config = cast(dict[str, Any], thread.get("config", {}))
1346
+ thread_config = {
1347
+ **thread_config,
1348
+ "configurable": {
1349
+ **thread_config.get("configurable", {}),
1350
+ **config.get("configurable", {}),
1351
+ },
1352
+ }
1331
1353
 
1332
1354
  # Fallback to graph_id from run if not in thread metadata
1333
1355
  graph_id = metadata.get("graph_id")
@@ -1414,6 +1436,13 @@ class Threads(Authenticated):
1414
1436
  status_code=409,
1415
1437
  detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
1416
1438
  )
1439
+ thread_config = {
1440
+ **thread_config,
1441
+ "configurable": {
1442
+ **thread_config.get("configurable", {}),
1443
+ **config.get("configurable", {}),
1444
+ },
1445
+ }
1417
1446
 
1418
1447
  # Fallback to graph_id from run if not in thread metadata
1419
1448
  graph_id = metadata.get("graph_id")
@@ -1454,6 +1483,19 @@ class Threads(Authenticated):
1454
1483
  thread["values"] = state.values
1455
1484
  break
1456
1485
 
1486
+ # Publish state update event
1487
+ from langgraph_api.serde import json_dumpb
1488
+
1489
+ event_data = {
1490
+ "state": state,
1491
+ "thread_id": str(thread_id),
1492
+ }
1493
+ await Threads.Stream.publish(
1494
+ thread_id,
1495
+ "state_update",
1496
+ json_dumpb(event_data),
1497
+ )
1498
+
1457
1499
  return ThreadUpdateResponse(
1458
1500
  checkpoint=next_config["configurable"],
1459
1501
  # Including deprecated fields
@@ -1496,7 +1538,14 @@ class Threads(Authenticated):
1496
1538
  thread_iter, not_found_detail=f"Thread {thread_id} not found."
1497
1539
  )
1498
1540
 
1499
- thread_config = thread["config"]
1541
+ thread_config = cast(dict[str, Any], thread["config"])
1542
+ thread_config = {
1543
+ **thread_config,
1544
+ "configurable": {
1545
+ **thread_config.get("configurable", {}),
1546
+ **config.get("configurable", {}),
1547
+ },
1548
+ }
1500
1549
  metadata = thread["metadata"]
1501
1550
 
1502
1551
  if not thread:
@@ -1543,6 +1592,19 @@ class Threads(Authenticated):
1543
1592
  thread["values"] = state.values
1544
1593
  break
1545
1594
 
1595
+ # Publish state update event
1596
+ from langgraph_api.serde import json_dumpb
1597
+
1598
+ event_data = {
1599
+ "state": state,
1600
+ "thread_id": str(thread_id),
1601
+ }
1602
+ await Threads.Stream.publish(
1603
+ thread_id,
1604
+ "state_update",
1605
+ json_dumpb(event_data),
1606
+ )
1607
+
1546
1608
  return ThreadUpdateResponse(
1547
1609
  checkpoint=next_config["configurable"],
1548
1610
  )
@@ -1584,7 +1646,14 @@ class Threads(Authenticated):
1584
1646
  if not _check_filter_match(thread_metadata, filters):
1585
1647
  return []
1586
1648
 
1587
- thread_config = thread["config"]
1649
+ thread_config = cast(dict[str, Any], thread["config"])
1650
+ thread_config = {
1651
+ **thread_config,
1652
+ "configurable": {
1653
+ **thread_config.get("configurable", {}),
1654
+ **config.get("configurable", {}),
1655
+ },
1656
+ }
1588
1657
  # If graph_id exists, get state history
1589
1658
  if graph_id := thread_metadata.get("graph_id"):
1590
1659
  async with get_graph(
@@ -1626,6 +1695,13 @@ class Threads(Authenticated):
1626
1695
 
1627
1696
  # Create new queues only for runs not yet seen
1628
1697
  thread_id = _ensure_uuid(thread_id)
1698
+
1699
+ # Add thread stream queue
1700
+ if thread_id not in seen_runs:
1701
+ queue = await stream_manager.add_thread_stream(thread_id)
1702
+ queues.append((thread_id, queue))
1703
+ seen_runs.add(thread_id)
1704
+
1629
1705
  for run in conn.store["runs"]:
1630
1706
  if run["thread_id"] == thread_id:
1631
1707
  run_id = run["run_id"]
@@ -1641,8 +1717,27 @@ class Threads(Authenticated):
1641
1717
  thread_id: UUID,
1642
1718
  *,
1643
1719
  last_event_id: str | None = None,
1720
+ stream_modes: list[ThreadStreamMode],
1644
1721
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
1645
1722
  """Stream the thread output."""
1723
+
1724
+ def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
1725
+ """Check if an event should be filtered out based on stream_modes."""
1726
+ if "run_modes" in stream_modes and event_name != "state_update":
1727
+ return False
1728
+ if "state_update" in stream_modes and event_name == "state_update":
1729
+ return False
1730
+ if "lifecycle" in stream_modes and event_name == "metadata":
1731
+ try:
1732
+ message_data = orjson.loads(message_bytes)
1733
+ if message_data.get("status") == "run_done":
1734
+ return False
1735
+ if "attempt" in message_data and "run_id" in message_data:
1736
+ return False
1737
+ except (orjson.JSONDecodeError, TypeError):
1738
+ pass
1739
+ return True
1740
+
1646
1741
  from langgraph_api.serde import json_loads
1647
1742
 
1648
1743
  stream_manager = get_stream_manager()
@@ -1679,19 +1774,29 @@ class Threads(Authenticated):
1679
1774
 
1680
1775
  if event_name == "control":
1681
1776
  if message_content == b"done":
1777
+ event_bytes = b"metadata"
1778
+ message_bytes = orjson.dumps(
1779
+ {"status": "run_done", "run_id": run_id}
1780
+ )
1781
+ # Filter events based on stream_modes
1782
+ if not should_filter_event(
1783
+ "metadata", message_bytes
1784
+ ):
1785
+ yield (
1786
+ event_bytes,
1787
+ message_bytes,
1788
+ message.id,
1789
+ )
1790
+ else:
1791
+ event_bytes = event_name.encode()
1792
+ message_bytes = base64.b64decode(message_content)
1793
+ # Filter events based on stream_modes
1794
+ if not should_filter_event(event_name, message_bytes):
1682
1795
  yield (
1683
- b"metadata",
1684
- orjson.dumps(
1685
- {"status": "run_done", "run_id": run_id}
1686
- ),
1796
+ event_bytes,
1797
+ message_bytes,
1687
1798
  message.id,
1688
1799
  )
1689
- else:
1690
- yield (
1691
- event_name.encode(),
1692
- base64.b64decode(message_content),
1693
- message.id,
1694
- )
1695
1800
 
1696
1801
  # Listen for live messages from all queues
1697
1802
  while True:
@@ -1717,19 +1822,31 @@ class Threads(Authenticated):
1717
1822
  # Extract run_id from topic
1718
1823
  topic = message.topic.decode()
1719
1824
  run_id = topic.split("run:")[1].split(":")[0]
1825
+ event_bytes = b"metadata"
1826
+ message_bytes = orjson.dumps(
1827
+ {"status": "run_done", "run_id": run_id}
1828
+ )
1829
+ # Filter events based on stream_modes
1830
+ if not should_filter_event(
1831
+ "metadata", message_bytes
1832
+ ):
1833
+ yield (
1834
+ event_bytes,
1835
+ message_bytes,
1836
+ message.id,
1837
+ )
1838
+ else:
1839
+ event_bytes = event_name.encode()
1840
+ message_bytes = base64.b64decode(message_content)
1841
+ # Filter events based on stream_modes
1842
+ if not should_filter_event(
1843
+ event_name, message_bytes
1844
+ ):
1720
1845
  yield (
1721
- b"metadata",
1722
- orjson.dumps(
1723
- {"status": "run_done", "run_id": run_id}
1724
- ),
1846
+ event_bytes,
1847
+ message_bytes,
1725
1848
  message.id,
1726
1849
  )
1727
- else:
1728
- yield (
1729
- event_name.encode(),
1730
- base64.b64decode(message_content),
1731
- message.id,
1732
- )
1733
1850
 
1734
1851
  except TimeoutError:
1735
1852
  continue
@@ -1758,6 +1875,29 @@ class Threads(Authenticated):
1758
1875
  # Ignore cleanup errors
1759
1876
  pass
1760
1877
 
1878
+ @staticmethod
1879
+ async def publish(
1880
+ thread_id: UUID | str,
1881
+ event: str,
1882
+ message: bytes,
1883
+ ) -> None:
1884
+ """Publish a thread-level event to the thread stream."""
1885
+ from langgraph_api.serde import json_dumpb
1886
+
1887
+ topic = f"thread:{thread_id}:stream".encode()
1888
+
1889
+ stream_manager = get_stream_manager()
1890
+ # Send to thread stream topic
1891
+ payload = json_dumpb(
1892
+ {
1893
+ "event": event,
1894
+ "message": message,
1895
+ }
1896
+ )
1897
+ await stream_manager.put_thread(
1898
+ str(thread_id), Message(topic=topic, data=payload)
1899
+ )
1900
+
1761
1901
  @staticmethod
1762
1902
  async def count(
1763
1903
  conn: InMemConnectionProto,
@@ -2004,6 +2144,7 @@ class Runs(Authenticated):
2004
2144
  run_id = _ensure_uuid(run_id) if run_id else None
2005
2145
  metadata = metadata if metadata is not None else {}
2006
2146
  config = kwargs.get("config", {})
2147
+ temporary = kwargs.get("temporary", False)
2007
2148
 
2008
2149
  # Handle thread creation/update
2009
2150
  existing_thread = next(
@@ -2013,7 +2154,7 @@ class Runs(Authenticated):
2013
2154
  ctx,
2014
2155
  "create_run",
2015
2156
  Auth.types.RunsCreate(
2016
- thread_id=thread_id,
2157
+ thread_id=None if temporary else thread_id,
2017
2158
  assistant_id=assistant_id,
2018
2159
  run_id=run_id,
2019
2160
  status=status,
@@ -2538,7 +2679,7 @@ class Runs(Authenticated):
2538
2679
  async def subscribe(
2539
2680
  run_id: UUID,
2540
2681
  thread_id: UUID | None = None,
2541
- ) -> asyncio.Queue:
2682
+ ) -> ContextQueue:
2542
2683
  """Subscribe to the run stream, returning a queue."""
2543
2684
  stream_manager = get_stream_manager()
2544
2685
  queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.9.0
3
+ Version: 0.10.0
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -1,13 +1,13 @@
1
- langgraph_runtime_inmem/__init__.py,sha256=f-VPPHH1-hKFwEreffg7dNATe9IdcYwQedcSx2MiZog,310
1
+ langgraph_runtime_inmem/__init__.py,sha256=4xhdO3o6RduCHDXSNh42I51Wwq7Kcnt3JK1U1IhP-BU,311
2
2
  langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
3
3
  langgraph_runtime_inmem/database.py,sha256=QgaA_WQo1IY6QioYd8r-e6-0B0rnC5anS0muIEJWby0,6364
4
- langgraph_runtime_inmem/inmem_stream.py,sha256=pUEiHW-1uXQrVTcwEYPwO8YXaYm5qZbpRWawt67y6Lw,8187
4
+ langgraph_runtime_inmem/inmem_stream.py,sha256=utL1OlOJsy6VDkSGAA6eX9nETreZlM6K6nhfNoubmRQ,9011
5
5
  langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
6
6
  langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
7
- langgraph_runtime_inmem/ops.py,sha256=0Jx65S3PCvvHlIpA0XYpl-UnDEo_AiGWXRE2QiFSocY,105165
7
+ langgraph_runtime_inmem/ops.py,sha256=54jiyWhfbSu9z9pca6AQdNuaIBmD0WMrQ7xGQcLPDF4,111183
8
8
  langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
9
9
  langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
10
10
  langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
11
- langgraph_runtime_inmem-0.9.0.dist-info/METADATA,sha256=ptwW1Ei-Xln53P81eJK1aPcFozU8D192OCZBuC_y5EQ,565
12
- langgraph_runtime_inmem-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- langgraph_runtime_inmem-0.9.0.dist-info/RECORD,,
11
+ langgraph_runtime_inmem-0.10.0.dist-info/METADATA,sha256=gdjdQjZF2KjDtwA9rDiW53pG4FYNfv8TkT1U8t2lftQ,566
12
+ langgraph_runtime_inmem-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ langgraph_runtime_inmem-0.10.0.dist-info/RECORD,,