langgraph-api 0.5.4__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/__init__.py +93 -27
  3. langgraph_api/api/a2a.py +36 -32
  4. langgraph_api/api/assistants.py +114 -26
  5. langgraph_api/api/mcp.py +3 -3
  6. langgraph_api/api/meta.py +15 -2
  7. langgraph_api/api/openapi.py +27 -17
  8. langgraph_api/api/profile.py +108 -0
  9. langgraph_api/api/runs.py +114 -57
  10. langgraph_api/api/store.py +19 -2
  11. langgraph_api/api/threads.py +133 -10
  12. langgraph_api/asgi_transport.py +14 -9
  13. langgraph_api/auth/custom.py +23 -13
  14. langgraph_api/cli.py +86 -41
  15. langgraph_api/command.py +2 -2
  16. langgraph_api/config/__init__.py +532 -0
  17. langgraph_api/config/_parse.py +58 -0
  18. langgraph_api/config/schemas.py +431 -0
  19. langgraph_api/cron_scheduler.py +17 -1
  20. langgraph_api/encryption/__init__.py +15 -0
  21. langgraph_api/encryption/aes_json.py +158 -0
  22. langgraph_api/encryption/context.py +35 -0
  23. langgraph_api/encryption/custom.py +280 -0
  24. langgraph_api/encryption/middleware.py +632 -0
  25. langgraph_api/encryption/shared.py +63 -0
  26. langgraph_api/errors.py +12 -1
  27. langgraph_api/executor_entrypoint.py +11 -6
  28. langgraph_api/feature_flags.py +19 -0
  29. langgraph_api/graph.py +163 -64
  30. langgraph_api/{grpc_ops → grpc}/client.py +142 -12
  31. langgraph_api/{grpc_ops → grpc}/config_conversion.py +16 -10
  32. langgraph_api/grpc/generated/__init__.py +29 -0
  33. langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
  34. langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
  35. langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
  36. langgraph_api/grpc/generated/core_api_pb2.py +216 -0
  37. langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2.pyi +292 -372
  38. langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2_grpc.py +252 -31
  39. langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
  40. langgraph_api/{grpc_ops → grpc}/generated/engine_common_pb2.pyi +178 -104
  41. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
  42. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
  43. langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
  44. langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
  45. langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
  46. langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
  47. langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
  48. langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
  49. langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
  50. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
  51. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
  52. langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
  53. langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
  54. langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
  55. langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
  56. langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
  57. langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
  58. langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
  59. langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
  60. langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
  61. langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
  62. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
  63. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
  64. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
  65. langgraph_api/grpc/generated/errors_pb2.py +39 -0
  66. langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
  67. langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
  68. langgraph_api/grpc/ops/__init__.py +370 -0
  69. langgraph_api/grpc/ops/assistants.py +424 -0
  70. langgraph_api/grpc/ops/runs.py +792 -0
  71. langgraph_api/grpc/ops/threads.py +1013 -0
  72. langgraph_api/http.py +16 -5
  73. langgraph_api/js/client.mts +1 -4
  74. langgraph_api/js/package.json +28 -27
  75. langgraph_api/js/remote.py +39 -17
  76. langgraph_api/js/sse.py +2 -2
  77. langgraph_api/js/ui.py +1 -1
  78. langgraph_api/js/yarn.lock +1139 -869
  79. langgraph_api/metadata.py +29 -3
  80. langgraph_api/middleware/http_logger.py +1 -1
  81. langgraph_api/middleware/private_network.py +7 -7
  82. langgraph_api/models/run.py +44 -26
  83. langgraph_api/otel_context.py +205 -0
  84. langgraph_api/patch.py +2 -2
  85. langgraph_api/queue_entrypoint.py +34 -35
  86. langgraph_api/route.py +33 -1
  87. langgraph_api/schema.py +84 -9
  88. langgraph_api/self_hosted_logs.py +2 -2
  89. langgraph_api/self_hosted_metrics.py +73 -3
  90. langgraph_api/serde.py +16 -4
  91. langgraph_api/server.py +33 -31
  92. langgraph_api/state.py +3 -2
  93. langgraph_api/store.py +25 -16
  94. langgraph_api/stream.py +20 -16
  95. langgraph_api/thread_ttl.py +28 -13
  96. langgraph_api/timing/__init__.py +25 -0
  97. langgraph_api/timing/profiler.py +200 -0
  98. langgraph_api/timing/timer.py +318 -0
  99. langgraph_api/utils/__init__.py +53 -8
  100. langgraph_api/utils/config.py +2 -1
  101. langgraph_api/utils/future.py +10 -6
  102. langgraph_api/utils/uuids.py +29 -62
  103. langgraph_api/validation.py +6 -0
  104. langgraph_api/webhook.py +120 -6
  105. langgraph_api/worker.py +54 -24
  106. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +8 -6
  107. langgraph_api-0.7.3.dist-info/RECORD +168 -0
  108. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
  109. langgraph_runtime/__init__.py +1 -0
  110. langgraph_runtime/routes.py +11 -0
  111. logging.json +1 -3
  112. openapi.json +635 -537
  113. langgraph_api/config.py +0 -523
  114. langgraph_api/grpc_ops/generated/__init__.py +0 -5
  115. langgraph_api/grpc_ops/generated/core_api_pb2.py +0 -275
  116. langgraph_api/grpc_ops/generated/engine_common_pb2.py +0 -194
  117. langgraph_api/grpc_ops/ops.py +0 -1045
  118. langgraph_api-0.5.4.dist-info/RECORD +0 -121
  119. /langgraph_api/{grpc_ops → grpc}/__init__.py +0 -0
  120. /langgraph_api/{grpc_ops → grpc}/generated/engine_common_pb2_grpc.py +0 -0
  121. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
  122. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
