langgraph-api 0.4.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_api/__init__.py +1 -1
- langgraph_api/api/__init__.py +111 -51
- langgraph_api/api/a2a.py +1610 -0
- langgraph_api/api/assistants.py +212 -89
- langgraph_api/api/mcp.py +3 -3
- langgraph_api/api/meta.py +52 -28
- langgraph_api/api/openapi.py +27 -17
- langgraph_api/api/profile.py +108 -0
- langgraph_api/api/runs.py +342 -195
- langgraph_api/api/store.py +19 -2
- langgraph_api/api/threads.py +209 -27
- langgraph_api/asgi_transport.py +14 -9
- langgraph_api/asyncio.py +14 -4
- langgraph_api/auth/custom.py +52 -37
- langgraph_api/auth/langsmith/backend.py +4 -3
- langgraph_api/auth/langsmith/client.py +13 -8
- langgraph_api/cli.py +230 -133
- langgraph_api/command.py +5 -3
- langgraph_api/config/__init__.py +532 -0
- langgraph_api/config/_parse.py +58 -0
- langgraph_api/config/schemas.py +431 -0
- langgraph_api/cron_scheduler.py +17 -1
- langgraph_api/encryption/__init__.py +15 -0
- langgraph_api/encryption/aes_json.py +158 -0
- langgraph_api/encryption/context.py +35 -0
- langgraph_api/encryption/custom.py +280 -0
- langgraph_api/encryption/middleware.py +632 -0
- langgraph_api/encryption/shared.py +63 -0
- langgraph_api/errors.py +12 -1
- langgraph_api/executor_entrypoint.py +11 -6
- langgraph_api/feature_flags.py +29 -0
- langgraph_api/graph.py +176 -76
- langgraph_api/grpc/client.py +313 -0
- langgraph_api/grpc/config_conversion.py +231 -0
- langgraph_api/grpc/generated/__init__.py +29 -0
- langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
- langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
- langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
- langgraph_api/grpc/generated/core_api_pb2.py +216 -0
- langgraph_api/grpc/generated/core_api_pb2.pyi +905 -0
- langgraph_api/grpc/generated/core_api_pb2_grpc.py +1621 -0
- langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
- langgraph_api/grpc/generated/engine_common_pb2.pyi +722 -0
- langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
- langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/errors_pb2.py +39 -0
- langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
- langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
- langgraph_api/grpc/ops/__init__.py +370 -0
- langgraph_api/grpc/ops/assistants.py +424 -0
- langgraph_api/grpc/ops/runs.py +792 -0
- langgraph_api/grpc/ops/threads.py +1013 -0
- langgraph_api/http.py +16 -5
- langgraph_api/http_metrics.py +15 -35
- langgraph_api/http_metrics_utils.py +38 -0
- langgraph_api/js/build.mts +1 -1
- langgraph_api/js/client.http.mts +13 -7
- langgraph_api/js/client.mts +2 -5
- langgraph_api/js/package.json +29 -28
- langgraph_api/js/remote.py +56 -30
- langgraph_api/js/src/graph.mts +20 -0
- langgraph_api/js/sse.py +2 -2
- langgraph_api/js/ui.py +1 -1
- langgraph_api/js/yarn.lock +1204 -1006
- langgraph_api/logging.py +29 -2
- langgraph_api/metadata.py +99 -28
- langgraph_api/middleware/http_logger.py +7 -2
- langgraph_api/middleware/private_network.py +7 -7
- langgraph_api/models/run.py +54 -93
- langgraph_api/otel_context.py +205 -0
- langgraph_api/patch.py +5 -3
- langgraph_api/queue_entrypoint.py +154 -65
- langgraph_api/route.py +47 -5
- langgraph_api/schema.py +88 -10
- langgraph_api/self_hosted_logs.py +124 -0
- langgraph_api/self_hosted_metrics.py +450 -0
- langgraph_api/serde.py +79 -37
- langgraph_api/server.py +138 -60
- langgraph_api/state.py +4 -3
- langgraph_api/store.py +25 -16
- langgraph_api/stream.py +80 -29
- langgraph_api/thread_ttl.py +31 -13
- langgraph_api/timing/__init__.py +25 -0
- langgraph_api/timing/profiler.py +200 -0
- langgraph_api/timing/timer.py +318 -0
- langgraph_api/utils/__init__.py +53 -8
- langgraph_api/utils/cache.py +47 -10
- langgraph_api/utils/config.py +2 -1
- langgraph_api/utils/errors.py +77 -0
- langgraph_api/utils/future.py +10 -6
- langgraph_api/utils/headers.py +76 -2
- langgraph_api/utils/retriable_client.py +74 -0
- langgraph_api/utils/stream_codec.py +315 -0
- langgraph_api/utils/uuids.py +29 -62
- langgraph_api/validation.py +9 -0
- langgraph_api/webhook.py +120 -6
- langgraph_api/worker.py +55 -24
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +16 -8
- langgraph_api-0.7.3.dist-info/RECORD +168 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
- langgraph_runtime/__init__.py +1 -0
- langgraph_runtime/routes.py +11 -0
- logging.json +1 -3
- openapi.json +839 -478
- langgraph_api/config.py +0 -387
- langgraph_api/js/isolate-0x130008000-46649-46649-v8.log +0 -4430
- langgraph_api/js/isolate-0x138008000-44681-44681-v8.log +0 -4430
- langgraph_api/js/package-lock.json +0 -3308
- langgraph_api-0.4.1.dist-info/RECORD +0 -107
- /langgraph_api/{utils.py → grpc/__init__.py} +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
langgraph_api/api/store.py
CHANGED
|
@@ -5,7 +5,13 @@ from starlette.responses import Response
|
|
|
5
5
|
from starlette.routing import BaseRoute
|
|
6
6
|
|
|
7
7
|
from langgraph_api.auth.custom import handle_event as _handle_event
|
|
8
|
+
from langgraph_api.encryption.middleware import (
|
|
9
|
+
decrypt_response,
|
|
10
|
+
decrypt_responses,
|
|
11
|
+
encrypt_request,
|
|
12
|
+
)
|
|
8
13
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
14
|
+
from langgraph_api.schema import STORE_ENCRYPTION_FIELDS
|
|
9
15
|
from langgraph_api.store import get_store
|
|
10
16
|
from langgraph_api.utils import get_auth_ctx
|
|
11
17
|
from langgraph_api.validation import (
|
|
@@ -48,6 +54,7 @@ async def handle_event(
|
|
|
48
54
|
async def put_item(request: ApiRequest):
|
|
49
55
|
"""Store or update an item."""
|
|
50
56
|
payload = await request.json(StorePutRequest)
|
|
57
|
+
payload = await encrypt_request(payload, "store", STORE_ENCRYPTION_FIELDS)
|
|
51
58
|
namespace = tuple(payload["namespace"]) if payload.get("namespace") else ()
|
|
52
59
|
if err := _validate_namespace(namespace):
|
|
53
60
|
return err
|
|
@@ -78,7 +85,11 @@ async def get_item(request: ApiRequest):
|
|
|
78
85
|
}
|
|
79
86
|
await handle_event("get", handler_payload)
|
|
80
87
|
result = await (await get_store()).aget(namespace, key)
|
|
81
|
-
|
|
88
|
+
if result is None:
|
|
89
|
+
return ApiResponse(None)
|
|
90
|
+
return ApiResponse(
|
|
91
|
+
await decrypt_response(result.dict(), "store", STORE_ENCRYPTION_FIELDS)
|
|
92
|
+
)
|
|
82
93
|
|
|
83
94
|
|
|
84
95
|
@retry_db
|
|
@@ -125,7 +136,13 @@ async def search_items(request: ApiRequest):
|
|
|
125
136
|
offset=handler_payload["offset"],
|
|
126
137
|
query=handler_payload["query"],
|
|
127
138
|
)
|
|
128
|
-
return ApiResponse(
|
|
139
|
+
return ApiResponse(
|
|
140
|
+
{
|
|
141
|
+
"items": await decrypt_responses(
|
|
142
|
+
[item.dict() for item in items], "store", STORE_ENCRYPTION_FIELDS
|
|
143
|
+
)
|
|
144
|
+
}
|
|
145
|
+
)
|
|
129
146
|
|
|
130
147
|
|
|
131
148
|
@retry_db
|
langgraph_api/api/threads.py
CHANGED
|
@@ -1,11 +1,23 @@
|
|
|
1
|
+
from typing import get_args
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
3
4
|
from starlette.exceptions import HTTPException
|
|
4
5
|
from starlette.responses import Response
|
|
5
6
|
from starlette.routing import BaseRoute
|
|
6
7
|
|
|
8
|
+
from langgraph_api.encryption.middleware import (
|
|
9
|
+
decrypt_response,
|
|
10
|
+
decrypt_responses,
|
|
11
|
+
encrypt_request,
|
|
12
|
+
)
|
|
13
|
+
from langgraph_api.feature_flags import FF_USE_CORE_API
|
|
14
|
+
from langgraph_api.grpc.ops import Threads as GrpcThreads
|
|
7
15
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
8
|
-
from langgraph_api.schema import
|
|
16
|
+
from langgraph_api.schema import (
|
|
17
|
+
THREAD_ENCRYPTION_FIELDS,
|
|
18
|
+
THREAD_FIELDS,
|
|
19
|
+
ThreadStreamMode,
|
|
20
|
+
)
|
|
9
21
|
from langgraph_api.sse import EventSourceResponse
|
|
10
22
|
from langgraph_api.state import state_snapshot_to_thread_state
|
|
11
23
|
from langgraph_api.utils import (
|
|
@@ -15,10 +27,12 @@ from langgraph_api.utils import (
|
|
|
15
27
|
validate_stream_id,
|
|
16
28
|
validate_uuid,
|
|
17
29
|
)
|
|
30
|
+
from langgraph_api.utils.headers import get_configurable_headers
|
|
18
31
|
from langgraph_api.validation import (
|
|
19
32
|
ThreadCountRequest,
|
|
20
33
|
ThreadCreate,
|
|
21
34
|
ThreadPatch,
|
|
35
|
+
ThreadPruneRequest,
|
|
22
36
|
ThreadSearchRequest,
|
|
23
37
|
ThreadStateCheckpointRequest,
|
|
24
38
|
ThreadStateSearch,
|
|
@@ -28,6 +42,8 @@ from langgraph_runtime.database import connect
|
|
|
28
42
|
from langgraph_runtime.ops import Threads
|
|
29
43
|
from langgraph_runtime.retry import retry_db
|
|
30
44
|
|
|
45
|
+
CrudThreads = GrpcThreads if FF_USE_CORE_API else Threads
|
|
46
|
+
|
|
31
47
|
|
|
32
48
|
@retry_db
|
|
33
49
|
async def create_thread(
|
|
@@ -37,28 +53,56 @@ async def create_thread(
|
|
|
37
53
|
payload = await request.json(ThreadCreate)
|
|
38
54
|
if thread_id := payload.get("thread_id"):
|
|
39
55
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
56
|
+
|
|
57
|
+
# Validate keep_latest TTL requires core API
|
|
58
|
+
ttl = payload.get("ttl")
|
|
59
|
+
if ttl and ttl.get("strategy") == "keep_latest" and not FF_USE_CORE_API:
|
|
60
|
+
raise HTTPException(
|
|
61
|
+
status_code=422,
|
|
62
|
+
detail="keep_latest TTL strategy requires FF_USE_CORE_API=true",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Encrypt metadata before storing
|
|
66
|
+
encrypted_payload = await encrypt_request(
|
|
67
|
+
payload,
|
|
68
|
+
"thread",
|
|
69
|
+
["metadata"],
|
|
70
|
+
)
|
|
71
|
+
|
|
40
72
|
async with connect() as conn:
|
|
41
73
|
thread_id = thread_id or str(uuid4())
|
|
42
|
-
iter = await
|
|
74
|
+
iter = await CrudThreads.put(
|
|
43
75
|
conn,
|
|
44
76
|
thread_id,
|
|
45
|
-
metadata=
|
|
77
|
+
metadata=encrypted_payload.get("metadata") or {},
|
|
46
78
|
if_exists=payload.get("if_exists") or "raise",
|
|
47
79
|
ttl=payload.get("ttl"),
|
|
48
80
|
)
|
|
49
|
-
|
|
81
|
+
config = {
|
|
82
|
+
"configurable": {
|
|
83
|
+
**get_configurable_headers(request.headers),
|
|
84
|
+
"thread_id": thread_id,
|
|
85
|
+
}
|
|
86
|
+
}
|
|
50
87
|
if supersteps := payload.get("supersteps"):
|
|
51
88
|
try:
|
|
52
89
|
await Threads.State.bulk(
|
|
53
90
|
conn,
|
|
54
|
-
config=
|
|
91
|
+
config=config,
|
|
55
92
|
supersteps=supersteps,
|
|
56
93
|
)
|
|
57
94
|
except HTTPException as e:
|
|
58
95
|
detail = f"Thread {thread_id} was created, but there were problems updating the state: {e.detail}"
|
|
59
96
|
raise HTTPException(status_code=201, detail=detail) from e
|
|
60
97
|
|
|
61
|
-
|
|
98
|
+
# Decrypt thread fields in response
|
|
99
|
+
thread = await fetchone(iter, not_found_code=409)
|
|
100
|
+
thread = await decrypt_response(
|
|
101
|
+
thread,
|
|
102
|
+
"thread",
|
|
103
|
+
THREAD_ENCRYPTION_FIELDS,
|
|
104
|
+
)
|
|
105
|
+
return ApiResponse(thread)
|
|
62
106
|
|
|
63
107
|
|
|
64
108
|
@retry_db
|
|
@@ -70,12 +114,14 @@ async def search_threads(
|
|
|
70
114
|
select = validate_select_columns(payload.get("select") or None, THREAD_FIELDS)
|
|
71
115
|
limit = int(payload.get("limit") or 10)
|
|
72
116
|
offset = int(payload.get("offset") or 0)
|
|
117
|
+
|
|
73
118
|
async with connect() as conn:
|
|
74
|
-
threads_iter, next_offset = await
|
|
119
|
+
threads_iter, next_offset = await CrudThreads.search(
|
|
75
120
|
conn,
|
|
76
121
|
status=payload.get("status"),
|
|
77
122
|
values=payload.get("values"),
|
|
78
123
|
metadata=payload.get("metadata"),
|
|
124
|
+
ids=payload.get("ids"),
|
|
79
125
|
limit=limit,
|
|
80
126
|
offset=offset,
|
|
81
127
|
sort_by=payload.get("sort_by"),
|
|
@@ -85,7 +131,15 @@ async def search_threads(
|
|
|
85
131
|
threads, response_headers = await get_pagination_headers(
|
|
86
132
|
threads_iter, next_offset, offset
|
|
87
133
|
)
|
|
88
|
-
|
|
134
|
+
|
|
135
|
+
# Decrypt metadata, values, interrupts, and error in all returned threads
|
|
136
|
+
decrypted_threads = await decrypt_responses(
|
|
137
|
+
threads,
|
|
138
|
+
"thread",
|
|
139
|
+
THREAD_ENCRYPTION_FIELDS,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return ApiResponse(decrypted_threads, headers=response_headers)
|
|
89
143
|
|
|
90
144
|
|
|
91
145
|
@retry_db
|
|
@@ -95,7 +149,7 @@ async def count_threads(
|
|
|
95
149
|
"""Count threads."""
|
|
96
150
|
payload = await request.json(ThreadCountRequest)
|
|
97
151
|
async with connect() as conn:
|
|
98
|
-
count = await
|
|
152
|
+
count = await CrudThreads.count(
|
|
99
153
|
conn,
|
|
100
154
|
status=payload.get("status"),
|
|
101
155
|
values=payload.get("values"),
|
|
@@ -113,10 +167,14 @@ async def get_thread_state(
|
|
|
113
167
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
114
168
|
subgraphs = request.query_params.get("subgraphs") in ("true", "True")
|
|
115
169
|
async with connect() as conn:
|
|
170
|
+
config = {
|
|
171
|
+
"configurable": {
|
|
172
|
+
**get_configurable_headers(request.headers),
|
|
173
|
+
"thread_id": thread_id,
|
|
174
|
+
}
|
|
175
|
+
}
|
|
116
176
|
state = state_snapshot_to_thread_state(
|
|
117
|
-
await Threads.State.get(
|
|
118
|
-
conn, {"configurable": {"thread_id": thread_id}}, subgraphs=subgraphs
|
|
119
|
-
)
|
|
177
|
+
await Threads.State.get(conn, config=config, subgraphs=subgraphs)
|
|
120
178
|
)
|
|
121
179
|
return ApiResponse(state)
|
|
122
180
|
|
|
@@ -130,15 +188,17 @@ async def get_thread_state_at_checkpoint(
|
|
|
130
188
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
131
189
|
checkpoint_id = request.path_params["checkpoint_id"]
|
|
132
190
|
async with connect() as conn:
|
|
191
|
+
config = {
|
|
192
|
+
"configurable": {
|
|
193
|
+
**get_configurable_headers(request.headers),
|
|
194
|
+
"thread_id": thread_id,
|
|
195
|
+
"checkpoint_id": checkpoint_id,
|
|
196
|
+
}
|
|
197
|
+
}
|
|
133
198
|
state = state_snapshot_to_thread_state(
|
|
134
199
|
await Threads.State.get(
|
|
135
200
|
conn,
|
|
136
|
-
|
|
137
|
-
"configurable": {
|
|
138
|
-
"thread_id": thread_id,
|
|
139
|
-
"checkpoint_id": checkpoint_id,
|
|
140
|
-
}
|
|
141
|
-
},
|
|
201
|
+
config=config,
|
|
142
202
|
subgraphs=request.query_params.get("subgraphs") in ("true", "True"),
|
|
143
203
|
)
|
|
144
204
|
)
|
|
@@ -154,10 +214,17 @@ async def get_thread_state_at_checkpoint_post(
|
|
|
154
214
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
155
215
|
payload = await request.json(ThreadStateCheckpointRequest)
|
|
156
216
|
async with connect() as conn:
|
|
217
|
+
config = {
|
|
218
|
+
"configurable": {
|
|
219
|
+
**payload["checkpoint"],
|
|
220
|
+
**get_configurable_headers(request.headers),
|
|
221
|
+
"thread_id": thread_id,
|
|
222
|
+
}
|
|
223
|
+
}
|
|
157
224
|
state = state_snapshot_to_thread_state(
|
|
158
225
|
await Threads.State.get(
|
|
159
226
|
conn,
|
|
160
|
-
|
|
227
|
+
config=config,
|
|
161
228
|
subgraphs=payload.get("subgraphs", False),
|
|
162
229
|
)
|
|
163
230
|
)
|
|
@@ -182,6 +249,7 @@ async def update_thread_state(
|
|
|
182
249
|
config["configurable"]["user_id"] = user_id
|
|
183
250
|
except AssertionError:
|
|
184
251
|
pass
|
|
252
|
+
config["configurable"].update(get_configurable_headers(request.headers))
|
|
185
253
|
async with connect() as conn:
|
|
186
254
|
inserted = await Threads.State.post(
|
|
187
255
|
conn,
|
|
@@ -205,7 +273,13 @@ async def get_thread_history(
|
|
|
205
273
|
except ValueError:
|
|
206
274
|
raise HTTPException(status_code=422, detail=f"Invalid limit {limit_}") from None
|
|
207
275
|
before = request.query_params.get("before")
|
|
208
|
-
config = {
|
|
276
|
+
config = {
|
|
277
|
+
"configurable": {
|
|
278
|
+
"thread_id": thread_id,
|
|
279
|
+
"checkpoint_ns": "",
|
|
280
|
+
**get_configurable_headers(request.headers),
|
|
281
|
+
}
|
|
282
|
+
}
|
|
209
283
|
async with connect() as conn:
|
|
210
284
|
states = [
|
|
211
285
|
state_snapshot_to_thread_state(c)
|
|
@@ -226,6 +300,7 @@ async def get_thread_history_post(
|
|
|
226
300
|
payload = await request.json(ThreadStateSearch)
|
|
227
301
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
228
302
|
config["configurable"].update(payload.get("checkpoint", {}))
|
|
303
|
+
config["configurable"].update(get_configurable_headers(request.headers))
|
|
229
304
|
async with connect() as conn:
|
|
230
305
|
states = [
|
|
231
306
|
state_snapshot_to_thread_state(c)
|
|
@@ -247,9 +322,23 @@ async def get_thread(
|
|
|
247
322
|
"""Get a thread by ID."""
|
|
248
323
|
thread_id = request.path_params["thread_id"]
|
|
249
324
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
325
|
+
|
|
326
|
+
# Parse include parameter for optional fields (e.g., ttl)
|
|
327
|
+
include_param = request.query_params.get("include", "")
|
|
328
|
+
include_fields = [f.strip() for f in include_param.split(",") if f.strip()]
|
|
329
|
+
include_ttl = "ttl" in include_fields
|
|
330
|
+
|
|
250
331
|
async with connect() as conn:
|
|
251
|
-
thread = await
|
|
252
|
-
|
|
332
|
+
thread = await CrudThreads.get(conn, thread_id, include_ttl=include_ttl)
|
|
333
|
+
|
|
334
|
+
# Decrypt metadata, values, interrupts, and error in response
|
|
335
|
+
thread_data = await fetchone(thread)
|
|
336
|
+
thread_data = await decrypt_response(
|
|
337
|
+
thread_data,
|
|
338
|
+
"thread",
|
|
339
|
+
THREAD_ENCRYPTION_FIELDS,
|
|
340
|
+
)
|
|
341
|
+
return ApiResponse(thread_data)
|
|
253
342
|
|
|
254
343
|
|
|
255
344
|
@retry_db
|
|
@@ -260,9 +349,37 @@ async def patch_thread(
|
|
|
260
349
|
thread_id = request.path_params["thread_id"]
|
|
261
350
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
262
351
|
payload = await request.json(ThreadPatch)
|
|
352
|
+
|
|
353
|
+
# Validate keep_latest TTL requires core API
|
|
354
|
+
ttl = payload.get("ttl")
|
|
355
|
+
if ttl and ttl.get("strategy") == "keep_latest" and not FF_USE_CORE_API:
|
|
356
|
+
raise HTTPException(
|
|
357
|
+
status_code=422,
|
|
358
|
+
detail="keep_latest TTL strategy requires FF_USE_CORE_API=true",
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Encrypt metadata before storing
|
|
362
|
+
encrypted_payload = await encrypt_request(
|
|
363
|
+
payload,
|
|
364
|
+
"thread",
|
|
365
|
+
["metadata"],
|
|
366
|
+
)
|
|
367
|
+
|
|
263
368
|
async with connect() as conn:
|
|
264
|
-
thread = await
|
|
265
|
-
|
|
369
|
+
thread = await CrudThreads.patch(
|
|
370
|
+
conn,
|
|
371
|
+
thread_id,
|
|
372
|
+
metadata=encrypted_payload.get("metadata") or {},
|
|
373
|
+
ttl=payload.get("ttl"),
|
|
374
|
+
)
|
|
375
|
+
thread_data = await fetchone(thread)
|
|
376
|
+
# Decrypt metadata, values, interrupts, and error in response
|
|
377
|
+
thread_data = await decrypt_response(
|
|
378
|
+
thread_data,
|
|
379
|
+
"thread",
|
|
380
|
+
THREAD_ENCRYPTION_FIELDS,
|
|
381
|
+
)
|
|
382
|
+
return ApiResponse(thread_data)
|
|
266
383
|
|
|
267
384
|
|
|
268
385
|
@retry_db
|
|
@@ -271,17 +388,60 @@ async def delete_thread(request: ApiRequest):
|
|
|
271
388
|
thread_id = request.path_params["thread_id"]
|
|
272
389
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
273
390
|
async with connect() as conn:
|
|
274
|
-
tid = await
|
|
391
|
+
tid = await CrudThreads.delete(conn, thread_id)
|
|
275
392
|
await fetchone(tid)
|
|
276
393
|
return Response(status_code=204)
|
|
277
394
|
|
|
278
395
|
|
|
396
|
+
@retry_db
|
|
397
|
+
async def prune_threads(request: ApiRequest):
|
|
398
|
+
"""Prune threads by ID."""
|
|
399
|
+
payload = await request.json(ThreadPruneRequest)
|
|
400
|
+
thread_ids = payload.get("thread_ids", [])
|
|
401
|
+
strategy = payload.get("strategy", "delete")
|
|
402
|
+
|
|
403
|
+
# Validate each thread_id is a valid UUID
|
|
404
|
+
for tid in thread_ids:
|
|
405
|
+
validate_uuid(tid, "Invalid thread ID: must be a UUID")
|
|
406
|
+
|
|
407
|
+
# Validate strategy
|
|
408
|
+
if strategy not in ("delete", "keep_latest"):
|
|
409
|
+
raise HTTPException(
|
|
410
|
+
status_code=422,
|
|
411
|
+
detail=f"Invalid strategy: {strategy}. Expected 'delete' or 'keep_latest'.",
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Empty list is a no-op, return early
|
|
415
|
+
if not thread_ids:
|
|
416
|
+
return ApiResponse({"pruned_count": 0})
|
|
417
|
+
|
|
418
|
+
if not FF_USE_CORE_API:
|
|
419
|
+
raise HTTPException(
|
|
420
|
+
status_code=422,
|
|
421
|
+
detail="Thread prune requires FF_USE_CORE_API=true",
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
pruned_count = await CrudThreads.prune(
|
|
425
|
+
thread_ids=thread_ids,
|
|
426
|
+
strategy=strategy,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
return ApiResponse({"pruned_count": pruned_count})
|
|
430
|
+
|
|
431
|
+
|
|
279
432
|
@retry_db
|
|
280
433
|
async def copy_thread(request: ApiRequest):
|
|
281
434
|
thread_id = request.path_params["thread_id"]
|
|
282
435
|
async with connect() as conn:
|
|
283
|
-
iter = await
|
|
284
|
-
|
|
436
|
+
iter = await CrudThreads.copy(conn, thread_id)
|
|
437
|
+
thread_data = await fetchone(iter, not_found_code=409)
|
|
438
|
+
# Decrypt metadata, values, interrupts, and error in response
|
|
439
|
+
thread_data = await decrypt_response(
|
|
440
|
+
thread_data,
|
|
441
|
+
"thread",
|
|
442
|
+
THREAD_ENCRYPTION_FIELDS,
|
|
443
|
+
)
|
|
444
|
+
return ApiResponse(thread_data)
|
|
285
445
|
|
|
286
446
|
|
|
287
447
|
@retry_db
|
|
@@ -293,10 +453,31 @@ async def join_thread_stream(request: ApiRequest):
|
|
|
293
453
|
validate_stream_id(
|
|
294
454
|
last_event_id, "Invalid last-event-id: must be a valid Redis stream ID"
|
|
295
455
|
)
|
|
456
|
+
|
|
457
|
+
# Parse stream_modes parameter - can be single string or comma-separated list
|
|
458
|
+
stream_modes_param = request.query_params.get("stream_modes")
|
|
459
|
+
if stream_modes_param:
|
|
460
|
+
if "," in stream_modes_param:
|
|
461
|
+
# Handle comma-separated list
|
|
462
|
+
stream_modes = [mode.strip() for mode in stream_modes_param.split(",")]
|
|
463
|
+
else:
|
|
464
|
+
# Handle single value
|
|
465
|
+
stream_modes = [stream_modes_param]
|
|
466
|
+
# Validate each mode
|
|
467
|
+
for mode in stream_modes:
|
|
468
|
+
if mode not in get_args(ThreadStreamMode):
|
|
469
|
+
raise HTTPException(
|
|
470
|
+
status_code=422, detail=f"Invalid stream mode: {mode}"
|
|
471
|
+
)
|
|
472
|
+
else:
|
|
473
|
+
# Default to run_modes
|
|
474
|
+
stream_modes = ["run_modes"]
|
|
475
|
+
|
|
296
476
|
return EventSourceResponse(
|
|
297
477
|
Threads.Stream.join(
|
|
298
478
|
thread_id,
|
|
299
479
|
last_event_id=last_event_id,
|
|
480
|
+
stream_modes=stream_modes,
|
|
300
481
|
),
|
|
301
482
|
)
|
|
302
483
|
|
|
@@ -305,6 +486,7 @@ threads_routes: list[BaseRoute] = [
|
|
|
305
486
|
ApiRoute("/threads", endpoint=create_thread, methods=["POST"]),
|
|
306
487
|
ApiRoute("/threads/search", endpoint=search_threads, methods=["POST"]),
|
|
307
488
|
ApiRoute("/threads/count", endpoint=count_threads, methods=["POST"]),
|
|
489
|
+
ApiRoute("/threads/prune", endpoint=prune_threads, methods=["POST"]),
|
|
308
490
|
ApiRoute("/threads/{thread_id}", endpoint=get_thread, methods=["GET"]),
|
|
309
491
|
ApiRoute("/threads/{thread_id}", endpoint=patch_thread, methods=["PATCH"]),
|
|
310
492
|
ApiRoute("/threads/{thread_id}", endpoint=delete_thread, methods=["DELETE"]),
|
langgraph_api/asgi_transport.py
CHANGED
|
@@ -25,7 +25,7 @@ def is_running_trio() -> bool:
|
|
|
25
25
|
# sniffio is a dependency of trio.
|
|
26
26
|
|
|
27
27
|
# See https://github.com/python-trio/trio/issues/2802
|
|
28
|
-
import sniffio
|
|
28
|
+
import sniffio # type: ignore[unresolved-import]
|
|
29
29
|
|
|
30
30
|
if sniffio.current_async_library() == "trio":
|
|
31
31
|
return True
|
|
@@ -84,7 +84,8 @@ class ASGITransport(ASGITransportBase):
|
|
|
84
84
|
) -> Response:
|
|
85
85
|
from langgraph_api.asyncio import call_soon_in_main_loop
|
|
86
86
|
|
|
87
|
-
|
|
87
|
+
if not isinstance(request.stream, AsyncByteStream):
|
|
88
|
+
raise ValueError("Request stream must be an AsyncByteStream")
|
|
88
89
|
|
|
89
90
|
# ASGI scope.
|
|
90
91
|
scope = {
|
|
@@ -133,14 +134,15 @@ class ASGITransport(ASGITransportBase):
|
|
|
133
134
|
nonlocal status_code, response_headers, response_started
|
|
134
135
|
|
|
135
136
|
if message["type"] == "http.response.start":
|
|
136
|
-
|
|
137
|
-
|
|
137
|
+
if response_started:
|
|
138
|
+
raise RuntimeError("Response already started")
|
|
138
139
|
status_code = message["status"]
|
|
139
140
|
response_headers = message.get("headers", [])
|
|
140
141
|
response_started = True
|
|
141
142
|
|
|
142
143
|
elif message["type"] == "http.response.body":
|
|
143
|
-
|
|
144
|
+
if response_complete.is_set():
|
|
145
|
+
raise RuntimeError("Response already complete")
|
|
144
146
|
body = message.get("body", b"")
|
|
145
147
|
more_body = message.get("more_body", False)
|
|
146
148
|
|
|
@@ -152,7 +154,7 @@ class ASGITransport(ASGITransportBase):
|
|
|
152
154
|
|
|
153
155
|
try:
|
|
154
156
|
await call_soon_in_main_loop(self.app(scope, receive, send))
|
|
155
|
-
except Exception:
|
|
157
|
+
except Exception:
|
|
156
158
|
if self.raise_app_exceptions:
|
|
157
159
|
raise
|
|
158
160
|
|
|
@@ -162,9 +164,12 @@ class ASGITransport(ASGITransportBase):
|
|
|
162
164
|
if response_headers is None:
|
|
163
165
|
response_headers = {}
|
|
164
166
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
167
|
+
if not response_complete.is_set():
|
|
168
|
+
raise RuntimeError("Response not complete")
|
|
169
|
+
if status_code is None:
|
|
170
|
+
raise RuntimeError("Status code not set")
|
|
171
|
+
if response_headers is None:
|
|
172
|
+
raise RuntimeError("Response headers not set")
|
|
168
173
|
|
|
169
174
|
stream = ASGIResponseStream(body_parts)
|
|
170
175
|
|
langgraph_api/asyncio.py
CHANGED
|
@@ -115,11 +115,16 @@ def create_task(
|
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
def run_coroutine_threadsafe(
|
|
118
|
-
coro: Coroutine[Any, Any, T],
|
|
118
|
+
coro: Coroutine[Any, Any, T],
|
|
119
|
+
ignore_exceptions: tuple[type[Exception], ...] = (),
|
|
120
|
+
*,
|
|
121
|
+
loop: asyncio.AbstractEventLoop | None = None,
|
|
119
122
|
) -> concurrent.futures.Future[T] | concurrent.futures.Future[None]:
|
|
120
|
-
if
|
|
123
|
+
if loop is None:
|
|
124
|
+
loop = _MAIN_LOOP
|
|
125
|
+
if loop is None:
|
|
121
126
|
raise RuntimeError("No event loop set")
|
|
122
|
-
future = asyncio.run_coroutine_threadsafe(coro,
|
|
127
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
123
128
|
future.add_done_callback(partial(_create_task_done_callback, ignore_exceptions))
|
|
124
129
|
return future
|
|
125
130
|
|
|
@@ -158,12 +163,15 @@ class SimpleTaskGroup(AbstractAsyncContextManager["SimpleTaskGroup"]):
|
|
|
158
163
|
self,
|
|
159
164
|
*coros: Coroutine[Any, Any, T],
|
|
160
165
|
cancel: bool = False,
|
|
166
|
+
cancel_event: asyncio.Event | None = None,
|
|
161
167
|
wait: bool = True,
|
|
162
168
|
taskset: set[asyncio.Task] | None = None,
|
|
163
169
|
taskgroup_name: str | None = None,
|
|
164
170
|
) -> None:
|
|
165
|
-
|
|
171
|
+
# Copy the taskset to avoid modifying the original set unintentionally (like in lifespan)
|
|
172
|
+
self.tasks = taskset.copy() if taskset is not None else set()
|
|
166
173
|
self.cancel = cancel
|
|
174
|
+
self.cancel_event = cancel_event
|
|
167
175
|
self.wait = wait
|
|
168
176
|
if taskset:
|
|
169
177
|
for task in tuple(taskset):
|
|
@@ -180,6 +188,8 @@ class SimpleTaskGroup(AbstractAsyncContextManager["SimpleTaskGroup"]):
|
|
|
180
188
|
try:
|
|
181
189
|
if (exc := task.exception()) and not isinstance(exc, ignore_exceptions):
|
|
182
190
|
logger.exception("asyncio.task failed in task group", exc_info=exc)
|
|
191
|
+
if self.cancel_event:
|
|
192
|
+
self.cancel_event.set()
|
|
183
193
|
except asyncio.CancelledError:
|
|
184
194
|
pass
|
|
185
195
|
|