langgraph-runtime-inmem 0.6.4__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.6.4"
12
+ __version__ = "0.18.1"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -142,7 +142,9 @@ class InMemConnectionProto:
142
142
 
143
143
 
144
144
  @asynccontextmanager
145
- async def connect(*, __test__: bool = False) -> AsyncIterator["AsyncConnectionProto"]:
145
+ async def connect(
146
+ *, supports_core_api: bool = False, __test__: bool = False
147
+ ) -> AsyncIterator["AsyncConnectionProto"]:
146
148
  yield InMemConnectionProto()
147
149
 
148
150
 
@@ -182,6 +184,8 @@ async def start_pool() -> None:
182
184
  for a in GLOBAL_STORE["assistants"]:
183
185
  if a["metadata"].get("created_by") == "system":
184
186
  GLOBAL_STORE["assistants"].remove(a)
187
+ if "context" not in a:
188
+ a["context"] = {}
185
189
  for k in ["crons"]:
186
190
  if not GLOBAL_STORE.get(k):
187
191
  GLOBAL_STORE[k] = {}
@@ -206,6 +210,6 @@ async def healthcheck() -> None:
206
210
  pass
207
211
 
208
212
 
209
- def pool_stats() -> dict[str, dict[str, int]]:
213
+ def pool_stats(*args, **kwargs) -> dict[str, dict[str, int]]:
210
214
  # TODO??
211
215
  return {}
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import logging
3
+ import time
3
4
  from collections import defaultdict
4
5
  from collections.abc import Iterator
5
6
  from dataclasses import dataclass
@@ -12,6 +13,14 @@ def _ensure_uuid(id: str | UUID) -> UUID:
12
13
  return UUID(id) if isinstance(id, str) else id
13
14
 
14
15
 
16
+ def _generate_ms_seq_id() -> str:
17
+ """Generate a Redis-like millisecond-sequence ID (e.g., '1234567890123-0')"""
18
+ # Get current time in milliseconds
19
+ ms = int(time.time() * 1000)
20
+ # For simplicity, always use sequence 0 since we're not handling high throughput
21
+ return f"{ms}-0"
22
+
23
+
15
24
  @dataclass
16
25
  class Message:
17
26
  topic: bytes
@@ -39,72 +48,187 @@ class ContextQueue(asyncio.Queue):
39
48
  break
40
49
 
41
50
 
51
+ THREADLESS_KEY = "no-thread"
52
+
53
+
42
54
  class StreamManager:
43
55
  def __init__(self):
44
- self.queues = defaultdict(list) # Dict[UUID, List[asyncio.Queue]]
45
- self.control_queues = defaultdict(list)
56
+ self.queues = defaultdict(
57
+ lambda: defaultdict(list)
58
+ ) # Dict[str, List[asyncio.Queue]]
59
+ self.control_keys = defaultdict(lambda: defaultdict())
60
+ self.control_queues = defaultdict(lambda: defaultdict(list))
61
+ self.thread_streams = defaultdict(list)
46
62
 
47
- self.message_stores = defaultdict(list) # Dict[UUID, List[Message]]
48
- self.message_next_idx = defaultdict(int) # Dict[UUID, int]
63
+ self.message_stores = defaultdict(
64
+ lambda: defaultdict(list[Message])
65
+ ) # Dict[str, List[Message]]
66
+
67
+ def get_queues(
68
+ self, run_id: UUID | str, thread_id: UUID | str | None
69
+ ) -> list[asyncio.Queue]:
70
+ run_id = _ensure_uuid(run_id)
71
+ if thread_id is None:
72
+ thread_id = THREADLESS_KEY
73
+ else:
74
+ thread_id = _ensure_uuid(thread_id)
75
+ return self.queues[thread_id][run_id]
76
+
77
+ def get_control_queues(
78
+ self, run_id: UUID | str, thread_id: UUID | str | None
79
+ ) -> list[asyncio.Queue]:
80
+ run_id = _ensure_uuid(run_id)
81
+ if thread_id is None:
82
+ thread_id = THREADLESS_KEY
83
+ else:
84
+ thread_id = _ensure_uuid(thread_id)
85
+ return self.control_queues[thread_id][run_id]
49
86
 
