langgraph-runtime-inmem 0.9.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.9.0"
12
+ __version__ = "0.20.1"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -142,7 +142,9 @@ class InMemConnectionProto:
142
142
 
143
143
 
144
144
  @asynccontextmanager
145
- async def connect(*, __test__: bool = False) -> AsyncIterator["AsyncConnectionProto"]:
145
+ async def connect(
146
+ *, supports_core_api: bool = False, __test__: bool = False
147
+ ) -> AsyncIterator["AsyncConnectionProto"]:
146
148
  yield InMemConnectionProto()
147
149
 
148
150
 
@@ -58,6 +58,7 @@ class StreamManager:
58
58
  ) # Dict[str, List[asyncio.Queue]]
59
59
  self.control_keys = defaultdict(lambda: defaultdict())
60
60
  self.control_queues = defaultdict(lambda: defaultdict(list))
61
+ self.thread_streams = defaultdict(list)
61
62
 
62
63
  self.message_stores = defaultdict(
63
64
  lambda: defaultdict(list[Message])
@@ -95,7 +96,7 @@ class StreamManager:
95
96
 
96
97
  async def put(
97
98
  self,
98
- run_id: UUID | str,
99
+ run_id: UUID | str | None,
99
100
  thread_id: UUID | str | None,
100
101
  message: Message,
101
102
  resumable: bool = False,
@@ -107,9 +108,10 @@ class StreamManager:
107
108
  thread_id = _ensure_uuid(thread_id)
108
109
 
109
110
  message.id = _generate_ms_seq_id().encode()
111
+ # For resumable run streams, embed the generated message ID into the frame
112
+ topic = message.topic.decode()
110
113
  if resumable:
111
114
  self.message_stores[thread_id][run_id].append(message)
112
- topic = message.topic.decode()
113
115
  if "control" in topic:
114
116
  self.control_keys[thread_id][run_id] = message
115
117
  queues = self.control_queues[thread_id][run_id]
@@ -121,6 +123,20 @@ class StreamManager:
121
123
  if isinstance(result, Exception):
122
124
  logger.exception(f"Failed to put message in queue: {result}")
123
125
 
126
+ async def put_thread(
127
+ self,
128
+ thread_id: UUID | str,
129
+ message: Message,
130
+ ) -> None:
131
+ thread_id = _ensure_uuid(thread_id)
132
+ message.id = _generate_ms_seq_id().encode()
133
+ queues = self.thread_streams[thread_id]
134
+ coros = [queue.put(message) for queue in queues]
135
+ results = await asyncio.gather(*coros, return_exceptions=True)
136
+ for result in results:
137
+ if isinstance(result, Exception):
138
+ logger.exception(f"Failed to put message in queue: {result}")
139
+
124
140
  async def add_queue(
125
141
  self, run_id: UUID | str, thread_id: UUID | str | None
126
142
  ) -> asyncio.Queue:
@@ -145,6 +161,12 @@ class StreamManager:
145
161
  self.control_queues[thread_id][run_id].append(queue)
146
162
  return queue
147
163
 
164
+ async def add_thread_stream(self, thread_id: UUID | str) -> asyncio.Queue:
165
+ thread_id = _ensure_uuid(thread_id)
166
+ queue = ContextQueue()
167
+ self.thread_streams[thread_id].append(queue)
168
+ return queue
169
+
148
170
  async def remove_queue(
149
171
  self, run_id: UUID | str, thread_id: UUID | str | None, queue: asyncio.Queue
150
172
  ):
@@ -14,9 +14,17 @@ from langgraph_runtime_inmem.database import start_pool, stop_pool
14
14
  logger = structlog.stdlib.get_logger(__name__)
15
15
 
16
16
 
17
+ _LAST_LIFESPAN_ERROR: BaseException | None = None
18
+
19
+
20
+ def get_last_error() -> BaseException | None:
21
+ return _LAST_LIFESPAN_ERROR
22
+
23
+
17
24
  @asynccontextmanager
18
25
  async def lifespan(
19
26
  app: Starlette | None = None,
27
+ cancel_event: asyncio.Event | None = None,
20
28
  taskset: set[asyncio.Task] | None = None,
21
29
  **kwargs: Any,
22
30
  ):
@@ -41,13 +49,31 @@ async def lifespan(
41
49
  except RuntimeError:
42
50
  await logger.aerror("Failed to set loop")
43
51
 
52
+ global _LAST_LIFESPAN_ERROR
53
+ _LAST_LIFESPAN_ERROR = None
54
+
44
55
  await start_http_client()
45
56
  await start_pool()
46
57
  await start_ui_bundler()
58
+
59
+ async def _log_graph_load_failure(err: graph.GraphLoadError) -> None:
60
+ cause = err.__cause__ or err.cause
61
+ log_fields = err.log_fields()
62
+ log_fields["action"] = "fix_user_graph"
63
+ await logger.aerror(
64
+ f"Graph '{err.spec.id}' failed to load: {err.cause_message}",
65
+ **log_fields,
66
+ )
67
+ await logger.adebug(
68
+ "Full graph load failure traceback (internal)",
69
+ **{k: v for k, v in log_fields.items() if k != "user_traceback"},
70
+ exc_info=cause,
71
+ )
72
+
47
73
  try:
48
74
  async with SimpleTaskGroup(
49
75
  cancel=True,
50
- taskset=taskset,
76
+ cancel_event=cancel_event,
51
77
  taskgroup_name="Lifespan",
52
78
  ) as tg:
53
79
  tg.create_task(metadata_loop())
@@ -76,11 +102,21 @@ async def lifespan(
76
102
  var_child_runnable_config.set(langgraph_config)
77
103
 
78
104
  # Keep after the setter above so users can access the store from within the factory function
79
- await graph.collect_graphs_from_env(True)
105
+ try:
106
+ await graph.collect_graphs_from_env(True)
107
+ except graph.GraphLoadError as exc:
108
+ _LAST_LIFESPAN_ERROR = exc
109
+ await _log_graph_load_failure(exc)
110
+ raise
80
111
  if config.N_JOBS_PER_WORKER > 0:
81
112
  tg.create_task(queue_with_signal())
82
113
 
83
114
  yield
115
+ except graph.GraphLoadError as exc:
116
+ _LAST_LIFESPAN_ERROR = exc
117
+ raise
118
+ except asyncio.CancelledError:
119
+ pass
84
120
  finally:
85
121
  await api_store.exit_store()
86
122
  await stop_ui_bundler()
@@ -97,3 +133,6 @@ async def queue_with_signal():
97
133
  except Exception as exc:
98
134
  logger.exception("Queue failed. Signaling shutdown", exc_info=exc)
99
135
  signal.raise_signal(signal.SIGINT)
136
+
137
+
138
+ lifespan.get_last_error = get_last_error # type: ignore[attr-defined]
@@ -1,7 +1,7 @@
1
1
  from langgraph_runtime_inmem.queue import get_num_workers
2
2
 
3
3
 
4
- def get_metrics() -> dict[str, int]:
4
+ def get_metrics() -> dict[str, dict[str, int]]:
5
5
  from langgraph_api import config
6
6
 
7
7
  workers_max = config.N_JOBS_PER_WORKER