langgraph_api/errors.py CHANGED
@@ -17,8 +17,19 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> Respon
17
17
  headers = getattr(exc, "headers", None)
18
18
  if not is_body_allowed_for_status_code(exc.status_code):
19
19
  return Response(status_code=exc.status_code, headers=headers)
20
+
21
+ detail = exc.detail
22
+ if not detail or not isinstance(detail, str):
23
+ logger.warning(
24
+ "HTTPException detail is not a string or was not set",
25
+ detail_type=type(detail).__name__,
26
+ status_code=exc.status_code,
27
+ )
28
+ # Use safe fallback that won't fail or leak sensitive info
29
+ detail = "unknown error"
30
+
20
31
  return JSONResponse(
21
- {"detail": exc.detail}, status_code=exc.status_code, headers=headers
32
+ {"detail": detail}, status_code=exc.status_code, headers=headers
22
33
  )
23
34
 
24
35
 
@@ -4,13 +4,10 @@ import json
4
4
  import logging.config
5
5
  import pathlib
6
6
 
7
- from langgraph_api.queue_entrypoint import main
7
+ from langgraph_api.queue_entrypoint import main as queue_main
8
8
 
9
- if __name__ == "__main__":
10
- parser = argparse.ArgumentParser()
11
9
 
12
- parser.add_argument("--grpc-port", type=int, default=50051)
13
- args = parser.parse_args()
10
+ async def main(grpc_port: int = 50051):
14
11
  with open(pathlib.Path(__file__).parent.parent / "logging.json") as file:
15
12
  loaded_config = json.load(file)
16
13
  logging.config.dictConfig(loaded_config)
@@ -23,4 +20,12 @@ if __name__ == "__main__":
23
20
  from langgraph_api import config
24
21
 
25
22
  config.IS_EXECUTOR_ENTRYPOINT = True
26
- asyncio.run(main(grpc_port=args.grpc_port, entrypoint_name="python-executor"))
23
+ await queue_main(grpc_port=grpc_port, entrypoint_name="python-executor")
24
+
25
+
26
+ if __name__ == "__main__":
27
+ parser = argparse.ArgumentParser()
28
+
29
+ parser.add_argument("--grpc-port", type=int, default=50051)
30
+ args = parser.parse_args()
31
+ asyncio.run(main(grpc_port=args.grpc_port))
@@ -16,3 +16,22 @@ FF_USE_CORE_API = os.getenv("FF_USE_CORE_API", "false").lower() in (
16
16
  "1",
17
17
  "yes",
18
18
  )
