langgraph-runtime-inmem 0.8.1__tar.gz → 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/PKG-INFO +1 -1
  2. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/__init__.py +1 -1
  3. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/database.py +1 -1
  4. langgraph_runtime_inmem-0.9.0/langgraph_runtime_inmem/inmem_stream.py +247 -0
  5. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/ops.py +191 -28
  6. langgraph_runtime_inmem-0.8.1/langgraph_runtime_inmem/inmem_stream.py +0 -159
  7. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/.gitignore +0 -0
  8. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/Makefile +0 -0
  9. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/README.md +0 -0
  10. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/checkpoint.py +0 -0
  11. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/lifespan.py +0 -0
  12. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/metrics.py +0 -0
  13. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/queue.py +0 -0
  14. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/retry.py +0 -0
  15. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/store.py +0 -0
  16. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/pyproject.toml +0 -0
  17. {langgraph_runtime_inmem-0.8.1 → langgraph_runtime_inmem-0.9.0}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.8.1
3
+ Version: 0.9.0
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.8.1"
12
+ __version__ = "0.9.0"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -208,6 +208,6 @@ async def healthcheck() -> None:
208
208
  pass
209
209
 
210
210
 
211
- def pool_stats() -> dict[str, dict[str, int]]:
211
+ def pool_stats(*args, **kwargs) -> dict[str, dict[str, int]]:
212
212
  # TODO??
213
213
  return {}
