langgraph-runtime-inmem 0.8.2__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem/inmem_stream.py +152 -43
- langgraph_runtime_inmem/ops.py +332 -26
- {langgraph_runtime_inmem-0.8.2.dist-info → langgraph_runtime_inmem-0.10.0.dist-info}/METADATA +1 -1
- {langgraph_runtime_inmem-0.8.2.dist-info → langgraph_runtime_inmem-0.10.0.dist-info}/RECORD +6 -6
- {langgraph_runtime_inmem-0.8.2.dist-info → langgraph_runtime_inmem-0.10.0.dist-info}/WHEEL +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
+
import time
|
|
3
4
|
from collections import defaultdict
|
|
4
5
|
from collections.abc import Iterator
|
|
5
6
|
from dataclasses import dataclass
|
|
@@ -12,6 +13,14 @@ def _ensure_uuid(id: str | UUID) -> UUID:
|
|
|
12
13
|
return UUID(id) if isinstance(id, str) else id
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
def _generate_ms_seq_id() -> str:
|
|
17
|
+
"""Generate a Redis-like millisecond-sequence ID (e.g., '1234567890123-0')"""
|
|
18
|
+
# Get current time in milliseconds
|
|
19
|
+
ms = int(time.time() * 1000)
|
|
20
|
+
# For simplicity, always use sequence 0 since we're not handling high throughput
|
|
21
|
+
return f"{ms}-0"
|
|
22
|
+
|
|
23
|
+
|
|
15
24
|
@dataclass
|
|
16
25
|
class Message:
|
|
17
26
|
topic: bytes
|
|
@@ -39,86 +48,186 @@ class ContextQueue(asyncio.Queue):
|
|
|
39
48
|
break
|
|
40
49
|
|
|
41
50
|
|
|
42
|
-
|
|
43
|
-
def __init__(self):
|
|
44
|
-
self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
|
|
45
|
-
self.control_keys = defaultdict()
|
|
46
|
-
self.control_queues = defaultdict(list)
|
|
51
|
+
THREADLESS_KEY = "no-thread"
|
|
47
52
|
|
|
48
|
-
self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
|
|
49
|
-
self.message_next_idx = defaultdict(int) # Dict[UUID, int]
|
|
50
53
|
|
|
51
|
-
|
|
54
|
+
class StreamManager:
|
|
55
|
+
def __init__(self):
|
|
56
|
+
self.queues = defaultdict(
|
|
57
|
+
lambda: defaultdict(list)
|
|
58
|
+
) # Dict[str, List[asyncio.Queue]]
|
|
59
|
+
self.control_keys = defaultdict(lambda: defaultdict())
|
|
60
|
+
self.control_queues = defaultdict(lambda: defaultdict(list))
|
|
61
|
+
self.thread_streams = defaultdict(list)
|
|
62
|
+
|
|
63
|
+
self.message_stores = defaultdict(
|
|
64
|
+
lambda: defaultdict(list[Message])
|
|
65
|
+
) # Dict[str, List[Message]]
|
|
66
|
+
|
|
67
|
+
def get_queues(
|
|
68
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
69
|
+
) -> list[asyncio.Queue]:
|
|
52
70
|
run_id = _ensure_uuid(run_id)
|
|
53
|
-
|
|
71
|
+
if thread_id is None:
|
|
72
|
+
thread_id = THREADLESS_KEY
|
|
73
|
+
else:
|
|
74
|
+
thread_id = _ensure_uuid(thread_id)
|
|
75
|
+
return self.queues[thread_id][run_id]
|
|
54
76
|
|
|
55
|
-
def get_control_queues(
|
|
77
|
+
def get_control_queues(
|
|
78
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
79
|
+
) -> list[asyncio.Queue]:
|
|
56
80
|
run_id = _ensure_uuid(run_id)
|
|
57
|
-
|
|
81
|
+
if thread_id is None:
|
|
82
|
+
thread_id = THREADLESS_KEY
|
|
83
|
+
else:
|
|
84
|
+
thread_id = _ensure_uuid(thread_id)
|
|
85
|
+
return self.control_queues[thread_id][run_id]
|
|
58
86
|
|
|
59
|
-
def get_control_key(
|
|
87
|
+
def get_control_key(
|
|
88
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
89
|
+
) -> Message | None:
|
|
60
90
|
run_id = _ensure_uuid(run_id)
|
|
61
|
-
|
|
91
|
+
if thread_id is None:
|
|
92
|
+
thread_id = THREADLESS_KEY
|
|
93
|
+
else:
|
|
94
|
+
thread_id = _ensure_uuid(thread_id)
|
|
95
|
+
return self.control_keys.get(thread_id, {}).get(run_id)
|
|
62
96
|
|
|
63
97
|
async def put(
|
|
64
|
-
self,
|
|
98
|
+
self,
|
|
99
|
+
run_id: UUID | str | None,
|
|
100
|
+
thread_id: UUID | str | None,
|
|
101
|
+
message: Message,
|
|
102
|
+
resumable: bool = False,
|
|
65
103
|
) -> None:
|
|
66
104
|
run_id = _ensure_uuid(run_id)
|
|
67
|
-
|
|
68
|
-
|
|
105
|
+
if thread_id is None:
|
|
106
|
+
thread_id = THREADLESS_KEY
|
|
107
|
+
else:
|
|
108
|
+
thread_id = _ensure_uuid(thread_id)
|
|
109
|
+
|
|
110
|
+
message.id = _generate_ms_seq_id().encode()
|
|
69
111
|
if resumable:
|
|
70
|
-
self.message_stores[run_id].append(message)
|
|
112
|
+
self.message_stores[thread_id][run_id].append(message)
|
|
71
113
|
topic = message.topic.decode()
|
|
72
114
|
if "control" in topic:
|
|
73
|
-
self.control_keys[run_id] = message
|
|
74
|
-
queues = self.control_queues[run_id]
|
|
115
|
+
self.control_keys[thread_id][run_id] = message
|
|
116
|
+
queues = self.control_queues[thread_id][run_id]
|
|
75
117
|
else:
|
|
76
|
-
queues = self.queues[run_id]
|
|
118
|
+
queues = self.queues[thread_id][run_id]
|
|
77
119
|
coros = [queue.put(message) for queue in queues]
|
|
78
120
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
79
121
|
for result in results:
|
|
80
122
|
if isinstance(result, Exception):
|
|
81
123
|
logger.exception(f"Failed to put message in queue: {result}")
|
|
82
124
|
|
|
83
|
-
async def
|
|
125
|
+
async def put_thread(
|
|
126
|
+
self,
|
|
127
|
+
thread_id: UUID | str,
|
|
128
|
+
message: Message,
|
|
129
|
+
) -> None:
|
|
130
|
+
thread_id = _ensure_uuid(thread_id)
|
|
131
|
+
message.id = _generate_ms_seq_id().encode()
|
|
132
|
+
queues = self.thread_streams[thread_id]
|
|
133
|
+
coros = [queue.put(message) for queue in queues]
|
|
134
|
+
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
135
|
+
for result in results:
|
|
136
|
+
if isinstance(result, Exception):
|
|
137
|
+
logger.exception(f"Failed to put message in queue: {result}")
|
|
138
|
+
|
|
139
|
+
async def add_queue(
|
|
140
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
141
|
+
) -> asyncio.Queue:
|
|
84
142
|
run_id = _ensure_uuid(run_id)
|
|
85
143
|
queue = ContextQueue()
|
|
86
|
-
|
|
144
|
+
if thread_id is None:
|
|
145
|
+
thread_id = THREADLESS_KEY
|
|
146
|
+
else:
|
|
147
|
+
thread_id = _ensure_uuid(thread_id)
|
|
148
|
+
self.queues[thread_id][run_id].append(queue)
|
|
87
149
|
return queue
|
|
88
150
|
|
|
89
|
-
async def add_control_queue(
|
|
151
|
+
async def add_control_queue(
|
|
152
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
153
|
+
) -> asyncio.Queue:
|
|
90
154
|
run_id = _ensure_uuid(run_id)
|
|
155
|
+
if thread_id is None:
|
|
156
|
+
thread_id = THREADLESS_KEY
|
|
157
|
+
else:
|
|
158
|
+
thread_id = _ensure_uuid(thread_id)
|
|
91
159
|
queue = ContextQueue()
|
|
92
|
-
self.control_queues[run_id].append(queue)
|
|
160
|
+
self.control_queues[thread_id][run_id].append(queue)
|
|
93
161
|
return queue
|
|
94
162
|
|
|
95
|
-
async def
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
del self.queues[run_id]
|
|
163
|
+
async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
|
|
164
|
+
thread_id = _ensure_uuid(thread_id)
|
|
165
|
+
queue = ContextQueue()
|
|
166
|
+
self.thread_streams[thread_id].append(queue)
|
|
167
|
+
return queue
|
|
101
168
|
|
|
102
|
-
async def
|
|
169
|
+
async def remove_queue(
|
|
170
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
171
|
+
):
|
|
172
|
+
run_id = _ensure_uuid(run_id)
|
|
173
|
+
if thread_id is None:
|
|
174
|
+
thread_id = THREADLESS_KEY
|
|
175
|
+
else:
|
|
176
|
+
thread_id = _ensure_uuid(thread_id)
|
|
177
|
+
if thread_id in self.queues and run_id in self.queues[thread_id]:
|
|
178
|
+
self.queues[thread_id][run_id].remove(queue)
|
|
179
|
+
if not self.queues[thread_id][run_id]:
|
|
180
|
+
del self.queues[thread_id][run_id]
|
|
181
|
+
|
|
182
|
+
async def remove_control_queue(
|
|
183
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
184
|
+
):
|
|
103
185
|
run_id = _ensure_uuid(run_id)
|
|
104
|
-
if
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
186
|
+
if thread_id is None:
|
|
187
|
+
thread_id = THREADLESS_KEY
|
|
188
|
+
else:
|
|
189
|
+
thread_id = _ensure_uuid(thread_id)
|
|
190
|
+
if (
|
|
191
|
+
thread_id in self.control_queues
|
|
192
|
+
and run_id in self.control_queues[thread_id]
|
|
193
|
+
):
|
|
194
|
+
self.control_queues[thread_id][run_id].remove(queue)
|
|
195
|
+
if not self.control_queues[thread_id][run_id]:
|
|
196
|
+
del self.control_queues[thread_id][run_id]
|
|
108
197
|
|
|
109
198
|
def restore_messages(
|
|
110
|
-
self, run_id: UUID | str, message_id: str | None
|
|
199
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, message_id: str | None
|
|
111
200
|
) -> Iterator[Message]:
|
|
112
201
|
"""Get a stored message by ID for resumable streams."""
|
|
113
202
|
run_id = _ensure_uuid(run_id)
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
203
|
+
if thread_id is None:
|
|
204
|
+
thread_id = THREADLESS_KEY
|
|
205
|
+
else:
|
|
206
|
+
thread_id = _ensure_uuid(thread_id)
|
|
207
|
+
if message_id is None:
|
|
118
208
|
return
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
209
|
+
try:
|
|
210
|
+
# Handle ms-seq format (e.g., "1234567890123-0")
|
|
211
|
+
if thread_id in self.message_stores:
|
|
212
|
+
for message in self.message_stores[thread_id][run_id]:
|
|
213
|
+
if message.id.decode() > message_id:
|
|
214
|
+
yield message
|
|
215
|
+
except TypeError:
|
|
216
|
+
# Try integer format if ms-seq fails
|
|
217
|
+
message_idx = int(message_id) + 1
|
|
218
|
+
if run_id in self.message_stores:
|
|
219
|
+
yield from self.message_stores[thread_id][run_id][message_idx:]
|
|
220
|
+
|
|
221
|
+
def get_queues_by_thread_id(self, thread_id: UUID | str) -> list[asyncio.Queue]:
|
|
222
|
+
"""Get all queues for a specific thread_id across all runs."""
|
|
223
|
+
all_queues = []
|
|
224
|
+
# Search through all stored queue keys for ones ending with the thread_id
|
|
225
|
+
thread_id = _ensure_uuid(thread_id)
|
|
226
|
+
if thread_id in self.queues:
|
|
227
|
+
for run_id in self.queues[thread_id]:
|
|
228
|
+
all_queues.extend(self.queues[thread_id][run_id])
|
|
229
|
+
|
|
230
|
+
return all_queues
|
|
122
231
|
|
|
123
232
|
|
|
124
233
|
# Global instance
|
langgraph_runtime_inmem/ops.py
CHANGED
|
@@ -27,7 +27,12 @@ from starlette.exceptions import HTTPException
|
|
|
27
27
|
|
|
28
28
|
from langgraph_runtime_inmem.checkpoint import Checkpointer
|
|
29
29
|
from langgraph_runtime_inmem.database import InMemConnectionProto, connect
|
|
30
|
-
from langgraph_runtime_inmem.inmem_stream import
|
|
30
|
+
from langgraph_runtime_inmem.inmem_stream import (
|
|
31
|
+
THREADLESS_KEY,
|
|
32
|
+
ContextQueue,
|
|
33
|
+
Message,
|
|
34
|
+
get_stream_manager,
|
|
35
|
+
)
|
|
31
36
|
|
|
32
37
|
if typing.TYPE_CHECKING:
|
|
33
38
|
from langgraph_api.asyncio import ValueEvent
|
|
@@ -54,6 +59,7 @@ if typing.TYPE_CHECKING:
|
|
|
54
59
|
Thread,
|
|
55
60
|
ThreadSelectField,
|
|
56
61
|
ThreadStatus,
|
|
62
|
+
ThreadStreamMode,
|
|
57
63
|
ThreadUpdateResponse,
|
|
58
64
|
)
|
|
59
65
|
from langgraph_api.schema import Interrupt as InterruptSchema
|
|
@@ -734,6 +740,7 @@ class Threads(Authenticated):
|
|
|
734
740
|
async def search(
|
|
735
741
|
conn: InMemConnectionProto,
|
|
736
742
|
*,
|
|
743
|
+
ids: list[str] | list[UUID] | None = None,
|
|
737
744
|
metadata: MetadataInput,
|
|
738
745
|
values: MetadataInput,
|
|
739
746
|
status: ThreadStatus | None,
|
|
@@ -761,7 +768,19 @@ class Threads(Authenticated):
|
|
|
761
768
|
)
|
|
762
769
|
|
|
763
770
|
# Apply filters
|
|
771
|
+
id_set: set[UUID] | None = None
|
|
772
|
+
if ids:
|
|
773
|
+
id_set = set()
|
|
774
|
+
for i in ids:
|
|
775
|
+
try:
|
|
776
|
+
id_set.add(_ensure_uuid(i))
|
|
777
|
+
except Exception:
|
|
778
|
+
raise HTTPException(
|
|
779
|
+
status_code=400, detail="Invalid thread ID " + str(i)
|
|
780
|
+
) from None
|
|
764
781
|
for thread in threads:
|
|
782
|
+
if id_set is not None and thread.get("thread_id") not in id_set:
|
|
783
|
+
continue
|
|
765
784
|
if filters and not _check_filter_match(thread["metadata"], filters):
|
|
766
785
|
continue
|
|
767
786
|
|
|
@@ -1323,7 +1342,14 @@ class Threads(Authenticated):
|
|
|
1323
1342
|
)
|
|
1324
1343
|
|
|
1325
1344
|
metadata = thread.get("metadata", {})
|
|
1326
|
-
thread_config = thread.get("config", {})
|
|
1345
|
+
thread_config = cast(dict[str, Any], thread.get("config", {}))
|
|
1346
|
+
thread_config = {
|
|
1347
|
+
**thread_config,
|
|
1348
|
+
"configurable": {
|
|
1349
|
+
**thread_config.get("configurable", {}),
|
|
1350
|
+
**config.get("configurable", {}),
|
|
1351
|
+
},
|
|
1352
|
+
}
|
|
1327
1353
|
|
|
1328
1354
|
# Fallback to graph_id from run if not in thread metadata
|
|
1329
1355
|
graph_id = metadata.get("graph_id")
|
|
@@ -1410,6 +1436,13 @@ class Threads(Authenticated):
|
|
|
1410
1436
|
status_code=409,
|
|
1411
1437
|
detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
|
|
1412
1438
|
)
|
|
1439
|
+
thread_config = {
|
|
1440
|
+
**thread_config,
|
|
1441
|
+
"configurable": {
|
|
1442
|
+
**thread_config.get("configurable", {}),
|
|
1443
|
+
**config.get("configurable", {}),
|
|
1444
|
+
},
|
|
1445
|
+
}
|
|
1413
1446
|
|
|
1414
1447
|
# Fallback to graph_id from run if not in thread metadata
|
|
1415
1448
|
graph_id = metadata.get("graph_id")
|
|
@@ -1450,6 +1483,19 @@ class Threads(Authenticated):
|
|
|
1450
1483
|
thread["values"] = state.values
|
|
1451
1484
|
break
|
|
1452
1485
|
|
|
1486
|
+
# Publish state update event
|
|
1487
|
+
from langgraph_api.serde import json_dumpb
|
|
1488
|
+
|
|
1489
|
+
event_data = {
|
|
1490
|
+
"state": state,
|
|
1491
|
+
"thread_id": str(thread_id),
|
|
1492
|
+
}
|
|
1493
|
+
await Threads.Stream.publish(
|
|
1494
|
+
thread_id,
|
|
1495
|
+
"state_update",
|
|
1496
|
+
json_dumpb(event_data),
|
|
1497
|
+
)
|
|
1498
|
+
|
|
1453
1499
|
return ThreadUpdateResponse(
|
|
1454
1500
|
checkpoint=next_config["configurable"],
|
|
1455
1501
|
# Including deprecated fields
|
|
@@ -1492,7 +1538,14 @@ class Threads(Authenticated):
|
|
|
1492
1538
|
thread_iter, not_found_detail=f"Thread {thread_id} not found."
|
|
1493
1539
|
)
|
|
1494
1540
|
|
|
1495
|
-
thread_config = thread["config"]
|
|
1541
|
+
thread_config = cast(dict[str, Any], thread["config"])
|
|
1542
|
+
thread_config = {
|
|
1543
|
+
**thread_config,
|
|
1544
|
+
"configurable": {
|
|
1545
|
+
**thread_config.get("configurable", {}),
|
|
1546
|
+
**config.get("configurable", {}),
|
|
1547
|
+
},
|
|
1548
|
+
}
|
|
1496
1549
|
metadata = thread["metadata"]
|
|
1497
1550
|
|
|
1498
1551
|
if not thread:
|
|
@@ -1539,6 +1592,19 @@ class Threads(Authenticated):
|
|
|
1539
1592
|
thread["values"] = state.values
|
|
1540
1593
|
break
|
|
1541
1594
|
|
|
1595
|
+
# Publish state update event
|
|
1596
|
+
from langgraph_api.serde import json_dumpb
|
|
1597
|
+
|
|
1598
|
+
event_data = {
|
|
1599
|
+
"state": state,
|
|
1600
|
+
"thread_id": str(thread_id),
|
|
1601
|
+
}
|
|
1602
|
+
await Threads.Stream.publish(
|
|
1603
|
+
thread_id,
|
|
1604
|
+
"state_update",
|
|
1605
|
+
json_dumpb(event_data),
|
|
1606
|
+
)
|
|
1607
|
+
|
|
1542
1608
|
return ThreadUpdateResponse(
|
|
1543
1609
|
checkpoint=next_config["configurable"],
|
|
1544
1610
|
)
|
|
@@ -1580,7 +1646,14 @@ class Threads(Authenticated):
|
|
|
1580
1646
|
if not _check_filter_match(thread_metadata, filters):
|
|
1581
1647
|
return []
|
|
1582
1648
|
|
|
1583
|
-
thread_config = thread["config"]
|
|
1649
|
+
thread_config = cast(dict[str, Any], thread["config"])
|
|
1650
|
+
thread_config = {
|
|
1651
|
+
**thread_config,
|
|
1652
|
+
"configurable": {
|
|
1653
|
+
**thread_config.get("configurable", {}),
|
|
1654
|
+
**config.get("configurable", {}),
|
|
1655
|
+
},
|
|
1656
|
+
}
|
|
1584
1657
|
# If graph_id exists, get state history
|
|
1585
1658
|
if graph_id := thread_metadata.get("graph_id"):
|
|
1586
1659
|
async with get_graph(
|
|
@@ -1609,6 +1682,222 @@ class Threads(Authenticated):
|
|
|
1609
1682
|
|
|
1610
1683
|
return []
|
|
1611
1684
|
|
|
1685
|
+
class Stream:
|
|
1686
|
+
@staticmethod
|
|
1687
|
+
async def subscribe(
|
|
1688
|
+
conn: InMemConnectionProto | AsyncConnectionProto,
|
|
1689
|
+
thread_id: UUID,
|
|
1690
|
+
seen_runs: set[UUID],
|
|
1691
|
+
) -> list[tuple[UUID, asyncio.Queue]]:
|
|
1692
|
+
"""Subscribe to the thread stream, creating queues for unseen runs."""
|
|
1693
|
+
stream_manager = get_stream_manager()
|
|
1694
|
+
queues = []
|
|
1695
|
+
|
|
1696
|
+
# Create new queues only for runs not yet seen
|
|
1697
|
+
thread_id = _ensure_uuid(thread_id)
|
|
1698
|
+
|
|
1699
|
+
# Add thread stream queue
|
|
1700
|
+
if thread_id not in seen_runs:
|
|
1701
|
+
queue = await stream_manager.add_thread_stream(thread_id)
|
|
1702
|
+
queues.append((thread_id, queue))
|
|
1703
|
+
seen_runs.add(thread_id)
|
|
1704
|
+
|
|
1705
|
+
for run in conn.store["runs"]:
|
|
1706
|
+
if run["thread_id"] == thread_id:
|
|
1707
|
+
run_id = run["run_id"]
|
|
1708
|
+
if run_id not in seen_runs:
|
|
1709
|
+
queue = await stream_manager.add_queue(run_id, thread_id)
|
|
1710
|
+
queues.append((run_id, queue))
|
|
1711
|
+
seen_runs.add(run_id)
|
|
1712
|
+
|
|
1713
|
+
return queues
|
|
1714
|
+
|
|
1715
|
+
@staticmethod
|
|
1716
|
+
async def join(
|
|
1717
|
+
thread_id: UUID,
|
|
1718
|
+
*,
|
|
1719
|
+
last_event_id: str | None = None,
|
|
1720
|
+
stream_modes: list[ThreadStreamMode],
|
|
1721
|
+
) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
|
|
1722
|
+
"""Stream the thread output."""
|
|
1723
|
+
|
|
1724
|
+
def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
|
|
1725
|
+
"""Check if an event should be filtered out based on stream_modes."""
|
|
1726
|
+
if "run_modes" in stream_modes and event_name != "state_update":
|
|
1727
|
+
return False
|
|
1728
|
+
if "state_update" in stream_modes and event_name == "state_update":
|
|
1729
|
+
return False
|
|
1730
|
+
if "lifecycle" in stream_modes and event_name == "metadata":
|
|
1731
|
+
try:
|
|
1732
|
+
message_data = orjson.loads(message_bytes)
|
|
1733
|
+
if message_data.get("status") == "run_done":
|
|
1734
|
+
return False
|
|
1735
|
+
if "attempt" in message_data and "run_id" in message_data:
|
|
1736
|
+
return False
|
|
1737
|
+
except (orjson.JSONDecodeError, TypeError):
|
|
1738
|
+
pass
|
|
1739
|
+
return True
|
|
1740
|
+
|
|
1741
|
+
from langgraph_api.serde import json_loads
|
|
1742
|
+
|
|
1743
|
+
stream_manager = get_stream_manager()
|
|
1744
|
+
seen_runs: set[UUID] = set()
|
|
1745
|
+
created_queues: list[tuple[UUID, asyncio.Queue]] = []
|
|
1746
|
+
|
|
1747
|
+
try:
|
|
1748
|
+
async with connect() as conn:
|
|
1749
|
+
await logger.ainfo(
|
|
1750
|
+
"Joined thread stream",
|
|
1751
|
+
thread_id=str(thread_id),
|
|
1752
|
+
)
|
|
1753
|
+
|
|
1754
|
+
# Restore messages if resuming from a specific event
|
|
1755
|
+
if last_event_id is not None:
|
|
1756
|
+
# Collect all events from all message stores for this thread
|
|
1757
|
+
all_events = []
|
|
1758
|
+
for run_id in stream_manager.message_stores.get(
|
|
1759
|
+
str(thread_id), []
|
|
1760
|
+
):
|
|
1761
|
+
for message in stream_manager.restore_messages(
|
|
1762
|
+
run_id, thread_id, last_event_id
|
|
1763
|
+
):
|
|
1764
|
+
all_events.append((message, run_id))
|
|
1765
|
+
|
|
1766
|
+
# Sort by message ID (which is ms-seq format)
|
|
1767
|
+
all_events.sort(key=lambda x: x[0].id.decode())
|
|
1768
|
+
|
|
1769
|
+
# Yield sorted events
|
|
1770
|
+
for message, run_id in all_events:
|
|
1771
|
+
data = json_loads(message.data)
|
|
1772
|
+
event_name = data["event"]
|
|
1773
|
+
message_content = data["message"]
|
|
1774
|
+
|
|
1775
|
+
if event_name == "control":
|
|
1776
|
+
if message_content == b"done":
|
|
1777
|
+
event_bytes = b"metadata"
|
|
1778
|
+
message_bytes = orjson.dumps(
|
|
1779
|
+
{"status": "run_done", "run_id": run_id}
|
|
1780
|
+
)
|
|
1781
|
+
# Filter events based on stream_modes
|
|
1782
|
+
if not should_filter_event(
|
|
1783
|
+
"metadata", message_bytes
|
|
1784
|
+
):
|
|
1785
|
+
yield (
|
|
1786
|
+
event_bytes,
|
|
1787
|
+
message_bytes,
|
|
1788
|
+
message.id,
|
|
1789
|
+
)
|
|
1790
|
+
else:
|
|
1791
|
+
event_bytes = event_name.encode()
|
|
1792
|
+
message_bytes = base64.b64decode(message_content)
|
|
1793
|
+
# Filter events based on stream_modes
|
|
1794
|
+
if not should_filter_event(event_name, message_bytes):
|
|
1795
|
+
yield (
|
|
1796
|
+
event_bytes,
|
|
1797
|
+
message_bytes,
|
|
1798
|
+
message.id,
|
|
1799
|
+
)
|
|
1800
|
+
|
|
1801
|
+
# Listen for live messages from all queues
|
|
1802
|
+
while True:
|
|
1803
|
+
# Refresh queues to pick up any new runs that joined this thread
|
|
1804
|
+
new_queue_tuples = await Threads.Stream.subscribe(
|
|
1805
|
+
conn, thread_id, seen_runs
|
|
1806
|
+
)
|
|
1807
|
+
# Track new queues for cleanup
|
|
1808
|
+
for run_id, queue in new_queue_tuples:
|
|
1809
|
+
created_queues.append((run_id, queue))
|
|
1810
|
+
|
|
1811
|
+
for run_id, queue in created_queues:
|
|
1812
|
+
try:
|
|
1813
|
+
message = await asyncio.wait_for(
|
|
1814
|
+
queue.get(), timeout=0.2
|
|
1815
|
+
)
|
|
1816
|
+
data = json_loads(message.data)
|
|
1817
|
+
event_name = data["event"]
|
|
1818
|
+
message_content = data["message"]
|
|
1819
|
+
|
|
1820
|
+
if event_name == "control":
|
|
1821
|
+
if message_content == b"done":
|
|
1822
|
+
# Extract run_id from topic
|
|
1823
|
+
topic = message.topic.decode()
|
|
1824
|
+
run_id = topic.split("run:")[1].split(":")[0]
|
|
1825
|
+
event_bytes = b"metadata"
|
|
1826
|
+
message_bytes = orjson.dumps(
|
|
1827
|
+
{"status": "run_done", "run_id": run_id}
|
|
1828
|
+
)
|
|
1829
|
+
# Filter events based on stream_modes
|
|
1830
|
+
if not should_filter_event(
|
|
1831
|
+
"metadata", message_bytes
|
|
1832
|
+
):
|
|
1833
|
+
yield (
|
|
1834
|
+
event_bytes,
|
|
1835
|
+
message_bytes,
|
|
1836
|
+
message.id,
|
|
1837
|
+
)
|
|
1838
|
+
else:
|
|
1839
|
+
event_bytes = event_name.encode()
|
|
1840
|
+
message_bytes = base64.b64decode(message_content)
|
|
1841
|
+
# Filter events based on stream_modes
|
|
1842
|
+
if not should_filter_event(
|
|
1843
|
+
event_name, message_bytes
|
|
1844
|
+
):
|
|
1845
|
+
yield (
|
|
1846
|
+
event_bytes,
|
|
1847
|
+
message_bytes,
|
|
1848
|
+
message.id,
|
|
1849
|
+
)
|
|
1850
|
+
|
|
1851
|
+
except TimeoutError:
|
|
1852
|
+
continue
|
|
1853
|
+
except (ValueError, KeyError):
|
|
1854
|
+
continue
|
|
1855
|
+
|
|
1856
|
+
# Yield execution to other tasks to prevent event loop starvation
|
|
1857
|
+
await asyncio.sleep(0)
|
|
1858
|
+
|
|
1859
|
+
except WrappedHTTPException as e:
|
|
1860
|
+
raise e.http_exception from None
|
|
1861
|
+
except asyncio.CancelledError:
|
|
1862
|
+
await logger.awarning(
|
|
1863
|
+
"Thread stream client disconnected",
|
|
1864
|
+
thread_id=str(thread_id),
|
|
1865
|
+
)
|
|
1866
|
+
raise
|
|
1867
|
+
except:
|
|
1868
|
+
raise
|
|
1869
|
+
finally:
|
|
1870
|
+
# Clean up all created queues
|
|
1871
|
+
for run_id, queue in created_queues:
|
|
1872
|
+
try:
|
|
1873
|
+
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
1874
|
+
except Exception:
|
|
1875
|
+
# Ignore cleanup errors
|
|
1876
|
+
pass
|
|
1877
|
+
|
|
1878
|
+
@staticmethod
|
|
1879
|
+
async def publish(
|
|
1880
|
+
thread_id: UUID | str,
|
|
1881
|
+
event: str,
|
|
1882
|
+
message: bytes,
|
|
1883
|
+
) -> None:
|
|
1884
|
+
"""Publish a thread-level event to the thread stream."""
|
|
1885
|
+
from langgraph_api.serde import json_dumpb
|
|
1886
|
+
|
|
1887
|
+
topic = f"thread:{thread_id}:stream".encode()
|
|
1888
|
+
|
|
1889
|
+
stream_manager = get_stream_manager()
|
|
1890
|
+
# Send to thread stream topic
|
|
1891
|
+
payload = json_dumpb(
|
|
1892
|
+
{
|
|
1893
|
+
"event": event,
|
|
1894
|
+
"message": message,
|
|
1895
|
+
}
|
|
1896
|
+
)
|
|
1897
|
+
await stream_manager.put_thread(
|
|
1898
|
+
str(thread_id), Message(topic=topic, data=payload)
|
|
1899
|
+
)
|
|
1900
|
+
|
|
1612
1901
|
@staticmethod
|
|
1613
1902
|
async def count(
|
|
1614
1903
|
conn: InMemConnectionProto,
|
|
@@ -1767,7 +2056,7 @@ class Runs(Authenticated):
|
|
|
1767
2056
|
@asynccontextmanager
|
|
1768
2057
|
@staticmethod
|
|
1769
2058
|
async def enter(
|
|
1770
|
-
run_id: UUID, loop: asyncio.AbstractEventLoop
|
|
2059
|
+
run_id: UUID, thread_id: UUID | None, loop: asyncio.AbstractEventLoop
|
|
1771
2060
|
) -> AsyncIterator[ValueEvent]:
|
|
1772
2061
|
"""Enter a run, listen for cancellation while running, signal when done."
|
|
1773
2062
|
This method should be called as a context manager by a worker executing a run.
|
|
@@ -1775,12 +2064,14 @@ class Runs(Authenticated):
|
|
|
1775
2064
|
from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
|
|
1776
2065
|
|
|
1777
2066
|
stream_manager = get_stream_manager()
|
|
1778
|
-
# Get queue for this run
|
|
1779
|
-
|
|
2067
|
+
# Get control queue for this run (normal queue is created during run creation)
|
|
2068
|
+
control_queue = await stream_manager.add_control_queue(run_id, thread_id)
|
|
1780
2069
|
|
|
1781
2070
|
async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
|
|
1782
2071
|
done = ValueEvent()
|
|
1783
|
-
tg.create_task(
|
|
2072
|
+
tg.create_task(
|
|
2073
|
+
listen_for_cancellation(control_queue, run_id, thread_id, done)
|
|
2074
|
+
)
|
|
1784
2075
|
|
|
1785
2076
|
# Give done event to caller
|
|
1786
2077
|
yield done
|
|
@@ -1788,17 +2079,17 @@ class Runs(Authenticated):
|
|
|
1788
2079
|
control_message = Message(
|
|
1789
2080
|
topic=f"run:{run_id}:control".encode(), data=b"done"
|
|
1790
2081
|
)
|
|
1791
|
-
await stream_manager.put(run_id, control_message)
|
|
2082
|
+
await stream_manager.put(run_id, thread_id, control_message)
|
|
1792
2083
|
|
|
1793
2084
|
# Signal done to all subscribers
|
|
1794
2085
|
stream_message = Message(
|
|
1795
2086
|
topic=f"run:{run_id}:stream".encode(),
|
|
1796
2087
|
data={"event": "control", "message": b"done"},
|
|
1797
2088
|
)
|
|
1798
|
-
await stream_manager.put(run_id, stream_message)
|
|
2089
|
+
await stream_manager.put(run_id, thread_id, stream_message)
|
|
1799
2090
|
|
|
1800
|
-
# Remove the queue
|
|
1801
|
-
await stream_manager.remove_control_queue(run_id,
|
|
2091
|
+
# Remove the control_queue (normal queue is cleaned up during run deletion)
|
|
2092
|
+
await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
|
|
1802
2093
|
|
|
1803
2094
|
@staticmethod
|
|
1804
2095
|
async def sweep() -> None:
|
|
@@ -1853,6 +2144,7 @@ class Runs(Authenticated):
|
|
|
1853
2144
|
run_id = _ensure_uuid(run_id) if run_id else None
|
|
1854
2145
|
metadata = metadata if metadata is not None else {}
|
|
1855
2146
|
config = kwargs.get("config", {})
|
|
2147
|
+
temporary = kwargs.get("temporary", False)
|
|
1856
2148
|
|
|
1857
2149
|
# Handle thread creation/update
|
|
1858
2150
|
existing_thread = next(
|
|
@@ -1862,7 +2154,7 @@ class Runs(Authenticated):
|
|
|
1862
2154
|
ctx,
|
|
1863
2155
|
"create_run",
|
|
1864
2156
|
Auth.types.RunsCreate(
|
|
1865
|
-
thread_id=thread_id,
|
|
2157
|
+
thread_id=None if temporary else thread_id,
|
|
1866
2158
|
assistant_id=assistant_id,
|
|
1867
2159
|
run_id=run_id,
|
|
1868
2160
|
status=status,
|
|
@@ -2086,6 +2378,7 @@ class Runs(Authenticated):
|
|
|
2086
2378
|
if not thread:
|
|
2087
2379
|
return _empty_generator()
|
|
2088
2380
|
_delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
|
|
2381
|
+
|
|
2089
2382
|
found = False
|
|
2090
2383
|
for i, run in enumerate(conn.store["runs"]):
|
|
2091
2384
|
if run["run_id"] == run_id and run["thread_id"] == thread_id:
|
|
@@ -2268,9 +2561,9 @@ class Runs(Authenticated):
|
|
|
2268
2561
|
topic=f"run:{run_id}:control".encode(),
|
|
2269
2562
|
data=action.encode(),
|
|
2270
2563
|
)
|
|
2271
|
-
coros.append(stream_manager.put(run_id, control_message))
|
|
2564
|
+
coros.append(stream_manager.put(run_id, thread_id, control_message))
|
|
2272
2565
|
|
|
2273
|
-
queues = stream_manager.get_queues(run_id)
|
|
2566
|
+
queues = stream_manager.get_queues(run_id, thread_id)
|
|
2274
2567
|
|
|
2275
2568
|
if run["status"] in ("pending", "running"):
|
|
2276
2569
|
cancelable_runs.append(run)
|
|
@@ -2385,15 +2678,25 @@ class Runs(Authenticated):
|
|
|
2385
2678
|
@staticmethod
|
|
2386
2679
|
async def subscribe(
|
|
2387
2680
|
run_id: UUID,
|
|
2388
|
-
|
|
2681
|
+
thread_id: UUID | None = None,
|
|
2682
|
+
) -> ContextQueue:
|
|
2389
2683
|
"""Subscribe to the run stream, returning a queue."""
|
|
2390
2684
|
stream_manager = get_stream_manager()
|
|
2391
|
-
queue = await stream_manager.add_queue(_ensure_uuid(run_id))
|
|
2685
|
+
queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
|
|
2392
2686
|
|
|
2393
2687
|
# If there's a control message already stored, send it to the new subscriber
|
|
2394
|
-
if
|
|
2395
|
-
|
|
2396
|
-
|
|
2688
|
+
if thread_id is None:
|
|
2689
|
+
thread_id = THREADLESS_KEY
|
|
2690
|
+
if control_queues := stream_manager.control_queues.get(thread_id, {}).get(
|
|
2691
|
+
run_id
|
|
2692
|
+
):
|
|
2693
|
+
for control_queue in control_queues:
|
|
2694
|
+
try:
|
|
2695
|
+
while True:
|
|
2696
|
+
control_msg = control_queue.get()
|
|
2697
|
+
await queue.put(control_msg)
|
|
2698
|
+
except asyncio.QueueEmpty:
|
|
2699
|
+
pass
|
|
2397
2700
|
return queue
|
|
2398
2701
|
|
|
2399
2702
|
@staticmethod
|
|
@@ -2415,7 +2718,7 @@ class Runs(Authenticated):
|
|
|
2415
2718
|
queue = (
|
|
2416
2719
|
stream_channel
|
|
2417
2720
|
if stream_channel
|
|
2418
|
-
else await Runs.Stream.subscribe(run_id)
|
|
2721
|
+
else await Runs.Stream.subscribe(run_id, thread_id)
|
|
2419
2722
|
)
|
|
2420
2723
|
|
|
2421
2724
|
try:
|
|
@@ -2438,7 +2741,7 @@ class Runs(Authenticated):
|
|
|
2438
2741
|
run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
|
|
2439
2742
|
|
|
2440
2743
|
for message in get_stream_manager().restore_messages(
|
|
2441
|
-
run_id, last_event_id
|
|
2744
|
+
run_id, thread_id, last_event_id
|
|
2442
2745
|
):
|
|
2443
2746
|
data, id = message.data, message.id
|
|
2444
2747
|
|
|
@@ -2529,7 +2832,7 @@ class Runs(Authenticated):
|
|
|
2529
2832
|
raise
|
|
2530
2833
|
finally:
|
|
2531
2834
|
stream_manager = get_stream_manager()
|
|
2532
|
-
await stream_manager.remove_queue(run_id, queue)
|
|
2835
|
+
await stream_manager.remove_queue(run_id, thread_id, queue)
|
|
2533
2836
|
|
|
2534
2837
|
@staticmethod
|
|
2535
2838
|
async def publish(
|
|
@@ -2537,6 +2840,7 @@ class Runs(Authenticated):
|
|
|
2537
2840
|
event: str,
|
|
2538
2841
|
message: bytes,
|
|
2539
2842
|
*,
|
|
2843
|
+
thread_id: UUID | str | None = None,
|
|
2540
2844
|
resumable: bool = False,
|
|
2541
2845
|
) -> None:
|
|
2542
2846
|
"""Publish a message to all subscribers of the run stream."""
|
|
@@ -2553,17 +2857,19 @@ class Runs(Authenticated):
|
|
|
2553
2857
|
}
|
|
2554
2858
|
)
|
|
2555
2859
|
await stream_manager.put(
|
|
2556
|
-
run_id, Message(topic=topic, data=payload), resumable
|
|
2860
|
+
run_id, thread_id, Message(topic=topic, data=payload), resumable
|
|
2557
2861
|
)
|
|
2558
2862
|
|
|
2559
2863
|
|
|
2560
|
-
async def listen_for_cancellation(
|
|
2864
|
+
async def listen_for_cancellation(
|
|
2865
|
+
queue: asyncio.Queue, run_id: UUID, thread_id: UUID | None, done: ValueEvent
|
|
2866
|
+
):
|
|
2561
2867
|
"""Listen for cancellation messages and set the done event accordingly."""
|
|
2562
2868
|
from langgraph_api.errors import UserInterrupt, UserRollback
|
|
2563
2869
|
|
|
2564
2870
|
stream_manager = get_stream_manager()
|
|
2565
2871
|
|
|
2566
|
-
if control_key := stream_manager.get_control_key(run_id):
|
|
2872
|
+
if control_key := stream_manager.get_control_key(run_id, thread_id):
|
|
2567
2873
|
payload = control_key.data
|
|
2568
2874
|
if payload == b"rollback":
|
|
2569
2875
|
done.set(UserRollback())
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
langgraph_runtime_inmem/__init__.py,sha256=
|
|
1
|
+
langgraph_runtime_inmem/__init__.py,sha256=4xhdO3o6RduCHDXSNh42I51Wwq7Kcnt3JK1U1IhP-BU,311
|
|
2
2
|
langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
|
|
3
3
|
langgraph_runtime_inmem/database.py,sha256=QgaA_WQo1IY6QioYd8r-e6-0B0rnC5anS0muIEJWby0,6364
|
|
4
|
-
langgraph_runtime_inmem/inmem_stream.py,sha256=
|
|
4
|
+
langgraph_runtime_inmem/inmem_stream.py,sha256=utL1OlOJsy6VDkSGAA6eX9nETreZlM6K6nhfNoubmRQ,9011
|
|
5
5
|
langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
|
|
6
6
|
langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
|
|
7
|
-
langgraph_runtime_inmem/ops.py,sha256=
|
|
7
|
+
langgraph_runtime_inmem/ops.py,sha256=54jiyWhfbSu9z9pca6AQdNuaIBmD0WMrQ7xGQcLPDF4,111183
|
|
8
8
|
langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
|
|
9
9
|
langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
|
|
10
10
|
langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
|
|
11
|
-
langgraph_runtime_inmem-0.
|
|
12
|
-
langgraph_runtime_inmem-0.
|
|
13
|
-
langgraph_runtime_inmem-0.
|
|
11
|
+
langgraph_runtime_inmem-0.10.0.dist-info/METADATA,sha256=gdjdQjZF2KjDtwA9rDiW53pG4FYNfv8TkT1U8t2lftQ,566
|
|
12
|
+
langgraph_runtime_inmem-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
13
|
+
langgraph_runtime_inmem-0.10.0.dist-info/RECORD,,
|
|
File without changes
|