langgraph-runtime-inmem 0.8.2__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.8.2"
12
+ __version__ = "0.10.0"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import logging
3
+ import time
3
4
  from collections import defaultdict
4
5
  from collections.abc import Iterator
5
6
  from dataclasses import dataclass
@@ -12,6 +13,14 @@ def _ensure_uuid(id: str | UUID) -> UUID:
12
13
  return UUID(id) if isinstance(id, str) else id
13
14
 
14
15
 
16
+ def _generate_ms_seq_id() -> str:
17
+ """Generate a Redis-like millisecond-sequence ID (e.g., '1234567890123-0')"""
18
+ # Get current time in milliseconds
19
+ ms = int(time.time() * 1000)
20
+ # For simplicity, always use sequence 0 since we're not handling high throughput
21
+ return f"{ms}-0"
22
+
23
+
15
24
  @dataclass
16
25
  class Message:
17
26
  topic: bytes
@@ -39,86 +48,186 @@ class ContextQueue(asyncio.Queue):
39
48
  break
40
49
 
41
50
 
42
- class StreamManager:
43
- def __init__(self):
44
- self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
45
- self.control_keys = defaultdict()
46
- self.control_queues = defaultdict(list)
51
+ THREADLESS_KEY = "no-thread"
47
52
 
48
- self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
49
- self.message_next_idx = defaultdict(int) # Dict[UUID, int]
50
53
 
51
- def get_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
54
+ class StreamManager:
55
+ def __init__(self):
56
+ self.queues = defaultdict(
57
+ lambda: defaultdict(list)
58
+ ) # Dict[str, List[asyncio.Queue]]
59
+ self.control_keys = defaultdict(lambda: defaultdict())
60
+ self.control_queues = defaultdict(lambda: defaultdict(list))
61
+ self.thread_streams = defaultdict(list)
62
+
63
+ self.message_stores = defaultdict(
64
+ lambda: defaultdict(list[Message])
65
+ ) # Dict[str, List[Message]]
66
+
67
+ def get_queues(
68
+ self, run_id: UUID | str, thread_id: UUID | str | None
69
+ ) -> list[asyncio.Queue]:
52
70
  run_id = _ensure_uuid(run_id)
53
- return self.queues[run_id]
71
+ if thread_id is None:
72
+ thread_id = THREADLESS_KEY
73
+ else:
74
+ thread_id = _ensure_uuid(thread_id)
75
+ return self.queues[thread_id][run_id]
54
76
 
55
- def get_control_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
77
+ def get_control_queues(
78
+ self, run_id: UUID | str, thread_id: UUID | str | None
79
+ ) -> list[asyncio.Queue]:
56
80
  run_id = _ensure_uuid(run_id)
57
- return self.control_queues[run_id]
81
+ if thread_id is None:
82
+ thread_id = THREADLESS_KEY
83
+ else:
84
+ thread_id = _ensure_uuid(thread_id)
85
+ return self.control_queues[thread_id][run_id]
58
86
 
59
- def get_control_key(self, run_id: UUID | str) -> Message | None:
87
+ def get_control_key(
88
+ self, run_id: UUID | str, thread_id: UUID | str | None
89
+ ) -> Message | None:
60
90
  run_id = _ensure_uuid(run_id)
61
- return self.control_keys.get(run_id)
91
+ if thread_id is None:
92
+ thread_id = THREADLESS_KEY
93
+ else:
94
+ thread_id = _ensure_uuid(thread_id)
95
+ return self.control_keys.get(thread_id, {}).get(run_id)
62
96
 
63
97
  async def put(
64
- self, run_id: UUID | str, message: Message, resumable: bool = False
98
+ self,
99
+ run_id: UUID | str | None,
100
+ thread_id: UUID | str | None,
101
+ message: Message,
102
+ resumable: bool = False,
65
103
  ) -> None:
66
104
  run_id = _ensure_uuid(run_id)
67
- message.id = str(self.message_next_idx[run_id]).encode()
68
- self.message_next_idx[run_id] += 1
105
+ if thread_id is None:
106
+ thread_id = THREADLESS_KEY
107
+ else:
108
+ thread_id = _ensure_uuid(thread_id)
109
+
110
+ message.id = _generate_ms_seq_id().encode()
69
111
  if resumable:
70
- self.message_stores[run_id].append(message)
112
+ self.message_stores[thread_id][run_id].append(message)
71
113
  topic = message.topic.decode()