@@ -0,0 +1,247 @@
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ from collections import defaultdict
5
+ from collections.abc import Iterator
6
+ from dataclasses import dataclass
7
+ from uuid import UUID
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def _ensure_uuid(id: str | UUID) -> UUID:
13
+ return UUID(id) if isinstance(id, str) else id
14
+
15
+
16
+ def _generate_ms_seq_id() -> str:
17
+ """Generate a Redis-like millisecond-sequence ID (e.g., '1234567890123-0')"""
18
+ # Get current time in milliseconds
19
+ ms = int(time.time() * 1000)
20
+ # For simplicity, always use sequence 0 since we're not handling high throughput
21
+ return f"{ms}-0"
22
+
23
+
24
+ @dataclass
25
+ class Message:
26
+ topic: bytes
27
+ data: bytes
28
+ id: bytes | None = None
29
+
30
+
31
+ class ContextQueue(asyncio.Queue):
32
+ """Queue that supports async context manager protocol"""
33
+
34
+ async def __aenter__(self):
35
+ return self
36
+
37
+ async def __aexit__(
38
+ self,
39
+ exc_type: type[BaseException] | None,
40
+ exc_val: BaseException | None,
41
+ exc_tb: object | None,
42
+ ) -> None:
43
+ # Clear the queue
44
+ while not self.empty():
45
+ try:
46
+ self.get_nowait()
47
+ except asyncio.QueueEmpty:
48
+ break
49
+
50
+
51
+ THREADLESS_KEY = "no-thread"
52
+
53
+
54
+ class StreamManager:
55
+ def __init__(self):
56
+ self.queues = defaultdict(
57
+ lambda: defaultdict(list)
58
+ ) # Dict[str, List[asyncio.Queue]]
59
+ self.control_keys = defaultdict(lambda: defaultdict())
60
+ self.control_queues = defaultdict(lambda: defaultdict(list))
61
+
62
+ self.message_stores = defaultdict(
63
+ lambda: defaultdict(list[Message])
64
+ ) # Dict[str, List[Message]]
65
+
66
+ def get_queues(
67
+ self, run_id: UUID | str, thread_id: UUID | str | None
68
+ ) -> list[asyncio.Queue]:
69
+ run_id = _ensure_uuid(run_id)
70
+ if thread_id is None:
71
+ thread_id = THREADLESS_KEY
72
+ else:
73
+ thread_id = _ensure_uuid(thread_id)
74
+ return self.queues[thread_id][run_id]
75
+
76
+ def get_control_queues(
77
+ self, run_id: UUID | str, thread_id: UUID | str | None
78
+ ) -> list[asyncio.Queue]:
79
+ run_id = _ensure_uuid(run_id)
80
+ if thread_id is None:
81
+ thread_id = THREADLESS_KEY
82
+ else:
83
+ thread_id = _ensure_uuid(thread_id)
84
+ return self.control_queues[thread_id][run_id]
85
+
86
+ def get_control_key(
87
+ self, run_id: UUID | str, thread_id: UUID | str | None
88
+ ) -> Message | None:
89
+ run_id = _ensure_uuid(run_id)
90
+ if thread_id is None:
91
+ thread_id = THREADLESS_KEY
92
+ else:
93
+ thread_id = _ensure_uuid(thread_id)
94
+ return self.control_keys.get(thread_id, {}).get(run_id)
95
+
96
+ async def put(
97
+ self,
98
+ run_id: UUID | str,
99
+ thread_id: UUID | str | None,
100
+ message: Message,
101
+ resumable: bool = False,
102
+ ) -> None:
103
+ run_id = _ensure_uuid(run_id)
104
+ if thread_id is None:
105
+ thread_id = THREADLESS_KEY
106
+ else:
107
+ thread_id = _ensure_uuid(thread_id)
108
+
109
+ message.id = _generate_ms_seq_id().encode()
110
+ if resumable:
111
+ self.message_stores[thread_id][run_id].append(message)
112
+ topic = message.topic.decode()
113
+ if "control" in topic:
114
+ self.control_keys[thread_id][run_id] = message
115
+ queues = self.control_queues[thread_id][run_id]
116
+ else:
117
+ queues = self.queues[thread_id][run_id]
118
+ coros = [queue.put(message) for queue in queues]
119
+ results = await asyncio.gather(*coros, return_exceptions=True)
120
+ for result in results:
121
+ if isinstance(result, Exception):
122
+ logger.exception(f"Failed to put message in queue: {result}")
123
+
124
+ async def add_queue(
125
+ self, run_id: UUID | str, thread_id: UUID | str | None
126
+ ) -> asyncio.Queue:
127
+ run_id = _ensure_uuid(run_id)
128
+ queue = ContextQueue()
129
+ if thread_id is None:
130
+ thread_id = THREADLESS_KEY
131
+ else:
132
+ thread_id = _ensure_uuid(thread_id)
133
+ self.queues[thread_id][run_id].append(queue)
134
+ return queue
135
+
136
+ async def add_control_queue(
137
+ self, run_id: UUID | str, thread_id: UUID | str | None
138
+ ) -> asyncio.Queue:
139
+ run_id = _ensure_uuid(run_id)
140
+ if thread_id is None:
141
+ thread_id = THREADLESS_KEY
142
+ else:
143
+ thread_id = _ensure_uuid(thread_id)
144
+ queue = ContextQueue()
145
+ self.control_queues[thread_id][run_id].append(queue)
146
+ return queue
147
+
148
+ async def remove_queue(
149
+ self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
150
+ ):
151
+ run_id = _ensure_uuid(run_id)
152
+ if thread_id is None:
153
+ thread_id = THREADLESS_KEY
154
+ else:
155
+ thread_id = _ensure_uuid(thread_id)
156
+ if thread_id in self.queues and run_id in self.queues[thread_id]:
157
+ self.queues[thread_id][run_id].remove(queue)
158
+ if not self.queues[thread_id][run_id]:
159
+ del self.queues[thread_id][run_id]
160
+
161
+ async def remove_control_queue(
162
+ self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
163
+ ):
164
+ run_id = _ensure_uuid(run_id)
165
+ if thread_id is None:
166
+ thread_id = THREADLESS_KEY
167
+ else:
168
+ thread_id = _ensure_uuid(thread_id)
169
+ if (
170
+ thread_id in self.control_queues
171
+ and run_id in self.control_queues[thread_id]
172
+ ):
173
+ self.control_queues[thread_id][run_id].remove(queue)
174
+ if not self.control_queues[thread_id][run_id]:
175
+ del self.control_queues[thread_id][run_id]
176
+
177
+ def restore_messages(
178
+ self, run_id: UUID | str, thread_id: UUID | str | None, message_id: str | None
179
+ ) -> Iterator[Message]:
180
+ """Get a stored message by ID for resumable streams."""
181
+ run_id = _ensure_uuid(run_id)
182
+ if thread_id is None:
183
+ thread_id = THREADLESS_KEY
184
+ else:
185
+ thread_id = _ensure_uuid(thread_id)
186
+ if message_id is None:
187
+ return
188
+ try:
189
+ # Handle ms-seq format (e.g., "1234567890123-0")
190
+ if thread_id in self.message_stores:
191
+ for message in self.message_stores[thread_id][run_id]:
192
+ if message.id.decode() > message_id:
193
+ yield message
194
+ except TypeError:
195
+ # Try integer format if ms-seq fails
196
+ message_idx = int(message_id) + 1
197
+ if run_id in self.message_stores:
198
+ yield from self.message_stores[thread_id][run_id][message_idx:]
199
+
200
+ def get_queues_by_thread_id(self, thread_id: UUID | str) -> list[asyncio.Queue]:
201
+ """Get all queues for a specific thread_id across all runs."""
202
+ all_queues = []
203
+ # Search through all stored queue keys for ones ending with the thread_id
204
+ thread_id = _ensure_uuid(thread_id)
205
+ if thread_id in self.queues:
206
+ for run_id in self.queues[thread_id]:
207
+ all_queues.extend(self.queues[thread_id][run_id])
208
+
209
+ return all_queues
210
+
211
+
212
+ # Global instance
213
+ stream_manager = StreamManager()
214
+
215
+
216
+ async def start_stream() -> None:
217
+ """Initialize the queue system.
218
+ In this in-memory implementation, we just need to ensure we have a clean StreamManager instance.
219
+ """
220
+ global stream_manager
221
+ stream_manager = StreamManager()
222
+
223
+
224
+ async def stop_stream() -> None:
225
+ """Clean up the queue system.
226
+ Clear all queues and stored control messages."""
227
+ global stream_manager
228
+
229
+ # Send 'done' message to all active queues before clearing
230
+ for run_id in list(stream_manager.queues.keys()):
231
+ control_message = Message(topic=f"run:{run_id}:control".encode(), data=b"done")
232
+
233
+ for queue in stream_manager.queues[run_id]:
234
+ try:
235
+ await queue.put(control_message)
236
+ except (Exception, RuntimeError):
237
+ pass # Ignore errors during shutdown
238
+
239
+ # Clear all stored data
240
+ stream_manager.queues.clear()
241
+ stream_manager.control_queues.clear()
242
+ stream_manager.message_stores.clear()
243
+
244
+
245
+ def get_stream_manager() -> StreamManager:
246
+ """Get the global stream manager instance."""
247
+ return stream_manager
@@ -27,7 +27,11 @@ from starlette.exceptions import HTTPException
27
27
 
