langgraph-runtime-inmem 0.8.2__tar.gz → 0.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/PKG-INFO +1 -1
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem-0.10.0/langgraph_runtime_inmem/inmem_stream.py +268 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/ops.py +332 -26
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/uv.lock +104 -104
- langgraph_runtime_inmem-0.8.2/langgraph_runtime_inmem/inmem_stream.py +0 -159
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/.gitignore +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/Makefile +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/README.md +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/checkpoint.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/database.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/lifespan.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/metrics.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/queue.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/retry.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/langgraph_runtime_inmem/store.py +0 -0
- {langgraph_runtime_inmem-0.8.2 → langgraph_runtime_inmem-0.10.0}/pyproject.toml +0 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Iterator
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _ensure_uuid(id: str | UUID) -> UUID:
|
|
13
|
+
return UUID(id) if isinstance(id, str) else id
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _generate_ms_seq_id() -> str:
|
|
17
|
+
"""Generate a Redis-like millisecond-sequence ID (e.g., '1234567890123-0')"""
|
|
18
|
+
# Get current time in milliseconds
|
|
19
|
+
ms = int(time.time() * 1000)
|
|
20
|
+
# For simplicity, always use sequence 0 since we're not handling high throughput
|
|
21
|
+
return f"{ms}-0"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Message:
|
|
26
|
+
topic: bytes
|
|
27
|
+
data: bytes
|
|
28
|
+
id: bytes | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ContextQueue(asyncio.Queue):
|
|
32
|
+
"""Queue that supports async context manager protocol"""
|
|
33
|
+
|
|
34
|
+
async def __aenter__(self):
|
|
35
|
+
return self
|
|
36
|
+
|
|
37
|
+
async def __aexit__(
|
|
38
|
+
self,
|
|
39
|
+
exc_type: type[BaseException] | None,
|
|
40
|
+
exc_val: BaseException | None,
|
|
41
|
+
exc_tb: object | None,
|
|
42
|
+
) -> None:
|
|
43
|
+
# Clear the queue
|
|
44
|
+
while not self.empty():
|
|
45
|
+
try:
|
|
46
|
+
self.get_nowait()
|
|
47
|
+
except asyncio.QueueEmpty:
|
|
48
|
+
break
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
THREADLESS_KEY = "no-thread"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class StreamManager:
|
|
55
|
+
def __init__(self):
|
|
56
|
+
self.queues = defaultdict(
|
|
57
|
+
lambda: defaultdict(list)
|
|
58
|
+
) # Dict[str, List[asyncio.Queue]]
|
|
59
|
+
self.control_keys = defaultdict(lambda: defaultdict())
|
|
60
|
+
self.control_queues = defaultdict(lambda: defaultdict(list))
|
|
61
|
+
self.thread_streams = defaultdict(list)
|
|
62
|
+
|
|
63
|
+
self.message_stores = defaultdict(
|
|
64
|
+
lambda: defaultdict(list[Message])
|
|
65
|
+
) # Dict[str, List[Message]]
|
|
66
|
+
|
|
67
|
+
def get_queues(
|
|
68
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
69
|
+
) -> list[asyncio.Queue]:
|
|
70
|
+
run_id = _ensure_uuid(run_id)
|
|
71
|
+
if thread_id is None:
|
|
72
|
+
thread_id = THREADLESS_KEY
|
|
73
|
+
else:
|
|
74
|
+
thread_id = _ensure_uuid(thread_id)
|
|
75
|
+
return self.queues[thread_id][run_id]
|
|
76
|
+
|
|
77
|
+
def get_control_queues(
|
|
78
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
79
|
+
) -> list[asyncio.Queue]:
|
|
80
|
+
run_id = _ensure_uuid(run_id)
|
|
81
|
+
if thread_id is None:
|
|
82
|
+
thread_id = THREADLESS_KEY
|
|
83
|
+
else:
|
|
84
|
+
thread_id = _ensure_uuid(thread_id)
|
|
85
|
+
return self.control_queues[thread_id][run_id]
|
|
86
|
+
|
|
87
|
+
def get_control_key(
|
|
88
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
89
|
+
) -> Message | None:
|
|
90
|
+
run_id = _ensure_uuid(run_id)
|
|
91
|
+
if thread_id is None:
|
|
92
|
+
thread_id = THREADLESS_KEY
|
|
93
|
+
else:
|
|
94
|
+
thread_id = _ensure_uuid(thread_id)
|
|
95
|
+
return self.control_keys.get(thread_id, {}).get(run_id)
|
|
96
|
+
|
|
97
|
+
async def put(
|
|
98
|
+
self,
|
|
99
|
+
run_id: UUID | str | None,
|
|
100
|
+
thread_id: UUID | str | None,
|
|
101
|
+
message: Message,
|
|
102
|
+
resumable: bool = False,
|
|
103
|
+
) -> None:
|
|
104
|
+
run_id = _ensure_uuid(run_id)
|
|
105
|
+
if thread_id is None:
|
|
106
|
+
thread_id = THREADLESS_KEY
|
|
107
|
+
else:
|
|
108
|
+
thread_id = _ensure_uuid(thread_id)
|
|
109
|
+
|
|
110
|
+
message.id = _generate_ms_seq_id().encode()
|
|
111
|
+
if resumable:
|
|
112
|
+
self.message_stores[thread_id][run_id].append(message)
|
|
113
|
+
topic = message.topic.decode()
|
|
114
|
+
if "control" in topic:
|
|
115
|
+
self.control_keys[thread_id][run_id] = message
|
|
116
|
+
queues = self.control_queues[thread_id][run_id]
|
|
117
|
+
else:
|
|
118
|
+
queues = self.queues[thread_id][run_id]
|
|
119
|
+
coros = [queue.put(message) for queue in queues]
|
|
120
|
+
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
121
|
+
for result in results:
|
|
122
|
+
if isinstance(result, Exception):
|
|
123
|
+
logger.exception(f"Failed to put message in queue: {result}")
|
|
124
|
+
|
|
125
|
+
async def put_thread(
|
|
126
|
+
self,
|
|
127
|
+
thread_id: UUID | str,
|
|
128
|
+
message: Message,
|
|
129
|
+
) -> None:
|
|
130
|
+
thread_id = _ensure_uuid(thread_id)
|
|
131
|
+
message.id = _generate_ms_seq_id().encode()
|
|
132
|
+
queues = self.thread_streams[thread_id]
|
|
133
|
+
coros = [queue.put(message) for queue in queues]
|
|
134
|
+
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
135
|
+
for result in results:
|
|
136
|
+
if isinstance(result, Exception):
|
|
137
|
+
logger.exception(f"Failed to put message in queue: {result}")
|
|
138
|
+
|
|
139
|
+
async def add_queue(
|
|
140
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
141
|
+
) -> asyncio.Queue:
|
|
142
|
+
run_id = _ensure_uuid(run_id)
|
|
143
|
+
queue = ContextQueue()
|
|
144
|
+
if thread_id is None:
|
|
145
|
+
thread_id = THREADLESS_KEY
|
|
146
|
+
else:
|
|
147
|
+
thread_id = _ensure_uuid(thread_id)
|
|
148
|
+
self.queues[thread_id][run_id].append(queue)
|
|
149
|
+
return queue
|
|
150
|
+
|
|
151
|
+
async def add_control_queue(
|
|
152
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
153
|
+
) -> asyncio.Queue:
|
|
154
|
+
run_id = _ensure_uuid(run_id)
|
|
155
|
+
if thread_id is None:
|
|
156
|
+
thread_id = THREADLESS_KEY
|
|
157
|
+
else:
|
|
158
|
+
thread_id = _ensure_uuid(thread_id)
|
|
159
|
+
queue = ContextQueue()
|
|
160
|
+
self.control_queues[thread_id][run_id].append(queue)
|
|
161
|
+
return queue
|
|
162
|
+
|
|
163
|
+
async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
|
|
164
|
+
thread_id = _ensure_uuid(thread_id)
|
|
165
|
+
queue = ContextQueue()
|
|
166
|
+
self.thread_streams[thread_id].append(queue)
|
|
167
|
+
return queue
|
|
168
|
+
|
|
169
|
+
async def remove_queue(
|
|
170
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
171
|
+
):
|
|
172
|
+
run_id = _ensure_uuid(run_id)
|
|
173
|
+
if thread_id is None:
|
|
174
|
+
thread_id = THREADLESS_KEY
|
|
175
|
+
else:
|
|
176
|
+
thread_id = _ensure_uuid(thread_id)
|
|
177
|
+
if thread_id in self.queues and run_id in self.queues[thread_id]:
|
|
178
|
+
self.queues[thread_id][run_id].remove(queue)
|
|
179
|
+
if not self.queues[thread_id][run_id]:
|
|
180
|
+
del self.queues[thread_id][run_id]
|
|
181
|
+
|
|
182
|
+
async def remove_control_queue(
|
|
183
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
184
|
+
):
|
|
185
|
+
run_id = _ensure_uuid(run_id)
|
|
186
|
+
if thread_id is None:
|
|
187
|
+
thread_id = THREADLESS_KEY
|
|
188
|
+
else:
|
|
189
|
+
thread_id = _ensure_uuid(thread_id)
|
|
190
|
+
if (
|
|
191
|
+
thread_id in self.control_queues
|
|
192
|
+
and run_id in self.control_queues[thread_id]
|
|
193
|
+
):
|
|
194
|
+
self.control_queues[thread_id][run_id].remove(queue)
|
|
195
|
+
if not self.control_queues[thread_id][run_id]:
|
|
196
|
+
del self.control_queues[thread_id][run_id]
|
|
197
|
+
|
|
198
|
+
def restore_messages(
|
|
199
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, message_id: str | None
|
|
200
|
+
) -> Iterator[Message]:
|
|
201
|
+
"""Get a stored message by ID for resumable streams."""
|
|
202
|
+
run_id = _ensure_uuid(run_id)
|
|
203
|
+
if thread_id is None:
|
|
204
|
+
thread_id = THREADLESS_KEY
|
|
205
|
+
else:
|
|
206
|
+
thread_id = _ensure_uuid(thread_id)
|
|
207
|
+
if message_id is None:
|
|
208
|
+
return
|
|
209
|
+
try:
|
|
210
|
+
# Handle ms-seq format (e.g., "1234567890123-0")
|
|
211
|
+
if thread_id in self.message_stores:
|
|
212
|
+
for message in self.message_stores[thread_id][run_id]:
|
|
213
|
+
if message.id.decode() > message_id:
|
|
214
|
+
yield message
|
|
215
|
+
except TypeError:
|
|
216
|
+
# Try integer format if ms-seq fails
|
|
217
|
+
message_idx = int(message_id) + 1
|
|
218
|
+
if run_id in self.message_stores:
|
|
219
|
+
yield from self.message_stores[thread_id][run_id][message_idx:]
|
|
220
|
+
|
|
221
|
+
def get_queues_by_thread_id(self, thread_id: UUID | str) -> list[asyncio.Queue]:
|
|
222
|
+
"""Get all queues for a specific thread_id across all runs."""
|
|
223
|
+
all_queues = []
|
|
224
|
+
# Search through all stored queue keys for ones ending with the thread_id
|
|
225
|
+
thread_id = _ensure_uuid(thread_id)
|
|
226
|
+
if thread_id in self.queues:
|
|
227
|
+
for run_id in self.queues[thread_id]:
|
|
228
|
+
all_queues.extend(self.queues[thread_id][run_id])
|
|
229
|
+
|
|
230
|
+
return all_queues
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# Global instance
|
|
234
|
+
stream_manager = StreamManager()
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
async def start_stream() -> None:
|
|
238
|
+
"""Initialize the queue system.
|
|
239
|
+
In this in-memory implementation, we just need to ensure we have a clean StreamManager instance.
|
|
240
|
+
"""
|
|
241
|
+
global stream_manager
|
|
242
|
+
stream_manager = StreamManager()
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
async def stop_stream() -> None:
|
|
246
|
+
"""Clean up the queue system.
|
|
247
|
+
Clear all queues and stored control messages."""
|
|
248
|
+
global stream_manager
|
|
249
|
+
|
|
250
|
+
# Send 'done' message to all active queues before clearing
|
|
251
|
+
for run_id in list(stream_manager.queues.keys()):
|
|
252
|
+
control_message = Message(topic=f"run:{run_id}:control".encode(), data=b"done")
|
|
253
|
+
|
|
254
|
+
for queue in stream_manager.queues[run_id]:
|
|
255
|
+
try:
|
|
256
|
+
await queue.put(control_message)
|
|
257
|
+
except (Exception, RuntimeError):
|
|
258
|
+
pass # Ignore errors during shutdown
|
|
259
|
+
|
|
260
|
+
# Clear all stored data
|
|
261
|
+
stream_manager.queues.clear()
|
|
262
|
+
stream_manager.control_queues.clear()
|
|
263
|
+
stream_manager.message_stores.clear()
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def get_stream_manager() -> StreamManager:
|
|
267
|
+
"""Get the global stream manager instance."""
|
|
268
|
+
return stream_manager
|