langgraph-runtime-inmem 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem/database.py +1 -1
- langgraph_runtime_inmem/inmem_stream.py +131 -43
- langgraph_runtime_inmem/ops.py +191 -28
- {langgraph_runtime_inmem-0.8.1.dist-info → langgraph_runtime_inmem-0.9.0.dist-info}/METADATA +1 -1
- langgraph_runtime_inmem-0.9.0.dist-info/RECORD +13 -0
- langgraph_runtime_inmem-0.8.1.dist-info/RECORD +0 -13
- {langgraph_runtime_inmem-0.8.1.dist-info → langgraph_runtime_inmem-0.9.0.dist-info}/WHEEL +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
+
import time
|
|
3
4
|
from collections import defaultdict
|
|
4
5
|
from collections.abc import Iterator
|
|
5
6
|
from dataclasses import dataclass
|
|
@@ -12,6 +13,14 @@ def _ensure_uuid(id: str | UUID) -> UUID:
|
|
|
12
13
|
return UUID(id) if isinstance(id, str) else id
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
def _generate_ms_seq_id() -> str:
|
|
17
|
+
"""Generate a Redis-like millisecond-sequence ID (e.g., '1234567890123-0')"""
|
|
18
|
+
# Get current time in milliseconds
|
|
19
|
+
ms = int(time.time() * 1000)
|
|
20
|
+
# For simplicity, always use sequence 0 since we're not handling high throughput
|
|
21
|
+
return f"{ms}-0"
|
|
22
|
+
|
|
23
|
+
|
|
15
24
|
@dataclass
|
|
16
25
|
class Message:
|
|
17
26
|
topic: bytes
|
|
@@ -39,86 +48,165 @@ class ContextQueue(asyncio.Queue):
|
|
|
39
48
|
break
|
|
40
49
|
|
|
41
50
|
|
|
42
|
-
|
|
43
|
-
def __init__(self):
|
|
44
|
-
self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
|
|
45
|
-
self.control_keys = defaultdict()
|
|
46
|
-
self.control_queues = defaultdict(list)
|
|
51
|
+
THREADLESS_KEY = "no-thread"
|
|
47
52
|
|
|
48
|
-
self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
|
|
49
|
-
self.message_next_idx = defaultdict(int) # Dict[UUID, int]
|
|
50
53
|
|
|
51
|
-
|
|
54
|
+
class StreamManager:
|
|
55
|
+
def __init__(self):
|
|
56
|
+
self.queues = defaultdict(
|
|
57
|
+
lambda: defaultdict(list)
|
|
58
|
+
) # Dict[str, List[asyncio.Queue]]
|
|
59
|
+
self.control_keys = defaultdict(lambda: defaultdict())
|
|
60
|
+
self.control_queues = defaultdict(lambda: defaultdict(list))
|
|
61
|
+
|
|
62
|
+
self.message_stores = defaultdict(
|
|
63
|
+
lambda: defaultdict(list[Message])
|
|
64
|
+
) # Dict[str, List[Message]]
|
|
65
|
+
|
|
66
|
+
def get_queues(
|
|
67
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
68
|
+
) -> list[asyncio.Queue]:
|
|
52
69
|
run_id = _ensure_uuid(run_id)
|
|
53
|
-
|
|
70
|
+
if thread_id is None:
|
|
71
|
+
thread_id = THREADLESS_KEY
|
|
72
|
+
else:
|
|
73
|
+
thread_id = _ensure_uuid(thread_id)
|
|
74
|
+
return self.queues[thread_id][run_id]
|
|
54
75
|
|
|
55
|
-
def get_control_queues(
|
|
76
|
+
def get_control_queues(
|
|
77
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
78
|
+
) -> list[asyncio.Queue]:
|
|
56
79
|
run_id = _ensure_uuid(run_id)
|
|
57
|
-
|
|
80
|
+
if thread_id is None:
|
|
81
|
+
thread_id = THREADLESS_KEY
|
|
82
|
+
else:
|
|
83
|
+
thread_id = _ensure_uuid(thread_id)
|
|
84
|
+
return self.control_queues[thread_id][run_id]
|
|
58
85
|
|
|
59
|
-
def get_control_key(
|
|
86
|
+
def get_control_key(
|
|
87
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
88
|
+
) -> Message | None:
|
|
60
89
|
run_id = _ensure_uuid(run_id)
|
|
61
|
-
|
|
90
|
+
if thread_id is None:
|
|
91
|
+
thread_id = THREADLESS_KEY
|
|
92
|
+
else:
|
|
93
|
+
thread_id = _ensure_uuid(thread_id)
|
|
94
|
+
return self.control_keys.get(thread_id, {}).get(run_id)
|
|
62
95
|
|
|
63
96
|
async def put(
|
|
64
|
-
self,
|
|
97
|
+
self,
|
|
98
|
+
run_id: UUID | str,
|
|
99
|
+
thread_id: UUID | str | None,
|
|
100
|
+
message: Message,
|
|
101
|
+
resumable: bool = False,
|
|
65
102
|
) -> None:
|
|
66
103
|
run_id = _ensure_uuid(run_id)
|
|
67
|
-
|
|
68
|
-
|
|
104
|
+
if thread_id is None:
|
|
105
|
+
thread_id = THREADLESS_KEY
|
|
106
|
+
else:
|
|
107
|
+
thread_id = _ensure_uuid(thread_id)
|
|
108
|
+
|
|
109
|
+
message.id = _generate_ms_seq_id().encode()
|
|
69
110
|
if resumable:
|
|
70
|
-
self.message_stores[run_id].append(message)
|
|
111
|
+
self.message_stores[thread_id][run_id].append(message)
|
|
71
112
|
topic = message.topic.decode()
|
|
72
113
|
if "control" in topic:
|
|
73
|
-
self.control_keys[run_id] = message
|
|
74
|
-
queues = self.control_queues[run_id]
|
|
114
|
+
self.control_keys[thread_id][run_id] = message
|
|
115
|
+
queues = self.control_queues[thread_id][run_id]
|
|
75
116
|
else:
|
|
76
|
-
queues = self.queues[run_id]
|
|
117
|
+
queues = self.queues[thread_id][run_id]
|
|
77
118
|
coros = [queue.put(message) for queue in queues]
|
|
78
119
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
79
120
|
for result in results:
|
|
80
121
|
if isinstance(result, Exception):
|
|
81
122
|
logger.exception(f"Failed to put message in queue: {result}")
|
|
82
123
|
|
|
83
|
-
async def add_queue(
|
|
124
|
+
async def add_queue(
|
|
125
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
126
|
+
) -> asyncio.Queue:
|
|
84
127
|
run_id = _ensure_uuid(run_id)
|
|
85
128
|
queue = ContextQueue()
|
|
86
|
-
|
|
129
|
+
if thread_id is None:
|
|
130
|
+
thread_id = THREADLESS_KEY
|
|
131
|
+
else:
|
|
132
|
+
thread_id = _ensure_uuid(thread_id)
|
|
133
|
+
self.queues[thread_id][run_id].append(queue)
|
|
87
134
|
return queue
|
|
88
135
|
|
|
89
|
-
async def add_control_queue(
|
|
136
|
+
async def add_control_queue(
|
|
137
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
138
|
+
) -> asyncio.Queue:
|
|
90
139
|
run_id = _ensure_uuid(run_id)
|
|
140
|
+
if thread_id is None:
|
|
141
|
+
thread_id = THREADLESS_KEY
|
|
142
|
+
else:
|
|
143
|
+
thread_id = _ensure_uuid(thread_id)
|
|
91
144
|
queue = ContextQueue()
|
|
92
|
-
self.control_queues[run_id].append(queue)
|
|
145
|
+
self.control_queues[thread_id][run_id].append(queue)
|
|
93
146
|
return queue
|
|
94
147
|
|
|
95
|
-
async def remove_queue(
|
|
148
|
+
async def remove_queue(
|
|
149
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
150
|
+
):
|
|
96
151
|
run_id = _ensure_uuid(run_id)
|
|
97
|
-
if
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
152
|
+
if thread_id is None:
|
|
153
|
+
thread_id = THREADLESS_KEY
|
|
154
|
+
else:
|
|
155
|
+
thread_id = _ensure_uuid(thread_id)
|
|
156
|
+
if thread_id in self.queues and run_id in self.queues[thread_id]:
|
|
157
|
+
self.queues[thread_id][run_id].remove(queue)
|
|
158
|
+
if not self.queues[thread_id][run_id]:
|
|
159
|
+
del self.queues[thread_id][run_id]
|
|
160
|
+
|
|
161
|
+
async def remove_control_queue(
|
|
162
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
163
|
+
):
|
|
103
164
|
run_id = _ensure_uuid(run_id)
|
|
104
|
-
if
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
165
|
+
if thread_id is None:
|
|
166
|
+
thread_id = THREADLESS_KEY
|
|
167
|
+
else:
|
|
168
|
+
thread_id = _ensure_uuid(thread_id)
|
|
169
|
+
if (
|
|
170
|
+
thread_id in self.control_queues
|
|
171
|
+
and run_id in self.control_queues[thread_id]
|
|
172
|
+
):
|
|
173
|
+
self.control_queues[thread_id][run_id].remove(queue)
|
|
174
|
+
if not self.control_queues[thread_id][run_id]:
|
|
175
|
+
del self.control_queues[thread_id][run_id]
|
|
108
176
|
|
|
109
177
|
def restore_messages(
|
|
110
|
-
self, run_id: UUID | str, message_id: str | None
|
|
178
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, message_id: str | None
|
|
111
179
|
) -> Iterator[Message]:
|
|
112
180
|
"""Get a stored message by ID for resumable streams."""
|
|
113
181
|
run_id = _ensure_uuid(run_id)
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
182
|
+
if thread_id is None:
|
|
183
|
+
thread_id = THREADLESS_KEY
|
|
184
|
+
else:
|
|
185
|
+
thread_id = _ensure_uuid(thread_id)
|
|
186
|
+
if message_id is None:
|
|
118
187
|
return
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
188
|
+
try:
|
|
189
|
+
# Handle ms-seq format (e.g., "1234567890123-0")
|
|
190
|
+
if thread_id in self.message_stores:
|
|
191
|
+
for message in self.message_stores[thread_id][run_id]:
|
|
192
|
+
if message.id.decode() > message_id:
|
|
193
|
+
yield message
|
|
194
|
+
except TypeError:
|
|
195
|
+
# Try integer format if ms-seq fails
|
|
196
|
+
message_idx = int(message_id) + 1
|
|
197
|
+
if run_id in self.message_stores:
|
|
198
|
+
yield from self.message_stores[thread_id][run_id][message_idx:]
|
|
199
|
+
|
|
200
|
+
def get_queues_by_thread_id(self, thread_id: UUID | str) -> list[asyncio.Queue]:
|
|
201
|
+
"""Get all queues for a specific thread_id across all runs."""
|
|
202
|
+
all_queues = []
|
|
203
|
+
# Search through all stored queue keys for ones ending with the thread_id
|
|
204
|
+
thread_id = _ensure_uuid(thread_id)
|
|
205
|
+
if thread_id in self.queues:
|
|
206
|
+
for run_id in self.queues[thread_id]:
|
|
207
|
+
all_queues.extend(self.queues[thread_id][run_id])
|
|
208
|
+
|
|
209
|
+
return all_queues
|
|
122
210
|
|
|
123
211
|
|
|
124
212
|
# Global instance
|
langgraph_runtime_inmem/ops.py
CHANGED
|
@@ -27,7 +27,11 @@ from starlette.exceptions import HTTPException
|
|
|
27
27
|
|
|
28
28
|
from langgraph_runtime_inmem.checkpoint import Checkpointer
|
|
29
29
|
from langgraph_runtime_inmem.database import InMemConnectionProto, connect
|
|
30
|
-
from langgraph_runtime_inmem.inmem_stream import
|
|
30
|
+
from langgraph_runtime_inmem.inmem_stream import (
|
|
31
|
+
THREADLESS_KEY,
|
|
32
|
+
Message,
|
|
33
|
+
get_stream_manager,
|
|
34
|
+
)
|
|
31
35
|
|
|
32
36
|
if typing.TYPE_CHECKING:
|
|
33
37
|
from langgraph_api.asyncio import ValueEvent
|
|
@@ -406,19 +410,17 @@ class Assistants(Authenticated):
|
|
|
406
410
|
else 1
|
|
407
411
|
)
|
|
408
412
|
|
|
409
|
-
# Update assistant_versions table
|
|
410
|
-
if metadata:
|
|
411
|
-
metadata = {
|
|
412
|
-
**assistant["metadata"],
|
|
413
|
-
**metadata,
|
|
414
|
-
}
|
|
415
413
|
new_version_entry = {
|
|
416
414
|
"assistant_id": assistant_id,
|
|
417
415
|
"version": new_version,
|
|
418
416
|
"graph_id": graph_id if graph_id is not None else assistant["graph_id"],
|
|
419
417
|
"config": config if config else assistant["config"],
|
|
420
418
|
"context": context if context is not None else assistant.get("context", {}),
|
|
421
|
-
"metadata":
|
|
419
|
+
"metadata": (
|
|
420
|
+
{**assistant["metadata"], **metadata}
|
|
421
|
+
if metadata is not None
|
|
422
|
+
else assistant["metadata"]
|
|
423
|
+
),
|
|
422
424
|
"created_at": now,
|
|
423
425
|
"name": name if name is not None else assistant["name"],
|
|
424
426
|
"description": (
|
|
@@ -1611,6 +1613,151 @@ class Threads(Authenticated):
|
|
|
1611
1613
|
|
|
1612
1614
|
return []
|
|
1613
1615
|
|
|
1616
|
+
class Stream:
|
|
1617
|
+
@staticmethod
|
|
1618
|
+
async def subscribe(
|
|
1619
|
+
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
1620
|
+
thread_id: UUID,
|
|
1621
|
+
seen_runs: set[UUID],
|
|
1622
|
+
) -> list[tuple[UUID, asyncio.Queue]]:
|
|
1623
|
+
"""Subscribe to the thread stream, creating queues for unseen runs."""
|
|
1624
|
+
stream_manager = get_stream_manager()
|
|
1625
|
+
queues = []
|
|
1626
|
+
|
|
1627
|
+
# Create new queues only for runs not yet seen
|
|
1628
|
+
thread_id = _ensure_uuid(thread_id)
|
|
1629
|
+
for run in conn.store["runs"]:
|
|
1630
|
+
if run["thread_id"] == thread_id:
|
|
1631
|
+
run_id = run["run_id"]
|
|
1632
|
+
if run_id not in seen_runs:
|
|
1633
|
+
queue = await stream_manager.add_queue(run_id, thread_id)
|
|
1634
|
+
queues.append((run_id, queue))
|
|
1635
|
+
seen_runs.add(run_id)
|
|
1636
|
+
|
|
1637
|
+
return queues
|
|
1638
|
+
|
|
1639
|
+
@staticmethod
|
|
1640
|
+
async def join(
|
|
1641
|
+
thread_id: UUID,
|
|
1642
|
+
*,
|
|
1643
|
+
last_event_id: str | None = None,
|
|
1644
|
+
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
1645
|
+
"""Stream the thread output."""
|
|
1646
|
+
from langgraph_api.serde import json_loads
|
|
1647
|
+
|
|
1648
|
+
stream_manager = get_stream_manager()
|
|
1649
|
+
seen_runs: set[UUID] = set()
|
|
1650
|
+
created_queues: list[tuple[UUID, asyncio.Queue]] = []
|
|
1651
|
+
|
|
1652
|
+
try:
|
|
1653
|
+
async with connect() as conn:
|
|
1654
|
+
await logger.ainfo(
|
|
1655
|
+
"Joined thread stream",
|
|
1656
|
+
thread_id=str(thread_id),
|
|
1657
|
+
)
|
|
1658
|
+
|
|
1659
|
+
# Restore messages if resuming from a specific event
|
|
1660
|
+
if last_event_id is not None:
|
|
1661
|
+
# Collect all events from all message stores for this thread
|
|
1662
|
+
all_events = []
|
|
1663
|
+
for run_id in stream_manager.message_stores.get(
|
|
1664
|
+
str(thread_id), []
|
|
1665
|
+
):
|
|
1666
|
+
for message in stream_manager.restore_messages(
|
|
1667
|
+
run_id, thread_id, last_event_id
|
|
1668
|
+
):
|
|
1669
|
+
all_events.append((message, run_id))
|
|
1670
|
+
|
|
1671
|
+
# Sort by message ID (which is ms-seq format)
|
|
1672
|
+
all_events.sort(key=lambda x: x[0].id.decode())
|
|
1673
|
+
|
|
1674
|
+
# Yield sorted events
|
|
1675
|
+
for message, run_id in all_events:
|
|
1676
|
+
data = json_loads(message.data)
|
|
1677
|
+
event_name = data["event"]
|
|
1678
|
+
message_content = data["message"]
|
|
1679
|
+
|
|
1680
|
+
if event_name == "control":
|
|
1681
|
+
if message_content == b"done":
|
|
1682
|
+
yield (
|
|
1683
|
+
b"metadata",
|
|
1684
|
+
orjson.dumps(
|
|
1685
|
+
{"status": "run_done", "run_id": run_id}
|
|
1686
|
+
),
|
|
1687
|
+
message.id,
|
|
1688
|
+
)
|
|
1689
|
+
else:
|
|
1690
|
+
yield (
|
|
1691
|
+
event_name.encode(),
|
|
1692
|
+
base64.b64decode(message_content),
|
|
1693
|
+
message.id,
|
|
1694
|
+
)
|
|
1695
|
+
|
|
1696
|
+
# Listen for live messages from all queues
|
|
1697
|
+
while True:
|
|
1698
|
+
# Refresh queues to pick up any new runs that joined this thread
|
|
1699
|
+
new_queue_tuples = await Threads.Stream.subscribe(
|
|
1700
|
+
conn, thread_id, seen_runs
|
|
1701
|
+
)
|
|
1702
|
+
# Track new queues for cleanup
|
|
1703
|
+
for run_id, queue in new_queue_tuples:
|
|
1704
|
+
created_queues.append((run_id, queue))
|
|
1705
|
+
|
|
1706
|
+
for run_id, queue in created_queues:
|
|
1707
|
+
try:
|
|
1708
|
+
message = await asyncio.wait_for(
|
|
1709
|
+
queue.get(), timeout=0.2
|
|
1710
|
+
)
|
|
1711
|
+
data = json_loads(message.data)
|
|
1712
|
+
event_name = data["event"]
|
|
1713
|
+
message_content = data["message"]
|
|
1714
|
+
|
|
1715
|
+
if event_name == "control":
|
|
1716
|
+
if message_content == b"done":
|
|
1717
|
+
# Extract run_id from topic
|
|
1718
|
+
topic = message.topic.decode()
|
|
1719
|
+
run_id = topic.split("run:")[1].split(":")[0]
|
|
1720
|
+
yield (
|
|
1721
|
+
b"metadata",
|
|
1722
|
+
orjson.dumps(
|
|
1723
|
+
{"status": "run_done", "run_id": run_id}
|
|
1724
|
+
),
|
|
1725
|
+
message.id,
|
|
1726
|
+
)
|
|
1727
|
+
else:
|
|
1728
|
+
yield (
|
|
1729
|
+
event_name.encode(),
|
|
1730
|
+
base64.b64decode(message_content),
|
|
1731
|
+
message.id,
|
|
1732
|
+
)
|
|
1733
|
+
|
|
1734
|
+
except TimeoutError:
|
|
1735
|
+
continue
|
|
1736
|
+
except (ValueError, KeyError):
|
|
1737
|
+
continue
|
|
1738
|
+
|
|
1739
|
+
# Yield execution to other tasks to prevent event loop starvation
|
|
1740
|
+
await asyncio.sleep(0)
|
|
1741
|
+
|
|
1742
|
+
except WrappedHTTPException as e:
|
|
1743
|
+
raise e.http_exception from None
|
|
1744
|
+
except asyncio.CancelledError:
|
|
1745
|
+
await logger.awarning(
|
|
1746
|
+
"Thread stream client disconnected",
|
|
1747
|
+
thread_id=str(thread_id),
|
|
1748
|
+
)
|
|
1749
|
+
raise
|
|
1750
|
+
except:
|
|
1751
|
+
raise
|
|
1752
|
+
finally:
|
|
1753
|
+
# Clean up all created queues
|
|
1754
|
+
for run_id, queue in created_queues:
|
|
1755
|
+
try:
|
|
1756
|
+
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
1757
|
+
except Exception:
|
|
1758
|
+
# Ignore cleanup errors
|
|
1759
|
+
pass
|
|
1760
|
+
|
|
1614
1761
|
@staticmethod
|
|
1615
1762
|
async def count(
|
|
1616
1763
|
conn: InMemConnectionProto,
|
|
@@ -1769,7 +1916,7 @@ class Runs(Authenticated):
|
|
|
1769
1916
|
@asynccontextmanager
|
|
1770
1917
|
@staticmethod
|
|
1771
1918
|
async def enter(
|
|
1772
|
-
run_id: UUID, loop: asyncio.AbstractEventLoop
|
|
1919
|
+
run_id: UUID, thread_id: UUID | None, loop: asyncio.AbstractEventLoop
|
|
1773
1920
|
) -> AsyncIterator[ValueEvent]:
|
|
1774
1921
|
"""Enter a run, listen for cancellation while running, signal when done."
|
|
1775
1922
|
This method should be called as a context manager by a worker executing a run.
|
|
@@ -1777,12 +1924,14 @@ class Runs(Authenticated):
|
|
|
1777
1924
|
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
|
|
1778
1925
|
|
|
1779
1926
|
stream_manager = get_stream_manager()
|
|
1780
|
-
# Get queue for this run
|
|
1781
|
-
|
|
1927
|
+
# Get control queue for this run (normal queue is created during run creation)
|
|
1928
|
+
control_queue = await stream_manager.add_control_queue(run_id, thread_id)
|
|
1782
1929
|
|
|
1783
1930
|
async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
|
|
1784
1931
|
done = ValueEvent()
|
|
1785
|
-
tg.create_task(
|
|
1932
|
+
tg.create_task(
|
|
1933
|
+
listen_for_cancellation(control_queue, run_id, thread_id, done)
|
|
1934
|
+
)
|
|
1786
1935
|
|
|
1787
1936
|
# Give done event to caller
|
|
1788
1937
|
yield done
|
|
@@ -1790,17 +1939,17 @@ class Runs(Authenticated):
|
|
|
1790
1939
|
control_message = Message(
|
|
1791
1940
|
topic=f"run:{run_id}:control".encode(), data=b"done"
|
|
1792
1941
|
)
|
|
1793
|
-
await stream_manager.put(run_id, control_message)
|
|
1942
|
+
await stream_manager.put(run_id, thread_id, control_message)
|
|
1794
1943
|
|
|
1795
1944
|
# Signal done to all subscribers
|
|
1796
1945
|
stream_message = Message(
|
|
1797
1946
|
topic=f"run:{run_id}:stream".encode(),
|
|
1798
1947
|
data={"event": "control", "message": b"done"},
|
|
1799
1948
|
)
|
|
1800
|
-
await stream_manager.put(run_id, stream_message)
|
|
1949
|
+
await stream_manager.put(run_id, thread_id, stream_message)
|
|
1801
1950
|
|
|
1802
|
-
# Remove the queue
|
|
1803
|
-
await stream_manager.remove_control_queue(run_id,
|
|
1951
|
+
# Remove the control_queue (normal queue is cleaned up during run deletion)
|
|
1952
|
+
await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
|
|
1804
1953
|
|
|
1805
1954
|
@staticmethod
|
|
1806
1955
|
async def sweep() -> None:
|
|
@@ -2088,6 +2237,7 @@ class Runs(Authenticated):
|
|
|
2088
2237
|
if not thread:
|
|
2089
2238
|
return _empty_generator()
|
|
2090
2239
|
_delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
|
|
2240
|
+
|
|
2091
2241
|
found = False
|
|
2092
2242
|
for i, run in enumerate(conn.store["runs"]):
|
|
2093
2243
|
if run["run_id"] == run_id and run["thread_id"] == thread_id:
|
|
@@ -2270,9 +2420,9 @@ class Runs(Authenticated):
|
|
|
2270
2420
|
topic=f"run:{run_id}:control".encode(),
|
|
2271
2421
|
data=action.encode(),
|
|
2272
2422
|
)
|
|
2273
|
-
coros.append(stream_manager.put(run_id, control_message))
|
|
2423
|
+
coros.append(stream_manager.put(run_id, thread_id, control_message))
|
|
2274
2424
|
|
|
2275
|
-
queues = stream_manager.get_queues(run_id)
|
|
2425
|
+
queues = stream_manager.get_queues(run_id, thread_id)
|
|
2276
2426
|
|
|
2277
2427
|
if run["status"] in ("pending", "running"):
|
|
2278
2428
|
cancelable_runs.append(run)
|
|
@@ -2387,15 +2537,25 @@ class Runs(Authenticated):
|
|
|
2387
2537
|
@staticmethod
|
|
2388
2538
|
async def subscribe(
|
|
2389
2539
|
run_id: UUID,
|
|
2540
|
+
thread_id: UUID | None = None,
|
|
2390
2541
|
) -> asyncio.Queue:
|
|
2391
2542
|
"""Subscribe to the run stream, returning a queue."""
|
|
2392
2543
|
stream_manager = get_stream_manager()
|
|
2393
|
-
queue = await stream_manager.add_queue(_ensure_uuid(run_id))
|
|
2544
|
+
queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
|
|
2394
2545
|
|
|
2395
2546
|
# If there's a control message already stored, send it to the new subscriber
|
|
2396
|
-
if
|
|
2397
|
-
|
|
2398
|
-
|
|
2547
|
+
if thread_id is None:
|
|
2548
|
+
thread_id = THREADLESS_KEY
|
|
2549
|
+
if control_queues := stream_manager.control_queues.get(thread_id, {}).get(
|
|
2550
|
+
run_id
|
|
2551
|
+
):
|
|
2552
|
+
for control_queue in control_queues:
|
|
2553
|
+
try:
|
|
2554
|
+
while True:
|
|
2555
|
+
control_msg = control_queue.get()
|
|
2556
|
+
await queue.put(control_msg)
|
|
2557
|
+
except asyncio.QueueEmpty:
|
|
2558
|
+
pass
|
|
2399
2559
|
return queue
|
|
2400
2560
|
|
|
2401
2561
|
@staticmethod
|
|
@@ -2417,7 +2577,7 @@ class Runs(Authenticated):
|
|
|
2417
2577
|
queue = (
|
|
2418
2578
|
stream_channel
|
|
2419
2579
|
if stream_channel
|
|
2420
|
-
else await Runs.Stream.subscribe(run_id)
|
|
2580
|
+
else await Runs.Stream.subscribe(run_id, thread_id)
|
|
2421
2581
|
)
|
|
2422
2582
|
|
|
2423
2583
|
try:
|
|
@@ -2440,7 +2600,7 @@ class Runs(Authenticated):
|
|
|
2440
2600
|
run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
|
|
2441
2601
|
|
|
2442
2602
|
for message in get_stream_manager().restore_messages(
|
|
2443
|
-
run_id, last_event_id
|
|
2603
|
+
run_id, thread_id, last_event_id
|
|
2444
2604
|
):
|
|
2445
2605
|
data, id = message.data, message.id
|
|
2446
2606
|
|
|
@@ -2531,7 +2691,7 @@ class Runs(Authenticated):
|
|
|
2531
2691
|
raise
|
|
2532
2692
|
finally:
|
|
2533
2693
|
stream_manager = get_stream_manager()
|
|
2534
|
-
await stream_manager.remove_queue(run_id, queue)
|
|
2694
|
+
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
2535
2695
|
|
|
2536
2696
|
@staticmethod
|
|
2537
2697
|
async def publish(
|
|
@@ -2539,6 +2699,7 @@ class Runs(Authenticated):
|
|
|
2539
2699
|
event: str,
|
|
2540
2700
|
message: bytes,
|
|
2541
2701
|
*,
|
|
2702
|
+
thread_id: UUID | str | None = None,
|
|
2542
2703
|
resumable: bool = False,
|
|
2543
2704
|
) -> None:
|
|
2544
2705
|
"""Publish a message to all subscribers of the run stream."""
|
|
@@ -2555,17 +2716,19 @@ class Runs(Authenticated):
|
|
|
2555
2716
|
}
|
|
2556
2717
|
)
|
|
2557
2718
|
await stream_manager.put(
|
|
2558
|
-
run_id, Message(topic=topic, data=payload), resumable
|
|
2719
|
+
run_id, thread_id, Message(topic=topic, data=payload), resumable
|
|
2559
2720
|
)
|
|
2560
2721
|
|
|
2561
2722
|
|
|
2562
|
-
async def listen_for_cancellation(
|
|
2723
|
+
async def listen_for_cancellation(
|
|
2724
|
+
queue: asyncio.Queue, run_id: UUID, thread_id: UUID | None, done: ValueEvent
|
|
2725
|
+
):
|
|
2563
2726
|
"""Listen for cancellation messages and set the done event accordingly."""
|
|
2564
2727
|
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
2565
2728
|
|
|
2566
2729
|
stream_manager = get_stream_manager()
|
|
2567
2730
|
|
|
2568
|
-
if control_key := stream_manager.get_control_key(run_id):
|
|
2731
|
+
if control_key := stream_manager.get_control_key(run_id, thread_id):
|
|
2569
2732
|
payload = control_key.data
|
|
2570
2733
|
if payload == b"rollback":
|
|
2571
2734
|
done.set(UserRollback())
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
langgraph_runtime_inmem/__init__.py,sha256=f-VPPHH1-hKFwEreffg7dNATe9IdcYwQedcSx2MiZog,310
|
|
2
|
+
langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
|
|
3
|
+
langgraph_runtime_inmem/database.py,sha256=QgaA_WQo1IY6QioYd8r-e6-0B0rnC5anS0muIEJWby0,6364
|
|
4
|
+
langgraph_runtime_inmem/inmem_stream.py,sha256=pUEiHW-1uXQrVTcwEYPwO8YXaYm5qZbpRWawt67y6Lw,8187
|
|
5
|
+
langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
|
|
6
|
+
langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
|
|
7
|
+
langgraph_runtime_inmem/ops.py,sha256=0Jx65S3PCvvHlIpA0XYpl-UnDEo_AiGWXRE2QiFSocY,105165
|
|
8
|
+
langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
|
|
9
|
+
langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
|
|
10
|
+
langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
|
|
11
|
+
langgraph_runtime_inmem-0.9.0.dist-info/METADATA,sha256=ptwW1Ei-Xln53P81eJK1aPcFozU8D192OCZBuC_y5EQ,565
|
|
12
|
+
langgraph_runtime_inmem-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
13
|
+
langgraph_runtime_inmem-0.9.0.dist-info/RECORD,,
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
langgraph_runtime_inmem/__init__.py,sha256=HSPTGiVB69XNTkwTDcmNR5AmVYBGvgbwoW_RmOWec8g,310
|
|
2
|
-
langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
|
|
3
|
-
langgraph_runtime_inmem/database.py,sha256=G_6L2khpRDSpS2Vs_SujzHayODcwG5V2IhFP7LLBXgw,6349
|
|
4
|
-
langgraph_runtime_inmem/inmem_stream.py,sha256=UWk1srLF44HZPPbRdArGGhsy0MY0UOJKSIxBSO7Hosc,5138
|
|
5
|
-
langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
|
|
6
|
-
langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
|
|
7
|
-
langgraph_runtime_inmem/ops.py,sha256=rtO-dgPQnJEymF_yvzxpynUNse-lq1flb0B112pg6pk,97940
|
|
8
|
-
langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
|
|
9
|
-
langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
|
|
10
|
-
langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
|
|
11
|
-
langgraph_runtime_inmem-0.8.1.dist-info/METADATA,sha256=WfRHwBTIUfr1Ux1T1gYgGE5QojW_83T91KELEwub2Bg,565
|
|
12
|
-
langgraph_runtime_inmem-0.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
13
|
-
langgraph_runtime_inmem-0.8.1.dist-info/RECORD,,
|
|
File without changes
|