19
+
20
+ # Runtime edition detection
21
+ _RUNTIME_EDITION = os.getenv("LANGGRAPH_RUNTIME_EDITION", "inmem")
22
+ IS_POSTGRES_BACKEND = _RUNTIME_EDITION == "postgres"
23
+ IS_POSTGRES_OR_GRPC_BACKEND = IS_POSTGRES_BACKEND or FF_USE_CORE_API
24
+ # Feature flag for using the JS native API
25
+ FF_USE_JS_API = os.getenv("FF_USE_JS_API", "false").lower() in (
26
+ "true",
27
+ "1",
28
+ "yes",
29
+ )
30
+
31
+ # In langgraph <= 1.0.3, we automatically subscribed to updates stream events to surface interrupts. In langgraph 1.0.4 we include interrupts in values events (which we are automatically subscribed to), so we no longer need to implicitly subscribe to updates stream events
32
+ # If the version is not valid, e.g. rc/alpha/etc., we default to 0.0.0
33
+ try:
34
+ LANGGRAPH_PY_PATCH = tuple(map(int, __version__.split(".")[:3]))
35
+ except ValueError:
36
+ LANGGRAPH_PY_PATCH = (0, 0, 0)
37
+ UPDATES_NEEDED_FOR_INTERRUPTS = LANGGRAPH_PY_PATCH <= (1, 0, 3)
langgraph_api/graph.py CHANGED
@@ -3,17 +3,20 @@ import functools
3
3
  import glob
4
4
  import importlib.util
5
5
  import inspect
6
+ import logging
6
7
  import os
7
8
  import sys
9
+ import time
8
10
  import warnings
9
11
  from collections.abc import AsyncIterator, Callable
10
12
  from contextlib import asynccontextmanager
11
13
  from itertools import filterfalse
12
- from typing import TYPE_CHECKING, Any, NamedTuple, TypeGuard, cast
14
+ from typing import Any, NamedTuple, TypeGuard, cast
13
15
  from uuid import UUID, uuid5
14
16
 
15
17
  import orjson
16
18
  import structlog
19
+ from langchain_core.embeddings import Embeddings # noqa: TC002
17
20
  from langgraph.checkpoint.base import BaseCheckpointSaver
18
21
  from langgraph.constants import CONFIG_KEY_CHECKPOINTER
19
22
  from langgraph.graph import StateGraph
@@ -22,15 +25,17 @@ from langgraph.store.base import BaseStore
22
25
  from starlette.exceptions import HTTPException
23
26
 
24
27
  from langgraph_api import config as lg_api_config
25
- from langgraph_api.feature_flags import FF_USE_CORE_API, USE_RUNTIME_CONTEXT_API
28
+ from langgraph_api import timing
29
+ from langgraph_api.feature_flags import (
30
+ IS_POSTGRES_OR_GRPC_BACKEND,
31
+ USE_RUNTIME_CONTEXT_API,
32
+ )
26
33
  from langgraph_api.js.base import BaseRemotePregel, is_js_path
27
34
  from langgraph_api.schema import Config
35
+ from langgraph_api.timing import profiled_import
28
36
  from langgraph_api.utils.config import run_in_executor, var_child_runnable_config
29
37
  from langgraph_api.utils.errors import GraphLoadError
30
38
 
31
- if TYPE_CHECKING:
32
- from langchain_core.embeddings import Embeddings
33
-
34
39
  logger = structlog.stdlib.get_logger(__name__)
35
40
 
36
41
  GraphFactoryFromConfig = Callable[[Config], Pregel | StateGraph]
@@ -51,13 +56,13 @@ async def register_graph(
51
56
  description: str | None = None,
52
57
  ) -> None:
53
58
  """Register a graph."""
54
- from langgraph_api.grpc_ops.ops import Assistants as AssistantsGrpc
55
59
  from langgraph_runtime.database import connect
56
- from langgraph_runtime.ops import Assistants as AssistantsRuntime
57
60
 
58
- Assistants = AssistantsGrpc if FF_USE_CORE_API else AssistantsRuntime
61
+ if IS_POSTGRES_OR_GRPC_BACKEND:
62
+ from langgraph_api.grpc.ops import Assistants
63
+ else:
64
+ from langgraph_runtime.ops import Assistants
59
65
 
60
- await logger.ainfo(f"Registering graph with id '{graph_id}'", graph_id=graph_id)
61
66
  GRAPHS[graph_id] = graph
62
67
  if callable(graph):
63
68
  FACTORY_ACCEPTS_CONFIG[graph_id] = len(inspect.signature(graph).parameters) > 0
