langgraph-api 0.2.130__py3-none-any.whl → 0.2.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/assistants.py +32 -6
  3. langgraph_api/api/meta.py +3 -1
  4. langgraph_api/api/openapi.py +1 -1
  5. langgraph_api/api/runs.py +50 -10
  6. langgraph_api/api/threads.py +27 -1
  7. langgraph_api/api/ui.py +2 -0
  8. langgraph_api/asgi_transport.py +2 -2
  9. langgraph_api/asyncio.py +10 -8
  10. langgraph_api/auth/custom.py +9 -4
  11. langgraph_api/auth/langsmith/client.py +1 -1
  12. langgraph_api/cli.py +5 -4
  13. langgraph_api/config.py +1 -1
  14. langgraph_api/executor_entrypoint.py +23 -0
  15. langgraph_api/graph.py +25 -9
  16. langgraph_api/http.py +10 -7
  17. langgraph_api/http_metrics.py +4 -1
  18. langgraph_api/js/build.mts +11 -2
  19. langgraph_api/js/client.http.mts +2 -0
  20. langgraph_api/js/client.mts +13 -3
  21. langgraph_api/js/package.json +2 -2
  22. langgraph_api/js/remote.py +17 -12
  23. langgraph_api/js/src/preload.mjs +9 -1
  24. langgraph_api/js/src/utils/files.mts +5 -2
  25. langgraph_api/js/sse.py +1 -1
  26. langgraph_api/js/yarn.lock +9 -9
  27. langgraph_api/logging.py +3 -3
  28. langgraph_api/middleware/http_logger.py +2 -1
  29. langgraph_api/models/run.py +19 -14
  30. langgraph_api/patch.py +2 -2
  31. langgraph_api/queue_entrypoint.py +33 -18
  32. langgraph_api/schema.py +88 -4
  33. langgraph_api/serde.py +32 -5
  34. langgraph_api/server.py +5 -3
  35. langgraph_api/state.py +8 -8
  36. langgraph_api/store.py +1 -1
  37. langgraph_api/stream.py +33 -20
  38. langgraph_api/traceblock.py +1 -1
  39. langgraph_api/utils/__init__.py +40 -5
  40. langgraph_api/utils/config.py +13 -4
  41. langgraph_api/utils/future.py +1 -1
  42. langgraph_api/utils/uuids.py +87 -0
  43. langgraph_api/validation.py +9 -0
  44. langgraph_api/webhook.py +20 -20
  45. langgraph_api/worker.py +8 -5
  46. {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/METADATA +2 -2
  47. {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/RECORD +51 -49
  48. openapi.json +331 -1
  49. {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/WHEEL +0 -0
  50. {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/entry_points.txt +0 -0
  51. {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/licenses/LICENSE +0 -0
langgraph_api/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.2.130"
1
+ __version__ = "0.2.134"
@@ -16,9 +16,16 @@ from langgraph_api.feature_flags import USE_RUNTIME_CONTEXT_API
16
16
  from langgraph_api.graph import get_assistant_id, get_graph
17
17
  from langgraph_api.js.base import BaseRemotePregel
18
18
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
19
+ from langgraph_api.schema import ASSISTANT_FIELDS
19
20
  from langgraph_api.serde import ajson_loads
20
- from langgraph_api.utils import fetchone, get_pagination_headers, validate_uuid
21
+ from langgraph_api.utils import (
22
+ fetchone,
23
+ get_pagination_headers,
24
+ validate_select_columns,
25
+ validate_uuid,
26
+ )
21
27
  from langgraph_api.validation import (
28
+ AssistantCountRequest,
22
29
  AssistantCreate,
23
30
  AssistantPatch,
24
31
  AssistantSearchRequest,
@@ -61,7 +68,8 @@ def _get_configurable_jsonschema(graph: Pregel) -> dict:
61
68
  in favor of graph.get_context_jsonschema().
62
69
  """
63
70
  # Otherwise, use the config_schema method.
64
- config_schema = graph.config_schema()
71
+ # TODO: Remove this when we no longer support langgraph < 0.6
72
+ config_schema = graph.config_schema() # type: ignore[deprecated]
65
73
  model_fields = getattr(config_schema, "model_fields", None) or getattr(
66
74
  config_schema, "__fields__", None
67
75
  )
@@ -87,11 +95,11 @@ def _state_jsonschema(graph: Pregel) -> dict | None:
87
95
  for k in graph.stream_channels_list:
88
96
  v = graph.channels[k]
89
97
  try:
90
- create_model(k, __root__=(v.UpdateType, None)).schema()
98
+ create_model(k, __root__=(v.UpdateType, None)).model_json_schema()
91
99
  fields[k] = (v.UpdateType, None)
92
100
  except Exception:
93
101
  fields[k] = (Any, None)
94
- return create_model(graph.get_name("State"), **fields).schema()
102
+ return create_model(graph.get_name("State"), **fields).model_json_schema()
95
103
 
96
104
 
97
105
  def _graph_schemas(graph: Pregel) -> dict:
@@ -132,7 +140,7 @@ def _graph_schemas(graph: Pregel) -> dict:
132
140
  logger.warning(
133
141
  f"Failed to get context schema for graph {graph.name} with error: `{str(e)}`"
134
142
  )
135
- context_schema = graph.config_schema()
143
+ context_schema = graph.config_schema() # type: ignore[deprecated]
136
144
  else:
137
145
  context_schema = None
138
146
 
@@ -172,6 +180,7 @@ async def search_assistants(
172
180
  ) -> ApiResponse:
173
181
  """List assistants."""
174
182
  payload = await request.json(AssistantSearchRequest)
183
+ select = validate_select_columns(payload.get("select") or None, ASSISTANT_FIELDS)
175
184
  offset = int(payload.get("offset") or 0)
176
185
  async with connect() as conn:
177
186
  assistants_iter, next_offset = await Assistants.search(
@@ -182,6 +191,7 @@ async def search_assistants(
182
191
  offset=offset,
183
192
  sort_by=payload.get("sort_by"),
184
193
  sort_order=payload.get("sort_order"),
194
+ select=select,
185
195
  )
186
196
  assistants, response_headers = await get_pagination_headers(
187
197
  assistants_iter, next_offset, offset
@@ -189,6 +199,21 @@ async def search_assistants(
189
199
  return ApiResponse(assistants, headers=response_headers)
190
200
 
191
201
 
202
+ @retry_db
203
+ async def count_assistants(
204
+ request: ApiRequest,
205
+ ) -> ApiResponse:
206
+ """Count assistants."""
207
+ payload = await request.json(AssistantCountRequest)
208
+ async with connect() as conn:
209
+ count = await Assistants.count(
210
+ conn,
211
+ graph_id=payload.get("graph_id"),
212
+ metadata=payload.get("metadata"),
213
+ )
214
+ return ApiResponse(count)
215
+
216
+
192
217
  @retry_db
193
218
  async def get_assistant(
194
219
  request: ApiRequest,
@@ -366,7 +391,7 @@ async def patch_assistant(
366
391
 
367
392
 
368
393
  @retry_db
369
- async def delete_assistant(request: ApiRequest) -> ApiResponse:
394
+ async def delete_assistant(request: ApiRequest) -> Response:
370
395
  """Delete an assistant by ID."""
371
396
  assistant_id = request.path_params["assistant_id"]
372
397
  validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
@@ -414,6 +439,7 @@ async def set_latest_assistant_version(request: ApiRequest) -> ApiResponse:
414
439
  assistants_routes: list[BaseRoute] = [
415
440
  ApiRoute("/assistants", create_assistant, methods=["POST"]),
416
441
  ApiRoute("/assistants/search", search_assistants, methods=["POST"]),
442
+ ApiRoute("/assistants/count", count_assistants, methods=["POST"]),
417
443
  ApiRoute(
418
444
  "/assistants/{assistant_id}/latest",
419
445
  set_latest_assistant_version,
langgraph_api/api/meta.py CHANGED
@@ -1,3 +1,5 @@
1
+ from typing import cast
2
+
1
3
  import langgraph.version
2
4
  from starlette.responses import JSONResponse, PlainTextResponse
3
5
 
@@ -43,7 +45,7 @@ async def meta_metrics(request: ApiRequest):
43
45
 
44
46
  # collect stats
45
47
  metrics = get_metrics()
46
- worker_metrics = metrics["workers"]
48
+ worker_metrics = cast(dict[str, int], metrics["workers"])
47
49
  workers_max = worker_metrics["max"]
48
50
  workers_active = worker_metrics["active"]
49
51
  workers_available = worker_metrics["available"]
@@ -25,7 +25,7 @@ def set_custom_spec(spec: dict):
25
25
 
26
26
 
27
27
  @lru_cache(maxsize=1)
28
- def get_openapi_spec() -> str:
28
+ def get_openapi_spec() -> bytes:
29
29
  # patch the graph_id enums
30
30
  graph_ids = list(GRAPHS.keys())
31
31
  for schema in (
langgraph_api/api/runs.py CHANGED
@@ -1,9 +1,8 @@
1
1
  import asyncio
2
2
  from collections.abc import AsyncIterator
3
- from typing import Literal
3
+ from typing import Literal, cast
4
4
 
5
5
  import orjson
6
- from langgraph.checkpoint.base.id import uuid6
7
6
  from starlette.exceptions import HTTPException
8
7
  from starlette.responses import Response, StreamingResponse
9
8
 
@@ -11,9 +10,17 @@ from langgraph_api import config
11
10
  from langgraph_api.asyncio import ValueEvent, aclosing
12
11
  from langgraph_api.models.run import create_valid_run
13
12
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
13
+ from langgraph_api.schema import CRON_FIELDS, RUN_FIELDS
14
14
  from langgraph_api.sse import EventSourceResponse
15
- from langgraph_api.utils import fetchone, get_pagination_headers, validate_uuid
15
+ from langgraph_api.utils import (
16
+ fetchone,
17
+ get_pagination_headers,
18
+ uuid7,
19
+ validate_select_columns,
20
+ validate_uuid,
21
+ )
16
22
  from langgraph_api.validation import (
23
+ CronCountRequest,
17
24
  CronCreate,
18
25
  CronSearch,
19
26
  RunBatchCreate,
@@ -92,7 +99,7 @@ async def stream_run(
92
99
  thread_id = request.path_params["thread_id"]
93
100
  payload = await request.json(RunCreateStateful)
94
101
  on_disconnect = payload.get("on_disconnect", "continue")
95
- run_id = uuid6()
102
+ run_id = uuid7()
96
103
  sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
97
104
 
98
105
  try:
@@ -132,7 +139,7 @@ async def stream_run_stateless(
132
139
  """Create a stateless run."""
133
140
  payload = await request.json(RunCreateStateless)
134
141
  on_disconnect = payload.get("on_disconnect", "continue")
135
- run_id = uuid6()
142
+ run_id = uuid7()
136
143
  sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
137
144
 
138
145
  try:
@@ -173,7 +180,7 @@ async def wait_run(request: ApiRequest):
173
180
  thread_id = request.path_params["thread_id"]
174
181
  payload = await request.json(RunCreateStateful)
175
182
  on_disconnect = payload.get("on_disconnect", "continue")
176
- run_id = uuid6()
183
+ run_id = uuid7()
177
184
  sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
178
185
 
179
186
  try:
@@ -255,7 +262,7 @@ async def wait_run_stateless(request: ApiRequest):
255
262
  """Create a stateless run, wait for the output."""
256
263
  payload = await request.json(RunCreateStateless)
257
264
  on_disconnect = payload.get("on_disconnect", "continue")
258
- run_id = uuid6()
265
+ run_id = uuid7()
259
266
  sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
260
267
 
261
268
  try:
@@ -338,6 +345,9 @@ async def list_runs(
338
345
  limit = int(request.query_params.get("limit", 10))
339
346
  offset = int(request.query_params.get("offset", 0))
340
347
  status = request.query_params.get("status")
348
+ select = validate_select_columns(
349
+ request.query_params.getlist("select") or None, RUN_FIELDS
350
+ )
341
351
 
342
352
  async with connect() as conn, conn.pipeline():
343
353
  thread, runs = await asyncio.gather(
@@ -348,6 +358,7 @@ async def list_runs(
348
358
  limit=limit,
349
359
  offset=offset,
350
360
  status=status,
361
+ select=select,
351
362
  ),
352
363
  )
353
364
  await fetchone(thread)
@@ -425,7 +436,10 @@ async def cancel_run(
425
436
  wait_str = request.query_params.get("wait", "false")
426
437
  wait = wait_str.lower() in {"true", "yes", "1"}
427
438
  action_str = request.query_params.get("action", "interrupt")
428
- action = action_str if action_str in {"interrupt", "rollback"} else "interrupt"
439
+ action = cast(
440
+ Literal["interrupt", "rollback"],
441
+ action_str if action_str in {"interrupt", "rollback"} else "interrupt",
442
+ )
429
443
 
430
444
  async with connect() as conn:
431
445
  await Runs.cancel(
@@ -471,8 +485,9 @@ async def cancel_runs(
471
485
  for rid in run_ids:
472
486
  validate_uuid(rid, "Invalid run ID: must be a UUID")
473
487
  action_str = request.query_params.get("action", "interrupt")
474
- action: Literal["interrupt", "rollback"] = (
475
- action_str if action_str in ("interrupt", "rollback") else "interrupt"
488
+ action = cast(
489
+ Literal["interrupt", "rollback"],
490
+ action_str if action_str in ("interrupt", "rollback") else "interrupt",
476
491
  )
477
492
 
478
493
  async with connect() as conn:
@@ -557,6 +572,7 @@ async def delete_cron(request: ApiRequest):
557
572
  async def search_crons(request: ApiRequest):
558
573
  """List all cron jobs for an assistant"""
559
574
  payload = await request.json(CronSearch)
575
+ select = validate_select_columns(payload.get("select") or None, CRON_FIELDS)
560
576
  if assistant_id := payload.get("assistant_id"):
561
577
  validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
562
578
  if thread_id := payload.get("thread_id"):
@@ -572,6 +588,7 @@ async def search_crons(request: ApiRequest):
572
588
  offset=offset,
573
589
  sort_by=payload.get("sort_by"),
574
590
  sort_order=payload.get("sort_order"),
591
+ select=select,
575
592
  )
576
593
  crons, response_headers = await get_pagination_headers(
577
594
  crons_iter, next_offset, offset
@@ -579,6 +596,24 @@ async def search_crons(request: ApiRequest):
579
596
  return ApiResponse(crons, headers=response_headers)
580
597
 
581
598
 
599
+ @retry_db
600
+ async def count_crons(request: ApiRequest):
601
+ """Count cron jobs."""
602
+ payload = await request.json(CronCountRequest)
603
+ if assistant_id := payload.get("assistant_id"):
604
+ validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
605
+ if thread_id := payload.get("thread_id"):
606
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
607
+
608
+ async with connect() as conn:
609
+ count = await Crons.count(
610
+ conn,
611
+ assistant_id=assistant_id,
612
+ thread_id=thread_id,
613
+ )
614
+ return ApiResponse(count)
615
+
616
+
582
617
  runs_routes = [
583
618
  ApiRoute("/runs/stream", stream_run_stateless, methods=["POST"]),
584
619
  ApiRoute("/runs/wait", wait_run_stateless, methods=["POST"]),
@@ -595,6 +630,11 @@ runs_routes = [
595
630
  if config.FF_CRONS_ENABLED and plus_features_enabled()
596
631
  else None
597
632
  ),
633
+ (
634
+ ApiRoute("/runs/crons/count", count_crons, methods=["POST"])
635
+ if config.FF_CRONS_ENABLED and plus_features_enabled()
636
+ else None
637
+ ),
598
638
  ApiRoute("/threads/{thread_id}/runs/{run_id}/join", join_run, methods=["GET"]),
599
639
  ApiRoute(
600
640
  "/threads/{thread_id}/runs/{run_id}/stream",
@@ -5,9 +5,16 @@ from starlette.responses import Response
5
5
  from starlette.routing import BaseRoute
6
6
 
7
7
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
8
+ from langgraph_api.schema import THREAD_FIELDS
8
9
  from langgraph_api.state import state_snapshot_to_thread_state
9
- from langgraph_api.utils import fetchone, get_pagination_headers, validate_uuid
10
+ from langgraph_api.utils import (
11
+ fetchone,
12
+ get_pagination_headers,
13
+ validate_select_columns,
14
+ validate_uuid,
15
+ )
10
16
  from langgraph_api.validation import (
17
+ ThreadCountRequest,
11
18
  ThreadCreate,
12
19
  ThreadPatch,
13
20
  ThreadSearchRequest,
@@ -58,6 +65,7 @@ async def search_threads(
58
65
  ):
59
66
  """List threads."""
60
67
  payload = await request.json(ThreadSearchRequest)
68
+ select = validate_select_columns(payload.get("select") or None, THREAD_FIELDS)
61
69
  limit = int(payload.get("limit") or 10)
62
70
  offset = int(payload.get("offset") or 0)
63
71
  async with connect() as conn:
@@ -70,6 +78,7 @@ async def search_threads(
70
78
  offset=offset,
71
79
  sort_by=payload.get("sort_by"),
72
80
  sort_order=payload.get("sort_order"),
81
+ select=select,
73
82
  )
74
83
  threads, response_headers = await get_pagination_headers(
75
84
  threads_iter, next_offset, offset
@@ -77,6 +86,22 @@ async def search_threads(
77
86
  return ApiResponse(threads, headers=response_headers)
78
87
 
79
88
 
89
+ @retry_db
90
+ async def count_threads(
91
+ request: ApiRequest,
92
+ ):
93
+ """Count threads."""
94
+ payload = await request.json(ThreadCountRequest)
95
+ async with connect() as conn:
96
+ count = await Threads.count(
97
+ conn,
98
+ status=payload.get("status"),
99
+ values=payload.get("values"),
100
+ metadata=payload.get("metadata"),
101
+ )
102
+ return ApiResponse(count)
103
+
104
+
80
105
  @retry_db
81
106
  async def get_thread_state(
82
107
  request: ApiRequest,
@@ -260,6 +285,7 @@ async def copy_thread(request: ApiRequest):
260
285
  threads_routes: list[BaseRoute] = [
261
286
  ApiRoute("/threads", endpoint=create_thread, methods=["POST"]),
262
287
  ApiRoute("/threads/search", endpoint=search_threads, methods=["POST"]),
288
+ ApiRoute("/threads/count", endpoint=count_threads, methods=["POST"]),
263
289
  ApiRoute("/threads/{thread_id}", endpoint=get_thread, methods=["GET"]),
264
290
  ApiRoute("/threads/{thread_id}", endpoint=patch_thread, methods=["PATCH"]),
265
291
  ApiRoute("/threads/{thread_id}", endpoint=delete_thread, methods=["DELETE"]),
langgraph_api/api/ui.py CHANGED
@@ -56,6 +56,8 @@ async def handle_ui(request: ApiRequest) -> Response:
56
56
 
57
57
  # Use http:// protocol if accessing a localhost service
58
58
  def is_host(needle: str) -> bool:
59
+ if not isinstance(host, str):
60
+ return False
59
61
  return host.startswith(needle + ":") or host == needle
60
62
 
61
63
  protocol = "http:" if is_host("localhost") or is_host("127.0.0.1") else ""
@@ -13,7 +13,7 @@ from httpx import AsyncByteStream, Request, Response
13
13
  if typing.TYPE_CHECKING: # pragma: no cover
14
14
  import asyncio
15
15
 
16
- import trio
16
+ import trio # type: ignore[unresolved-import]
17
17
 
18
18
  Event = asyncio.Event | trio.Event
19
19
 
@@ -37,7 +37,7 @@ def is_running_trio() -> bool:
37
37
 
38
38
  def create_event() -> Event:
39
39
  if is_running_trio():
40
- import trio
40
+ import trio # type: ignore[unresolved-import]
41
41
 
42
42
  return trio.Event()
43
43
 
langgraph_api/asyncio.py CHANGED
@@ -119,7 +119,7 @@ def create_task(
119
119
 
120
120
  def run_coroutine_threadsafe(
121
121
  coro: Coroutine[Any, Any, T], ignore_exceptions: tuple[type[Exception], ...] = ()
122
- ) -> concurrent.futures.Future[T | None]:
122
+ ) -> concurrent.futures.Future[T] | concurrent.futures.Future[None]:
123
123
  if _MAIN_LOOP is None:
124
124
  raise RuntimeError("No event loop set")
125
125
  future = asyncio.run_coroutine_threadsafe(coro, _MAIN_LOOP)
@@ -226,7 +226,7 @@ def to_aiter(*args: T) -> AsyncIterator[T]:
226
226
  V = TypeVar("V")
227
227
 
228
228
 
229
- class aclosing(Generic[V], AbstractAsyncContextManager):
229
+ class aclosing(Generic[V], AbstractAsyncContextManager[V]):
230
230
  """Async context manager for safely finalizing an asynchronously cleaned-up
231
231
  resource such as an async generator, calling its ``aclose()`` method.
232
232
 
@@ -255,14 +255,16 @@ class aclosing(Generic[V], AbstractAsyncContextManager):
255
255
  await self.thing.aclose()
256
256
 
257
257
 
258
- async def aclosing_aiter(aiter: AsyncIterator[T]) -> AsyncIterator[T]:
259
- if hasattr(aiter, "__aenter__"):
260
- async with aiter:
261
- async for item in aiter:
258
+ async def aclosing_aiter(
259
+ aiterator: AsyncIterator[T],
260
+ ) -> AsyncIterator[T]:
261
+ if hasattr(aiterator, "__aenter__"):
262
+ async with aiterator: # type: ignore[invalid-context-manager]
263
+ async for item in aiterator:
262
264
  yield item
263
265
  else:
264
- async with aclosing(aiter):
265
- async for item in aiter:
266
+ async with aclosing(aiterator):
267
+ async for item in aiterator:
266
268
  yield item
267
269
 
268
270
 
@@ -251,14 +251,15 @@ def _get_auth_instance(path: str | None = None) -> Auth | Literal["js"] | None:
251
251
  deps := _get_dependencies(auth_instance._authenticate_handler)
252
252
  ):
253
253
  auth_instance._authenticate_handler = _solve_fastapi_dependencies(
254
- auth_instance._authenticate_handler, deps
254
+ auth_instance._authenticate_handler, # type: ignore[invalid-argument-type]
255
+ deps,
255
256
  )
256
257
  logger.info(f"Loaded auth instance from path {path}: {auth_instance}")
257
258
  return auth_instance
258
259
 
259
260
 
260
261
  def _extract_arguments_from_scope(
261
- scope: dict[str, Any],
262
+ scope: Mapping[str, Any],
262
263
  param_names: set[str],
263
264
  request: Request | None = None,
264
265
  response: Response | None = None,
@@ -283,7 +284,11 @@ def _extract_arguments_from_scope(
283
284
  if "path" in param_names:
284
285
  args["path"] = scope["path"]
285
286
  if "query_params" in param_names:
286
- args["query_params"] = QueryParams(scope.get("query_string"))
287
+ query_params = scope.get("query_string")
288
+ if query_params:
289
+ args["query_params"] = QueryParams(query_params)
290
+ else:
291
+ args["query_params"] = QueryParams()
287
292
  if "headers" in param_names:
288
293
  args["headers"] = dict(scope.get("headers", {}))
289
294
  if "authorization" in param_names:
@@ -595,7 +600,7 @@ def _load_auth_obj(path: str) -> Auth | Literal["js"]:
595
600
  raise ValueError(f"Could not load file: {module_name}")
596
601
  module = importlib.util.module_from_spec(modspec)
597
602
  sys.modules[modname] = module
598
- modspec.loader.exec_module(module)
603
+ modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
599
604
  else:
600
605
  # Load from Python module
601
606
  module = importlib.import_module(module_name)
@@ -14,7 +14,7 @@ from langgraph_api.config import LANGSMITH_AUTH_ENDPOINT
14
14
  _client: "JsonHttpClient"
15
15
 
16
16
 
17
- def is_retriable_error(exception: Exception) -> bool:
17
+ def is_retriable_error(exception: BaseException) -> bool:
18
18
  if isinstance(exception, httpx.TransportError):
19
19
  return True
20
20
  if isinstance(exception, httpx.HTTPStatusError):
langgraph_api/cli.py CHANGED
@@ -204,7 +204,7 @@ def run_server(
204
204
  mount_prefix = os.environ.get("LANGGRAPH_MOUNT_PREFIX")
205
205
  if isinstance(env, str | pathlib.Path):
206
206
  try:
207
- from dotenv.main import DotEnv
207
+ from dotenv.main import DotEnv # type: ignore[unresolved-import]
208
208
 
209
209
  env_vars = DotEnv(dotenv_path=env).dict() or {}
210
210
  logger.debug(f"Loaded environment variables from {env}: {sorted(env_vars)}")
@@ -216,7 +216,7 @@ def run_server(
216
216
 
217
217
  if debug_port is not None:
218
218
  try:
219
- import debugpy
219
+ import debugpy # type: ignore[unresolved-import]
220
220
  except ImportError:
221
221
  logger.warning("debugpy is not installed. Debugging will not be available.")
222
222
  logger.info("To enable debugging, install debugpy: pip install debugpy")
@@ -301,6 +301,7 @@ def run_server(
301
301
  def _open_browser():
302
302
  nonlocal studio_origin, full_studio_url
303
303
  import time
304
+ import urllib.error
304
305
  import urllib.request
305
306
  import webbrowser
306
307
  from concurrent.futures import ThreadPoolExecutor
@@ -377,8 +378,8 @@ For production use, please use LangGraph Platform.
377
378
  reload=reload,
378
379
  env_file=env_file,
379
380
  access_log=False,
380
- reload_includes=reload_includes,
381
- reload_excludes=reload_excludes,
381
+ reload_includes=list(reload_includes) if reload_includes else None,
382
+ reload_excludes=list(reload_excludes) if reload_excludes else None,
382
383
  log_config={
383
384
  "version": 1,
384
385
  "incremental": False,
langgraph_api/config.py CHANGED
@@ -287,7 +287,7 @@ if THREAD_TTL is None and CHECKPOINTER_CONFIG is not None:
287
287
  N_JOBS_PER_WORKER = env("N_JOBS_PER_WORKER", cast=int, default=10)
288
288
  BG_JOB_TIMEOUT_SECS = env("BG_JOB_TIMEOUT_SECS", cast=float, default=3600)
289
289
  FF_CRONS_ENABLED = env("FF_CRONS_ENABLED", cast=bool, default=True)
290
- FF_RICH_THREADS = env("FF_RICH_THREADS", cast=bool, default=False)
290
+ FF_RICH_THREADS = env("FF_RICH_THREADS", cast=bool, default=True)
291
291
 
292
292
  # auth
293
293
 
@@ -0,0 +1,23 @@
1
+ import argparse
2
+ import asyncio
3
+ import json
4
+ import logging.config
5
+ import pathlib
6
+
7
+ from langgraph_api.queue_entrypoint import main
8
+
9
+ if __name__ == "__main__":
10
+ parser = argparse.ArgumentParser()
11
+
12
+ parser.add_argument("--grpc-port", type=int, default=50051)
13
+ args = parser.parse_args()
14
+ with open(pathlib.Path(__file__).parent.parent / "logging.json") as file:
15
+ loaded_config = json.load(file)
16
+ logging.config.dictConfig(loaded_config)
17
+ try:
18
+ import uvloop # type: ignore[unresolved-import]
19
+
20
+ uvloop.install()
21
+ except ImportError:
22
+ pass
23
+ asyncio.run(main(grpc_port=args.grpc_port, entrypoint_name="python-executor"))
langgraph_api/graph.py CHANGED
@@ -9,7 +9,7 @@ import warnings
9
9
  from collections.abc import AsyncIterator, Callable
10
10
  from contextlib import asynccontextmanager
11
11
  from itertools import filterfalse
12
- from typing import TYPE_CHECKING, Any, NamedTuple, cast
12
+ from typing import TYPE_CHECKING, Any, NamedTuple, TypeGuard, cast
13
13
  from uuid import UUID, uuid5
14
14
 
15
15
  import orjson
@@ -35,10 +35,10 @@ logger = structlog.stdlib.get_logger(__name__)
35
35
 
36
36
  GraphFactoryFromConfig = Callable[[Config], Pregel | StateGraph]
37
37
  GraphFactory = Callable[[], Pregel | StateGraph]
38
- GraphValue = Pregel | GraphFactory
38
+ GraphValue = Pregel | GraphFactory | GraphFactoryFromConfig
39
39
 
40
40
 
41
- GRAPHS: dict[str, Pregel | GraphFactoryFromConfig | GraphFactory] = {}
41
+ GRAPHS: dict[str, GraphValue] = {}
42
42
  NAMESPACE_GRAPH = UUID("6ba7b821-9dad-11d1-80b4-00c04fd430c8")
43
43
  FACTORY_ACCEPTS_CONFIG: dict[str, bool] = {}
44
44
 
@@ -110,11 +110,23 @@ async def _generate_graph(value: Any) -> AsyncIterator[Any]:
110
110
  yield value
111
111
 
112
112
 
113
- def is_js_graph(graph_id: str) -> bool:
113
+ def is_js_graph(graph_id: str) -> TypeGuard[BaseRemotePregel]:
114
114
  """Return whether a graph is a JS graph."""
115
115
  return graph_id in GRAPHS and isinstance(GRAPHS[graph_id], BaseRemotePregel)
116
116
 
117
117
 
118
+ def is_factory(
119
+ value: GraphValue, graph_id: str
120
+ ) -> TypeGuard[GraphFactoryFromConfig | GraphFactory]:
121
+ return graph_id in FACTORY_ACCEPTS_CONFIG
122
+
123
+
124
+ def factory_accepts_config(
125
+ value: GraphValue, graph_id: str
126
+ ) -> TypeGuard[GraphFactoryFromConfig]:
127
+ return FACTORY_ACCEPTS_CONFIG.get(graph_id, False)
128
+
129
+
118
130
  @asynccontextmanager
119
131
  async def get_graph(
120
132
  graph_id: str,
@@ -128,7 +140,7 @@ async def get_graph(
128
140
 
129
141
  assert_graph_exists(graph_id)
130
142
  value = GRAPHS[graph_id]
131
- if graph_id in FACTORY_ACCEPTS_CONFIG:
143
+ if is_factory(value, graph_id):
132
144
  config = lg_config.ensure_config(config)
133
145
 
134
146
  if store is not None:
@@ -139,6 +151,8 @@ async def get_graph(
139
151
  runtime = config["configurable"].get(CONFIG_KEY_RUNTIME)
140
152
  if runtime is None:
141
153
  patched_runtime = Runtime(store=store)
154
+ elif isinstance(runtime, dict):
155
+ patched_runtime = Runtime(**(runtime | {"store": store}))
142
156
  elif runtime.store is None:
143
157
  patched_runtime = cast(Runtime, runtime).override(store=store)
144
158
  else:
@@ -156,7 +170,7 @@ async def get_graph(
156
170
  ):
157
171
  config["configurable"][CONFIG_KEY_CHECKPOINTER] = checkpointer
158
172
  var_child_runnable_config.set(config)
159
- value = value(config) if FACTORY_ACCEPTS_CONFIG[graph_id] else value()
173
+ value = value(config) if factory_accepts_config(value, graph_id) else value()
160
174
  try:
161
175
  async with _generate_graph(value) as graph_obj:
162
176
  if isinstance(graph_obj, StateGraph):
@@ -451,7 +465,7 @@ def _graph_from_spec(spec: GraphSpec) -> GraphValue:
451
465
  raise ValueError(f"Could not find python file for graph: {spec}")
452
466
  module = importlib.util.module_from_spec(modspec)
453
467
  sys.modules[modname] = module
454
- modspec.loader.exec_module(module)
468
+ modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
455
469
  except ImportError as e:
456
470
  e.add_note(f"Could not import python module for graph:\n{spec}")
457
471
  if config.API_VARIANT == "local_dev":
@@ -565,7 +579,9 @@ def _graph_from_spec(spec: GraphSpec) -> GraphValue:
565
579
  @functools.lru_cache(maxsize=1)
566
580
  def _get_init_embeddings() -> Callable[[str, ...], "Embeddings"] | None:
567
581
  try:
568
- from langchain.embeddings import init_embeddings
582
+ from langchain.embeddings import ( # type: ignore[unresolved-import]
583
+ init_embeddings,
584
+ )
569
585
 
570
586
  return init_embeddings
571
587
  except ImportError:
@@ -606,7 +622,7 @@ def resolve_embeddings(index_config: dict) -> "Embeddings":
606
622
  raise ValueError(f"Could not find embeddings file: {module_name}")
607
623
  module = importlib.util.module_from_spec(modspec)
608
624
  sys.modules[modname] = module
609
- modspec.loader.exec_module(module)
625
+ modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
610
626
  else:
611
627
  # Load from Python module
612
628
  module = importlib.import_module(module_name)