50
- def get_queues(self, run_id: UUID | str) -> list[asyncio.Queue]:
87
+ def get_control_key(
88
+ self, run_id: UUID | str, thread_id: UUID | str | None
89
+ ) -> Message | None:
51
90
  run_id = _ensure_uuid(run_id)
52
- return self.queues[run_id]
91
+ if thread_id is None:
92
+ thread_id = THREADLESS_KEY
93
+ else:
94
+ thread_id = _ensure_uuid(thread_id)
95
+ return self.control_keys.get(thread_id, {}).get(run_id)
53
96
 
54
97
  async def put(
55
- self, run_id: UUID | str, message: Message, resumable: bool = False
98
+ self,
99
+ run_id: UUID | str | None,
100
+ thread_id: UUID | str | None,
101
+ message: Message,
102
+ resumable: bool = False,
56
103
  ) -> None:
57
104
  run_id = _ensure_uuid(run_id)
58
- message.id = str(self.message_next_idx[run_id]).encode()
59
- self.message_next_idx[run_id] += 1
60
- if resumable:
61
- self.message_stores[run_id].append(message)
105
+ if thread_id is None:
106
+ thread_id = THREADLESS_KEY
107
+ else:
108
+ thread_id = _ensure_uuid(thread_id)
109
+
110
+ message.id = _generate_ms_seq_id().encode()
111
+ # For resumable run streams, embed the generated message ID into the frame
62
112
  topic = message.topic.decode()
113
+ if resumable:
114
+ self.message_stores[thread_id][run_id].append(message)
63
115
  if "control" in topic:
64
- self.control_queues[run_id].append(message)
65
- queues = self.queues.get(run_id, [])
116
+ self.control_keys[thread_id][run_id] = message
117
+ queues = self.control_queues[thread_id][run_id]
118
+ else:
119
+ queues = self.queues[thread_id][run_id]
66
120
  coros = [queue.put(message) for queue in queues]
67
121
  results = await asyncio.gather(*coros, return_exceptions=True)
68
122
  for result in results:
69
123
  if isinstance(result, Exception):
70
124
  logger.exception(f"Failed to put message in queue: {result}")
71
125
 
72
- async def add_queue(self, run_id: UUID | str) -> asyncio.Queue:
126
+ async def put_thread(
127
+ self,
128
+ thread_id: UUID | str,
129
+ message: Message,
130
+ ) -> None:
131
+ thread_id = _ensure_uuid(thread_id)
132
+ message.id = _generate_ms_seq_id().encode()
133
+ queues = self.thread_streams[thread_id]
134
+ coros = [queue.put(message) for queue in queues]
135
+ results = await asyncio.gather(*coros, return_exceptions=True)
136
+ for result in results:
137
+ if isinstance(result, Exception):
138
+ logger.exception(f"Failed to put message in queue: {result}")
139
+
140
+ async def add_queue(
141
+ self, run_id: UUID | str, thread_id: UUID | str | None
142
+ ) -> asyncio.Queue:
73
143
  run_id = _ensure_uuid(run_id)
74
144
  queue = ContextQueue()
75
- self.queues[run_id].append(queue)
76
- for control_msg in self.control_queues[run_id]:
77
- try:
78
- await queue.put(control_msg)
79
- except Exception:
80
- logger.exception(
81
- f"Failed to put control message in queue: {control_msg}"
82
- )
145
+ if thread_id is None:
146
+ thread_id = THREADLESS_KEY
147
+ else:
148
+ thread_id = _ensure_uuid(thread_id)
149
+ self.queues[thread_id][run_id].append(queue)
150
+ return queue
83
151
 