@@ -91,19 +96,70 @@ async def register_graph(
91
96
  await register_graph_db()
92
97
 
93
98
 
99
+ def _validate_assistant_id(assistant_id: str) -> None:
100
+ """Validate an assistant ID is either a graph_id or a valid UUID. Throw an error if not valid."""
101
+ if assistant_id and assistant_id not in GRAPHS:
102
+ # Not a graph_id, must be a valid UUID
103
+ try:
104
+ UUID(assistant_id)
105
+ except ValueError:
106
+ # Invalid format - return 404 to match test expectations
107
+ raise HTTPException(
108
+ status_code=404,
109
+ detail=f"Assistant '{assistant_id}' not found",
110
+ ) from None
111
+
112
+
113
+ def _log_slow_graph_generation(
114
+ start: float,
115
+ value_type: str,
116
+ graph_id: str,
117
+ warn_threshold_ms: float = 100,
118
+ error_threshold_ms: float = 250,
119
+ ) -> None:
120
+ """Log warning/error if graph generation was slow."""
121
+ elapsed_secs = time.perf_counter() - start
122
+ elapsed_ms = elapsed_secs * 1000
123
+ elapsed_ms_rounded = round(elapsed_ms, 2)
124
+ log_level = None
125
+ if elapsed_ms > error_threshold_ms:
126
+ log_level = logging.ERROR
127
+ elif elapsed_ms > warn_threshold_ms:
128
+ log_level = logging.WARNING
129
+ if log_level is not None:
130
+ logger.log(
131
+ log_level,
132
+ f"Slow graph load. Accessing graph '{graph_id}' took {elapsed_ms_rounded}ms."
133
+ " Move expensive initialization (API clients, DB connections, model loading)"
134
+ " from graph factory if you are seeing API slowness.",
135
+ elapsed_ms=elapsed_ms_rounded,
136
+ value_type=value_type,
137
+ graph_id=graph_id,
138
+ )
139
+
140
+
94
141
  @asynccontextmanager
95
- async def _generate_graph(value: Any) -> AsyncIterator[Any]:
96
- """Yield a graph object regardless of its type."""
142
+ async def _generate_graph(value: Any, graph_id: str) -> AsyncIterator[Any]:
143
+ """Yield a graph object regardless of its type.
144
+
145
+ Logs a warning if graph generation takes >100ms, error if >250ms.
146
+ """
147
+ start = time.perf_counter()
148
+ value_type = type(value).__name__
97
149
  if isinstance(value, Pregel | BaseRemotePregel):
98
150
  yield value
99
151
  elif hasattr(value, "__aenter__") and hasattr(value, "__aexit__"):
100
152
  async with value as ctx_value:
153
+ _log_slow_graph_generation(start, value_type, graph_id)
101
154
  yield ctx_value
102
155
  elif hasattr(value, "__enter__") and hasattr(value, "__exit__"):
103
156
  with value as ctx_value:
157
+ _log_slow_graph_generation(start, value_type, graph_id)
104
158
  yield ctx_value
105
159
  elif asyncio.iscoroutine(value):
106
- yield await value
160
+ result = await value
161
+ _log_slow_graph_generation(start, value_type, graph_id)
162
+ yield result
107
163
  else:
108
164
  yield value
109
165
 
@@ -132,14 +188,18 @@ async def get_graph(
132
188
  *,
133
189
  checkpointer: BaseCheckpointSaver | None = None,
134
190
  store: BaseStore | None = None,
191
+ is_for_execution: bool = True,
135
192
  ) -> AsyncIterator[Pregel]:
136
193
  """Return the runnable."""
137
194
  from langgraph_api.utils import config as lg_config
195
+ from langgraph_api.utils import merge_auth
138
196
 
139
197
  assert_graph_exists(graph_id)
140
198
  value = GRAPHS[graph_id]
141
199
  if is_factory(value, graph_id):
142
200
  config = lg_config.ensure_config(config)
201
+ config["configurable"]["__is_for_execution__"] = is_for_execution
202
+ config = merge_auth(config)
143
203
 
144
204
  if store is not None:
145
205
  if USE_RUNTIME_CONTEXT_API:
@@ -152,7 +212,7 @@ async def get_graph(
152
212
  elif isinstance(runtime, dict):
153
213
  patched_runtime = Runtime(**(runtime | {"store": store}))
154
214
  elif runtime.store is None:
155
- patched_runtime = cast(Runtime, runtime).override(store=store)
215
+ patched_runtime = cast("Runtime", runtime).override(store=store)
156
216
  else:
157
217
  patched_runtime = runtime
158
218
 
@@ -170,7 +230,7 @@ async def get_graph(
170
230
  var_child_runnable_config.set(config)
171
231
  value = value(config) if factory_accepts_config(value, graph_id) else value()
172
232
  try:
173
- async with _generate_graph(value) as graph_obj:
233
+ async with _generate_graph(value, graph_id) as graph_obj:
174
234
  if isinstance(graph_obj, StateGraph):
175
235
  graph_obj = graph_obj.compile()
176
236
  if not isinstance(graph_obj, Pregel | BaseRemotePregel):
@@ -232,9 +292,9 @@ class GraphSpec(NamedTuple):
232
292
  variable: str | None = None
233
293
  config: dict | None = None
234
294
  """The configuration for the graph.
235
-
295
+
236
296
  Contains information such as: tags, recursion_limit and configurable.
237
-
297
+
238
298
  Configurable is a dict containing user defined values for the graph.
239
299
  """
240
300
  description: str | None = None
@@ -429,7 +489,7 @@ async def collect_graphs_from_env(register: bool = False) -> None:
429
489
  def _handle_exception(task: asyncio.Task) -> None:
430
490
  try:
431
491
  task.result()
432
- except asyncio.CancelledError:
492
+ except (asyncio.CancelledError, SystemExit):
433
493
  pass
434
494
  except Exception as e:
435
495
  logger.exception("Task failed", exc_info=e)
@@ -448,42 +508,59 @@ def verify_graphs() -> None:
448
508
  asyncio.run(collect_graphs_from_env())
449
509
 
450
510
 
511
+ def _metadata_fn(spec: GraphSpec) -> dict[str, Any]:
512
+ return {"graph_id": spec.id, "module": spec.module, "path": spec.path}
513
+
514
+
515
+ @timing.timer(
516
+ message="Importing graph with id {graph_id}",
517
+ metadata_fn=_metadata_fn,
518
+ warn_threshold_secs=3,
519
+ warn_message=(
520
+ "Import for graph {graph_id} exceeded the expected startup time. "
521
+ "Slow initialization (often due to work executed at import time) can delay readiness, "
522
+ "reduce scale-out capacity, and may cause deployments to be marked unhealthy."
523
+ ),
524
+ error_threshold_secs=30,
525
+ )
451
526
  def _graph_from_spec(spec: GraphSpec) -> GraphValue:
452
527
  """Return a graph from a spec."""
453
528
  # import the graph module
454
- if spec.module:
455
- module = importlib.import_module(spec.module)
456
- elif spec.path:
457
- try:
458
- modname = (
459
- spec.path.replace("/", "__")
460
- .replace(".py", "")
461
- .replace(" ", "_")
462
- .lstrip(".")
463
- )
464
- modspec = importlib.util.spec_from_file_location(modname, spec.path)
465
- if modspec is None:
466
- raise ValueError(f"Could not find python file for graph: {spec}")
467
- module = importlib.util.module_from_spec(modspec)
468
- sys.modules[modname] = module
469
- modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
470
- except ImportError as e:
471
- e.add_note(f"Could not import python module for graph:\n{spec}")
472
- if lg_api_config.API_VARIANT == "local_dev":
473
- e.add_note(
474
- "This error likely means you haven't installed your project and its dependencies yet. Before running the server, install your project:\n\n"
475
- "If you are using requirements.txt:\n"
476
- "python -m pip install -r requirements.txt\n\n"
477
- "If you are using pyproject.toml or setuptools:\n"
478
- "python -m pip install -e .\n\n"
479
- "Make sure to run this command from your project's root directory (where your setup.py or pyproject.toml is located)"
529
+ import_path = f"{spec.module or spec.path}:{spec.variable or '<auto>'}"
530
+ with profiled_import(import_path):
531
+ if spec.module:
532
+ module = importlib.import_module(spec.module)
533
+ elif spec.path:
534
+ try:
535
+ modname = (
536
+ spec.path.replace("/", "__")
537
+ .replace(".py", "")
538
+ .replace(" ", "_")
539
+ .lstrip(".")
480
540
  )
481
- raise
482
- except FileNotFoundError as e:
483
- e.add_note(f"Could not find python file for graph: {spec}")
484
- raise
485
- else:
486
- raise ValueError("Graph specification must have a path or module")
541
+ modspec = importlib.util.spec_from_file_location(modname, spec.path)
542
+ if modspec is None:
543
+ raise ValueError(f"Could not find python file for graph: {spec}")
544
+ module = importlib.util.module_from_spec(modspec)
545
+ sys.modules[modname] = module
546
+ modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
547
+ except ImportError as e:
548
+ e.add_note(f"Could not import python module for graph:\n{spec}")
549
+ if lg_api_config.API_VARIANT == "local_dev":
550
+ e.add_note(
551
+ "This error likely means you haven't installed your project and its dependencies yet. Before running the server, install your project:\n\n"
552
+ "If you are using requirements.txt:\n"
553
+ "python -m pip install -r requirements.txt\n\n"
554
+ "If you are using pyproject.toml or setuptools:\n"
555
+ "python -m pip install -e .\n\n"
556
+ "Make sure to run this command from your project's root directory (where your setup.py or pyproject.toml is located)"
557
+ )
558
+ raise
559
+ except FileNotFoundError as e:
560
+ e.add_note(f"Could not find python file for graph: {spec}")
561
+ raise
562
+ else:
563
+ raise ValueError("Graph specification must have a path or module")
487
564
 
488
565
  if spec.variable:
489
566
  try:
@@ -589,6 +666,13 @@ def _get_init_embeddings() -> Callable[[str, ...], "Embeddings"] | None:
589
666
  return None
590
667
 
591
668
 
669
+ @timing.timer(
670
+ message="Loading embeddings {embeddings_path}",
671
+ metadata_fn=lambda index_config: {"embeddings_path": index_config.get("embed")},
672
+ warn_threshold_secs=5,
673
+ warn_message="Loading embeddings '{embeddings_path}' took longer than expected",
674
+ error_threshold_secs=10,
675
+ )
592
676
  def resolve_embeddings(index_config: dict) -> "Embeddings":
593
677
  """Return embeddings from config.
594
678
 
@@ -607,26 +691,41 @@ def resolve_embeddings(index_config: dict) -> "Embeddings":
607
691
  from langchain_core.embeddings import Embeddings
608
692
  from langgraph.store.base import ensure_embeddings
609
693
 
610
- embed: str = index_config["embed"]
694
+ embed = index_config["embed"]
695
+ if isinstance(embed, Embeddings):
696
+ return embed
697
+ if callable(embed):
698
+ return ensure_embeddings(embed)
699
+ if not isinstance(embed, str):
700
+ raise ValueError(
701
+ f"Embeddings config must be a string or callable, got: {type(embed).__name__}"
702
+ )
611
703
  if ".py:" in embed:
612
704
  module_name, function = embed.rsplit(":", 1)
613
705
  module_name = module_name.rstrip(":")
614
706
 
615
707
  try:
616
- if "/" in module_name:
617
- # Load from file path
618
- modname = (
619
- module_name.replace("/", "__").replace(".py", "").replace(" ", "_")
620
- )
621
- modspec = importlib.util.spec_from_file_location(modname, module_name)
622
- if modspec is None:
623
- raise ValueError(f"Could not find embeddings file: {module_name}")
624
- module = importlib.util.module_from_spec(modspec)
625
- sys.modules[modname] = module
626
- modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
627
- else:
628
- # Load from Python module
629
- module = importlib.import_module(module_name)
708
+ with profiled_import(embed):
709
+ if "/" in module_name:
710
+ # Load from file path
711
+ modname = (
712
+ module_name.replace("/", "__")
713
+ .replace(".py", "")
714
+ .replace(" ", "_")
715
+ )
716
+ modspec = importlib.util.spec_from_file_location(
717
+ modname, module_name
718
+ )
719
+ if modspec is None:
720
+ raise ValueError(
721
+ f"Could not find embeddings file: {module_name}"
722
+ )
723
+ module = importlib.util.module_from_spec(modspec)
724
+ sys.modules[modname] = module
725
+ modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
726
+ else:
727
+ # Load from Python module
728
+ module = importlib.import_module(module_name)
630
729
 
631
730
  embedding_fn = getattr(module, function, None)
632
731
  if embedding_fn is None:
@@ -1,18 +1,34 @@
1
1
  """gRPC client wrapper for LangGraph persistence services."""
2
2
 
3
3
  import asyncio
4
- import os
4
+ import threading
5
+ import time
5
6
 
6
7
  import structlog
7
8
  from grpc import aio # type: ignore[import]
9
+ from grpc_health.v1 import health_pb2, health_pb2_grpc # type: ignore[import]
8
10
 
9
- from .generated.core_api_pb2_grpc import AdminStub, AssistantsStub, ThreadsStub
11
+ from langgraph_api import config
12
+
13
+ from .generated.checkpointer_pb2_grpc import CheckpointerStub
14
+ from .generated.core_api_pb2_grpc import (
15
+ AdminStub,
16
+ AssistantsStub,
17
+ RunsStub,
18
+ ThreadsStub,
19
+ )
10
20
 
11
21
  logger = structlog.stdlib.get_logger(__name__)
12
22
 
13
23
 
14
- # Shared global client pool
24
+ # Shared gRPC client pools (main thread + thread-local for isolated loops).
15
25
  _client_pool: "GrpcClientPool | None" = None
26
+ _thread_local = threading.local()
27
+
28
+
29
+ GRPC_HEALTHCHECK_TIMEOUT = 5.0
30
+ GRPC_INIT_TIMEOUT = 10.0
31
+ GRPC_INIT_PROBE_INTERVAL = 0.5
16
32
 
17
33
 
18
34
  class GrpcClient:
@@ -27,13 +43,14 @@ class GrpcClient:
27
43
  Args:
28
44
  server_address: The gRPC server address (default: localhost:50051)
29
45
  """
30
- self.server_address = server_address or os.getenv(
31
- "GRPC_SERVER_ADDRESS", "localhost:50051"
32
- )
46
+ self.server_address = server_address or config.GRPC_SERVER_ADDRESS
33
47
  self._channel: aio.Channel | None = None
34
48
  self._assistants_stub: AssistantsStub | None = None
49
+ self._runs_stub: RunsStub | None = None
35
50
  self._threads_stub: ThreadsStub | None = None
36
51
  self._admin_stub: AdminStub | None = None
52
+ self._checkpointer_stub: CheckpointerStub | None = None
53
+ self._health_stub: health_pb2_grpc.HealthStub | None = None
37
54
 
38
55
  async def __aenter__(self):
39
56
  """Async context manager entry."""
@@ -49,11 +66,19 @@ class GrpcClient:
49
66
  if self._channel is not None:
50
67
  return
51
68
 
52
- self._channel = aio.insecure_channel(self.server_address)
69
+ options = [
70
+ ("grpc.max_receive_message_length", config.GRPC_CLIENT_MAX_RECV_MSG_BYTES),
71
+ ("grpc.max_send_message_length", config.GRPC_CLIENT_MAX_SEND_MSG_BYTES),
72
+ ]
73
+
74
+ self._channel = aio.insecure_channel(self.server_address, options=options)
53
75
 
54
76
  self._assistants_stub = AssistantsStub(self._channel)
77
+ self._runs_stub = RunsStub(self._channel)
55
78
  self._threads_stub = ThreadsStub(self._channel)
56
79
  self._admin_stub = AdminStub(self._channel)
80
+ self._checkpointer_stub = CheckpointerStub(self._channel)
81
+ self._health_stub = health_pb2_grpc.HealthStub(self._channel)
57
82
 
58
83
  await logger.adebug(
59
84
  "Connected to gRPC server", server_address=self.server_address
@@ -65,10 +90,37 @@ class GrpcClient:
65
90
  await self._channel.close()
66
91
  self._channel = None
67
92
  self._assistants_stub = None
93
+ self._runs_stub = None
68
94
  self._threads_stub = None
69
95
  self._admin_stub = None
96
+ self._checkpointer_stub = None
97
+ self._health_stub = None
70
98
  await logger.adebug("Closed gRPC connection")
71
99
 
100
+ async def healthcheck(self) -> bool:
101
+ """Check if the gRPC server is healthy.
102
+
103
+ Returns:
104
+ True if the server is healthy and serving.
105
+
106
+ Raises:
107
+ RuntimeError: If the client is not connected or the server is unhealthy.
108
+ """
109
+ if self._health_stub is None:
110
+ raise RuntimeError(
111
+ "Client not connected. Use async context manager or call connect() first."
112
+ )
113
+
114
+ request = health_pb2.HealthCheckRequest(service="")
115
+ response = await self._health_stub.Check(
116
+ request, timeout=GRPC_HEALTHCHECK_TIMEOUT
117
+ )
118
+
119
+ if response.status != health_pb2.HealthCheckResponse.SERVING:
120
+ raise RuntimeError(f"gRPC server is not healthy. Status: {response.status}")
121
+
122
+ return True
123
+
72
124
  @property
73
125
  def assistants(self) -> AssistantsStub:
74
126
  """Get the assistants service stub."""
@@ -87,6 +139,15 @@ class GrpcClient:
87
139
  )
88
140
  return self._threads_stub
89
141
 
142
+ @property
143
+ def runs(self) -> RunsStub:
144
+ """Get the runs service stub."""
145
+ if self._runs_stub is None:
146
+ raise RuntimeError(
147
+ "Client not connected. Use async context manager or call connect() first."
148
+ )
149
+ return self._runs_stub
150
+
90
151
  @property
91
152
  def admin(self) -> AdminStub:
92
153
  """Get the admin service stub."""
@@ -96,6 +157,15 @@ class GrpcClient:
96
157
  )
97
158
  return self._admin_stub
98
159
 
160
+ @property
161
+ def checkpointer(self) -> CheckpointerStub:
162
+ """Get the checkpointer service stub."""
163
+ if self._checkpointer_stub is None:
164
+ raise RuntimeError(
165
+ "Client not connected. Use async context manager or call connect() first."
166
+ )
167
+ return self._checkpointer_stub
168
+
99
169
 
100
170
  class GrpcClientPool:
101
171
  """Pool of gRPC clients for load distribution."""
@@ -158,25 +228,85 @@ async def get_shared_client() -> GrpcClient:
158
228
 
159
229
  Uses a pool of channels for better performance under high concurrency.
160
230
  Each channel is a separate TCP connection that can handle ~100-200
161
- concurrent streams effectively.
231
+ concurrent streams effectively. Pools are scoped per thread/loop to
232
+ avoid cross-loop gRPC channel usage.
162
233
 
163
234
  Returns:
164
235
  A GrpcClient instance from the pool
165
236
  """
237
+ if threading.current_thread() is not threading.main_thread():
238
+ pool = getattr(_thread_local, "grpc_pool", None)
239
+ if pool is None:
240
+ pool = GrpcClientPool(
241
+ pool_size=1,
242
+ server_address=config.GRPC_SERVER_ADDRESS,
243
+ )
244
+ _thread_local.grpc_pool = pool
245
+ return await pool.get_client()
246
+
166
247
  global _client_pool
167
248
  if _client_pool is None:
168
- from langgraph_api import config
169
-
170
249
  _client_pool = GrpcClientPool(
171
250
  pool_size=config.GRPC_CLIENT_POOL_SIZE,
172
- server_address=os.getenv("GRPC_SERVER_ADDRESS"),
251
+ server_address=config.GRPC_SERVER_ADDRESS,
173
252
  )
174
-
175
253
  return await _client_pool.get_client()
176
254
 
177
255
 
256
+ async def wait_until_grpc_ready(
257
+ timeout_seconds: float = GRPC_INIT_TIMEOUT,
258
+ interval_seconds: float = GRPC_INIT_PROBE_INTERVAL,
259
+ ):
260
+ """Wait for the gRPC server to be ready with retries during startup.
261
+
262
+ Args:
263
+ timeout_seconds: Maximum time to wait for the server to be ready.
264
+ interval_seconds: Time to wait between health check attempts.
265
+ Raises:
266
+ RuntimeError: If the server is not ready within the timeout period.
267
+ """
268
+ client = await get_shared_client()
269
+ max_attempts = int(timeout_seconds / interval_seconds)
270
+
271
+ await logger.ainfo(
272
+ "Waiting for gRPC server to be ready",
273
+ timeout_seconds=timeout_seconds,
274
+ interval_seconds=interval_seconds,
275
+ max_attempts=max_attempts,
276
+ )
277
+ start_time = time.time()
278
+ for attempt in range(max_attempts):
279
+ try:
280
+ await client.healthcheck()
281
+ await logger.ainfo(
282
+ "gRPC server is ready",
283
+ attempt=attempt + 1,
284
+ elapsed_seconds=round(time.time() - start_time, 3),
285
+ )
286
+ return
287
+ except Exception as exc:
288
+ if attempt >= max_attempts - 1:
289
+ raise RuntimeError(
290
+ f"gRPC server not ready after {timeout_seconds}s (reached max attempts: {max_attempts})"
291
+ ) from exc
292
+ else:
293
+ await logger.adebug(
294
+ "Waiting for gRPC server to be ready",
295
+ attempt=attempt + 1,
296
+ max_attempts=max_attempts,
297
+ )
298
+ await asyncio.sleep(interval_seconds)
299
+
300
+
178
301
  async def close_shared_client():
179
302
  """Close the shared gRPC client pool."""
303
+ if threading.current_thread() is not threading.main_thread():
304
+ pool = getattr(_thread_local, "grpc_pool", None)
305
+ if pool is not None:
306
+ await pool.close()
307
+ delattr(_thread_local, "grpc_pool")
308
+ return
309
+
180
310
  global _client_pool
181
311
  if _client_pool is not None:
182
312
  await _client_pool.close()