72
114
  if "control" in topic:
73
- self.control_keys[run_id] = message
74
- queues = self.control_queues[run_id]
115
+ self.control_keys[thread_id][run_id] = message
116
+ queues = self.control_queues[thread_id][run_id]
75
117
  else:
76
- queues = self.queues[run_id]
118
+ queues = self.queues[thread_id][run_id]
77
119
  coros = [queue.put(message) for queue in queues]
78
120
  results = await asyncio.gather(*coros, return_exceptions=True)
79
121
  for result in results:
80
122
  if isinstance(result, Exception):
81
123
  logger.exception(f"Failed to put message in queue: {result}")
82
124
 
83
- async def add_queue(self, run_id: UUID | str) -> asyncio.Queue:
125
+ async def put_thread(
126
+ self,
127
+ thread_id: UUID | str,
128
+ message: Message,
129
+ ) -> None:
130
+ thread_id = _ensure_uuid(thread_id)
131
+ message.id = _generate_ms_seq_id().encode()
132
+ queues = self.thread_streams[thread_id]
133
+ coros = [queue.put(message) for queue in queues]
134
+ results = await asyncio.gather(*coros, return_exceptions=True)
135
+ for result in results:
136
+ if isinstance(result, Exception):
137
+ logger.exception(f"Failed to put message in queue: {result}")
138
+
139
+ async def add_queue(
140
+ self, run_id: UUID | str, thread_id: UUID | str | None
141
+ ) -> asyncio.Queue:
84
142
  run_id = _ensure_uuid(run_id)
85
143
  queue = ContextQueue()
86
- self.queues[run_id].append(queue)
144
+ if thread_id is None:
145
+ thread_id = THREADLESS_KEY
146
+ else:
147
+ thread_id = _ensure_uuid(thread_id)
148
+ self.queues[thread_id][run_id].append(queue)
87
149
  return queue
88
150
 
89
- async def add_control_queue(self, run_id: UUID | str) -> asyncio.Queue:
151
+ async def add_control_queue(
152
+ self, run_id: UUID | str, thread_id: UUID | str | None
153
+ ) -> asyncio.Queue:
90
154
  run_id = _ensure_uuid(run_id)
155
+ if thread_id is None:
156
+ thread_id = THREADLESS_KEY
157
+ else:
158
+ thread_id = _ensure_uuid(thread_id)
91
159
  queue = ContextQueue()
92
- self.control_queues[run_id].append(queue)
160
+ self.control_queues[thread_id][run_id].append(queue)
93
161
  return queue
94
162
 
95
- async def remove_queue(self, run_id: UUID | str, queue: asyncio.Queue):
96
- run_id = _ensure_uuid(run_id)
97
- if run_id in self.queues:
98
- self.queues[run_id].remove(queue)
99
- if not self.queues[run_id]:
100
- del self.queues[run_id]
163
+ async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
164
+ thread_id = _ensure_uuid(thread_id)
165
+ queue = ContextQueue()
166
+ self.thread_streams[thread_id].append(queue)
167
+ return queue
101
168
 
102
- async def remove_control_queue(self, run_id: UUID | str, queue: asyncio.Queue):
169
+ async def remove_queue(
170
+ self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
171
+ ):
172
+ run_id = _ensure_uuid(run_id)
173
+ if thread_id is None:
174
+ thread_id = THREADLESS_KEY
175
+ else:
176
+ thread_id = _ensure_uuid(thread_id)
177
+ if thread_id in self.queues and run_id in self.queues[thread_id]:
178
+ self.queues[thread_id][run_id].remove(queue)
179
+ if not self.queues[thread_id][run_id]:
180
+ del self.queues[thread_id][run_id]
181
+
182
+ async def remove_control_queue(
183
+ self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
184
+ ):
103
185
  run_id = _ensure_uuid(run_id)
104
- if run_id in self.control_queues:
105
- self.control_queues[run_id].remove(queue)
106
- if not self.control_queues[run_id]:
107
- del self.control_queues[run_id]
186
+ if thread_id is None:
187
+ thread_id = THREADLESS_KEY
188
+ else:
189
+ thread_id = _ensure_uuid(thread_id)
190
+ if (
191
+ thread_id in self.control_queues
192
+ and run_id in self.control_queues[thread_id]
193
+ ):
194
+ self.control_queues[thread_id][run_id].remove(queue)
195
+ if not self.control_queues[thread_id][run_id]:
196
+ del self.control_queues[thread_id][run_id]
108
197
 