152
+ async def add_control_queue(
153
+ self, run_id: UUID | str, thread_id: UUID | str | None
154
+ ) -> asyncio.Queue:
155
+ run_id = _ensure_uuid(run_id)
156
+ if thread_id is None:
157
+ thread_id = THREADLESS_KEY
158
+ else:
159
+ thread_id = _ensure_uuid(thread_id)
160
+ queue = ContextQueue()
161
+ self.control_queues[thread_id][run_id].append(queue)
84
162
  return queue
85
163
 
86
- async def remove_queue(self, run_id: UUID | str, queue: asyncio.Queue):
164
+ async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
165
+ thread_id = _ensure_uuid(thread_id)
166
+ queue = ContextQueue()
167
+ self.thread_streams[thread_id].append(queue)
168
+ return queue
169
+
170
+ async def remove_queue(
171
+ self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
172
+ ):
87
173
  run_id = _ensure_uuid(run_id)
88
- if run_id in self.queues:
89
- self.queues[run_id].remove(queue)
90
- if not self.queues[run_id]:
91
- del self.queues[run_id]
92
- if run_id in self.message_stores:
93
- del self.message_stores[run_id]
174
+ if thread_id is None:
175
+ thread_id = THREADLESS_KEY
176
+ else:
177
+ thread_id = _ensure_uuid(thread_id)
178
+ if thread_id in self.queues and run_id in self.queues[thread_id]:
179
+ self.queues[thread_id][run_id].remove(queue)
180
+ if not self.queues[thread_id][run_id]:
181
+ del self.queues[thread_id][run_id]
182
+
183
+ async def remove_control_queue(
184
+ self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
185
+ ):
186
+ run_id = _ensure_uuid(run_id)
187
+ if thread_id is None:
188
+ thread_id = THREADLESS_KEY
189
+ else:
190
+ thread_id = _ensure_uuid(thread_id)
191
+ if (
192
+ thread_id in self.control_queues
193
+ and run_id in self.control_queues[thread_id]
194
+ ):
195
+ self.control_queues[thread_id][run_id].remove(queue)
196
+ if not self.control_queues[thread_id][run_id]:
197
+ del self.control_queues[thread_id][run_id]
94
198
 
95
199
  def restore_messages(
96
- self, run_id: UUID | str, message_id: str | None
200
+ self, run_id: UUID | str, thread_id: UUID | str | None, message_id: str | None
97
201
  ) -> Iterator[Message]:
98
202
  """Get a stored message by ID for resumable streams."""
99
203
  run_id = _ensure_uuid(run_id)
100
- message_idx = int(message_id) + 1 if message_id else None
101
-
102
- if message_idx is None:
103
- yield from []
204
+ if thread_id is None:
205
+ thread_id = THREADLESS_KEY
206
+ else:
207
+ thread_id = _ensure_uuid(thread_id)
208
+ if message_id is None:
104
209
  return
210
+ try:
211
+ # Handle ms-seq format (e.g., "1234567890123-0")
212
+ if thread_id in self.message_stores:
213
+ for message in self.message_stores[thread_id][run_id]:
214
+ if message.id.decode() > message_id:
215
+ yield message
216
+ except TypeError:
217
+ # Try integer format if ms-seq fails
218
+ message_idx = int(message_id) + 1
219
+ if run_id in self.message_stores:
220
+ yield from self.message_stores[thread_id][run_id][message_idx:]
221
+
222
+ def get_queues_by_thread_id(self, thread_id: UUID | str) -> list[asyncio.Queue]:
223
+ """Get all queues for a specific thread_id across all runs."""
224
+ all_queues = []
225
+ # Search through all stored queue keys for ones ending with the thread_id
226
+ thread_id = _ensure_uuid(thread_id)
227
+ if thread_id in self.queues:
228
+ for run_id in self.queues[thread_id]:
229
+ all_queues.extend(self.queues[thread_id][run_id])
105
230
 
