langgraph-runtime-inmem 0.6.4__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_runtime_inmem/__init__.py +1 -1
- langgraph_runtime_inmem/database.py +6 -2
- langgraph_runtime_inmem/inmem_stream.py +160 -36
- langgraph_runtime_inmem/lifespan.py +41 -2
- langgraph_runtime_inmem/metrics.py +1 -1
- langgraph_runtime_inmem/ops.py +695 -206
- langgraph_runtime_inmem/queue.py +8 -18
- {langgraph_runtime_inmem-0.6.4.dist-info → langgraph_runtime_inmem-0.18.1.dist-info}/METADATA +3 -3
- langgraph_runtime_inmem-0.18.1.dist-info/RECORD +13 -0
- langgraph_runtime_inmem-0.6.4.dist-info/RECORD +0 -13
- {langgraph_runtime_inmem-0.6.4.dist-info → langgraph_runtime_inmem-0.18.1.dist-info}/WHEEL +0 -0
|
@@ -142,7 +142,9 @@ class InMemConnectionProto:
|
|
|
142
142
|
|
|
143
143
|
|
|
144
144
|
@asynccontextmanager
|
|
145
|
-
async def connect(
|
|
145
|
+
async def connect(
|
|
146
|
+
*, supports_core_api: bool = False, __test__: bool = False
|
|
147
|
+
) -> AsyncIterator["AsyncConnectionProto"]:
|
|
146
148
|
yield InMemConnectionProto()
|
|
147
149
|
|
|
148
150
|
|
|
@@ -182,6 +184,8 @@ async def start_pool() -> None:
|
|
|
182
184
|
for a in GLOBAL_STORE["assistants"]:
|
|
183
185
|
if a["metadata"].get("created_by") == "system":
|
|
184
186
|
GLOBAL_STORE["assistants"].remove(a)
|
|
187
|
+
if "context" not in a:
|
|
188
|
+
a["context"] = {}
|
|
185
189
|
for k in ["crons"]:
|
|
186
190
|
if not GLOBAL_STORE.get(k):
|
|
187
191
|
GLOBAL_STORE[k] = {}
|
|
@@ -206,6 +210,6 @@ async def healthcheck() -> None:
|
|
|
206
210
|
pass
|
|
207
211
|
|
|
208
212
|
|
|
209
|
-
def pool_stats() -> dict[str, dict[str, int]]:
|
|
213
|
+
def pool_stats(*args, **kwargs) -> dict[str, dict[str, int]]:
|
|
210
214
|
# TODO??
|
|
211
215
|
return {}
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
+
import time
|
|
3
4
|
from collections import defaultdict
|
|
4
5
|
from collections.abc import Iterator
|
|
5
6
|
from dataclasses import dataclass
|
|
@@ -12,6 +13,14 @@ def _ensure_uuid(id: str | UUID) -> UUID:
|
|
|
12
13
|
return UUID(id) if isinstance(id, str) else id
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
def _generate_ms_seq_id() -> str:
|
|
17
|
+
"""Generate a Redis-like millisecond-sequence ID (e.g., '1234567890123-0')"""
|
|
18
|
+
# Get current time in milliseconds
|
|
19
|
+
ms = int(time.time() * 1000)
|
|
20
|
+
# For simplicity, always use sequence 0 since we're not handling high throughput
|
|
21
|
+
return f"{ms}-0"
|
|
22
|
+
|
|
23
|
+
|
|
15
24
|
@dataclass
|
|
16
25
|
class Message:
|
|
17
26
|
topic: bytes
|
|
@@ -39,72 +48,187 @@ class ContextQueue(asyncio.Queue):
|
|
|
39
48
|
break
|
|
40
49
|
|
|
41
50
|
|
|
51
|
+
THREADLESS_KEY = "no-thread"
|
|
52
|
+
|
|
53
|
+
|
|
42
54
|
class StreamManager:
|
|
43
55
|
def __init__(self):
|
|
44
|
-
self.queues = defaultdict(
|
|
45
|
-
|
|
56
|
+
self.queues = defaultdict(
|
|
57
|
+
lambda: defaultdict(list)
|
|
58
|
+
) # Dict[str, List[asyncio.Queue]]
|
|
59
|
+
self.control_keys = defaultdict(lambda: defaultdict())
|
|
60
|
+
self.control_queues = defaultdict(lambda: defaultdict(list))
|
|
61
|
+
self.thread_streams = defaultdict(list)
|
|
46
62
|
|
|
47
|
-
self.message_stores = defaultdict(
|
|
48
|
-
|
|
63
|
+
self.message_stores = defaultdict(
|
|
64
|
+
lambda: defaultdict(list[Message])
|
|
65
|
+
) # Dict[str, List[Message]]
|
|
66
|
+
|
|
67
|
+
def get_queues(
|
|
68
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
69
|
+
) -> list[asyncio.Queue]:
|
|
70
|
+
run_id = _ensure_uuid(run_id)
|
|
71
|
+
if thread_id is None:
|
|
72
|
+
thread_id = THREADLESS_KEY
|
|
73
|
+
else:
|
|
74
|
+
thread_id = _ensure_uuid(thread_id)
|
|
75
|
+
return self.queues[thread_id][run_id]
|
|
76
|
+
|
|
77
|
+
def get_control_queues(
|
|
78
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
79
|
+
) -> list[asyncio.Queue]:
|
|
80
|
+
run_id = _ensure_uuid(run_id)
|
|
81
|
+
if thread_id is None:
|
|
82
|
+
thread_id = THREADLESS_KEY
|
|
83
|
+
else:
|
|
84
|
+
thread_id = _ensure_uuid(thread_id)
|
|
85
|
+
return self.control_queues[thread_id][run_id]
|
|
49
86
|
|
|
50
|
-
def
|
|
87
|
+
def get_control_key(
|
|
88
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
89
|
+
) -> Message | None:
|
|
51
90
|
run_id = _ensure_uuid(run_id)
|
|
52
|
-
|
|
91
|
+
if thread_id is None:
|
|
92
|
+
thread_id = THREADLESS_KEY
|
|
93
|
+
else:
|
|
94
|
+
thread_id = _ensure_uuid(thread_id)
|
|
95
|
+
return self.control_keys.get(thread_id, {}).get(run_id)
|
|
53
96
|
|
|
54
97
|
async def put(
|
|
55
|
-
self,
|
|
98
|
+
self,
|
|
99
|
+
run_id: UUID | str | None,
|
|
100
|
+
thread_id: UUID | str | None,
|
|
101
|
+
message: Message,
|
|
102
|
+
resumable: bool = False,
|
|
56
103
|
) -> None:
|
|
57
104
|
run_id = _ensure_uuid(run_id)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
105
|
+
if thread_id is None:
|
|
106
|
+
thread_id = THREADLESS_KEY
|
|
107
|
+
else:
|
|
108
|
+
thread_id = _ensure_uuid(thread_id)
|
|
109
|
+
|
|
110
|
+
message.id = _generate_ms_seq_id().encode()
|
|
111
|
+
# For resumable run streams, embed the generated message ID into the frame
|
|
62
112
|
topic = message.topic.decode()
|
|
113
|
+
if resumable:
|
|
114
|
+
self.message_stores[thread_id][run_id].append(message)
|
|
63
115
|
if "control" in topic:
|
|
64
|
-
self.
|
|
65
|
-
|
|
116
|
+
self.control_keys[thread_id][run_id] = message
|
|
117
|
+
queues = self.control_queues[thread_id][run_id]
|
|
118
|
+
else:
|
|
119
|
+
queues = self.queues[thread_id][run_id]
|
|
66
120
|
coros = [queue.put(message) for queue in queues]
|
|
67
121
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
68
122
|
for result in results:
|
|
69
123
|
if isinstance(result, Exception):
|
|
70
124
|
logger.exception(f"Failed to put message in queue: {result}")
|
|
71
125
|
|
|
72
|
-
async def
|
|
126
|
+
async def put_thread(
|
|
127
|
+
self,
|
|
128
|
+
thread_id: UUID | str,
|
|
129
|
+
message: Message,
|
|
130
|
+
) -> None:
|
|
131
|
+
thread_id = _ensure_uuid(thread_id)
|
|
132
|
+
message.id = _generate_ms_seq_id().encode()
|
|
133
|
+
queues = self.thread_streams[thread_id]
|
|
134
|
+
coros = [queue.put(message) for queue in queues]
|
|
135
|
+
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
136
|
+
for result in results:
|
|
137
|
+
if isinstance(result, Exception):
|
|
138
|
+
logger.exception(f"Failed to put message in queue: {result}")
|
|
139
|
+
|
|
140
|
+
async def add_queue(
|
|
141
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
142
|
+
) -> asyncio.Queue:
|
|
73
143
|
run_id = _ensure_uuid(run_id)
|
|
74
144
|
queue = ContextQueue()
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
f"Failed to put control message in queue: {control_msg}"
|
|
82
|
-
)
|
|
145
|
+
if thread_id is None:
|
|
146
|
+
thread_id = THREADLESS_KEY
|
|
147
|
+
else:
|
|
148
|
+
thread_id = _ensure_uuid(thread_id)
|
|
149
|
+
self.queues[thread_id][run_id].append(queue)
|
|
150
|
+
return queue
|
|
83
151
|
|
|
152
|
+
async def add_control_queue(
|
|
153
|
+
self, run_id: UUID | str, thread_id: UUID | str | None
|
|
154
|
+
) -> asyncio.Queue:
|
|
155
|
+
run_id = _ensure_uuid(run_id)
|
|
156
|
+
if thread_id is None:
|
|
157
|
+
thread_id = THREADLESS_KEY
|
|
158
|
+
else:
|
|
159
|
+
thread_id = _ensure_uuid(thread_id)
|
|
160
|
+
queue = ContextQueue()
|
|
161
|
+
self.control_queues[thread_id][run_id].append(queue)
|
|
84
162
|
return queue
|
|
85
163
|
|
|
86
|
-
async def
|
|
164
|
+
async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
|
|
165
|
+
thread_id = _ensure_uuid(thread_id)
|
|
166
|
+
queue = ContextQueue()
|
|
167
|
+
self.thread_streams[thread_id].append(queue)
|
|
168
|
+
return queue
|
|
169
|
+
|
|
170
|
+
async def remove_queue(
|
|
171
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
172
|
+
):
|
|
87
173
|
run_id = _ensure_uuid(run_id)
|
|
88
|
-
if
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
174
|
+
if thread_id is None:
|
|
175
|
+
thread_id = THREADLESS_KEY
|
|
176
|
+
else:
|
|
177
|
+
thread_id = _ensure_uuid(thread_id)
|
|
178
|
+
if thread_id in self.queues and run_id in self.queues[thread_id]:
|
|
179
|
+
self.queues[thread_id][run_id].remove(queue)
|
|
180
|
+
if not self.queues[thread_id][run_id]:
|
|
181
|
+
del self.queues[thread_id][run_id]
|
|
182
|
+
|
|
183
|
+
async def remove_control_queue(
|
|
184
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
|
|
185
|
+
):
|
|
186
|
+
run_id = _ensure_uuid(run_id)
|
|
187
|
+
if thread_id is None:
|
|
188
|
+
thread_id = THREADLESS_KEY
|
|
189
|
+
else:
|
|
190
|
+
thread_id = _ensure_uuid(thread_id)
|
|
191
|
+
if (
|
|
192
|
+
thread_id in self.control_queues
|
|
193
|
+
and run_id in self.control_queues[thread_id]
|
|
194
|
+
):
|
|
195
|
+
self.control_queues[thread_id][run_id].remove(queue)
|
|
196
|
+
if not self.control_queues[thread_id][run_id]:
|
|
197
|
+
del self.control_queues[thread_id][run_id]
|
|
94
198
|
|
|
95
199
|
def restore_messages(
|
|
96
|
-
self, run_id: UUID | str, message_id: str | None
|
|
200
|
+
self, run_id: UUID | str, thread_id: UUID | str | None, message_id: str | None
|
|
97
201
|
) -> Iterator[Message]:
|
|
98
202
|
"""Get a stored message by ID for resumable streams."""
|
|
99
203
|
run_id = _ensure_uuid(run_id)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
204
|
+
if thread_id is None:
|
|
205
|
+
thread_id = THREADLESS_KEY
|
|
206
|
+
else:
|
|
207
|
+
thread_id = _ensure_uuid(thread_id)
|
|
208
|
+
if message_id is None:
|
|
104
209
|
return
|
|
210
|
+
try:
|
|
211
|
+
# Handle ms-seq format (e.g., "1234567890123-0")
|
|
212
|
+
if thread_id in self.message_stores:
|
|
213
|
+
for message in self.message_stores[thread_id][run_id]:
|
|
214
|
+
if message.id.decode() > message_id:
|
|
215
|
+
yield message
|
|
216
|
+
except TypeError:
|
|
217
|
+
# Try integer format if ms-seq fails
|
|
218
|
+
message_idx = int(message_id) + 1
|
|
219
|
+
if run_id in self.message_stores:
|
|
220
|
+
yield from self.message_stores[thread_id][run_id][message_idx:]
|
|
221
|
+
|
|
222
|
+
def get_queues_by_thread_id(self, thread_id: UUID | str) -> list[asyncio.Queue]:
|
|
223
|
+
"""Get all queues for a specific thread_id across all runs."""
|
|
224
|
+
all_queues = []
|
|
225
|
+
# Search through all stored queue keys for ones ending with the thread_id
|
|
226
|
+
thread_id = _ensure_uuid(thread_id)
|
|
227
|
+
if thread_id in self.queues:
|
|
228
|
+
for run_id in self.queues[thread_id]:
|
|
229
|
+
all_queues.extend(self.queues[thread_id][run_id])
|
|
105
230
|
|
|
106
|
-
|
|
107
|
-
yield from self.message_stores[run_id][message_idx:]
|
|
231
|
+
return all_queues
|
|
108
232
|
|
|
109
233
|
|
|
110
234
|
# Global instance
|
|
@@ -14,9 +14,17 @@ from langgraph_runtime_inmem.database import start_pool, stop_pool
|
|
|
14
14
|
logger = structlog.stdlib.get_logger(__name__)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
_LAST_LIFESPAN_ERROR: BaseException | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_last_error() -> BaseException | None:
|
|
21
|
+
return _LAST_LIFESPAN_ERROR
|
|
22
|
+
|
|
23
|
+
|
|
17
24
|
@asynccontextmanager
|
|
18
25
|
async def lifespan(
|
|
19
26
|
app: Starlette | None = None,
|
|
27
|
+
cancel_event: asyncio.Event | None = None,
|
|
20
28
|
taskset: set[asyncio.Task] | None = None,
|
|
21
29
|
**kwargs: Any,
|
|
22
30
|
):
|
|
@@ -41,13 +49,31 @@ async def lifespan(
|
|
|
41
49
|
except RuntimeError:
|
|
42
50
|
await logger.aerror("Failed to set loop")
|
|
43
51
|
|
|
52
|
+
global _LAST_LIFESPAN_ERROR
|
|
53
|
+
_LAST_LIFESPAN_ERROR = None
|
|
54
|
+
|
|
44
55
|
await start_http_client()
|
|
45
56
|
await start_pool()
|
|
46
57
|
await start_ui_bundler()
|
|
58
|
+
|
|
59
|
+
async def _log_graph_load_failure(err: graph.GraphLoadError) -> None:
|
|
60
|
+
cause = err.__cause__ or err.cause
|
|
61
|
+
log_fields = err.log_fields()
|
|
62
|
+
log_fields["action"] = "fix_user_graph"
|
|
63
|
+
await logger.aerror(
|
|
64
|
+
f"Graph '{err.spec.id}' failed to load: {err.cause_message}",
|
|
65
|
+
**log_fields,
|
|
66
|
+
)
|
|
67
|
+
await logger.adebug(
|
|
68
|
+
"Full graph load failure traceback (internal)",
|
|
69
|
+
**{k: v for k, v in log_fields.items() if k != "user_traceback"},
|
|
70
|
+
exc_info=cause,
|
|
71
|
+
)
|
|
72
|
+
|
|
47
73
|
try:
|
|
48
74
|
async with SimpleTaskGroup(
|
|
49
75
|
cancel=True,
|
|
50
|
-
|
|
76
|
+
cancel_event=cancel_event,
|
|
51
77
|
taskgroup_name="Lifespan",
|
|
52
78
|
) as tg:
|
|
53
79
|
tg.create_task(metadata_loop())
|
|
@@ -76,11 +102,21 @@ async def lifespan(
|
|
|
76
102
|
var_child_runnable_config.set(langgraph_config)
|
|
77
103
|
|
|
78
104
|
# Keep after the setter above so users can access the store from within the factory function
|
|
79
|
-
|
|
105
|
+
try:
|
|
106
|
+
await graph.collect_graphs_from_env(True)
|
|
107
|
+
except graph.GraphLoadError as exc:
|
|
108
|
+
_LAST_LIFESPAN_ERROR = exc
|
|
109
|
+
await _log_graph_load_failure(exc)
|
|
110
|
+
raise
|
|
80
111
|
if config.N_JOBS_PER_WORKER > 0:
|
|
81
112
|
tg.create_task(queue_with_signal())
|
|
82
113
|
|
|
83
114
|
yield
|
|
115
|
+
except graph.GraphLoadError as exc:
|
|
116
|
+
_LAST_LIFESPAN_ERROR = exc
|
|
117
|
+
raise
|
|
118
|
+
except asyncio.CancelledError:
|
|
119
|
+
pass
|
|
84
120
|
finally:
|
|
85
121
|
await api_store.exit_store()
|
|
86
122
|
await stop_ui_bundler()
|
|
@@ -97,3 +133,6 @@ async def queue_with_signal():
|
|
|
97
133
|
except Exception as exc:
|
|
98
134
|
logger.exception("Queue failed. Signaling shutdown", exc_info=exc)
|
|
99
135
|
signal.raise_signal(signal.SIGINT)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
lifespan.get_last_error = get_last_error # type: ignore[attr-defined]
|