langgraph-api 0.2.130__py3-none-any.whl → 0.2.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_api/__init__.py +1 -1
- langgraph_api/api/assistants.py +32 -6
- langgraph_api/api/meta.py +3 -1
- langgraph_api/api/openapi.py +1 -1
- langgraph_api/api/runs.py +50 -10
- langgraph_api/api/threads.py +27 -1
- langgraph_api/api/ui.py +2 -0
- langgraph_api/asgi_transport.py +2 -2
- langgraph_api/asyncio.py +10 -8
- langgraph_api/auth/custom.py +9 -4
- langgraph_api/auth/langsmith/client.py +1 -1
- langgraph_api/cli.py +5 -4
- langgraph_api/config.py +1 -1
- langgraph_api/executor_entrypoint.py +23 -0
- langgraph_api/graph.py +25 -9
- langgraph_api/http.py +10 -7
- langgraph_api/http_metrics.py +4 -1
- langgraph_api/js/build.mts +11 -2
- langgraph_api/js/client.http.mts +2 -0
- langgraph_api/js/client.mts +13 -3
- langgraph_api/js/package.json +2 -2
- langgraph_api/js/remote.py +17 -12
- langgraph_api/js/src/preload.mjs +9 -1
- langgraph_api/js/src/utils/files.mts +5 -2
- langgraph_api/js/sse.py +1 -1
- langgraph_api/js/yarn.lock +9 -9
- langgraph_api/logging.py +3 -3
- langgraph_api/middleware/http_logger.py +2 -1
- langgraph_api/models/run.py +19 -14
- langgraph_api/patch.py +2 -2
- langgraph_api/queue_entrypoint.py +33 -18
- langgraph_api/schema.py +88 -4
- langgraph_api/serde.py +32 -5
- langgraph_api/server.py +5 -3
- langgraph_api/state.py +8 -8
- langgraph_api/store.py +1 -1
- langgraph_api/stream.py +33 -20
- langgraph_api/traceblock.py +1 -1
- langgraph_api/utils/__init__.py +40 -5
- langgraph_api/utils/config.py +13 -4
- langgraph_api/utils/future.py +1 -1
- langgraph_api/utils/uuids.py +87 -0
- langgraph_api/validation.py +9 -0
- langgraph_api/webhook.py +20 -20
- langgraph_api/worker.py +8 -5
- {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/METADATA +2 -2
- {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/RECORD +51 -49
- openapi.json +331 -1
- {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/WHEEL +0 -0
- {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/entry_points.txt +0 -0
- {langgraph_api-0.2.130.dist-info → langgraph_api-0.2.134.dist-info}/licenses/LICENSE +0 -0
langgraph_api/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.2.
|
|
1
|
+
__version__ = "0.2.134"
|
langgraph_api/api/assistants.py
CHANGED
|
@@ -16,9 +16,16 @@ from langgraph_api.feature_flags import USE_RUNTIME_CONTEXT_API
|
|
|
16
16
|
from langgraph_api.graph import get_assistant_id, get_graph
|
|
17
17
|
from langgraph_api.js.base import BaseRemotePregel
|
|
18
18
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
19
|
+
from langgraph_api.schema import ASSISTANT_FIELDS
|
|
19
20
|
from langgraph_api.serde import ajson_loads
|
|
20
|
-
from langgraph_api.utils import
|
|
21
|
+
from langgraph_api.utils import (
|
|
22
|
+
fetchone,
|
|
23
|
+
get_pagination_headers,
|
|
24
|
+
validate_select_columns,
|
|
25
|
+
validate_uuid,
|
|
26
|
+
)
|
|
21
27
|
from langgraph_api.validation import (
|
|
28
|
+
AssistantCountRequest,
|
|
22
29
|
AssistantCreate,
|
|
23
30
|
AssistantPatch,
|
|
24
31
|
AssistantSearchRequest,
|
|
@@ -61,7 +68,8 @@ def _get_configurable_jsonschema(graph: Pregel) -> dict:
|
|
|
61
68
|
in favor of graph.get_context_jsonschema().
|
|
62
69
|
"""
|
|
63
70
|
# Otherwise, use the config_schema method.
|
|
64
|
-
|
|
71
|
+
# TODO: Remove this when we no longer support langgraph < 0.6
|
|
72
|
+
config_schema = graph.config_schema() # type: ignore[deprecated]
|
|
65
73
|
model_fields = getattr(config_schema, "model_fields", None) or getattr(
|
|
66
74
|
config_schema, "__fields__", None
|
|
67
75
|
)
|
|
@@ -87,11 +95,11 @@ def _state_jsonschema(graph: Pregel) -> dict | None:
|
|
|
87
95
|
for k in graph.stream_channels_list:
|
|
88
96
|
v = graph.channels[k]
|
|
89
97
|
try:
|
|
90
|
-
create_model(k, __root__=(v.UpdateType, None)).
|
|
98
|
+
create_model(k, __root__=(v.UpdateType, None)).model_json_schema()
|
|
91
99
|
fields[k] = (v.UpdateType, None)
|
|
92
100
|
except Exception:
|
|
93
101
|
fields[k] = (Any, None)
|
|
94
|
-
return create_model(graph.get_name("State"), **fields).
|
|
102
|
+
return create_model(graph.get_name("State"), **fields).model_json_schema()
|
|
95
103
|
|
|
96
104
|
|
|
97
105
|
def _graph_schemas(graph: Pregel) -> dict:
|
|
@@ -132,7 +140,7 @@ def _graph_schemas(graph: Pregel) -> dict:
|
|
|
132
140
|
logger.warning(
|
|
133
141
|
f"Failed to get context schema for graph {graph.name} with error: `{str(e)}`"
|
|
134
142
|
)
|
|
135
|
-
context_schema = graph.config_schema()
|
|
143
|
+
context_schema = graph.config_schema() # type: ignore[deprecated]
|
|
136
144
|
else:
|
|
137
145
|
context_schema = None
|
|
138
146
|
|
|
@@ -172,6 +180,7 @@ async def search_assistants(
|
|
|
172
180
|
) -> ApiResponse:
|
|
173
181
|
"""List assistants."""
|
|
174
182
|
payload = await request.json(AssistantSearchRequest)
|
|
183
|
+
select = validate_select_columns(payload.get("select") or None, ASSISTANT_FIELDS)
|
|
175
184
|
offset = int(payload.get("offset") or 0)
|
|
176
185
|
async with connect() as conn:
|
|
177
186
|
assistants_iter, next_offset = await Assistants.search(
|
|
@@ -182,6 +191,7 @@ async def search_assistants(
|
|
|
182
191
|
offset=offset,
|
|
183
192
|
sort_by=payload.get("sort_by"),
|
|
184
193
|
sort_order=payload.get("sort_order"),
|
|
194
|
+
select=select,
|
|
185
195
|
)
|
|
186
196
|
assistants, response_headers = await get_pagination_headers(
|
|
187
197
|
assistants_iter, next_offset, offset
|
|
@@ -189,6 +199,21 @@ async def search_assistants(
|
|
|
189
199
|
return ApiResponse(assistants, headers=response_headers)
|
|
190
200
|
|
|
191
201
|
|
|
202
|
+
@retry_db
|
|
203
|
+
async def count_assistants(
|
|
204
|
+
request: ApiRequest,
|
|
205
|
+
) -> ApiResponse:
|
|
206
|
+
"""Count assistants."""
|
|
207
|
+
payload = await request.json(AssistantCountRequest)
|
|
208
|
+
async with connect() as conn:
|
|
209
|
+
count = await Assistants.count(
|
|
210
|
+
conn,
|
|
211
|
+
graph_id=payload.get("graph_id"),
|
|
212
|
+
metadata=payload.get("metadata"),
|
|
213
|
+
)
|
|
214
|
+
return ApiResponse(count)
|
|
215
|
+
|
|
216
|
+
|
|
192
217
|
@retry_db
|
|
193
218
|
async def get_assistant(
|
|
194
219
|
request: ApiRequest,
|
|
@@ -366,7 +391,7 @@ async def patch_assistant(
|
|
|
366
391
|
|
|
367
392
|
|
|
368
393
|
@retry_db
|
|
369
|
-
async def delete_assistant(request: ApiRequest) ->
|
|
394
|
+
async def delete_assistant(request: ApiRequest) -> Response:
|
|
370
395
|
"""Delete an assistant by ID."""
|
|
371
396
|
assistant_id = request.path_params["assistant_id"]
|
|
372
397
|
validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
|
|
@@ -414,6 +439,7 @@ async def set_latest_assistant_version(request: ApiRequest) -> ApiResponse:
|
|
|
414
439
|
assistants_routes: list[BaseRoute] = [
|
|
415
440
|
ApiRoute("/assistants", create_assistant, methods=["POST"]),
|
|
416
441
|
ApiRoute("/assistants/search", search_assistants, methods=["POST"]),
|
|
442
|
+
ApiRoute("/assistants/count", count_assistants, methods=["POST"]),
|
|
417
443
|
ApiRoute(
|
|
418
444
|
"/assistants/{assistant_id}/latest",
|
|
419
445
|
set_latest_assistant_version,
|
langgraph_api/api/meta.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
1
3
|
import langgraph.version
|
|
2
4
|
from starlette.responses import JSONResponse, PlainTextResponse
|
|
3
5
|
|
|
@@ -43,7 +45,7 @@ async def meta_metrics(request: ApiRequest):
|
|
|
43
45
|
|
|
44
46
|
# collect stats
|
|
45
47
|
metrics = get_metrics()
|
|
46
|
-
worker_metrics = metrics["workers"]
|
|
48
|
+
worker_metrics = cast(dict[str, int], metrics["workers"])
|
|
47
49
|
workers_max = worker_metrics["max"]
|
|
48
50
|
workers_active = worker_metrics["active"]
|
|
49
51
|
workers_available = worker_metrics["available"]
|
langgraph_api/api/openapi.py
CHANGED
langgraph_api/api/runs.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from collections.abc import AsyncIterator
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal, cast
|
|
4
4
|
|
|
5
5
|
import orjson
|
|
6
|
-
from langgraph.checkpoint.base.id import uuid6
|
|
7
6
|
from starlette.exceptions import HTTPException
|
|
8
7
|
from starlette.responses import Response, StreamingResponse
|
|
9
8
|
|
|
@@ -11,9 +10,17 @@ from langgraph_api import config
|
|
|
11
10
|
from langgraph_api.asyncio import ValueEvent, aclosing
|
|
12
11
|
from langgraph_api.models.run import create_valid_run
|
|
13
12
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
13
|
+
from langgraph_api.schema import CRON_FIELDS, RUN_FIELDS
|
|
14
14
|
from langgraph_api.sse import EventSourceResponse
|
|
15
|
-
from langgraph_api.utils import
|
|
15
|
+
from langgraph_api.utils import (
|
|
16
|
+
fetchone,
|
|
17
|
+
get_pagination_headers,
|
|
18
|
+
uuid7,
|
|
19
|
+
validate_select_columns,
|
|
20
|
+
validate_uuid,
|
|
21
|
+
)
|
|
16
22
|
from langgraph_api.validation import (
|
|
23
|
+
CronCountRequest,
|
|
17
24
|
CronCreate,
|
|
18
25
|
CronSearch,
|
|
19
26
|
RunBatchCreate,
|
|
@@ -92,7 +99,7 @@ async def stream_run(
|
|
|
92
99
|
thread_id = request.path_params["thread_id"]
|
|
93
100
|
payload = await request.json(RunCreateStateful)
|
|
94
101
|
on_disconnect = payload.get("on_disconnect", "continue")
|
|
95
|
-
run_id =
|
|
102
|
+
run_id = uuid7()
|
|
96
103
|
sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
|
|
97
104
|
|
|
98
105
|
try:
|
|
@@ -132,7 +139,7 @@ async def stream_run_stateless(
|
|
|
132
139
|
"""Create a stateless run."""
|
|
133
140
|
payload = await request.json(RunCreateStateless)
|
|
134
141
|
on_disconnect = payload.get("on_disconnect", "continue")
|
|
135
|
-
run_id =
|
|
142
|
+
run_id = uuid7()
|
|
136
143
|
sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
|
|
137
144
|
|
|
138
145
|
try:
|
|
@@ -173,7 +180,7 @@ async def wait_run(request: ApiRequest):
|
|
|
173
180
|
thread_id = request.path_params["thread_id"]
|
|
174
181
|
payload = await request.json(RunCreateStateful)
|
|
175
182
|
on_disconnect = payload.get("on_disconnect", "continue")
|
|
176
|
-
run_id =
|
|
183
|
+
run_id = uuid7()
|
|
177
184
|
sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
|
|
178
185
|
|
|
179
186
|
try:
|
|
@@ -255,7 +262,7 @@ async def wait_run_stateless(request: ApiRequest):
|
|
|
255
262
|
"""Create a stateless run, wait for the output."""
|
|
256
263
|
payload = await request.json(RunCreateStateless)
|
|
257
264
|
on_disconnect = payload.get("on_disconnect", "continue")
|
|
258
|
-
run_id =
|
|
265
|
+
run_id = uuid7()
|
|
259
266
|
sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
|
|
260
267
|
|
|
261
268
|
try:
|
|
@@ -338,6 +345,9 @@ async def list_runs(
|
|
|
338
345
|
limit = int(request.query_params.get("limit", 10))
|
|
339
346
|
offset = int(request.query_params.get("offset", 0))
|
|
340
347
|
status = request.query_params.get("status")
|
|
348
|
+
select = validate_select_columns(
|
|
349
|
+
request.query_params.getlist("select") or None, RUN_FIELDS
|
|
350
|
+
)
|
|
341
351
|
|
|
342
352
|
async with connect() as conn, conn.pipeline():
|
|
343
353
|
thread, runs = await asyncio.gather(
|
|
@@ -348,6 +358,7 @@ async def list_runs(
|
|
|
348
358
|
limit=limit,
|
|
349
359
|
offset=offset,
|
|
350
360
|
status=status,
|
|
361
|
+
select=select,
|
|
351
362
|
),
|
|
352
363
|
)
|
|
353
364
|
await fetchone(thread)
|
|
@@ -425,7 +436,10 @@ async def cancel_run(
|
|
|
425
436
|
wait_str = request.query_params.get("wait", "false")
|
|
426
437
|
wait = wait_str.lower() in {"true", "yes", "1"}
|
|
427
438
|
action_str = request.query_params.get("action", "interrupt")
|
|
428
|
-
action =
|
|
439
|
+
action = cast(
|
|
440
|
+
Literal["interrupt", "rollback"],
|
|
441
|
+
action_str if action_str in {"interrupt", "rollback"} else "interrupt",
|
|
442
|
+
)
|
|
429
443
|
|
|
430
444
|
async with connect() as conn:
|
|
431
445
|
await Runs.cancel(
|
|
@@ -471,8 +485,9 @@ async def cancel_runs(
|
|
|
471
485
|
for rid in run_ids:
|
|
472
486
|
validate_uuid(rid, "Invalid run ID: must be a UUID")
|
|
473
487
|
action_str = request.query_params.get("action", "interrupt")
|
|
474
|
-
action
|
|
475
|
-
|
|
488
|
+
action = cast(
|
|
489
|
+
Literal["interrupt", "rollback"],
|
|
490
|
+
action_str if action_str in ("interrupt", "rollback") else "interrupt",
|
|
476
491
|
)
|
|
477
492
|
|
|
478
493
|
async with connect() as conn:
|
|
@@ -557,6 +572,7 @@ async def delete_cron(request: ApiRequest):
|
|
|
557
572
|
async def search_crons(request: ApiRequest):
|
|
558
573
|
"""List all cron jobs for an assistant"""
|
|
559
574
|
payload = await request.json(CronSearch)
|
|
575
|
+
select = validate_select_columns(payload.get("select") or None, CRON_FIELDS)
|
|
560
576
|
if assistant_id := payload.get("assistant_id"):
|
|
561
577
|
validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
|
|
562
578
|
if thread_id := payload.get("thread_id"):
|
|
@@ -572,6 +588,7 @@ async def search_crons(request: ApiRequest):
|
|
|
572
588
|
offset=offset,
|
|
573
589
|
sort_by=payload.get("sort_by"),
|
|
574
590
|
sort_order=payload.get("sort_order"),
|
|
591
|
+
select=select,
|
|
575
592
|
)
|
|
576
593
|
crons, response_headers = await get_pagination_headers(
|
|
577
594
|
crons_iter, next_offset, offset
|
|
@@ -579,6 +596,24 @@ async def search_crons(request: ApiRequest):
|
|
|
579
596
|
return ApiResponse(crons, headers=response_headers)
|
|
580
597
|
|
|
581
598
|
|
|
599
|
+
@retry_db
|
|
600
|
+
async def count_crons(request: ApiRequest):
|
|
601
|
+
"""Count cron jobs."""
|
|
602
|
+
payload = await request.json(CronCountRequest)
|
|
603
|
+
if assistant_id := payload.get("assistant_id"):
|
|
604
|
+
validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
|
|
605
|
+
if thread_id := payload.get("thread_id"):
|
|
606
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
607
|
+
|
|
608
|
+
async with connect() as conn:
|
|
609
|
+
count = await Crons.count(
|
|
610
|
+
conn,
|
|
611
|
+
assistant_id=assistant_id,
|
|
612
|
+
thread_id=thread_id,
|
|
613
|
+
)
|
|
614
|
+
return ApiResponse(count)
|
|
615
|
+
|
|
616
|
+
|
|
582
617
|
runs_routes = [
|
|
583
618
|
ApiRoute("/runs/stream", stream_run_stateless, methods=["POST"]),
|
|
584
619
|
ApiRoute("/runs/wait", wait_run_stateless, methods=["POST"]),
|
|
@@ -595,6 +630,11 @@ runs_routes = [
|
|
|
595
630
|
if config.FF_CRONS_ENABLED and plus_features_enabled()
|
|
596
631
|
else None
|
|
597
632
|
),
|
|
633
|
+
(
|
|
634
|
+
ApiRoute("/runs/crons/count", count_crons, methods=["POST"])
|
|
635
|
+
if config.FF_CRONS_ENABLED and plus_features_enabled()
|
|
636
|
+
else None
|
|
637
|
+
),
|
|
598
638
|
ApiRoute("/threads/{thread_id}/runs/{run_id}/join", join_run, methods=["GET"]),
|
|
599
639
|
ApiRoute(
|
|
600
640
|
"/threads/{thread_id}/runs/{run_id}/stream",
|
langgraph_api/api/threads.py
CHANGED
|
@@ -5,9 +5,16 @@ from starlette.responses import Response
|
|
|
5
5
|
from starlette.routing import BaseRoute
|
|
6
6
|
|
|
7
7
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
8
|
+
from langgraph_api.schema import THREAD_FIELDS
|
|
8
9
|
from langgraph_api.state import state_snapshot_to_thread_state
|
|
9
|
-
from langgraph_api.utils import
|
|
10
|
+
from langgraph_api.utils import (
|
|
11
|
+
fetchone,
|
|
12
|
+
get_pagination_headers,
|
|
13
|
+
validate_select_columns,
|
|
14
|
+
validate_uuid,
|
|
15
|
+
)
|
|
10
16
|
from langgraph_api.validation import (
|
|
17
|
+
ThreadCountRequest,
|
|
11
18
|
ThreadCreate,
|
|
12
19
|
ThreadPatch,
|
|
13
20
|
ThreadSearchRequest,
|
|
@@ -58,6 +65,7 @@ async def search_threads(
|
|
|
58
65
|
):
|
|
59
66
|
"""List threads."""
|
|
60
67
|
payload = await request.json(ThreadSearchRequest)
|
|
68
|
+
select = validate_select_columns(payload.get("select") or None, THREAD_FIELDS)
|
|
61
69
|
limit = int(payload.get("limit") or 10)
|
|
62
70
|
offset = int(payload.get("offset") or 0)
|
|
63
71
|
async with connect() as conn:
|
|
@@ -70,6 +78,7 @@ async def search_threads(
|
|
|
70
78
|
offset=offset,
|
|
71
79
|
sort_by=payload.get("sort_by"),
|
|
72
80
|
sort_order=payload.get("sort_order"),
|
|
81
|
+
select=select,
|
|
73
82
|
)
|
|
74
83
|
threads, response_headers = await get_pagination_headers(
|
|
75
84
|
threads_iter, next_offset, offset
|
|
@@ -77,6 +86,22 @@ async def search_threads(
|
|
|
77
86
|
return ApiResponse(threads, headers=response_headers)
|
|
78
87
|
|
|
79
88
|
|
|
89
|
+
@retry_db
|
|
90
|
+
async def count_threads(
|
|
91
|
+
request: ApiRequest,
|
|
92
|
+
):
|
|
93
|
+
"""Count threads."""
|
|
94
|
+
payload = await request.json(ThreadCountRequest)
|
|
95
|
+
async with connect() as conn:
|
|
96
|
+
count = await Threads.count(
|
|
97
|
+
conn,
|
|
98
|
+
status=payload.get("status"),
|
|
99
|
+
values=payload.get("values"),
|
|
100
|
+
metadata=payload.get("metadata"),
|
|
101
|
+
)
|
|
102
|
+
return ApiResponse(count)
|
|
103
|
+
|
|
104
|
+
|
|
80
105
|
@retry_db
|
|
81
106
|
async def get_thread_state(
|
|
82
107
|
request: ApiRequest,
|
|
@@ -260,6 +285,7 @@ async def copy_thread(request: ApiRequest):
|
|
|
260
285
|
threads_routes: list[BaseRoute] = [
|
|
261
286
|
ApiRoute("/threads", endpoint=create_thread, methods=["POST"]),
|
|
262
287
|
ApiRoute("/threads/search", endpoint=search_threads, methods=["POST"]),
|
|
288
|
+
ApiRoute("/threads/count", endpoint=count_threads, methods=["POST"]),
|
|
263
289
|
ApiRoute("/threads/{thread_id}", endpoint=get_thread, methods=["GET"]),
|
|
264
290
|
ApiRoute("/threads/{thread_id}", endpoint=patch_thread, methods=["PATCH"]),
|
|
265
291
|
ApiRoute("/threads/{thread_id}", endpoint=delete_thread, methods=["DELETE"]),
|
langgraph_api/api/ui.py
CHANGED
|
@@ -56,6 +56,8 @@ async def handle_ui(request: ApiRequest) -> Response:
|
|
|
56
56
|
|
|
57
57
|
# Use http:// protocol if accessing a localhost service
|
|
58
58
|
def is_host(needle: str) -> bool:
|
|
59
|
+
if not isinstance(host, str):
|
|
60
|
+
return False
|
|
59
61
|
return host.startswith(needle + ":") or host == needle
|
|
60
62
|
|
|
61
63
|
protocol = "http:" if is_host("localhost") or is_host("127.0.0.1") else ""
|
langgraph_api/asgi_transport.py
CHANGED
|
@@ -13,7 +13,7 @@ from httpx import AsyncByteStream, Request, Response
|
|
|
13
13
|
if typing.TYPE_CHECKING: # pragma: no cover
|
|
14
14
|
import asyncio
|
|
15
15
|
|
|
16
|
-
import trio
|
|
16
|
+
import trio # type: ignore[unresolved-import]
|
|
17
17
|
|
|
18
18
|
Event = asyncio.Event | trio.Event
|
|
19
19
|
|
|
@@ -37,7 +37,7 @@ def is_running_trio() -> bool:
|
|
|
37
37
|
|
|
38
38
|
def create_event() -> Event:
|
|
39
39
|
if is_running_trio():
|
|
40
|
-
import trio
|
|
40
|
+
import trio # type: ignore[unresolved-import]
|
|
41
41
|
|
|
42
42
|
return trio.Event()
|
|
43
43
|
|
langgraph_api/asyncio.py
CHANGED
|
@@ -119,7 +119,7 @@ def create_task(
|
|
|
119
119
|
|
|
120
120
|
def run_coroutine_threadsafe(
|
|
121
121
|
coro: Coroutine[Any, Any, T], ignore_exceptions: tuple[type[Exception], ...] = ()
|
|
122
|
-
) -> concurrent.futures.Future[T | None]:
|
|
122
|
+
) -> concurrent.futures.Future[T] | concurrent.futures.Future[None]:
|
|
123
123
|
if _MAIN_LOOP is None:
|
|
124
124
|
raise RuntimeError("No event loop set")
|
|
125
125
|
future = asyncio.run_coroutine_threadsafe(coro, _MAIN_LOOP)
|
|
@@ -226,7 +226,7 @@ def to_aiter(*args: T) -> AsyncIterator[T]:
|
|
|
226
226
|
V = TypeVar("V")
|
|
227
227
|
|
|
228
228
|
|
|
229
|
-
class aclosing(Generic[V], AbstractAsyncContextManager):
|
|
229
|
+
class aclosing(Generic[V], AbstractAsyncContextManager[V]):
|
|
230
230
|
"""Async context manager for safely finalizing an asynchronously cleaned-up
|
|
231
231
|
resource such as an async generator, calling its ``aclose()`` method.
|
|
232
232
|
|
|
@@ -255,14 +255,16 @@ class aclosing(Generic[V], AbstractAsyncContextManager):
|
|
|
255
255
|
await self.thing.aclose()
|
|
256
256
|
|
|
257
257
|
|
|
258
|
-
async def aclosing_aiter(
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
258
|
+
async def aclosing_aiter(
|
|
259
|
+
aiterator: AsyncIterator[T],
|
|
260
|
+
) -> AsyncIterator[T]:
|
|
261
|
+
if hasattr(aiterator, "__aenter__"):
|
|
262
|
+
async with aiterator: # type: ignore[invalid-context-manager]
|
|
263
|
+
async for item in aiterator:
|
|
262
264
|
yield item
|
|
263
265
|
else:
|
|
264
|
-
async with aclosing(
|
|
265
|
-
async for item in
|
|
266
|
+
async with aclosing(aiterator):
|
|
267
|
+
async for item in aiterator:
|
|
266
268
|
yield item
|
|
267
269
|
|
|
268
270
|
|
langgraph_api/auth/custom.py
CHANGED
|
@@ -251,14 +251,15 @@ def _get_auth_instance(path: str | None = None) -> Auth | Literal["js"] | None:
|
|
|
251
251
|
deps := _get_dependencies(auth_instance._authenticate_handler)
|
|
252
252
|
):
|
|
253
253
|
auth_instance._authenticate_handler = _solve_fastapi_dependencies(
|
|
254
|
-
auth_instance._authenticate_handler,
|
|
254
|
+
auth_instance._authenticate_handler, # type: ignore[invalid-argument-type]
|
|
255
|
+
deps,
|
|
255
256
|
)
|
|
256
257
|
logger.info(f"Loaded auth instance from path {path}: {auth_instance}")
|
|
257
258
|
return auth_instance
|
|
258
259
|
|
|
259
260
|
|
|
260
261
|
def _extract_arguments_from_scope(
|
|
261
|
-
scope:
|
|
262
|
+
scope: Mapping[str, Any],
|
|
262
263
|
param_names: set[str],
|
|
263
264
|
request: Request | None = None,
|
|
264
265
|
response: Response | None = None,
|
|
@@ -283,7 +284,11 @@ def _extract_arguments_from_scope(
|
|
|
283
284
|
if "path" in param_names:
|
|
284
285
|
args["path"] = scope["path"]
|
|
285
286
|
if "query_params" in param_names:
|
|
286
|
-
|
|
287
|
+
query_params = scope.get("query_string")
|
|
288
|
+
if query_params:
|
|
289
|
+
args["query_params"] = QueryParams(query_params)
|
|
290
|
+
else:
|
|
291
|
+
args["query_params"] = QueryParams()
|
|
287
292
|
if "headers" in param_names:
|
|
288
293
|
args["headers"] = dict(scope.get("headers", {}))
|
|
289
294
|
if "authorization" in param_names:
|
|
@@ -595,7 +600,7 @@ def _load_auth_obj(path: str) -> Auth | Literal["js"]:
|
|
|
595
600
|
raise ValueError(f"Could not load file: {module_name}")
|
|
596
601
|
module = importlib.util.module_from_spec(modspec)
|
|
597
602
|
sys.modules[modname] = module
|
|
598
|
-
modspec.loader.exec_module(module)
|
|
603
|
+
modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
|
|
599
604
|
else:
|
|
600
605
|
# Load from Python module
|
|
601
606
|
module = importlib.import_module(module_name)
|
|
@@ -14,7 +14,7 @@ from langgraph_api.config import LANGSMITH_AUTH_ENDPOINT
|
|
|
14
14
|
_client: "JsonHttpClient"
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def is_retriable_error(exception:
|
|
17
|
+
def is_retriable_error(exception: BaseException) -> bool:
|
|
18
18
|
if isinstance(exception, httpx.TransportError):
|
|
19
19
|
return True
|
|
20
20
|
if isinstance(exception, httpx.HTTPStatusError):
|
langgraph_api/cli.py
CHANGED
|
@@ -204,7 +204,7 @@ def run_server(
|
|
|
204
204
|
mount_prefix = os.environ.get("LANGGRAPH_MOUNT_PREFIX")
|
|
205
205
|
if isinstance(env, str | pathlib.Path):
|
|
206
206
|
try:
|
|
207
|
-
from dotenv.main import DotEnv
|
|
207
|
+
from dotenv.main import DotEnv # type: ignore[unresolved-import]
|
|
208
208
|
|
|
209
209
|
env_vars = DotEnv(dotenv_path=env).dict() or {}
|
|
210
210
|
logger.debug(f"Loaded environment variables from {env}: {sorted(env_vars)}")
|
|
@@ -216,7 +216,7 @@ def run_server(
|
|
|
216
216
|
|
|
217
217
|
if debug_port is not None:
|
|
218
218
|
try:
|
|
219
|
-
import debugpy
|
|
219
|
+
import debugpy # type: ignore[unresolved-import]
|
|
220
220
|
except ImportError:
|
|
221
221
|
logger.warning("debugpy is not installed. Debugging will not be available.")
|
|
222
222
|
logger.info("To enable debugging, install debugpy: pip install debugpy")
|
|
@@ -301,6 +301,7 @@ def run_server(
|
|
|
301
301
|
def _open_browser():
|
|
302
302
|
nonlocal studio_origin, full_studio_url
|
|
303
303
|
import time
|
|
304
|
+
import urllib.error
|
|
304
305
|
import urllib.request
|
|
305
306
|
import webbrowser
|
|
306
307
|
from concurrent.futures import ThreadPoolExecutor
|
|
@@ -377,8 +378,8 @@ For production use, please use LangGraph Platform.
|
|
|
377
378
|
reload=reload,
|
|
378
379
|
env_file=env_file,
|
|
379
380
|
access_log=False,
|
|
380
|
-
reload_includes=reload_includes,
|
|
381
|
-
reload_excludes=reload_excludes,
|
|
381
|
+
reload_includes=list(reload_includes) if reload_includes else None,
|
|
382
|
+
reload_excludes=list(reload_excludes) if reload_excludes else None,
|
|
382
383
|
log_config={
|
|
383
384
|
"version": 1,
|
|
384
385
|
"incremental": False,
|
langgraph_api/config.py
CHANGED
|
@@ -287,7 +287,7 @@ if THREAD_TTL is None and CHECKPOINTER_CONFIG is not None:
|
|
|
287
287
|
N_JOBS_PER_WORKER = env("N_JOBS_PER_WORKER", cast=int, default=10)
|
|
288
288
|
BG_JOB_TIMEOUT_SECS = env("BG_JOB_TIMEOUT_SECS", cast=float, default=3600)
|
|
289
289
|
FF_CRONS_ENABLED = env("FF_CRONS_ENABLED", cast=bool, default=True)
|
|
290
|
-
FF_RICH_THREADS = env("FF_RICH_THREADS", cast=bool, default=
|
|
290
|
+
FF_RICH_THREADS = env("FF_RICH_THREADS", cast=bool, default=True)
|
|
291
291
|
|
|
292
292
|
# auth
|
|
293
293
|
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import asyncio
|
|
3
|
+
import json
|
|
4
|
+
import logging.config
|
|
5
|
+
import pathlib
|
|
6
|
+
|
|
7
|
+
from langgraph_api.queue_entrypoint import main
|
|
8
|
+
|
|
9
|
+
if __name__ == "__main__":
|
|
10
|
+
parser = argparse.ArgumentParser()
|
|
11
|
+
|
|
12
|
+
parser.add_argument("--grpc-port", type=int, default=50051)
|
|
13
|
+
args = parser.parse_args()
|
|
14
|
+
with open(pathlib.Path(__file__).parent.parent / "logging.json") as file:
|
|
15
|
+
loaded_config = json.load(file)
|
|
16
|
+
logging.config.dictConfig(loaded_config)
|
|
17
|
+
try:
|
|
18
|
+
import uvloop # type: ignore[unresolved-import]
|
|
19
|
+
|
|
20
|
+
uvloop.install()
|
|
21
|
+
except ImportError:
|
|
22
|
+
pass
|
|
23
|
+
asyncio.run(main(grpc_port=args.grpc_port, entrypoint_name="python-executor"))
|
langgraph_api/graph.py
CHANGED
|
@@ -9,7 +9,7 @@ import warnings
|
|
|
9
9
|
from collections.abc import AsyncIterator, Callable
|
|
10
10
|
from contextlib import asynccontextmanager
|
|
11
11
|
from itertools import filterfalse
|
|
12
|
-
from typing import TYPE_CHECKING, Any, NamedTuple, cast
|
|
12
|
+
from typing import TYPE_CHECKING, Any, NamedTuple, TypeGuard, cast
|
|
13
13
|
from uuid import UUID, uuid5
|
|
14
14
|
|
|
15
15
|
import orjson
|
|
@@ -35,10 +35,10 @@ logger = structlog.stdlib.get_logger(__name__)
|
|
|
35
35
|
|
|
36
36
|
GraphFactoryFromConfig = Callable[[Config], Pregel | StateGraph]
|
|
37
37
|
GraphFactory = Callable[[], Pregel | StateGraph]
|
|
38
|
-
GraphValue = Pregel | GraphFactory
|
|
38
|
+
GraphValue = Pregel | GraphFactory | GraphFactoryFromConfig
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
GRAPHS: dict[str,
|
|
41
|
+
GRAPHS: dict[str, GraphValue] = {}
|
|
42
42
|
NAMESPACE_GRAPH = UUID("6ba7b821-9dad-11d1-80b4-00c04fd430c8")
|
|
43
43
|
FACTORY_ACCEPTS_CONFIG: dict[str, bool] = {}
|
|
44
44
|
|
|
@@ -110,11 +110,23 @@ async def _generate_graph(value: Any) -> AsyncIterator[Any]:
|
|
|
110
110
|
yield value
|
|
111
111
|
|
|
112
112
|
|
|
113
|
-
def is_js_graph(graph_id: str) ->
|
|
113
|
+
def is_js_graph(graph_id: str) -> TypeGuard[BaseRemotePregel]:
|
|
114
114
|
"""Return whether a graph is a JS graph."""
|
|
115
115
|
return graph_id in GRAPHS and isinstance(GRAPHS[graph_id], BaseRemotePregel)
|
|
116
116
|
|
|
117
117
|
|
|
118
|
+
def is_factory(
|
|
119
|
+
value: GraphValue, graph_id: str
|
|
120
|
+
) -> TypeGuard[GraphFactoryFromConfig | GraphFactory]:
|
|
121
|
+
return graph_id in FACTORY_ACCEPTS_CONFIG
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def factory_accepts_config(
|
|
125
|
+
value: GraphValue, graph_id: str
|
|
126
|
+
) -> TypeGuard[GraphFactoryFromConfig]:
|
|
127
|
+
return FACTORY_ACCEPTS_CONFIG.get(graph_id, False)
|
|
128
|
+
|
|
129
|
+
|
|
118
130
|
@asynccontextmanager
|
|
119
131
|
async def get_graph(
|
|
120
132
|
graph_id: str,
|
|
@@ -128,7 +140,7 @@ async def get_graph(
|
|
|
128
140
|
|
|
129
141
|
assert_graph_exists(graph_id)
|
|
130
142
|
value = GRAPHS[graph_id]
|
|
131
|
-
if graph_id
|
|
143
|
+
if is_factory(value, graph_id):
|
|
132
144
|
config = lg_config.ensure_config(config)
|
|
133
145
|
|
|
134
146
|
if store is not None:
|
|
@@ -139,6 +151,8 @@ async def get_graph(
|
|
|
139
151
|
runtime = config["configurable"].get(CONFIG_KEY_RUNTIME)
|
|
140
152
|
if runtime is None:
|
|
141
153
|
patched_runtime = Runtime(store=store)
|
|
154
|
+
elif isinstance(runtime, dict):
|
|
155
|
+
patched_runtime = Runtime(**(runtime | {"store": store}))
|
|
142
156
|
elif runtime.store is None:
|
|
143
157
|
patched_runtime = cast(Runtime, runtime).override(store=store)
|
|
144
158
|
else:
|
|
@@ -156,7 +170,7 @@ async def get_graph(
|
|
|
156
170
|
):
|
|
157
171
|
config["configurable"][CONFIG_KEY_CHECKPOINTER] = checkpointer
|
|
158
172
|
var_child_runnable_config.set(config)
|
|
159
|
-
value = value(config) if
|
|
173
|
+
value = value(config) if factory_accepts_config(value, graph_id) else value()
|
|
160
174
|
try:
|
|
161
175
|
async with _generate_graph(value) as graph_obj:
|
|
162
176
|
if isinstance(graph_obj, StateGraph):
|
|
@@ -451,7 +465,7 @@ def _graph_from_spec(spec: GraphSpec) -> GraphValue:
|
|
|
451
465
|
raise ValueError(f"Could not find python file for graph: {spec}")
|
|
452
466
|
module = importlib.util.module_from_spec(modspec)
|
|
453
467
|
sys.modules[modname] = module
|
|
454
|
-
modspec.loader.exec_module(module)
|
|
468
|
+
modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
|
|
455
469
|
except ImportError as e:
|
|
456
470
|
e.add_note(f"Could not import python module for graph:\n{spec}")
|
|
457
471
|
if config.API_VARIANT == "local_dev":
|
|
@@ -565,7 +579,9 @@ def _graph_from_spec(spec: GraphSpec) -> GraphValue:
|
|
|
565
579
|
@functools.lru_cache(maxsize=1)
|
|
566
580
|
def _get_init_embeddings() -> Callable[[str, ...], "Embeddings"] | None:
|
|
567
581
|
try:
|
|
568
|
-
from langchain.embeddings import
|
|
582
|
+
from langchain.embeddings import ( # type: ignore[unresolved-import]
|
|
583
|
+
init_embeddings,
|
|
584
|
+
)
|
|
569
585
|
|
|
570
586
|
return init_embeddings
|
|
571
587
|
except ImportError:
|
|
@@ -606,7 +622,7 @@ def resolve_embeddings(index_config: dict) -> "Embeddings":
|
|
|
606
622
|
raise ValueError(f"Could not find embeddings file: {module_name}")
|
|
607
623
|
module = importlib.util.module_from_spec(modspec)
|
|
608
624
|
sys.modules[modname] = module
|
|
609
|
-
modspec.loader.exec_module(module)
|
|
625
|
+
modspec.loader.exec_module(module) # type: ignore[possibly-unbound-attribute]
|
|
610
626
|
else:
|
|
611
627
|
# Load from Python module
|
|
612
628
|
module = importlib.import_module(module_name)
|