106
- if run_id in self.message_stores:
107
- yield from self.message_stores[run_id][message_idx:]
231
+ return all_queues
108
232
 
109
233
 
110
234
  # Global instance
@@ -14,9 +14,17 @@ from langgraph_runtime_inmem.database import start_pool, stop_pool
14
14
  logger = structlog.stdlib.get_logger(__name__)
15
15
 
16
16
 
17
+ _LAST_LIFESPAN_ERROR: BaseException | None = None
18
+
19
+
20
+ def get_last_error() -> BaseException | None:
21
+ return _LAST_LIFESPAN_ERROR
22
+
23
+
17
24
  @asynccontextmanager
18
25
  async def lifespan(
19
26
  app: Starlette | None = None,
27
+ cancel_event: asyncio.Event | None = None,
20
28
  taskset: set[asyncio.Task] | None = None,
21
29
  **kwargs: Any,
22
30
  ):
@@ -41,13 +49,31 @@ async def lifespan(
41
49
  except RuntimeError:
42
50
  await logger.aerror("Failed to set loop")
43
51
 
52
+ global _LAST_LIFESPAN_ERROR
53
+ _LAST_LIFESPAN_ERROR = None
54
+
44
55
  await start_http_client()
45
56
  await start_pool()
46
57
  await start_ui_bundler()
58
+
59
+ async def _log_graph_load_failure(err: graph.GraphLoadError) -> None:
60
+ cause = err.__cause__ or err.cause
61
+ log_fields = err.log_fields()
62
+ log_fields["action"] = "fix_user_graph"
63
+ await logger.aerror(
64
+ f"Graph '{err.spec.id}' failed to load: {err.cause_message}",
65
+ **log_fields,
66
+ )
67
+ await logger.adebug(
68
+ "Full graph load failure traceback (internal)",
69
+ **{k: v for k, v in log_fields.items() if k != "user_traceback"},
70
+ exc_info=cause,
71
+ )
72
+
47
73
  try:
48
74
  async with SimpleTaskGroup(
49
75
  cancel=True,
50
- taskset=taskset,
76
+ cancel_event=cancel_event,
51
77
  taskgroup_name="Lifespan",
52
78
  ) as tg:
53
79
  tg.create_task(metadata_loop())
@@ -76,11 +102,21 @@ async def lifespan(
76
102
  var_child_runnable_config.set(langgraph_config)
77
103
 
78
104
  # Keep after the setter above so users can access the store from within the factory function
79
- await graph.collect_graphs_from_env(True)
105
+ try:
106
+ await graph.collect_graphs_from_env(True)
107
+ except graph.GraphLoadError as exc:
108
+ _LAST_LIFESPAN_ERROR = exc
109
+ await _log_graph_load_failure(exc)
110
+ raise
80
111
  if config.N_JOBS_PER_WORKER > 0:
81
112
  tg.create_task(queue_with_signal())
82
113
 
83
114
  yield
115
+ except graph.GraphLoadError as exc:
116
+ _LAST_LIFESPAN_ERROR = exc
117
+ raise
118
+ except asyncio.CancelledError:
119
+ pass
84
120
  finally:
85
121
  await api_store.exit_store()
86
122
  await stop_ui_bundler()
@@ -97,3 +133,6 @@ async def queue_with_signal():
97
133
  except Exception as exc:
98
134
  logger.exception("Queue failed. Signaling shutdown", exc_info=exc)
99
135
  signal.raise_signal(signal.SIGINT)
136
+
137
+
138
+ lifespan.get_last_error = get_last_error # type: ignore[attr-defined]
@@ -1,7 +1,7 @@
1
1
  from langgraph_runtime_inmem.queue import get_num_workers
2
2
 
3
3
 
4
- def get_metrics() -> dict[str, int]:
4
+ def get_metrics() -> dict[str, dict[str, int]]:
5
5
  from langgraph_api import config
6
6
 
7
7
  workers_max = config.N_JOBS_PER_WORKER