28
28
  from langgraph_runtime_inmem.checkpoint import Checkpointer
29
29
  from langgraph_runtime_inmem.database import InMemConnectionProto, connect
30
- from langgraph_runtime_inmem.inmem_stream import Message, get_stream_manager
30
+ from langgraph_runtime_inmem.inmem_stream import (
31
+ THREADLESS_KEY,
32
+ Message,
33
+ get_stream_manager,
34
+ )
31
35
 
32
36
  if typing.TYPE_CHECKING:
33
37
  from langgraph_api.asyncio import ValueEvent
@@ -406,19 +410,17 @@ class Assistants(Authenticated):
406
410
  else 1
407
411
  )
408
412
 
409
- # Update assistant_versions table
410
- if metadata:
411
- metadata = {
412
- **assistant["metadata"],
413
- **metadata,
414
- }
415
413
  new_version_entry = {
416
414
  "assistant_id": assistant_id,
417
415
  "version": new_version,
418
416
  "graph_id": graph_id if graph_id is not None else assistant["graph_id"],
419
417
  "config": config if config else assistant["config"],
420
418
  "context": context if context is not None else assistant.get("context", {}),
421
- "metadata": metadata if metadata is not None else assistant["metadata"],
419
+ "metadata": (
420
+ {**assistant["metadata"], **metadata}
421
+ if metadata is not None
422
+ else assistant["metadata"]
423
+ ),
422
424
  "created_at": now,
423
425
  "name": name if name is not None else assistant["name"],
424
426
  "description": (
@@ -1611,6 +1613,151 @@ class Threads(Authenticated):
1611
1613
 
1612
1614
  return []
1613
1615
 
1616
+ class Stream:
1617
+ @staticmethod
1618
+ async def subscribe(
1619
+ conn: InMemConnectionProto | AsyncConnectionProto,
1620
+ thread_id: UUID,
1621
+ seen_runs: set[UUID],
1622
+ ) -> list[tuple[UUID, asyncio.Queue]]:
1623
+ """Subscribe to the thread stream, creating queues for unseen runs."""
1624
+ stream_manager = get_stream_manager()
1625
+ queues = []
1626
+
1627
+ # Create new queues only for runs not yet seen
1628
+ thread_id = _ensure_uuid(thread_id)
1629
+ for run in conn.store["runs"]:
1630
+ if run["thread_id"] == thread_id:
1631
+ run_id = run["run_id"]
1632
+ if run_id not in seen_runs:
1633
+ queue = await stream_manager.add_queue(run_id, thread_id)
1634
+ queues.append((run_id, queue))
1635
+ seen_runs.add(run_id)
1636
+
1637
+ return queues
1638
+
1639
+ @staticmethod
1640
+ async def join(
1641
+ thread_id: UUID,
1642
+ *,
1643
+ last_event_id: str | None = None,
1644
+ ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
1645
+ """Stream the thread output."""
1646
+ from langgraph_api.serde import json_loads
1647
+
1648
+ stream_manager = get_stream_manager()
1649
+ seen_runs: set[UUID] = set()
1650
+ created_queues: list[tuple[UUID, asyncio.Queue]] = []
1651
+
1652
+ try:
1653
+ async with connect() as conn:
1654
+ await logger.ainfo(
1655
+ "Joined thread stream",
1656
+ thread_id=str(thread_id),
1657
+ )
1658
+
1659
+ # Restore messages if resuming from a specific event
1660
+ if last_event_id is not None:
1661
+ # Collect all events from all message stores for this thread
1662
+ all_events = []
1663
+ for run_id in stream_manager.message_stores.get(
1664
+ str(thread_id), []
1665
+ ):
1666
+ for message in stream_manager.restore_messages(
1667
+ run_id, thread_id, last_event_id
1668
+ ):
1669
+ all_events.append((message, run_id))
1670
+
1671
+ # Sort by message ID (which is ms-seq format)
1672
+ all_events.sort(key=lambda x: x[0].id.decode())
1673
+
1674
+ # Yield sorted events
1675
+ for message, run_id in all_events:
1676
+ data = json_loads(message.data)
1677
+ event_name = data["event"]
1678
+ message_content = data["message"]
1679
+
1680
+ if event_name == "control":
1681
+ if message_content == b"done":
1682
+ yield (
1683
+ b"metadata",
1684
+ orjson.dumps(
1685
+ {"status": "run_done", "run_id": run_id}
1686
+ ),
1687
+ message.id,
1688
+ )
1689
+ else:
1690
+ yield (
1691
+ event_name.encode(),
1692
+ base64.b64decode(message_content),
1693
+ message.id,
1694
+ )
1695
+
1696
+ # Listen for live messages from all queues
1697
+ while True:
1698
+ # Refresh queues to pick up any new runs that joined this thread
1699
+ new_queue_tuples = await Threads.Stream.subscribe(
1700
+ conn, thread_id, seen_runs
1701
+ )
1702
+ # Track new queues for cleanup
1703
+ for run_id, queue in new_queue_tuples:
1704
+ created_queues.append((run_id, queue))
1705
+
1706
+ for run_id, queue in created_queues:
1707
+ try:
1708
+ message = await asyncio.wait_for(
1709
+ queue.get(), timeout=0.2
1710
+ )
1711
+ data = json_loads(message.data)
1712
+ event_name = data["event"]
1713
+ message_content = data["message"]
1714
+
1715
+ if event_name == "control":
1716
+ if message_content == b"done":
1717
+ # Extract run_id from topic
1718
+ topic = message.topic.decode()
1719
+ run_id = topic.split("run:")[1].split(":")[0]
1720
+ yield (
1721
+ b"metadata",
1722
+ orjson.dumps(
1723
+ {"status": "run_done", "run_id": run_id}
1724
+ ),
1725
+ message.id,
1726
+ )
1727
+ else:
1728
+ yield (
1729
+ event_name.encode(),
1730
+ base64.b64decode(message_content),
1731
+ message.id,
1732
+ )
1733
+
1734
+ except TimeoutError:
1735
+ continue
1736
+ except (ValueError, KeyError):
1737
+ continue
1738
+
1739
+ # Yield execution to other tasks to prevent event loop starvation
1740
+ await asyncio.sleep(0)
1741
+
1742
+ except WrappedHTTPException as e:
1743
+ raise e.http_exception from None
1744
+ except asyncio.CancelledError:
1745
+ await logger.awarning(
1746
+ "Thread stream client disconnected",
1747
+ thread_id=str(thread_id),
1748
+ )
1749
+ raise
1750
+ except:
1751
+ raise
1752
+ finally:
1753
+ # Clean up all created queues
1754
+ for run_id, queue in created_queues:
1755
+ try:
1756
+ await stream_manager.remove_queue(run_id, thread_id, queue)
1757
+ except Exception:
1758
+ # Ignore cleanup errors
1759
+ pass
1760
+
1614
1761
  @staticmethod
1615
1762
  async def count(
1616
1763
  conn: InMemConnectionProto,
@@ -1769,7 +1916,7 @@ class Runs(Authenticated):
1769
1916
  @asynccontextmanager
1770
1917
  @staticmethod
1771
1918
  async def enter(
1772
- run_id: UUID, loop: asyncio.AbstractEventLoop
1919
+ run_id: UUID, thread_id: UUID | None, loop: asyncio.AbstractEventLoop
1773
1920
  ) -> AsyncIterator[ValueEvent]:
1774
1921
  """Enter a run, listen for cancellation while running, signal when done."
1775
1922
  This method should be called as a context manager by a worker executing a run.
@@ -1777,12 +1924,14 @@ class Runs(Authenticated):
1777
1924
  from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
1778
1925
 
1779
1926
  stream_manager = get_stream_manager()
1780
- # Get queue for this run
1781
- queue = await stream_manager.add_control_queue(run_id)
1927
+ # Get control queue for this run (normal queue is created during run creation)
1928
+ control_queue = await stream_manager.add_control_queue(run_id, thread_id)
1782
1929
 
1783
1930
  async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
1784
1931
  done = ValueEvent()
1785
- tg.create_task(listen_for_cancellation(queue, run_id, done))
1932
+ tg.create_task(
1933
+ listen_for_cancellation(control_queue, run_id, thread_id, done)
1934
+ )
1786
1935
 
1787
1936
  # Give done event to caller
1788
1937
  yield done
@@ -1790,17 +1939,17 @@ class Runs(Authenticated):
1790
1939
  control_message = Message(
1791
1940
  topic=f"run:{run_id}:control".encode(), data=b"done"
1792
1941
  )
1793
- await stream_manager.put(run_id, control_message)
1942
+ await stream_manager.put(run_id, thread_id, control_message)
1794
1943
 
1795
1944
  # Signal done to all subscribers
1796
1945
  stream_message = Message(
1797
1946
  topic=f"run:{run_id}:stream".encode(),
1798
1947
  data={"event": "control", "message": b"done"},
1799
1948
  )
1800
- await stream_manager.put(run_id, stream_message)
1949
+ await stream_manager.put(run_id, thread_id, stream_message)
1801
1950
 
1802
- # Remove the queue
1803
- await stream_manager.remove_control_queue(run_id, queue)
1951
+ # Remove the control_queue (normal queue is cleaned up during run deletion)
1952
+ await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
1804
1953
 
1805
1954
  @staticmethod
1806
1955
  async def sweep() -> None:
@@ -2088,6 +2237,7 @@ class Runs(Authenticated):
2088
2237
  if not thread:
2089
2238
  return _empty_generator()
2090
2239
  _delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
2240
+
2091
2241
  found = False
2092
2242
  for i, run in enumerate(conn.store["runs"]):
2093
2243
  if run["run_id"] == run_id and run["thread_id"] == thread_id:
@@ -2270,9 +2420,9 @@ class Runs(Authenticated):
2270
2420
  topic=f"run:{run_id}:control".encode(),
2271
2421
  data=action.encode(),
2272
2422
  )
2273
- coros.append(stream_manager.put(run_id, control_message))
2423
+ coros.append(stream_manager.put(run_id, thread_id, control_message))
2274
2424
 
2275
- queues = stream_manager.get_queues(run_id)
2425
+ queues = stream_manager.get_queues(run_id, thread_id)
2276
2426
 
2277
2427
  if run["status"] in ("pending", "running"):
2278
2428
  cancelable_runs.append(run)
@@ -2387,15 +2537,25 @@ class Runs(Authenticated):
2387
2537
  @staticmethod
2388
2538
  async def subscribe(
2389
2539
  run_id: UUID,
2540
+ thread_id: UUID | None = None,
2390
2541
  ) -> asyncio.Queue:
2391
2542
  """Subscribe to the run stream, returning a queue."""
2392
2543
  stream_manager = get_stream_manager()
2393
- queue = await stream_manager.add_queue(_ensure_uuid(run_id))
2544
+ queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
2394
2545
 
2395
2546
  # If there's a control message already stored, send it to the new subscriber
2396
- if control_messages := stream_manager.control_queues.get(run_id):
2397
- for control_msg in control_messages:
2398
- await queue.put(control_msg)
2547
+ if thread_id is None:
2548
+ thread_id = THREADLESS_KEY
2549
+ if control_queues := stream_manager.control_queues.get(thread_id, {}).get(
2550
+ run_id
2551
+ ):
2552
+ for control_queue in control_queues:
2553
+ try:
2554
+ while True:
2555
+ control_msg = control_queue.get()
2556
+ await queue.put(control_msg)
2557
+ except asyncio.QueueEmpty:
2558
+ pass
2399
2559
  return queue
2400
2560
 
2401
2561
  @staticmethod
@@ -2417,7 +2577,7 @@ class Runs(Authenticated):
2417
2577
  queue = (
2418
2578
  stream_channel
2419
2579
  if stream_channel
2420
- else await Runs.Stream.subscribe(run_id)
2580
+ else await Runs.Stream.subscribe(run_id, thread_id)
2421
2581
  )
2422
2582
 
2423
2583
  try:
@@ -2440,7 +2600,7 @@ class Runs(Authenticated):
2440
2600
  run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
2441
2601
 
2442
2602
  for message in get_stream_manager().restore_messages(
2443
- run_id, last_event_id
2603
+ run_id, thread_id, last_event_id
2444
2604
  ):
2445
2605
  data, id = message.data, message.id
2446
2606
 
@@ -2531,7 +2691,7 @@ class Runs(Authenticated):
2531
2691
  raise
2532
2692
  finally:
2533
2693
  stream_manager = get_stream_manager()
2534
- await stream_manager.remove_queue(run_id, queue)
2694
+ await stream_manager.remove_queue(run_id, thread_id, queue)
2535
2695
 
2536
2696
  @staticmethod
2537
2697
  async def publish(
@@ -2539,6 +2699,7 @@ class Runs(Authenticated):
2539
2699
  event: str,
2540
2700
  message: bytes,
2541
2701
  *,
2702
+ thread_id: UUID | str | None = None,
2542
2703
  resumable: bool = False,
2543
2704
  ) -> None:
2544
2705
  """Publish a message to all subscribers of the run stream."""
@@ -2555,17 +2716,19 @@ class Runs(Authenticated):
2555
2716
  }
2556
2717
  )
2557
2718
  await stream_manager.put(
2558
- run_id, Message(topic=topic, data=payload), resumable
2719
+ run_id, thread_id, Message(topic=topic, data=payload), resumable
2559
2720
  )
2560
2721
 
2561
2722
 
2562
- async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: ValueEvent):
2723
+ async def listen_for_cancellation(
2724
+ queue: asyncio.Queue, run_id: UUID, thread_id: UUID | None, done: ValueEvent
2725
+ ):
2563
2726
  """Listen for cancellation messages and set the done event accordingly."""
2564
2727
  from langgraph_api.errors import UserInterrupt, UserRollback
2565
2728
 
2566
2729
  stream_manager = get_stream_manager()
2567
2730
 
2568
- if control_key := stream_manager.get_control_key(run_id):
2731
+ if control_key := stream_manager.get_control_key(run_id, thread_id):
2569
2732
  payload = control_key.data
2570
2733
  if payload == b"rollback":
2571
2734
  done.set(UserRollback())
@@ -1,159 +0,0 @@
1
- import asyncio
2
- import logging
3
- from collections import defaultdict
4
- from collections.abc import Iterator
5
- from dataclasses import dataclass
6
- from uuid import UUID
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
-
11
- def _ensure_uuid(id: str | UUID) -> UUID:
12
- return UUID(id) if isinstance(id, str) else id
13
-
14
-
15
- @dataclass
16
- class Message:
17
- topic: bytes
18
- data: bytes
19
- id: bytes | None = None
20
-
21
-
22
- class ContextQueue(asyncio.Queue):
23
- """Queue that supports async context manager protocol"""
24
-
25
- async def __aenter__(self):
26
- return self
27
-
28
- async def __aexit__(
29
- self,
30
- exc_type: type[BaseException] | None,
31
- exc_val: BaseException | None,
32
- exc_tb: object | None,
33
- ) -> None:
34
- # Clear the queue
35
- while not self.empty():
36
- try:
37
- self.get_nowait()
38
- except asyncio.QueueEmpty:
39
- break
40
-
41
-
42
- class StreamManager:
43
- def __init__(self):
44
- self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
45
- self.control_keys = defaultdict()
46
- self.control_queues = defaultdict(list)
47
-
48
- self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
49
- self.message_next_idx = defaultdict(int) # Dict[UUID, int]
50
-
51
- def get_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
52
- run_id = _ensure_uuid(run_id)
53
- return self.queues[run_id]
54
-
55
- def get_control_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
56
- run_id = _ensure_uuid(run_id)
57
- return self.control_queues[run_id]
58
-
59
- def get_control_key(self, run_id: UUID | str) -> Message | None:
60
- run_id = _ensure_uuid(run_id)
61
- return self.control_keys.get(run_id)
62
-
63
- async def put(
64
- self, run_id: UUID | str, message: Message, resumable: bool = False
65
- ) -> None:
66
- run_id = _ensure_uuid(run_id)
67
- message.id = str(self.message_next_idx[run_id]).encode()
68
- self.message_next_idx[run_id] += 1
69
- if resumable:
70
- self.message_stores[run_id].append(message)
71
- topic = message.topic.decode()
72
- if "control" in topic:
73
- self.control_keys[run_id] = message
74
- queues = self.control_queues[run_id]
75
- else:
76
- queues = self.queues[run_id]
77
- coros = [queue.put(message) for queue in queues]
78
- results = await asyncio.gather(*coros, return_exceptions=True)
79
- for result in results:
80
- if isinstance(result, Exception):
81
- logger.exception(f"Failed to put message in queue: {result}")
82
-
83
- async def add_queue(self, run_id: UUID | str) -> asyncio.Queue:
84
- run_id = _ensure_uuid(run_id)
85
- queue = ContextQueue()
86
- self.queues[run_id].append(queue)
87
- return queue
88
-
89
- async def add_control_queue(self, run_id: UUID | str) -> asyncio.Queue:
90
- run_id = _ensure_uuid(run_id)
91
- queue = ContextQueue()
92
- self.control_queues[run_id].append(queue)
93
- return queue
94
-
95
- async def remove_queue(self, run_id: UUID | str, queue: asyncio.Queue):
96
- run_id = _ensure_uuid(run_id)
97
- if run_id in self.queues:
98
- self.queues[run_id].remove(queue)
99
- if not self.queues[run_id]:
100
- del self.queues[run_id]
101
-
102
- async def remove_control_queue(self, run_id: UUID | str, queue: asyncio.Queue):
103
- run_id = _ensure_uuid(run_id)
104
- if run_id in self.control_queues:
105
- self.control_queues[run_id].remove(queue)
106
- if not self.control_queues[run_id]:
107
- del self.control_queues[run_id]
108
-
109
- def restore_messages(
110
- self, run_id: UUID | str, message_id: str | None
111
- ) -> Iterator[Message]:
112
- """Get a stored message by ID for resumable streams."""
113
- run_id = _ensure_uuid(run_id)
114
- message_idx = int(message_id) + 1 if message_id else None
115
-
116
- if message_idx is None:
117
- yield from []
118
- return
119
-
120
- if run_id in self.message_stores:
121
- yield from self.message_stores[run_id][message_idx:]
122
-
123
-
124
- # Global instance
125
- stream_manager = StreamManager()
126
-
127
-
128
- async def start_stream() -> None:
129
- """Initialize the queue system.
130
- In this in-memory implementation, we just need to ensure we have a clean StreamManager instance.
131
- """
132
- global stream_manager
133
- stream_manager = StreamManager()
134
-
135
-
136
- async def stop_stream() -> None:
137
- """Clean up the queue system.
138
- Clear all queues and stored control messages."""
139
- global stream_manager
140
-
141
- # Send 'done' message to all active queues before clearing
142
- for run_id in list(stream_manager.queues.keys()):
143
- control_message = Message(topic=f"run:{run_id}:control".encode(), data=b"done")
144
-
145
- for queue in stream_manager.queues[run_id]:
146
- try:
147
- await queue.put(control_message)
148
- except (Exception, RuntimeError):
149
- pass # Ignore errors during shutdown
150
-
151
- # Clear all stored data
152
- stream_manager.queues.clear()
153
- stream_manager.control_queues.clear()
154
- stream_manager.message_stores.clear()
155
-
156
-
157
- def get_stream_manager() -> StreamManager:
158
- """Get the global stream manager instance."""
159
- return stream_manager