langgraph-runtime-inmem 0.8.2__tar.gz → 0.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/PKG-INFO +1 -1
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem-0.9.0/langgraph_runtime_inmem/inmem_stream.py +247 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/ops.py +186 -21
- langgraph_runtime_inmem-0.8.2/langgraph_runtime_inmem/inmem_stream.py +0 -159
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/.gitignore +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/Makefile +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/README.md +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/checkpoint.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/database.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/lifespan.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/metrics.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/queue.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/retry.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/store.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/pyproject.toml +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/uv.lock +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Iterator
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _ensure_uuid(id: str | UUID) -> UUID:
|
|
13
|
+
return UUID(id) if isinstance(id, str) else id
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _generate_ms_seq_id() -> str:
|
|
17
|
+
"""Generate a Redis-like millisecond-sequence ID (e.g., '1234567890123-0')"""
|
|
18
|
+
# Get current time in milliseconds
|
|
19
|
+
ms = int(time.time() * 1000)
|
|
20
|
+
# For simplicity, always use sequence 0 since we're not handling high throughput
|
|
21
|
+
return f"{ms}-0"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Message:
|
|
26
|
+
topic: bytes
|
|
27
|
+
data: bytes
|
|
28
|
+
id: bytes | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ContextQueue(asyncio.Queue):
|
|
32
|
+
"""Queue that supports async context manager protocol"""
|
|
33
|
+
|
|
34
|
+
async def __aenter__(self):
|
|
35
|
+
return self
|
|
36
|
+
|
|
37
|
+
async def __aexit__(
|
|
38
|
+
self,
|
|
39
|
+
exc_type: type[BaseException] | None,
|
|
40
|
+
exc_val: BaseException | None,
|
|
41
|
+
exc_tb: object | None,
|
|
42
|
+
) -> None:
|
|
43
|
+
# Clear the queue
|
|
44
|
+
while not self.empty():
|
|
45
|
+
try:
|
|
46
|
+
self.get_nowait()
|
|
47
|
+
except asyncio.QueueEmpty:
|
|
48
|
+
break
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
THREADLESS_KEY = "no-thread"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class StreamManager:
|
|
55
|
+
def __init__(self):
|
|
56
|
+
self.queues = defaultdict(
|
|
57
|
+
lambda: defaultdict(list)
|
|
58
|
+
) # Dict[str, List[asyncio.Queue]]
|
|
59
|
+
self.control_keys = defaultdict(lambda: defaultdict())
|
|
60
|
+
self.control_queues = defaultdict(lambda: defaultdict(list))
|
|
61
|
+
|
|
62
|
+
self.message_stores = defaultdict(
|
|
63
|
+
lambda: defaultdict(list[Message])
|
|
64
|
+
) # Dict[str, List[Message]]
|
|
65
|
+
|
|
66
|
+
def get_queues(
|
|
67
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
68
|
+
) -> list[asyncio.Queue]:
|
|
69
|
+
run_id = _ensure_uuid(run_id)
|
|
70
|
+
if thread_id is None:
|
|
71
|
+
thread_id = THREADLESS_KEY
|
|
72
|
+
else:
|
|
73
|
+
thread_id = _ensure_uuid(thread_id)
|
|
74
|
+
return self.queues[thread_id][run_id]
|
|
75
|
+
|
|
76
|
+
def get_control_queues(
|
|
77
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
78
|
+
) -> list[asyncio.Queue]:
|
|
79
|
+
run_id = _ensure_uuid(run_id)
|
|
80
|
+
if thread_id is None:
|
|
81
|
+
thread_id = THREADLESS_KEY
|
|
82
|
+
else:
|
|
83
|
+
thread_id = _ensure_uuid(thread_id)
|
|
84
|
+
return self.control_queues[thread_id][run_id]
|
|
85
|
+
|
|
86
|
+
def get_control_key(
|
|
87
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
88
|
+
) -> Message | None:
|
|
89
|
+
run_id = _ensure_uuid(run_id)
|
|
90
|
+
if thread_id is None:
|
|
91
|
+
thread_id = THREADLESS_KEY
|
|
92
|
+
else:
|
|
93
|
+
thread_id = _ensure_uuid(thread_id)
|
|
94
|
+
return self.control_keys.get(thread_id, {}).get(run_id)
|
|
95
|
+
|
|
96
|
+
async def put(
|
|
97
|
+
self,
|
|
98
|
+
run_id: UUID | str,
|
|
99
|
+
thread_id: UUID | str | None,
|
|
100
|
+
message: Message,
|
|
101
|
+
resumable: bool = False,
|
|
102
|
+
) -> None:
|
|
103
|
+
run_id = _ensure_uuid(run_id)
|
|
104
|
+
if thread_id is None:
|
|
105
|
+
thread_id = THREADLESS_KEY
|
|
106
|
+
else:
|
|
107
|
+
thread_id = _ensure_uuid(thread_id)
|
|
108
|
+
|
|
109
|
+
message.id = _generate_ms_seq_id().encode()
|
|
110
|
+
if resumable:
|
|
111
|
+
self.message_stores[thread_id][run_id].append(message)
|
|
112
|
+
topic = message.topic.decode()
|
|
113
|
+
if "control" in topic:
|
|
114
|
+
self.control_keys[thread_id][run_id] = message
|
|
115
|
+
queues = self.control_queues[thread_id][run_id]
|
|
116
|
+
else:
|
|
117
|
+
queues = self.queues[thread_id][run_id]
|
|
118
|
+
coros = [queue.put(message) for queue in queues]
|
|
119
|
+
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
120
|
+
for result in results:
|
|
121
|
+
if isinstance(result, Exception):
|
|
122
|
+
logger.exception(f"Failed to put message in queue: {result}")
|
|
123
|
+
|
|
124
|
+
async def add_queue(
|
|
125
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
126
|
+
) -> asyncio.Queue:
|
|
127
|
+
run_id = _ensure_uuid(run_id)
|
|
128
|
+
queue = ContextQueue()
|
|
129
|
+
if thread_id is None:
|
|
130
|
+
thread_id = THREADLESS_KEY
|
|
131
|
+
else:
|
|
132
|
+
thread_id = _ensure_uuid(thread_id)
|
|
133
|
+
self.queues[thread_id][run_id].append(queue)
|
|
134
|
+
return queue
|
|
135
|
+
|
|
136
|
+
async def add_control_queue(
|
|
137
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
138
|
+
) -> asyncio.Queue:
|
|
139
|
+
run_id = _ensure_uuid(run_id)
|
|
140
|
+
if thread_id is None:
|
|
141
|
+
thread_id = THREADLESS_KEY
|
|
142
|
+
else:
|
|
143
|
+
thread_id = _ensure_uuid(thread_id)
|
|
144
|
+
queue = ContextQueue()
|
|
145
|
+
self.control_queues[thread_id][run_id].append(queue)
|
|
146
|
+
return queue
|
|
147
|
+
|
|
148
|
+
async def remove_queue(
|
|
149
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
150
|
+
):
|
|
151
|
+
run_id = _ensure_uuid(run_id)
|
|
152
|
+
if thread_id is None:
|
|
153
|
+
thread_id = THREADLESS_KEY
|
|
154
|
+
else:
|
|
155
|
+
thread_id = _ensure_uuid(thread_id)
|
|
156
|
+
if thread_id in self.queues and run_id in self.queues[thread_id]:
|
|
157
|
+
self.queues[thread_id][run_id].remove(queue)
|
|
158
|
+
if not self.queues[thread_id][run_id]:
|
|
159
|
+
del self.queues[thread_id][run_id]
|
|
160
|
+
|
|
161
|
+
async def remove_control_queue(
|
|
162
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
163
|
+
):
|
|
164
|
+
run_id = _ensure_uuid(run_id)
|
|
165
|
+
if thread_id is None:
|
|
166
|
+
thread_id = THREADLESS_KEY
|
|
167
|
+
else:
|
|
168
|
+
thread_id = _ensure_uuid(thread_id)
|
|
169
|
+
if (
|
|
170
|
+
thread_id in self.control_queues
|
|
171
|
+
and run_id in self.control_queues[thread_id]
|
|
172
|
+
):
|
|
173
|
+
self.control_queues[thread_id][run_id].remove(queue)
|
|
174
|
+
if not self.control_queues[thread_id][run_id]:
|
|
175
|
+
del self.control_queues[thread_id][run_id]
|
|
176
|
+
|
|
177
|
+
def restore_messages(
|
|
178
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, message_id: str | None
|
|
179
|
+
) -> Iterator[Message]:
|
|
180
|
+
"""Get a stored message by ID for resumable streams."""
|
|
181
|
+
run_id = _ensure_uuid(run_id)
|
|
182
|
+
if thread_id is None:
|
|
183
|
+
thread_id = THREADLESS_KEY
|
|
184
|
+
else:
|
|
185
|
+
thread_id = _ensure_uuid(thread_id)
|
|
186
|
+
if message_id is None:
|
|
187
|
+
return
|
|
188
|
+
try:
|
|
189
|
+
# Handle ms-seq format (e.g., "1234567890123-0")
|
|
190
|
+
if thread_id in self.message_stores:
|
|
191
|
+
for message in self.message_stores[thread_id][run_id]:
|
|
192
|
+
if message.id.decode() > message_id:
|
|
193
|
+
yield message
|
|
194
|
+
except TypeError:
|
|
195
|
+
# Try integer format if ms-seq fails
|
|
196
|
+
message_idx = int(message_id) + 1
|
|
197
|
+
if run_id in self.message_stores:
|
|
198
|
+
yield from self.message_stores[thread_id][run_id][message_idx:]
|
|
199
|
+
|
|
200
|
+
def get_queues_by_thread_id(self, thread_id: UUID | str) -> list[asyncio.Queue]:
|
|
201
|
+
"""Get all queues for a specific thread_id across all runs."""
|
|
202
|
+
all_queues = []
|
|
203
|
+
# Search through all stored queue keys for ones ending with the thread_id
|
|
204
|
+
thread_id = _ensure_uuid(thread_id)
|
|
205
|
+
if thread_id in self.queues:
|
|
206
|
+
for run_id in self.queues[thread_id]:
|
|
207
|
+
all_queues.extend(self.queues[thread_id][run_id])
|
|
208
|
+
|
|
209
|
+
return all_queues
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# Global instance
|
|
213
|
+
stream_manager = StreamManager()
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
async def start_stream() -> None:
|
|
217
|
+
"""Initialize the queue system.
|
|
218
|
+
In this in-memory implementation, we just need to ensure we have a clean StreamManager instance.
|
|
219
|
+
"""
|
|
220
|
+
global stream_manager
|
|
221
|
+
stream_manager = StreamManager()
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
async def stop_stream() -> None:
|
|
225
|
+
"""Clean up the queue system.
|
|
226
|
+
Clear all queues and stored control messages."""
|
|
227
|
+
global stream_manager
|
|
228
|
+
|
|
229
|
+
# Send 'done' message to all active queues before clearing
|
|
230
|
+
for run_id in list(stream_manager.queues.keys()):
|
|
231
|
+
control_message = Message(topic=f"run:{run_id}:control".encode(), data=b"done")
|
|
232
|
+
|
|
233
|
+
for queue in stream_manager.queues[run_id]:
|
|
234
|
+
try:
|
|
235
|
+
await queue.put(control_message)
|
|
236
|
+
except (Exception, RuntimeError):
|
|
237
|
+
pass # Ignore errors during shutdown
|
|
238
|
+
|
|
239
|
+
# Clear all stored data
|
|
240
|
+
stream_manager.queues.clear()
|
|
241
|
+
stream_manager.control_queues.clear()
|
|
242
|
+
stream_manager.message_stores.clear()
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def get_stream_manager() -> StreamManager:
|
|
246
|
+
"""Get the global stream manager instance."""
|
|
247
|
+
return stream_manager
|
{langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/ops.py
RENAMED
|
@@ -27,7 +27,11 @@ from starlette.exceptions import HTTPException
|
|
|
27
27
|
|
|
28
28
|
from langgraph_runtime_inmem.checkpoint import Checkpointer
|
|
29
29
|
from langgraph_runtime_inmem.database import InMemConnectionProto, connect
|
|
30
|
-
from langgraph_runtime_inmem.inmem_stream import
|
|
30
|
+
from langgraph_runtime_inmem.inmem_stream import (
|
|
31
|
+
THREADLESS_KEY,
|
|
32
|
+
Message,
|
|
33
|
+
get_stream_manager,
|
|
34
|
+
)
|
|
31
35
|
|
|
32
36
|
if typing.TYPE_CHECKING:
|
|
33
37
|
from langgraph_api.asyncio import ValueEvent
|
|
@@ -1609,6 +1613,151 @@ class Threads(Authenticated):
|
|
|
1609
1613
|
|
|
1610
1614
|
return []
|
|
1611
1615
|
|
|
1616
|
+
class Stream:
|
|
1617
|
+
@staticmethod
|
|
1618
|
+
async def subscribe(
|
|
1619
|
+
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
1620
|
+
thread_id: UUID,
|
|
1621
|
+
seen_runs: set[UUID],
|
|
1622
|
+
) -> list[tuple[UUID, asyncio.Queue]]:
|
|
1623
|
+
"""Subscribe to the thread stream, creating queues for unseen runs."""
|
|
1624
|
+
stream_manager = get_stream_manager()
|
|
1625
|
+
queues = []
|
|
1626
|
+
|
|
1627
|
+
# Create new queues only for runs not yet seen
|
|
1628
|
+
thread_id = _ensure_uuid(thread_id)
|
|
1629
|
+
for run in conn.store["runs"]:
|
|
1630
|
+
if run["thread_id"] == thread_id:
|
|
1631
|
+
run_id = run["run_id"]
|
|
1632
|
+
if run_id not in seen_runs:
|
|
1633
|
+
queue = await stream_manager.add_queue(run_id, thread_id)
|
|
1634
|
+
queues.append((run_id, queue))
|
|
1635
|
+
seen_runs.add(run_id)
|
|
1636
|
+
|
|
1637
|
+
return queues
|
|
1638
|
+
|
|
1639
|
+
@staticmethod
|
|
1640
|
+
async def join(
|
|
1641
|
+
thread_id: UUID,
|
|
1642
|
+
*,
|
|
1643
|
+
last_event_id: str | None = None,
|
|
1644
|
+
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
1645
|
+
"""Stream the thread output."""
|
|
1646
|
+
from langgraph_api.serde import json_loads
|
|
1647
|
+
|
|
1648
|
+
stream_manager = get_stream_manager()
|
|
1649
|
+
seen_runs: set[UUID] = set()
|
|
1650
|
+
created_queues: list[tuple[UUID, asyncio.Queue]] = []
|
|
1651
|
+
|
|
1652
|
+
try:
|
|
1653
|
+
async with connect() as conn:
|
|
1654
|
+
await logger.ainfo(
|
|
1655
|
+
"Joined thread stream",
|
|
1656
|
+
thread_id=str(thread_id),
|
|
1657
|
+
)
|
|
1658
|
+
|
|
1659
|
+
# Restore messages if resuming from a specific event
|
|
1660
|
+
if last_event_id is not None:
|
|
1661
|
+
# Collect all events from all message stores for this thread
|
|
1662
|
+
all_events = []
|
|
1663
|
+
for run_id in stream_manager.message_stores.get(
|
|
1664
|
+
str(thread_id), []
|
|
1665
|
+
):
|
|
1666
|
+
for message in stream_manager.restore_messages(
|
|
1667
|
+
run_id, thread_id, last_event_id
|
|
1668
|
+
):
|
|
1669
|
+
all_events.append((message, run_id))
|
|
1670
|
+
|
|
1671
|
+
# Sort by message ID (which is ms-seq format)
|
|
1672
|
+
all_events.sort(key=lambda x: x[0].id.decode())
|
|
1673
|
+
|
|
1674
|
+
# Yield sorted events
|
|
1675
|
+
for message, run_id in all_events:
|
|
1676
|
+
data = json_loads(message.data)
|
|
1677
|
+
event_name = data["event"]
|
|
1678
|
+
message_content = data["message"]
|
|
1679
|
+
|
|
1680
|
+
if event_name == "control":
|
|
1681
|
+
if message_content == b"done":
|
|
1682
|
+
yield (
|
|
1683
|
+
b"metadata",
|
|
1684
|
+
orjson.dumps(
|
|
1685
|
+
{"status": "run_done", "run_id": run_id}
|
|
1686
|
+
),
|
|
1687
|
+
message.id,
|
|
1688
|
+
)
|
|
1689
|
+
else:
|
|
1690
|
+
yield (
|
|
1691
|
+
event_name.encode(),
|
|
1692
|
+
base64.b64decode(message_content),
|
|
1693
|
+
message.id,
|
|
1694
|
+
)
|
|
1695
|
+
|
|
1696
|
+
# Listen for live messages from all queues
|
|
1697
|
+
while True:
|
|
1698
|
+
# Refresh queues to pick up any new runs that joined this thread
|
|
1699
|
+
new_queue_tuples = await Threads.Stream.subscribe(
|
|
1700
|
+
conn, thread_id, seen_runs
|
|
1701
|
+
)
|
|
1702
|
+
# Track new queues for cleanup
|
|
1703
|
+
for run_id, queue in new_queue_tuples:
|
|
1704
|
+
created_queues.append((run_id, queue))
|
|
1705
|
+
|
|
1706
|
+
for run_id, queue in created_queues:
|
|
1707
|
+
try:
|
|
1708
|
+
message = await asyncio.wait_for(
|
|
1709
|
+
queue.get(), timeout=0.2
|
|
1710
|
+
)
|
|
1711
|
+
data = json_loads(message.data)
|
|
1712
|
+
event_name = data["event"]
|
|
1713
|
+
message_content = data["message"]
|
|
1714
|
+
|
|
1715
|
+
if event_name == "control":
|
|
1716
|
+
if message_content == b"done":
|
|
1717
|
+
# Extract run_id from topic
|
|
1718
|
+
topic = message.topic.decode()
|
|
1719
|
+
run_id = topic.split("run:")[1].split(":")[0]
|
|
1720
|
+
yield (
|
|
1721
|
+
b"metadata",
|
|
1722
|
+
orjson.dumps(
|
|
1723
|
+
{"status": "run_done", "run_id": run_id}
|
|
1724
|
+
),
|
|
1725
|
+
message.id,
|
|
1726
|
+
)
|
|
1727
|
+
else:
|
|
1728
|
+
yield (
|
|
1729
|
+
event_name.encode(),
|
|
1730
|
+
base64.b64decode(message_content),
|
|
1731
|
+
message.id,
|
|
1732
|
+
)
|
|
1733
|
+
|
|
1734
|
+
except TimeoutError:
|
|
1735
|
+
continue
|
|
1736
|
+
except (ValueError, KeyError):
|
|
1737
|
+
continue
|
|
1738
|
+
|
|
1739
|
+
# Yield execution to other tasks to prevent event loop starvation
|
|
1740
|
+
await asyncio.sleep(0)
|
|
1741
|
+
|
|
1742
|
+
except WrappedHTTPException as e:
|
|
1743
|
+
raise e.http_exception from None
|
|
1744
|
+
except asyncio.CancelledError:
|
|
1745
|
+
await logger.awarning(
|
|
1746
|
+
"Thread stream client disconnected",
|
|
1747
|
+
thread_id=str(thread_id),
|
|
1748
|
+
)
|
|
1749
|
+
raise
|
|
1750
|
+
except:
|
|
1751
|
+
raise
|
|
1752
|
+
finally:
|
|
1753
|
+
# Clean up all created queues
|
|
1754
|
+
for run_id, queue in created_queues:
|
|
1755
|
+
try:
|
|
1756
|
+
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
1757
|
+
except Exception:
|
|
1758
|
+
# Ignore cleanup errors
|
|
1759
|
+
pass
|
|
1760
|
+
|
|
1612
1761
|
@staticmethod
|
|
1613
1762
|
async def count(
|
|
1614
1763
|
conn: InMemConnectionProto,
|
|
@@ -1767,7 +1916,7 @@ class Runs(Authenticated):
|
|
|
1767
1916
|
@asynccontextmanager
|
|
1768
1917
|
@staticmethod
|
|
1769
1918
|
async def enter(
|
|
1770
|
-
run_id: UUID, loop: asyncio.AbstractEventLoop
|
|
1919
|
+
run_id: UUID, thread_id: UUID | None, loop: asyncio.AbstractEventLoop
|
|
1771
1920
|
) -> AsyncIterator[ValueEvent]:
|
|
1772
1921
|
"""Enter a run, listen for cancellation while running, signal when done."
|
|
1773
1922
|
This method should be called as a context manager by a worker executing a run.
|
|
@@ -1775,12 +1924,14 @@ class Runs(Authenticated):
|
|
|
1775
1924
|
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
|
|
1776
1925
|
|
|
1777
1926
|
stream_manager = get_stream_manager()
|
|
1778
|
-
# Get queue for this run
|
|
1779
|
-
|
|
1927
|
+
# Get control queue for this run (normal queue is created during run creation)
|
|
1928
|
+
control_queue = await stream_manager.add_control_queue(run_id, thread_id)
|
|
1780
1929
|
|
|
1781
1930
|
async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
|
|
1782
1931
|
done = ValueEvent()
|
|
1783
|
-
tg.create_task(
|
|
1932
|
+
tg.create_task(
|
|
1933
|
+
listen_for_cancellation(control_queue, run_id, thread_id, done)
|
|
1934
|
+
)
|
|
1784
1935
|
|
|
1785
1936
|
# Give done event to caller
|
|
1786
1937
|
yield done
|
|
@@ -1788,17 +1939,17 @@ class Runs(Authenticated):
|
|
|
1788
1939
|
control_message = Message(
|
|
1789
1940
|
topic=f"run:{run_id}:control".encode(), data=b"done"
|
|
1790
1941
|
)
|
|
1791
|
-
await stream_manager.put(run_id, control_message)
|
|
1942
|
+
await stream_manager.put(run_id, thread_id, control_message)
|
|
1792
1943
|
|
|
1793
1944
|
# Signal done to all subscribers
|
|
1794
1945
|
stream_message = Message(
|
|
1795
1946
|
topic=f"run:{run_id}:stream".encode(),
|
|
1796
1947
|
data={"event": "control", "message": b"done"},
|
|
1797
1948
|
)
|
|
1798
|
-
await stream_manager.put(run_id, stream_message)
|
|
1949
|
+
await stream_manager.put(run_id, thread_id, stream_message)
|
|
1799
1950
|
|
|
1800
|
-
# Remove the queue
|
|
1801
|
-
await stream_manager.remove_control_queue(run_id,
|
|
1951
|
+
# Remove the control_queue (normal queue is cleaned up during run deletion)
|
|
1952
|
+
await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
|
|
1802
1953
|
|
|
1803
1954
|
@staticmethod
|
|
1804
1955
|
async def sweep() -> None:
|
|
@@ -2086,6 +2237,7 @@ class Runs(Authenticated):
|
|
|
2086
2237
|
if not thread:
|
|
2087
2238
|
return _empty_generator()
|
|
2088
2239
|
_delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
|
|
2240
|
+
|
|
2089
2241
|
found = False
|
|
2090
2242
|
for i, run in enumerate(conn.store["runs"]):
|
|
2091
2243
|
if run["run_id"] == run_id and run["thread_id"] == thread_id:
|
|
@@ -2268,9 +2420,9 @@ class Runs(Authenticated):
|
|
|
2268
2420
|
topic=f"run:{run_id}:control".encode(),
|
|
2269
2421
|
data=action.encode(),
|
|
2270
2422
|
)
|
|
2271
|
-
coros.append(stream_manager.put(run_id, control_message))
|
|
2423
|
+
coros.append(stream_manager.put(run_id, thread_id, control_message))
|
|
2272
2424
|
|
|
2273
|
-
queues = stream_manager.get_queues(run_id)
|
|
2425
|
+
queues = stream_manager.get_queues(run_id, thread_id)
|
|
2274
2426
|
|
|
2275
2427
|
if run["status"] in ("pending", "running"):
|
|
2276
2428
|
cancelable_runs.append(run)
|
|
@@ -2385,15 +2537,25 @@ class Runs(Authenticated):
|
|
|
2385
2537
|
@staticmethod
|
|
2386
2538
|
async def subscribe(
|
|
2387
2539
|
run_id: UUID,
|
|
2540
|
+
thread_id: UUID | None = None,
|
|
2388
2541
|
) -> asyncio.Queue:
|
|
2389
2542
|
"""Subscribe to the run stream, returning a queue."""
|
|
2390
2543
|
stream_manager = get_stream_manager()
|
|
2391
|
-
queue = await stream_manager.add_queue(_ensure_uuid(run_id))
|
|
2544
|
+
queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
|
|
2392
2545
|
|
|
2393
2546
|
# If there's a control message already stored, send it to the new subscriber
|
|
2394
|
-
if
|
|
2395
|
-
|
|
2396
|
-
|
|
2547
|
+
if thread_id is None:
|
|
2548
|
+
thread_id = THREADLESS_KEY
|
|
2549
|
+
if control_queues := stream_manager.control_queues.get(thread_id, {}).get(
|
|
2550
|
+
run_id
|
|
2551
|
+
):
|
|
2552
|
+
for control_queue in control_queues:
|
|
2553
|
+
try:
|
|
2554
|
+
while True:
|
|
2555
|
+
control_msg = control_queue.get()
|
|
2556
|
+
await queue.put(control_msg)
|
|
2557
|
+
except asyncio.QueueEmpty:
|
|
2558
|
+
pass
|
|
2397
2559
|
return queue
|
|
2398
2560
|
|
|
2399
2561
|
@staticmethod
|
|
@@ -2415,7 +2577,7 @@ class Runs(Authenticated):
|
|
|
2415
2577
|
queue = (
|
|
2416
2578
|
stream_channel
|
|
2417
2579
|
if stream_channel
|
|
2418
|
-
else await Runs.Stream.subscribe(run_id)
|
|
2580
|
+
else await Runs.Stream.subscribe(run_id, thread_id)
|
|
2419
2581
|
)
|
|
2420
2582
|
|
|
2421
2583
|
try:
|
|
@@ -2438,7 +2600,7 @@ class Runs(Authenticated):
|
|
|
2438
2600
|
run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
|
|
2439
2601
|
|
|
2440
2602
|
for message in get_stream_manager().restore_messages(
|
|
2441
|
-
run_id, last_event_id
|
|
2603
|
+
run_id, thread_id, last_event_id
|
|
2442
2604
|
):
|
|
2443
2605
|
data, id = message.data, message.id
|
|
2444
2606
|
|
|
@@ -2529,7 +2691,7 @@ class Runs(Authenticated):
|
|
|
2529
2691
|
raise
|
|
2530
2692
|
finally:
|
|
2531
2693
|
stream_manager = get_stream_manager()
|
|
2532
|
-
await stream_manager.remove_queue(run_id, queue)
|
|
2694
|
+
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
2533
2695
|
|
|
2534
2696
|
@staticmethod
|
|
2535
2697
|
async def publish(
|
|
@@ -2537,6 +2699,7 @@ class Runs(Authenticated):
|
|
|
2537
2699
|
event: str,
|
|
2538
2700
|
message: bytes,
|
|
2539
2701
|
*,
|
|
2702
|
+
thread_id: UUID | str | None = None,
|
|
2540
2703
|
resumable: bool = False,
|
|
2541
2704
|
) -> None:
|
|
2542
2705
|
"""Publish a message to all subscribers of the run stream."""
|
|
@@ -2553,17 +2716,19 @@ class Runs(Authenticated):
|
|
|
2553
2716
|
}
|
|
2554
2717
|
)
|
|
2555
2718
|
await stream_manager.put(
|
|
2556
|
-
run_id, Message(topic=topic, data=payload), resumable
|
|
2719
|
+
run_id, thread_id, Message(topic=topic, data=payload), resumable
|
|
2557
2720
|
)
|
|
2558
2721
|
|
|
2559
2722
|
|
|
2560
|
-
async def listen_for_cancellation(
|
|
2723
|
+
async def listen_for_cancellation(
|
|
2724
|
+
queue: asyncio.Queue, run_id: UUID, thread_id: UUID | None, done: ValueEvent
|
|
2725
|
+
):
|
|
2561
2726
|
"""Listen for cancellation messages and set the done event accordingly."""
|
|
2562
2727
|
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
2563
2728
|
|
|
2564
2729
|
stream_manager = get_stream_manager()
|
|
2565
2730
|
|
|
2566
|
-
if control_key := stream_manager.get_control_key(run_id):
|
|
2731
|
+
if control_key := stream_manager.get_control_key(run_id, thread_id):
|
|
2567
2732
|
payload = control_key.data
|
|
2568
2733
|
if payload == b"rollback":
|
|
2569
2734
|
done.set(UserRollback())
|
|
@@ -1,159 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import logging
|
|
3
|
-
from collections import defaultdict
|
|
4
|
-
from collections.abc import Iterator
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from uuid import UUID
|
|
7
|
-
|
|
8
|
-
logger = logging.getLogger(__name__)
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _ensure_uuid(id: str | UUID) -> UUID:
|
|
12
|
-
return UUID(id) if isinstance(id, str) else id
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@dataclass
|
|
16
|
-
class Message:
|
|
17
|
-
topic: bytes
|
|
18
|
-
data: bytes
|
|
19
|
-
id: bytes | None = None
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class ContextQueue(asyncio.Queue):
|
|
23
|
-
"""Queue that supports async context manager protocol"""
|
|
24
|
-
|
|
25
|
-
async def __aenter__(self):
|
|
26
|
-
return self
|
|
27
|
-
|
|
28
|
-
async def __aexit__(
|
|
29
|
-
self,
|
|
30
|
-
exc_type: type[BaseException] | None,
|
|
31
|
-
exc_val: BaseException | None,
|
|
32
|
-
exc_tb: object | None,
|
|
33
|
-
) -> None:
|
|
34
|
-
# Clear the queue
|
|
35
|
-
while not self.empty():
|
|
36
|
-
try:
|
|
37
|
-
self.get_nowait()
|
|
38
|
-
except asyncio.QueueEmpty:
|
|
39
|
-
break
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class StreamManager:
|
|
43
|
-
def __init__(self):
|
|
44
|
-
self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
|
|
45
|
-
self.control_keys = defaultdict()
|
|
46
|
-
self.control_queues = defaultdict(list)
|
|
47
|
-
|
|
48
|
-
self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
|
|
49
|
-
self.message_next_idx = defaultdict(int) # Dict[UUID, int]
|
|
50
|
-
|
|
51
|
-
def get_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
|
|
52
|
-
run_id = _ensure_uuid(run_id)
|
|
53
|
-
return self.queues[run_id]
|
|
54
|
-
|
|
55
|
-
def get_control_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
|
|
56
|
-
run_id = _ensure_uuid(run_id)
|
|
57
|
-
return self.control_queues[run_id]
|
|
58
|
-
|
|
59
|
-
def get_control_key(self, run_id: UUID | str) -> Message | None:
|
|
60
|
-
run_id = _ensure_uuid(run_id)
|
|
61
|
-
return self.control_keys.get(run_id)
|
|
62
|
-
|
|
63
|
-
async def put(
|
|
64
|
-
self, run_id: UUID | str, message: Message, resumable: bool = False
|
|
65
|
-
) -> None:
|
|
66
|
-
run_id = _ensure_uuid(run_id)
|
|
67
|
-
message.id = str(self.message_next_idx[run_id]).encode()
|
|
68
|
-
self.message_next_idx[run_id] += 1
|
|
69
|
-
if resumable:
|
|
70
|
-
self.message_stores[run_id].append(message)
|
|
71
|
-
topic = message.topic.decode()
|
|
72
|
-
if "control" in topic:
|
|
73
|
-
self.control_keys[run_id] = message
|
|
74
|
-
queues = self.control_queues[run_id]
|
|
75
|
-
else:
|
|
76
|
-
queues = self.queues[run_id]
|
|
77
|
-
coros = [queue.put(message) for queue in queues]
|
|
78
|
-
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
79
|
-
for result in results:
|
|
80
|
-
if isinstance(result, Exception):
|
|
81
|
-
logger.exception(f"Failed to put message in queue: {result}")
|
|
82
|
-
|
|
83
|
-
async def add_queue(self, run_id: UUID | str) -> asyncio.Queue:
|
|
84
|
-
run_id = _ensure_uuid(run_id)
|
|
85
|
-
queue = ContextQueue()
|
|
86
|
-
self.queues[run_id].append(queue)
|
|
87
|
-
return queue
|
|
88
|
-
|
|
89
|
-
async def add_control_queue(self, run_id: UUID | str) -> asyncio.Queue:
|
|
90
|
-
run_id = _ensure_uuid(run_id)
|
|
91
|
-
queue = ContextQueue()
|
|
92
|
-
self.control_queues[run_id].append(queue)
|
|
93
|
-
return queue
|
|
94
|
-
|
|
95
|
-
async def remove_queue(self, run_id: UUID | str, queue: asyncio.Queue):
|
|
96
|
-
run_id = _ensure_uuid(run_id)
|
|
97
|
-
if run_id in self.queues:
|
|
98
|
-
self.queues[run_id].remove(queue)
|
|
99
|
-
if not self.queues[run_id]:
|
|
100
|
-
del self.queues[run_id]
|
|
101
|
-
|
|
102
|
-
async def remove_control_queue(self, run_id: UUID | str, queue: asyncio.Queue):
|
|
103
|
-
run_id = _ensure_uuid(run_id)
|
|
104
|
-
if run_id in self.control_queues:
|
|
105
|
-
self.control_queues[run_id].remove(queue)
|
|
106
|
-
if not self.control_queues[run_id]:
|
|
107
|
-
del self.control_queues[run_id]
|
|
108
|
-
|
|
109
|
-
def restore_messages(
|
|
110
|
-
self, run_id: UUID | str, message_id: str | None
|
|
111
|
-
) -> Iterator[Message]:
|
|
112
|
-
"""Get a stored message by ID for resumable streams."""
|
|
113
|
-
run_id = _ensure_uuid(run_id)
|
|
114
|
-
message_idx = int(message_id) + 1 if message_id else None
|
|
115
|
-
|
|
116
|
-
if message_idx is None:
|
|
117
|
-
yield from []
|
|
118
|
-
return
|
|
119
|
-
|
|
120
|
-
if run_id in self.message_stores:
|
|
121
|
-
yield from self.message_stores[run_id][message_idx:]
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
# Global instance
|
|
125
|
-
stream_manager = StreamManager()
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
async def start_stream() -> None:
|
|
129
|
-
"""Initialize the queue system.
|
|
130
|
-
In this in-memory implementation, we just need to ensure we have a clean StreamManager instance.
|
|
131
|
-
"""
|
|
132
|
-
global stream_manager
|
|
133
|
-
stream_manager = StreamManager()
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
async def stop_stream() -> None:
|
|
137
|
-
"""Clean up the queue system.
|
|
138
|
-
Clear all queues and stored control messages."""
|
|
139
|
-
global stream_manager
|
|
140
|
-
|
|
141
|
-
# Send 'done' message to all active queues before clearing
|
|
142
|
-
for run_id in list(stream_manager.queues.keys()):
|
|
143
|
-
control_message = Message(topic=f"run:{run_id}:control".encode(), data=b"done")
|
|
144
|
-
|
|
145
|
-
for queue in stream_manager.queues[run_id]:
|
|
146
|
-
try:
|
|
147
|
-
await queue.put(control_message)
|
|
148
|
-
except (Exception, RuntimeError):
|
|
149
|
-
pass # Ignore errors during shutdown
|
|
150
|
-
|
|
151
|
-
# Clear all stored data
|
|
152
|
-
stream_manager.queues.clear()
|
|
153
|
-
stream_manager.control_queues.clear()
|
|
154
|
-
stream_manager.message_stores.clear()
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def get_stream_manager() -> StreamManager:
|
|
158
|
-
"""Get the global stream manager instance."""
|
|
159
|
-
return stream_manager
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/database.py
RENAMED
|
File without changes
|
{langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/lifespan.py
RENAMED
|
File without changes
|
{langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/metrics.py
RENAMED
|
File without changes
|
{langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/queue.py
RENAMED
|
File without changes
|
{langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/retry.py
RENAMED
|
File without changes
|
{langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.9.0}/langgraph_runtime_inmem/store.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|