109
198
  def restore_messages(
110
- self, run_id: UUID | str, message_id: str | None
199
+ self, run_id: UUID | str, thread_id: UUID | str | None, message_id: str | None
111
200
  ) -> Iterator[Message]:
112
201
  """Get a stored message by ID for resumable streams."""
113
202
  run_id = _ensure_uuid(run_id)
114
- message_idx = int(message_id) + 1 if message_id else None
115
-
116
- if message_idx is None:
117
- yield from []
203
+ if thread_id is None:
204
+ thread_id = THREADLESS_KEY
205
+ else:
206
+ thread_id = _ensure_uuid(thread_id)
207
+ if message_id is None:
118
208
  return
119
-
120
- if run_id in self.message_stores:
121
- yield from self.message_stores[run_id][message_idx:]
209
+ try:
210
+ # Handle ms-seq format (e.g., "1234567890123-0")
211
+ if thread_id in self.message_stores:
212
+ for message in self.message_stores[thread_id][run_id]:
213
+ if message.id.decode() > message_id:
214
+ yield message
215
+ except TypeError:
216
+ # Try integer format if ms-seq fails
217
+ message_idx = int(message_id) + 1
218
+ if run_id in self.message_stores:
219
+ yield from self.message_stores[thread_id][run_id][message_idx:]
220
+
221
+ def get_queues_by_thread_id(self, thread_id: UUID | str) -> list[asyncio.Queue]:
222
+ """Get all queues for a specific thread_id across all runs."""
223
+ all_queues = []
224
+ # Search through all stored queue keys for ones ending with the thread_id
225
+ thread_id = _ensure_uuid(thread_id)
226
+ if thread_id in self.queues:
227
+ for run_id in self.queues[thread_id]:
228
+ all_queues.extend(self.queues[thread_id][run_id])
229
+
230
+ return all_queues
122
231
 
123
232
 
124
233
  # Global instance
@@ -27,7 +27,12 @@ from starlette.exceptions import HTTPException
27
27
 
28
28
  from langgraph_runtime_inmem.checkpoint import Checkpointer
29
29
  from langgraph_runtime_inmem.database import InMemConnectionProto, connect
30
- from langgraph_runtime_inmem.inmem_stream import Message, get_stream_manager
30
+ from langgraph_runtime_inmem.inmem_stream import (
31
+ THREADLESS_KEY,
32
+ ContextQueue,
33
+ Message,
34
+ get_stream_manager,
35
+ )
31
36
 
32
37
  if typing.TYPE_CHECKING:
33
38
  from langgraph_api.asyncio import ValueEvent
@@ -54,6 +59,7 @@ if typing.TYPE_CHECKING:
54
59
  Thread,
55
60
  ThreadSelectField,
56
61
  ThreadStatus,
62
+ ThreadStreamMode,
57
63
  ThreadUpdateResponse,
58
64
  )
59
65
  from langgraph_api.schema import Interrupt as InterruptSchema
@@ -734,6 +740,7 @@ class Threads(Authenticated):
734
740
  async def search(
735
741
  conn: InMemConnectionProto,
736
742
  *,
743
+ ids: list[str] | list[UUID] | None = None,
737
744
  metadata: MetadataInput,
738
745
  values: MetadataInput,
739
746
  status: ThreadStatus | None,
@@ -761,7 +768,19 @@ class Threads(Authenticated):
761
768
  )
762
769
 
763
770
  # Apply filters
771
+ id_set: set[UUID] | None = None
772
+ if ids:
773
+ id_set = set()
774
+ for i in ids:
775
+ try:
776
+ id_set.add(_ensure_uuid(i))
777
+ except Exception:
778
+ raise HTTPException(
779
+ status_code=400, detail="Invalid thread ID " + str(i)
780
+ ) from None
764
781
  for thread in threads:
782
+ if id_set is not None and thread.get("thread_id") not in id_set:
783
+ continue
765
784
  if filters and not _check_filter_match(thread["metadata"], filters):
766
785
  continue
767
786
 
@@ -1323,7 +1342,14 @@ class Threads(Authenticated):
1323
1342
  )
1324
1343
 
1325
1344
  metadata = thread.get("metadata", {})
1326
- thread_config = thread.get("config", {})
1345
+ thread_config = cast(dict[str, Any], thread.get("config", {}))
1346
+ thread_config = {
1347
+ **thread_config,
1348
+ "configurable": {
1349
+ **thread_config.get("configurable", {}),
1350
+ **config.get("configurable", {}),
1351
+ },
1352
+ }
1327
1353
 
1328
1354
  # Fallback to graph_id from run if not in thread metadata
1329
1355
  graph_id = metadata.get("graph_id")
@@ -1410,6 +1436,13 @@ class Threads(Authenticated):
1410
1436
  status_code=409,
1411
1437
  detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
1412
1438
  )
1439
+ thread_config = {
1440
+ **thread_config,
1441
+ "configurable": {
1442
+ **thread_config.get("configurable", {}),
1443
+ **config.get("configurable", {}),
1444
+ },
1445
+ }
1413
1446
 
1414
1447
  # Fallback to graph_id from run if not in thread metadata
1415
1448
  graph_id = metadata.get("graph_id")
@@ -1450,6 +1483,19 @@ class Threads(Authenticated):
1450
1483
  thread["values"] = state.values
1451
1484
  break
1452
1485
 
1486
+ # Publish state update event
1487
+ from langgraph_api.serde import json_dumpb
1488
+
1489
+ event_data = {
1490
+ "state": state,
1491
+ "thread_id": str(thread_id),
1492
+ }
1493
+ await Threads.Stream.publish(
1494
+ thread_id,
1495
+ "state_update",
1496
+ json_dumpb(event_data),
1497
+ )
1498
+
1453
1499
  return ThreadUpdateResponse(
1454
1500
  checkpoint=next_config["configurable"],
1455
1501
  # Including deprecated fields
@@ -1492,7 +1538,14 @@ class Threads(Authenticated):
1492
1538
  thread_iter, not_found_detail=f"Thread {thread_id} not found."
1493
1539
  )
1494
1540
 
1495
- thread_config = thread["config"]
1541
+ thread_config = cast(dict[str, Any], thread["config"])
1542
+ thread_config = {
1543
+ **thread_config,
1544
+ "configurable": {
1545
+ **thread_config.get("configurable", {}),
1546
+ **config.get("configurable", {}),
1547
+ },
1548
+ }
1496
1549
  metadata = thread["metadata"]
1497
1550
 
1498
1551
  if not thread:
@@ -1539,6 +1592,19 @@ class Threads(Authenticated):
1539
1592
  thread["values"] = state.values
1540
1593
  break
1541
1594
 
1595
+ # Publish state update event
1596
+ from langgraph_api.serde import json_dumpb
1597
+
1598
+ event_data = {
1599
+ "state": state,
1600
+ "thread_id": str(thread_id),
1601
+ }
1602
+ await Threads.Stream.publish(
1603
+ thread_id,
1604
+ "state_update",
1605
+ json_dumpb(event_data),
1606
+ )
1607
+
1542
1608
  return ThreadUpdateResponse(
1543
1609
  checkpoint=next_config["configurable"],
1544
1610
  )
@@ -1580,7 +1646,14 @@ class Threads(Authenticated):
1580
1646
  if not _check_filter_match(thread_metadata, filters):
1581
1647
  return []
1582
1648
 
1583
- thread_config = thread["config"]
1649
+ thread_config = cast(dict[str, Any], thread["config"])
1650
+ thread_config = {
1651
+ **thread_config,
1652
+ "configurable": {
1653
+ **thread_config.get("configurable", {}),
1654
+ **config.get("configurable", {}),
1655
+ },
1656
+ }
1584
1657
  # If graph_id exists, get state history
1585
1658
  if graph_id := thread_metadata.get("graph_id"):
1586
1659
  async with get_graph(
@@ -1609,6 +1682,222 @@ class Threads(Authenticated):
1609
1682
 
1610
1683
  return []
1611
1684
 
1685
+ class Stream:
1686
+ @staticmethod
1687
+ async def subscribe(
1688
+ conn: InMemConnectionProto | AsyncConnectionProto,
1689
+ thread_id: UUID,
1690
+ seen_runs: set[UUID],
1691
+ ) -> list[tuple[UUID, asyncio.Queue]]:
1692
+ """Subscribe to the thread stream, creating queues for unseen runs."""
1693
+ stream_manager = get_stream_manager()
1694
+ queues = []
1695
+
1696
+ # Create new queues only for runs not yet seen
1697
+ thread_id = _ensure_uuid(thread_id)
1698
+
1699
+ # Add thread stream queue
1700
+ if thread_id not in seen_runs:
1701
+ queue = await stream_manager.add_thread_stream(thread_id)
1702
+ queues.append((thread_id, queue))
1703
+ seen_runs.add(thread_id)
1704
+
1705
+ for run in conn.store["runs"]:
1706
+ if run["thread_id"] == thread_id:
1707
+ run_id = run["run_id"]
1708
+ if run_id not in seen_runs:
1709
+ queue = await stream_manager.add_queue(run_id, thread_id)
1710
+ queues.append((run_id, queue))
1711
+ seen_runs.add(run_id)
1712
+
1713
+ return queues
1714
+
1715
+ @staticmethod
1716
+ async def join(
1717
+ thread_id: UUID,
1718
+ *,
1719
+ last_event_id: str | None = None,
1720
+ stream_modes: list[ThreadStreamMode],
1721
+ ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
1722
+ """Stream the thread output."""
1723
+
1724
+ def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
1725
+ """Check if an event should be filtered out based on stream_modes."""
1726
+ if "run_modes" in stream_modes and event_name != "state_update":
1727
+ return False
1728
+ if "state_update" in stream_modes and event_name == "state_update":
1729
+ return False
1730
+ if "lifecycle" in stream_modes and event_name == "metadata":
1731
+ try:
1732
+ message_data = orjson.loads(message_bytes)
1733
+ if message_data.get("status") == "run_done":
1734
+ return False
1735
+ if "attempt" in message_data and "run_id" in message_data:
1736
+ return False
1737
+ except (orjson.JSONDecodeError, TypeError):
1738
+ pass
1739
+ return True
1740
+
1741
+ from langgraph_api.serde import json_loads
1742
+
1743
+ stream_manager = get_stream_manager()
1744
+ seen_runs: set[UUID] = set()
1745
+ created_queues: list[tuple[UUID, asyncio.Queue]] = []
1746
+
1747
+ try:
1748
+ async with connect() as conn:
1749
+ await logger.ainfo(
1750
+ "Joined thread stream",
1751
+ thread_id=str(thread_id),
1752
+ )
1753
+
1754
+ # Restore messages if resuming from a specific event
1755
+ if last_event_id is not None:
1756
+ # Collect all events from all message stores for this thread
1757
+ all_events = []
1758
+ for run_id in stream_manager.message_stores.get(
1759
+ str(thread_id), []
1760
+ ):
1761
+ for message in stream_manager.restore_messages(
1762
+ run_id, thread_id, last_event_id
1763
+ ):
1764
+ all_events.append((message, run_id))
1765
+
1766
+ # Sort by message ID (which is ms-seq format)
1767
+ all_events.sort(key=lambda x: x[0].id.decode())
1768
+
1769
+ # Yield sorted events
1770
+ for message, run_id in all_events:
1771
+ data = json_loads(message.data)
1772
+ event_name = data["event"]
1773
+ message_content = data["message"]
1774
+
1775
+ if event_name == "control":
1776
+ if message_content == b"done":
1777
+ event_bytes = b"metadata"
1778
+ message_bytes = orjson.dumps(
1779
+ {"status": "run_done", "run_id": run_id}
1780
+ )
1781
+ # Filter events based on stream_modes
1782
+ if not should_filter_event(
1783
+ "metadata", message_bytes
1784
+ ):
1785
+ yield (
1786
+ event_bytes,
1787
+ message_bytes,
1788
+ message.id,
1789
+ )
1790
+ else:
1791
+ event_bytes = event_name.encode()
1792
+ message_bytes = base64.b64decode(message_content)
1793
+ # Filter events based on stream_modes
1794
+ if not should_filter_event(event_name, message_bytes):
1795
+ yield (
1796
+ event_bytes,
1797
+ message_bytes,
1798
+ message.id,
1799
+ )
1800
+
1801
+ # Listen for live messages from all queues
1802
+ while True:
1803
+ # Refresh queues to pick up any new runs that joined this thread
1804
+ new_queue_tuples = await Threads.Stream.subscribe(
1805
+ conn, thread_id, seen_runs
1806
+ )
1807
+ # Track new queues for cleanup
1808
+ for run_id, queue in new_queue_tuples:
1809
+ created_queues.append((run_id, queue))
1810
+
1811
+ for run_id, queue in created_queues:
1812
+ try:
1813
+ message = await asyncio.wait_for(
1814
+ queue.get(), timeout=0.2
1815
+ )
1816
+ data = json_loads(message.data)
1817
+ event_name = data["event"]
1818
+ message_content = data["message"]
1819
+
1820
+ if event_name == "control":
1821
+ if message_content == b"done":
1822
+ # Extract run_id from topic
1823
+ topic = message.topic.decode()
1824
+ run_id = topic.split("run:")[1].split(":")[0]
1825
+ event_bytes = b"metadata"
1826
+ message_bytes = orjson.dumps(
1827
+ {"status": "run_done", "run_id": run_id}
1828
+ )
1829
+ # Filter events based on stream_modes
1830
+ if not should_filter_event(
1831
+ "metadata", message_bytes
1832
+ ):
1833
+ yield (
1834
+ event_bytes,
1835
+ message_bytes,
1836
+ message.id,
1837
+ )
1838
+ else:
1839
+ event_bytes = event_name.encode()
1840
+ message_bytes = base64.b64decode(message_content)
1841
+ # Filter events based on stream_modes
1842
+ if not should_filter_event(
1843
+ event_name, message_bytes
1844
+ ):
1845
+ yield (
1846
+ event_bytes,
1847
+ message_bytes,
1848
+ message.id,
1849
+ )
1850
+
1851
+ except TimeoutError:
1852
+ continue
1853
+ except (ValueError, KeyError):
1854
+ continue
1855
+
1856
+ # Yield execution to other tasks to prevent event loop starvation
1857
+ await asyncio.sleep(0)
1858
+
1859
+ except WrappedHTTPException as e:
1860
+ raise e.http_exception from None
1861
+ except asyncio.CancelledError:
1862
+ await logger.awarning(
1863
+ "Thread stream client disconnected",
1864
+ thread_id=str(thread_id),
1865
+ )
1866
+ raise
1867
+ except:
1868
+ raise
1869
+ finally:
1870
+ # Clean up all created queues
1871
+ for run_id, queue in created_queues:
1872
+ try:
1873
+ await stream_manager.remove_queue(run_id, thread_id, queue)
1874
+ except Exception:
1875
+ # Ignore cleanup errors
1876
+ pass
1877
+
1878
+ @staticmethod
1879
+ async def publish(
1880
+ thread_id: UUID | str,
1881
+ event: str,
1882
+ message: bytes,
1883
+ ) -> None:
1884
+ """Publish a thread-level event to the thread stream."""
1885
+ from langgraph_api.serde import json_dumpb
1886
+
1887
+ topic = f"thread:{thread_id}:stream".encode()
1888
+
1889
+ stream_manager = get_stream_manager()
1890
+ # Send to thread stream topic
1891
+ payload = json_dumpb(
1892
+ {
1893
+ "event": event,
1894
+ "message": message,
1895
+ }
1896
+ )
1897
+ await stream_manager.put_thread(
1898
+ str(thread_id), Message(topic=topic, data=payload)
1899
+ )
1900
+
1612
1901
  @staticmethod
1613
1902
  async def count(
1614
1903
  conn: InMemConnectionProto,
@@ -1767,7 +2056,7 @@ class Runs(Authenticated):
1767
2056
  @asynccontextmanager
1768
2057
  @staticmethod
1769
2058
  async def enter(
1770
- run_id: UUID, loop: asyncio.AbstractEventLoop
2059
+ run_id: UUID, thread_id: UUID | None, loop: asyncio.AbstractEventLoop
1771
2060
  ) -> AsyncIterator[ValueEvent]:
1772
2061
  """Enter a run, listen for cancellation while running, signal when done."
1773
2062
  This method should be called as a context manager by a worker executing a run.
@@ -1775,12 +2064,14 @@ class Runs(Authenticated):
1775
2064
  from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
1776
2065
 
1777
2066
  stream_manager = get_stream_manager()
1778
- # Get queue for this run
1779
- queue = await stream_manager.add_control_queue(run_id)
2067
+ # Get control queue for this run (normal queue is created during run creation)
2068
+ control_queue = await stream_manager.add_control_queue(run_id, thread_id)
1780
2069
 
1781
2070
  async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
1782
2071
  done = ValueEvent()
1783
- tg.create_task(listen_for_cancellation(queue, run_id, done))
2072
+ tg.create_task(
2073
+ listen_for_cancellation(control_queue, run_id, thread_id, done)
2074
+ )
1784
2075
 
1785
2076
  # Give done event to caller
1786
2077
  yield done
@@ -1788,17 +2079,17 @@ class Runs(Authenticated):
1788
2079
  control_message = Message(
1789
2080
  topic=f"run:{run_id}:control".encode(), data=b"done"
1790
2081
  )
1791
- await stream_manager.put(run_id, control_message)
2082
+ await stream_manager.put(run_id, thread_id, control_message)
1792
2083
 
1793
2084
  # Signal done to all subscribers
1794
2085
  stream_message = Message(
1795
2086
  topic=f"run:{run_id}:stream".encode(),
1796
2087
  data={"event": "control", "message": b"done"},
1797
2088
  )
1798
- await stream_manager.put(run_id, stream_message)
2089
+ await stream_manager.put(run_id, thread_id, stream_message)
1799
2090
 
1800
- # Remove the queue
1801
- await stream_manager.remove_control_queue(run_id, queue)
2091
+ # Remove the control_queue (normal queue is cleaned up during run deletion)
2092
+ await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
1802
2093
 
1803
2094
  @staticmethod
1804
2095
  async def sweep() -> None:
@@ -1853,6 +2144,7 @@ class Runs(Authenticated):
1853
2144
  run_id = _ensure_uuid(run_id) if run_id else None
1854
2145
  metadata = metadata if metadata is not None else {}
1855
2146
  config = kwargs.get("config", {})
2147
+ temporary = kwargs.get("temporary", False)
1856
2148
 
1857
2149
  # Handle thread creation/update
1858
2150
  existing_thread = next(
@@ -1862,7 +2154,7 @@ class Runs(Authenticated):
1862
2154
  ctx,
1863
2155
  "create_run",
1864
2156
  Auth.types.RunsCreate(
1865
- thread_id=thread_id,
2157
+ thread_id=None if temporary else thread_id,
1866
2158
  assistant_id=assistant_id,
1867
2159
  run_id=run_id,
1868
2160
  status=status,
@@ -2086,6 +2378,7 @@ class Runs(Authenticated):
2086
2378
  if not thread:
2087
2379
  return _empty_generator()
2088
2380
  _delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
2381
+
2089
2382
  found = False
2090
2383
  for i, run in enumerate(conn.store["runs"]):
2091
2384
  if run["run_id"] == run_id and run["thread_id"] == thread_id:
@@ -2268,9 +2561,9 @@ class Runs(Authenticated):
2268
2561
  topic=f"run:{run_id}:control".encode(),
2269
2562
  data=action.encode(),
2270
2563
  )
2271
- coros.append(stream_manager.put(run_id, control_message))
2564
+ coros.append(stream_manager.put(run_id, thread_id, control_message))
2272
2565
 
2273
- queues = stream_manager.get_queues(run_id)
2566
+ queues = stream_manager.get_queues(run_id, thread_id)
2274
2567
 
2275
2568
  if run["status"] in ("pending", "running"):
2276
2569
  cancelable_runs.append(run)
@@ -2385,15 +2678,25 @@ class Runs(Authenticated):
2385
2678
  @staticmethod
2386
2679
  async def subscribe(
2387
2680
  run_id: UUID,
2388
- ) -> asyncio.Queue:
2681
+ thread_id: UUID | None = None,
2682
+ ) -> ContextQueue:
2389
2683
  """Subscribe to the run stream, returning a queue."""
2390
2684
  stream_manager = get_stream_manager()
2391
- queue = await stream_manager.add_queue(_ensure_uuid(run_id))
2685
+ queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
2392
2686
 
2393
2687
  # If there's a control message already stored, send it to the new subscriber
2394
- if control_messages := stream_manager.control_queues.get(run_id):
2395
- for control_msg in control_messages:
2396
- await queue.put(control_msg)
2688
+ if thread_id is None:
2689
+ thread_id = THREADLESS_KEY
2690
+ if control_queues := stream_manager.control_queues.get(thread_id, {}).get(
2691
+ run_id
2692
+ ):
2693
+ for control_queue in control_queues:
2694
+ try:
2695
+ while True:
2696
+ control_msg = control_queue.get()
2697
+ await queue.put(control_msg)
2698
+ except asyncio.QueueEmpty:
2699
+ pass
2397
2700
  return queue
2398
2701
 
2399
2702
  @staticmethod
@@ -2415,7 +2718,7 @@ class Runs(Authenticated):
2415
2718
  queue = (
2416
2719
  stream_channel
2417
2720
  if stream_channel
2418
- else await Runs.Stream.subscribe(run_id)
2721
+ else await Runs.Stream.subscribe(run_id, thread_id)
2419
2722
  )
2420
2723
 
2421
2724
  try:
@@ -2438,7 +2741,7 @@ class Runs(Authenticated):
2438
2741
  run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
2439
2742
 
2440
2743
  for message in get_stream_manager().restore_messages(
2441
- run_id, last_event_id
2744
+ run_id, thread_id, last_event_id
2442
2745
  ):
2443
2746
  data, id = message.data, message.id
2444
2747
 
@@ -2529,7 +2832,7 @@ class Runs(Authenticated):
2529
2832
  raise
2530
2833
  finally:
2531
2834
  stream_manager = get_stream_manager()
2532
- await stream_manager.remove_queue(run_id, queue)
2835
+ await stream_manager.remove_queue(run_id, thread_id, queue)
2533
2836
 
2534
2837
  @staticmethod
2535
2838
  async def publish(
@@ -2537,6 +2840,7 @@ class Runs(Authenticated):
2537
2840
  event: str,
2538
2841
  message: bytes,
2539
2842
  *,
2843
+ thread_id: UUID | str | None = None,
2540
2844
  resumable: bool = False,
2541
2845
  ) -> None:
2542
2846
  """Publish a message to all subscribers of the run stream."""
@@ -2553,17 +2857,19 @@ class Runs(Authenticated):
2553
2857
  }
2554
2858
  )
2555
2859
  await stream_manager.put(
2556
- run_id, Message(topic=topic, data=payload), resumable
2860
+ run_id, thread_id, Message(topic=topic, data=payload), resumable
2557
2861
  )
2558
2862
 
2559
2863
 
2560
- async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: ValueEvent):
2864
+ async def listen_for_cancellation(
2865
+ queue: asyncio.Queue, run_id: UUID, thread_id: UUID | None, done: ValueEvent
2866
+ ):
2561
2867
  """Listen for cancellation messages and set the done event accordingly."""
2562
2868
  from langgraph_api.errors import UserInterrupt, UserRollback
2563
2869
 
2564
2870
  stream_manager = get_stream_manager()
2565
2871
 
2566
- if control_key := stream_manager.get_control_key(run_id):
2872
+ if control_key := stream_manager.get_control_key(run_id, thread_id):
2567
2873
  payload = control_key.data
2568
2874
  if payload == b"rollback":
2569
2875
  done.set(UserRollback())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.8.2
3
+ Version: 0.10.0
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -1,13 +1,13 @@
1
- langgraph_runtime_inmem/__init__.py,sha256=AgGhozyDAnBy1osTaFV8oxSvO7Is7Rx0ASBE6XpUMDE,310
1
+ langgraph_runtime_inmem/__init__.py,sha256=4xhdO3o6RduCHDXSNh42I51Wwq7Kcnt3JK1U1IhP-BU,311
2
2
  langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
3
3
  langgraph_runtime_inmem/database.py,sha256=QgaA_WQo1IY6QioYd8r-e6-0B0rnC5anS0muIEJWby0,6364
4
- langgraph_runtime_inmem/inmem_stream.py,sha256=UWk1srLF44HZPPbRdArGGhsy0MY0UOJKSIxBSO7Hosc,5138
4
+ langgraph_runtime_inmem/inmem_stream.py,sha256=utL1OlOJsy6VDkSGAA6eX9nETreZlM6K6nhfNoubmRQ,9011
5
5
  langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
6
6
  langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
7
- langgraph_runtime_inmem/ops.py,sha256=Qtu4rSNd6uFjpYrrLxXMFbDDK4Z7PMYtqcRK_ikPmQA,97862
7
+ langgraph_runtime_inmem/ops.py,sha256=54jiyWhfbSu9z9pca6AQdNuaIBmD0WMrQ7xGQcLPDF4,111183
8
8
  langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
9
9
  langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
10
10
  langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
11
- langgraph_runtime_inmem-0.8.2.dist-info/METADATA,sha256=gcjizfhnsByvf34DwFigAdThjuyZxVD09uMbA1yQB9U,565
12
- langgraph_runtime_inmem-0.8.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- langgraph_runtime_inmem-0.8.2.dist-info/RECORD,,
11
+ langgraph_runtime_inmem-0.10.0.dist-info/METADATA,sha256=gdjdQjZF2KjDtwA9rDiW53pG4FYNfv8TkT1U8t2lftQ,566
12
+ langgraph_runtime_inmem-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ langgraph_runtime_inmem-0.10.0.dist-info/RECORD,,