langgraph-api 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- LICENSE +93 -0
- langgraph_api/__init__.py +0 -0
- langgraph_api/api/__init__.py +63 -0
- langgraph_api/api/assistants.py +326 -0
- langgraph_api/api/meta.py +71 -0
- langgraph_api/api/openapi.py +32 -0
- langgraph_api/api/runs.py +463 -0
- langgraph_api/api/store.py +116 -0
- langgraph_api/api/threads.py +263 -0
- langgraph_api/asyncio.py +201 -0
- langgraph_api/auth/__init__.py +0 -0
- langgraph_api/auth/langsmith/__init__.py +0 -0
- langgraph_api/auth/langsmith/backend.py +67 -0
- langgraph_api/auth/langsmith/client.py +145 -0
- langgraph_api/auth/middleware.py +41 -0
- langgraph_api/auth/noop.py +14 -0
- langgraph_api/cli.py +209 -0
- langgraph_api/config.py +70 -0
- langgraph_api/cron_scheduler.py +60 -0
- langgraph_api/errors.py +52 -0
- langgraph_api/graph.py +314 -0
- langgraph_api/http.py +168 -0
- langgraph_api/http_logger.py +89 -0
- langgraph_api/js/.gitignore +2 -0
- langgraph_api/js/build.mts +49 -0
- langgraph_api/js/client.mts +849 -0
- langgraph_api/js/global.d.ts +6 -0
- langgraph_api/js/package.json +33 -0
- langgraph_api/js/remote.py +673 -0
- langgraph_api/js/server_sent_events.py +126 -0
- langgraph_api/js/src/graph.mts +88 -0
- langgraph_api/js/src/hooks.mjs +12 -0
- langgraph_api/js/src/parser/parser.mts +443 -0
- langgraph_api/js/src/parser/parser.worker.mjs +12 -0
- langgraph_api/js/src/schema/types.mts +2136 -0
- langgraph_api/js/src/schema/types.template.mts +74 -0
- langgraph_api/js/src/utils/importMap.mts +85 -0
- langgraph_api/js/src/utils/pythonSchemas.mts +28 -0
- langgraph_api/js/src/utils/serde.mts +21 -0
- langgraph_api/js/tests/api.test.mts +1566 -0
- langgraph_api/js/tests/compose-postgres.yml +56 -0
- langgraph_api/js/tests/graphs/.gitignore +1 -0
- langgraph_api/js/tests/graphs/agent.mts +127 -0
- langgraph_api/js/tests/graphs/error.mts +17 -0
- langgraph_api/js/tests/graphs/langgraph.json +8 -0
- langgraph_api/js/tests/graphs/nested.mts +44 -0
- langgraph_api/js/tests/graphs/package.json +7 -0
- langgraph_api/js/tests/graphs/weather.mts +57 -0
- langgraph_api/js/tests/graphs/yarn.lock +159 -0
- langgraph_api/js/tests/parser.test.mts +870 -0
- langgraph_api/js/tests/utils.mts +17 -0
- langgraph_api/js/yarn.lock +1340 -0
- langgraph_api/lifespan.py +41 -0
- langgraph_api/logging.py +121 -0
- langgraph_api/metadata.py +101 -0
- langgraph_api/models/__init__.py +0 -0
- langgraph_api/models/run.py +229 -0
- langgraph_api/patch.py +42 -0
- langgraph_api/queue.py +245 -0
- langgraph_api/route.py +118 -0
- langgraph_api/schema.py +190 -0
- langgraph_api/serde.py +124 -0
- langgraph_api/server.py +48 -0
- langgraph_api/sse.py +118 -0
- langgraph_api/state.py +67 -0
- langgraph_api/stream.py +289 -0
- langgraph_api/utils.py +60 -0
- langgraph_api/validation.py +141 -0
- langgraph_api-0.0.1.dist-info/LICENSE +93 -0
- langgraph_api-0.0.1.dist-info/METADATA +26 -0
- langgraph_api-0.0.1.dist-info/RECORD +86 -0
- langgraph_api-0.0.1.dist-info/WHEEL +4 -0
- langgraph_api-0.0.1.dist-info/entry_points.txt +3 -0
- langgraph_license/__init__.py +0 -0
- langgraph_license/middleware.py +21 -0
- langgraph_license/validation.py +11 -0
- langgraph_storage/__init__.py +0 -0
- langgraph_storage/checkpoint.py +94 -0
- langgraph_storage/database.py +190 -0
- langgraph_storage/ops.py +1523 -0
- langgraph_storage/queue.py +108 -0
- langgraph_storage/retry.py +27 -0
- langgraph_storage/store.py +28 -0
- langgraph_storage/ttl_dict.py +54 -0
- logging.json +22 -0
- openapi.json +4304 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
from starlette.responses import Response
|
|
4
|
+
from starlette.routing import BaseRoute
|
|
5
|
+
|
|
6
|
+
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
7
|
+
from langgraph_api.state import state_snapshot_to_thread_state
|
|
8
|
+
from langgraph_api.utils import fetchone, validate_uuid
|
|
9
|
+
from langgraph_api.validation import (
|
|
10
|
+
ThreadCreate,
|
|
11
|
+
ThreadPatch,
|
|
12
|
+
ThreadSearchRequest,
|
|
13
|
+
ThreadStateCheckpointRequest,
|
|
14
|
+
ThreadStateSearch,
|
|
15
|
+
ThreadStateUpdate,
|
|
16
|
+
)
|
|
17
|
+
from langgraph_storage.database import connect
|
|
18
|
+
from langgraph_storage.ops import Threads
|
|
19
|
+
from langgraph_storage.retry import retry_db
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@retry_db
|
|
23
|
+
async def create_thread(
|
|
24
|
+
request: ApiRequest,
|
|
25
|
+
):
|
|
26
|
+
"""Create a thread."""
|
|
27
|
+
payload = await request.json(ThreadCreate)
|
|
28
|
+
if thread_id := payload.get("thread_id"):
|
|
29
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
30
|
+
async with connect() as conn:
|
|
31
|
+
iter = await Threads.put(
|
|
32
|
+
conn,
|
|
33
|
+
thread_id or uuid4(),
|
|
34
|
+
metadata=payload.get("metadata"),
|
|
35
|
+
if_exists=payload.get("if_exists") or "raise",
|
|
36
|
+
)
|
|
37
|
+
return ApiResponse(await fetchone(iter, not_found_code=409))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@retry_db
|
|
41
|
+
async def search_threads(
|
|
42
|
+
request: ApiRequest,
|
|
43
|
+
):
|
|
44
|
+
"""List threads."""
|
|
45
|
+
payload = await request.json(ThreadSearchRequest)
|
|
46
|
+
async with connect() as conn:
|
|
47
|
+
iter = await Threads.search(
|
|
48
|
+
conn,
|
|
49
|
+
status=payload.get("status"),
|
|
50
|
+
values=payload.get("values"),
|
|
51
|
+
metadata=payload.get("metadata"),
|
|
52
|
+
limit=payload.get("limit") or 10,
|
|
53
|
+
offset=payload.get("offset") or 0,
|
|
54
|
+
)
|
|
55
|
+
return ApiResponse([thread async for thread in iter])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@retry_db
|
|
59
|
+
async def get_thread_state(
|
|
60
|
+
request: ApiRequest,
|
|
61
|
+
):
|
|
62
|
+
"""Get state for a thread."""
|
|
63
|
+
thread_id = request.path_params["thread_id"]
|
|
64
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
65
|
+
subgraphs = request.query_params.get("subgraphs") in ("true", "True")
|
|
66
|
+
async with connect() as conn:
|
|
67
|
+
state = state_snapshot_to_thread_state(
|
|
68
|
+
await Threads.State.get(
|
|
69
|
+
conn, {"configurable": {"thread_id": thread_id}}, subgraphs=subgraphs
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
return ApiResponse(state)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@retry_db
|
|
76
|
+
async def get_thread_state_at_checkpoint(
|
|
77
|
+
request: ApiRequest,
|
|
78
|
+
):
|
|
79
|
+
"""Get state for a thread."""
|
|
80
|
+
thread_id = request.path_params["thread_id"]
|
|
81
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
82
|
+
checkpoint_id = request.path_params["checkpoint_id"]
|
|
83
|
+
async with connect() as conn:
|
|
84
|
+
state = state_snapshot_to_thread_state(
|
|
85
|
+
await Threads.State.get(
|
|
86
|
+
conn,
|
|
87
|
+
{
|
|
88
|
+
"configurable": {
|
|
89
|
+
"thread_id": thread_id,
|
|
90
|
+
"checkpoint_id": checkpoint_id,
|
|
91
|
+
}
|
|
92
|
+
},
|
|
93
|
+
subgraphs=request.query_params.get("subgraphs") in ("true", "True"),
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
return ApiResponse(state)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@retry_db
|
|
100
|
+
async def get_thread_state_at_checkpoint_post(
|
|
101
|
+
request: ApiRequest,
|
|
102
|
+
):
|
|
103
|
+
"""Get state for a thread."""
|
|
104
|
+
thread_id = request.path_params["thread_id"]
|
|
105
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
106
|
+
payload = await request.json(ThreadStateCheckpointRequest)
|
|
107
|
+
async with connect() as conn:
|
|
108
|
+
state = state_snapshot_to_thread_state(
|
|
109
|
+
await Threads.State.get(
|
|
110
|
+
conn,
|
|
111
|
+
{"configurable": {"thread_id": thread_id, **payload["checkpoint"]}},
|
|
112
|
+
subgraphs=payload.get("subgraphs", False),
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
return ApiResponse(state)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@retry_db
|
|
119
|
+
async def update_thread_state(
|
|
120
|
+
request: ApiRequest,
|
|
121
|
+
):
|
|
122
|
+
"""Add state to a thread."""
|
|
123
|
+
thread_id = request.path_params["thread_id"]
|
|
124
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
125
|
+
payload = await request.json(ThreadStateUpdate)
|
|
126
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
127
|
+
if payload.get("checkpoint_id"):
|
|
128
|
+
config["configurable"]["checkpoint_id"] = payload["checkpoint_id"]
|
|
129
|
+
if payload.get("checkpoint"):
|
|
130
|
+
config["configurable"].update(payload["checkpoint"])
|
|
131
|
+
try:
|
|
132
|
+
if user_id := request.user.display_name:
|
|
133
|
+
config["configurable"]["user_id"] = user_id
|
|
134
|
+
except AssertionError:
|
|
135
|
+
pass
|
|
136
|
+
async with connect() as conn:
|
|
137
|
+
inserted = await Threads.State.post(
|
|
138
|
+
conn,
|
|
139
|
+
config,
|
|
140
|
+
payload.get("values"),
|
|
141
|
+
payload.get("as_node"),
|
|
142
|
+
)
|
|
143
|
+
return ApiResponse(inserted)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@retry_db
|
|
147
|
+
async def get_thread_history(
|
|
148
|
+
request: ApiRequest,
|
|
149
|
+
):
|
|
150
|
+
"""Get all past states for a thread."""
|
|
151
|
+
thread_id = request.path_params["thread_id"]
|
|
152
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
153
|
+
limit = request.query_params.get("limit", 10)
|
|
154
|
+
before = request.query_params.get("before")
|
|
155
|
+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
156
|
+
async with connect() as conn:
|
|
157
|
+
states = [
|
|
158
|
+
state_snapshot_to_thread_state(c)
|
|
159
|
+
for c in await Threads.State.list(
|
|
160
|
+
conn, config=config, limit=limit, before=before
|
|
161
|
+
)
|
|
162
|
+
]
|
|
163
|
+
return ApiResponse(states)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@retry_db
|
|
167
|
+
async def get_thread_history_post(
|
|
168
|
+
request: ApiRequest,
|
|
169
|
+
):
|
|
170
|
+
"""Get all past states for a thread."""
|
|
171
|
+
thread_id = request.path_params["thread_id"]
|
|
172
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
173
|
+
payload = await request.json(ThreadStateSearch)
|
|
174
|
+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
175
|
+
config["configurable"].update(payload.get("checkpoint", {}))
|
|
176
|
+
async with connect() as conn:
|
|
177
|
+
states = [
|
|
178
|
+
state_snapshot_to_thread_state(c)
|
|
179
|
+
for c in await Threads.State.list(
|
|
180
|
+
conn,
|
|
181
|
+
config=config,
|
|
182
|
+
limit=int(payload.get("limit") or 10),
|
|
183
|
+
before=payload.get("before"),
|
|
184
|
+
metadata=payload.get("metadata"),
|
|
185
|
+
)
|
|
186
|
+
]
|
|
187
|
+
return ApiResponse(states)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@retry_db
|
|
191
|
+
async def get_thread(
|
|
192
|
+
request: ApiRequest,
|
|
193
|
+
):
|
|
194
|
+
"""Get a thread by ID."""
|
|
195
|
+
thread_id = request.path_params["thread_id"]
|
|
196
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
197
|
+
async with connect() as conn:
|
|
198
|
+
thread = await Threads.get(conn, thread_id)
|
|
199
|
+
return ApiResponse(await fetchone(thread))
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@retry_db
|
|
203
|
+
async def patch_thread(
|
|
204
|
+
request: ApiRequest,
|
|
205
|
+
):
|
|
206
|
+
"""Update a thread."""
|
|
207
|
+
thread_id = request.path_params["thread_id"]
|
|
208
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
209
|
+
payload = await request.json(ThreadPatch)
|
|
210
|
+
async with connect() as conn:
|
|
211
|
+
thread = await Threads.patch(conn, thread_id, metadata=payload["metadata"])
|
|
212
|
+
return ApiResponse(await fetchone(thread))
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@retry_db
|
|
216
|
+
async def delete_thread(request: ApiRequest):
|
|
217
|
+
"""Delete a thread by ID."""
|
|
218
|
+
thread_id = request.path_params["thread_id"]
|
|
219
|
+
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
220
|
+
async with connect() as conn:
|
|
221
|
+
tid = await Threads.delete(conn, thread_id)
|
|
222
|
+
await fetchone(tid)
|
|
223
|
+
return Response(status_code=204)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@retry_db
|
|
227
|
+
async def copy_thread(request: ApiRequest):
|
|
228
|
+
thread_id = request.path_params["thread_id"]
|
|
229
|
+
async with connect() as conn:
|
|
230
|
+
iter = await Threads.copy(conn, thread_id)
|
|
231
|
+
return ApiResponse(await fetchone(iter, not_found_code=409))
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
threads_routes: list[BaseRoute] = [
|
|
235
|
+
ApiRoute("/threads", endpoint=create_thread, methods=["POST"]),
|
|
236
|
+
ApiRoute("/threads/search", endpoint=search_threads, methods=["POST"]),
|
|
237
|
+
ApiRoute("/threads/{thread_id}", endpoint=get_thread, methods=["GET"]),
|
|
238
|
+
ApiRoute("/threads/{thread_id}", endpoint=patch_thread, methods=["PATCH"]),
|
|
239
|
+
ApiRoute("/threads/{thread_id}", endpoint=delete_thread, methods=["DELETE"]),
|
|
240
|
+
ApiRoute("/threads/{thread_id}/state", endpoint=get_thread_state, methods=["GET"]),
|
|
241
|
+
ApiRoute(
|
|
242
|
+
"/threads/{thread_id}/state", endpoint=update_thread_state, methods=["POST"]
|
|
243
|
+
),
|
|
244
|
+
ApiRoute(
|
|
245
|
+
"/threads/{thread_id}/history", endpoint=get_thread_history, methods=["GET"]
|
|
246
|
+
),
|
|
247
|
+
ApiRoute("/threads/{thread_id}/copy", endpoint=copy_thread, methods=["POST"]),
|
|
248
|
+
ApiRoute(
|
|
249
|
+
"/threads/{thread_id}/history",
|
|
250
|
+
endpoint=get_thread_history_post,
|
|
251
|
+
methods=["POST"],
|
|
252
|
+
),
|
|
253
|
+
ApiRoute(
|
|
254
|
+
"/threads/{thread_id}/state/{checkpoint_id}",
|
|
255
|
+
endpoint=get_thread_state_at_checkpoint,
|
|
256
|
+
methods=["GET"],
|
|
257
|
+
),
|
|
258
|
+
ApiRoute(
|
|
259
|
+
"/threads/{thread_id}/state/checkpoint",
|
|
260
|
+
endpoint=get_thread_state_at_checkpoint_post,
|
|
261
|
+
methods=["POST"],
|
|
262
|
+
),
|
|
263
|
+
]
|
langgraph_api/asyncio.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import AsyncIterator, Coroutine
|
|
3
|
+
from contextlib import AbstractAsyncContextManager
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
|
|
11
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def sleep_if_not_done(delay: float, done: asyncio.Event) -> None:
|
|
15
|
+
try:
|
|
16
|
+
await asyncio.wait_for(done.wait(), delay)
|
|
17
|
+
except TimeoutError:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ValueEvent(asyncio.Event):
|
|
22
|
+
def set(self, value: Any = True) -> None:
|
|
23
|
+
"""Set the internal flag to true. All coroutines waiting for it to
|
|
24
|
+
become set are awakened. Coroutine that call wait() once the flag is
|
|
25
|
+
true will not block at all.
|
|
26
|
+
"""
|
|
27
|
+
if not self._value:
|
|
28
|
+
self._value = value
|
|
29
|
+
|
|
30
|
+
for fut in self._waiters:
|
|
31
|
+
if not fut.done():
|
|
32
|
+
fut.set_result(value)
|
|
33
|
+
|
|
34
|
+
async def wait(self):
|
|
35
|
+
"""Block until the internal flag is set.
|
|
36
|
+
|
|
37
|
+
If the internal flag is set on entry, return value
|
|
38
|
+
immediately. Otherwise, block until another coroutine calls
|
|
39
|
+
set() to set the flag, then return the value.
|
|
40
|
+
"""
|
|
41
|
+
if self._value:
|
|
42
|
+
return self._value
|
|
43
|
+
|
|
44
|
+
fut = self._get_loop().create_future()
|
|
45
|
+
self._waiters.append(fut)
|
|
46
|
+
try:
|
|
47
|
+
return await fut
|
|
48
|
+
finally:
|
|
49
|
+
self._waiters.remove(fut)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
async def wait_if_not_done(coro: Coroutine[Any, Any, T], done: ValueEvent) -> T:
|
|
53
|
+
"""Wait for the coroutine to finish or the event to be set."""
|
|
54
|
+
try:
|
|
55
|
+
async with asyncio.TaskGroup() as tg:
|
|
56
|
+
coro_task = tg.create_task(coro)
|
|
57
|
+
done_task = tg.create_task(done.wait())
|
|
58
|
+
coro_task.add_done_callback(lambda _: done_task.cancel())
|
|
59
|
+
done_task.add_done_callback(lambda _: coro_task.cancel(done._value))
|
|
60
|
+
try:
|
|
61
|
+
return await coro_task
|
|
62
|
+
except asyncio.CancelledError as e:
|
|
63
|
+
if e.args and isinstance(e.args[0], Exception):
|
|
64
|
+
raise e.args[0] from None
|
|
65
|
+
raise
|
|
66
|
+
except ExceptionGroup as e:
|
|
67
|
+
raise e.exceptions[0] from None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
PENDING_TASKS = set()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _create_task_done_callback(
|
|
74
|
+
ignore_exceptions: tuple[Exception, ...], task: asyncio.Task
|
|
75
|
+
) -> None:
|
|
76
|
+
PENDING_TASKS.remove(task)
|
|
77
|
+
try:
|
|
78
|
+
if exc := task.exception():
|
|
79
|
+
if not isinstance(exc, ignore_exceptions):
|
|
80
|
+
logger.exception("Background task failed", exc_info=exc)
|
|
81
|
+
except asyncio.CancelledError:
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def create_task(
|
|
86
|
+
coro: Coroutine[Any, Any, T], ignore_exceptions: tuple[Exception, ...] = ()
|
|
87
|
+
) -> asyncio.Task[T]:
|
|
88
|
+
"""Create a new task in the current task group and return it."""
|
|
89
|
+
task = asyncio.create_task(coro)
|
|
90
|
+
PENDING_TASKS.add(task)
|
|
91
|
+
task.add_done_callback(partial(_create_task_done_callback, ignore_exceptions))
|
|
92
|
+
return task
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SimpleTaskGroup(AbstractAsyncContextManager["SimpleTaskGroup"]):
|
|
96
|
+
"""An async task group that can be configured to wait and/or cancel tasks on exit.
|
|
97
|
+
|
|
98
|
+
asyncio.TaskGroup and anyio.TaskGroup both expect enter and exit to be called
|
|
99
|
+
in the same asyncio task, which is not true for our use case, where exit is
|
|
100
|
+
shielded from cancellation."""
|
|
101
|
+
|
|
102
|
+
tasks: set[asyncio.Task]
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self, *coros: Coroutine[Any, Any, T], cancel: bool = False, wait: bool = True
|
|
106
|
+
) -> None:
|
|
107
|
+
self.tasks = set()
|
|
108
|
+
self.cancel = cancel
|
|
109
|
+
self.wait = wait
|
|
110
|
+
for coro in coros:
|
|
111
|
+
self.create_task(coro)
|
|
112
|
+
|
|
113
|
+
def _create_task_done_callback(
|
|
114
|
+
self, ignore_exceptions: tuple[Exception, ...], task: asyncio.Task
|
|
115
|
+
) -> None:
|
|
116
|
+
try:
|
|
117
|
+
self.tasks.remove(task)
|
|
118
|
+
except AttributeError:
|
|
119
|
+
pass
|
|
120
|
+
try:
|
|
121
|
+
if exc := task.exception():
|
|
122
|
+
if not isinstance(exc, ignore_exceptions):
|
|
123
|
+
logger.exception("Background task failed", exc_info=exc)
|
|
124
|
+
except asyncio.CancelledError:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
def create_task(
|
|
128
|
+
self,
|
|
129
|
+
coro: Coroutine[Any, Any, T],
|
|
130
|
+
ignore_exceptions: tuple[Exception, ...] = (),
|
|
131
|
+
) -> asyncio.Task[T]:
|
|
132
|
+
"""Create a new task in the current task group and return it."""
|
|
133
|
+
task = asyncio.create_task(coro)
|
|
134
|
+
self.tasks.add(task)
|
|
135
|
+
task.add_done_callback(
|
|
136
|
+
partial(self._create_task_done_callback, ignore_exceptions)
|
|
137
|
+
)
|
|
138
|
+
return task
|
|
139
|
+
|
|
140
|
+
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
|
141
|
+
tasks = self.tasks
|
|
142
|
+
# break reference cycles between tasks and task group
|
|
143
|
+
del self.tasks
|
|
144
|
+
# cancel all tasks
|
|
145
|
+
if self.cancel:
|
|
146
|
+
for task in tasks:
|
|
147
|
+
task.cancel()
|
|
148
|
+
# wait for all tasks
|
|
149
|
+
if self.wait:
|
|
150
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def to_aiter(*args: T) -> AsyncIterator[T]:
|
|
154
|
+
async def agen():
|
|
155
|
+
for arg in args:
|
|
156
|
+
yield arg
|
|
157
|
+
|
|
158
|
+
return agen()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
V = TypeVar("V")
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class aclosing(Generic[V], AbstractAsyncContextManager):
|
|
165
|
+
"""Async context manager for safely finalizing an asynchronously cleaned-up
|
|
166
|
+
resource such as an async generator, calling its ``aclose()`` method.
|
|
167
|
+
|
|
168
|
+
Code like this:
|
|
169
|
+
|
|
170
|
+
async with aclosing(<module>.fetch(<arguments>)) as agen:
|
|
171
|
+
<block>
|
|
172
|
+
|
|
173
|
+
is equivalent to this:
|
|
174
|
+
|
|
175
|
+
agen = <module>.fetch(<arguments>)
|
|
176
|
+
try:
|
|
177
|
+
<block>
|
|
178
|
+
finally:
|
|
179
|
+
await agen.aclose()
|
|
180
|
+
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(self, thing: V):
|
|
184
|
+
self.thing = thing
|
|
185
|
+
|
|
186
|
+
async def __aenter__(self) -> V:
|
|
187
|
+
return self.thing
|
|
188
|
+
|
|
189
|
+
async def __aexit__(self, *exc_info):
|
|
190
|
+
await self.thing.aclose()
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
async def aclosing_aiter(aiter: AsyncIterator[T]) -> AsyncIterator[T]:
|
|
194
|
+
if hasattr(aiter, "__aenter__"):
|
|
195
|
+
async with aiter:
|
|
196
|
+
async for item in aiter:
|
|
197
|
+
yield item
|
|
198
|
+
else:
|
|
199
|
+
async with aclosing(aiter):
|
|
200
|
+
async for item in aiter:
|
|
201
|
+
yield item
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import NotRequired, TypedDict
|
|
2
|
+
|
|
3
|
+
from starlette.authentication import (
|
|
4
|
+
AuthCredentials,
|
|
5
|
+
AuthenticationBackend,
|
|
6
|
+
AuthenticationError,
|
|
7
|
+
BaseUser,
|
|
8
|
+
SimpleUser,
|
|
9
|
+
)
|
|
10
|
+
from starlette.requests import HTTPConnection
|
|
11
|
+
|
|
12
|
+
from langgraph_api.auth.langsmith.client import auth_client
|
|
13
|
+
from langgraph_api.config import (
|
|
14
|
+
LANGSMITH_AUTH_VERIFY_TENANT_ID,
|
|
15
|
+
LANGSMITH_TENANT_ID,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AuthDict(TypedDict):
|
|
20
|
+
organization_id: str
|
|
21
|
+
tenant_id: str
|
|
22
|
+
user_id: NotRequired[str]
|
|
23
|
+
user_email: NotRequired[str]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LangsmithAuthBackend(AuthenticationBackend):
|
|
27
|
+
async def authenticate(
|
|
28
|
+
self, conn: HTTPConnection
|
|
29
|
+
) -> tuple[AuthCredentials, BaseUser] | None:
|
|
30
|
+
headers = [
|
|
31
|
+
("Authorization", conn.headers.get("Authorization")),
|
|
32
|
+
("X-Tenant-Id", conn.headers.get("x-tenant-id")),
|
|
33
|
+
("X-Api-Key", conn.headers.get("x-api-key")),
|
|
34
|
+
("X-Service-Key", conn.headers.get("x-service-key")),
|
|
35
|
+
]
|
|
36
|
+
if not any(h[1] for h in headers):
|
|
37
|
+
raise AuthenticationError("Missing authentication headers")
|
|
38
|
+
async with auth_client() as auth:
|
|
39
|
+
if not LANGSMITH_AUTH_VERIFY_TENANT_ID and not conn.headers.get(
|
|
40
|
+
"x-api-key"
|
|
41
|
+
):
|
|
42
|
+
# when LANGSMITH_AUTH_VERIFY_TENANT_ID is false, we allow
|
|
43
|
+
# any valid bearer token to pass through
|
|
44
|
+
# api key auth is always required to match the tenant id
|
|
45
|
+
res = await auth.get(
|
|
46
|
+
"/auth/verify", headers=[h for h in headers if h[1] is not None]
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
res = await auth.get(
|
|
50
|
+
"/auth/public", headers=[h for h in headers if h[1] is not None]
|
|
51
|
+
)
|
|
52
|
+
if res.status_code == 401:
|
|
53
|
+
raise AuthenticationError("Invalid token")
|
|
54
|
+
elif res.status_code == 403:
|
|
55
|
+
raise AuthenticationError("Forbidden")
|
|
56
|
+
else:
|
|
57
|
+
res.raise_for_status()
|
|
58
|
+
auth_dict: AuthDict = res.json()
|
|
59
|
+
|
|
60
|
+
# If tenant id verification is disabled, the bearer token requests
|
|
61
|
+
# are not required to match the tenant id. Api key requests are
|
|
62
|
+
# always required to match the tenant id.
|
|
63
|
+
if LANGSMITH_AUTH_VERIFY_TENANT_ID or conn.headers.get("x-api-key"):
|
|
64
|
+
if auth_dict["tenant_id"] != LANGSMITH_TENANT_ID:
|
|
65
|
+
raise AuthenticationError("Invalid tenant ID")
|
|
66
|
+
|
|
67
|
+
return AuthCredentials(["authenticated"]), SimpleUser(auth_dict.get("user_id"))
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from collections.abc import AsyncGenerator
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from httpx._types import HeaderTypes, QueryParamTypes, RequestData
|
|
7
|
+
from tenacity import retry
|
|
8
|
+
from tenacity.retry import retry_if_exception
|
|
9
|
+
from tenacity.stop import stop_after_attempt
|
|
10
|
+
from tenacity.wait import wait_exponential_jitter
|
|
11
|
+
|
|
12
|
+
from langgraph_api.config import LANGSMITH_AUTH_ENDPOINT
|
|
13
|
+
|
|
14
|
+
_client: "JsonHttpClient"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def is_retriable_error(exception: Exception) -> bool:
|
|
18
|
+
if isinstance(exception, httpx.TransportError):
|
|
19
|
+
return True
|
|
20
|
+
if isinstance(exception, httpx.HTTPStatusError):
|
|
21
|
+
if exception.response.status_code > 499:
|
|
22
|
+
return True
|
|
23
|
+
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
retry_httpx = retry(
|
|
28
|
+
reraise=True,
|
|
29
|
+
retry=retry_if_exception(is_retriable_error),
|
|
30
|
+
wait=wait_exponential_jitter(),
|
|
31
|
+
stop=stop_after_attempt(3),
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class JsonHttpClient:
|
|
36
|
+
"""HTTPX client for JSON requests, with retries."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, client: httpx.AsyncClient) -> None:
|
|
39
|
+
"""Initialize the auth client."""
|
|
40
|
+
self.client = client
|
|
41
|
+
|
|
42
|
+
async def _get(
|
|
43
|
+
self,
|
|
44
|
+
path: str,
|
|
45
|
+
*,
|
|
46
|
+
params: QueryParamTypes | None = None,
|
|
47
|
+
headers: HeaderTypes | None = None,
|
|
48
|
+
) -> httpx.Response:
|
|
49
|
+
return await self.client.get(path, params=params, headers=headers)
|
|
50
|
+
|
|
51
|
+
@retry_httpx
|
|
52
|
+
async def get(
|
|
53
|
+
self,
|
|
54
|
+
path: str,
|
|
55
|
+
*,
|
|
56
|
+
params: QueryParamTypes | None = None,
|
|
57
|
+
headers: HeaderTypes | None = None,
|
|
58
|
+
) -> httpx.Response:
|
|
59
|
+
return await self.client.get(path, params=params, headers=headers)
|
|
60
|
+
|
|
61
|
+
async def _post(
|
|
62
|
+
self,
|
|
63
|
+
path: str,
|
|
64
|
+
*,
|
|
65
|
+
data: RequestData | None = None,
|
|
66
|
+
json: Any | None = None,
|
|
67
|
+
params: QueryParamTypes | None = None,
|
|
68
|
+
headers: HeaderTypes | None = None,
|
|
69
|
+
) -> httpx.Response:
|
|
70
|
+
return await self.client.post(
|
|
71
|
+
path, data=data, json=json, params=params, headers=headers
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@retry_httpx
|
|
75
|
+
async def post(
|
|
76
|
+
self,
|
|
77
|
+
path: str,
|
|
78
|
+
*,
|
|
79
|
+
data: RequestData | None = None,
|
|
80
|
+
json: Any | None = None,
|
|
81
|
+
params: QueryParamTypes | None = None,
|
|
82
|
+
headers: HeaderTypes | None = None,
|
|
83
|
+
) -> httpx.Response:
|
|
84
|
+
return await self.client.post(
|
|
85
|
+
path, data=data, json=json, params=params, headers=headers
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def create_client() -> JsonHttpClient:
|
|
90
|
+
"""Create the auth http client."""
|
|
91
|
+
return JsonHttpClient(
|
|
92
|
+
httpx.AsyncClient(
|
|
93
|
+
transport=httpx.AsyncHTTPTransport(
|
|
94
|
+
retries=5, # this applies only to ConnectError, ConnectTimeout
|
|
95
|
+
limits=httpx.Limits(
|
|
96
|
+
max_keepalive_connections=40,
|
|
97
|
+
keepalive_expiry=240.0,
|
|
98
|
+
),
|
|
99
|
+
),
|
|
100
|
+
timeout=httpx.Timeout(2.0),
|
|
101
|
+
base_url=LANGSMITH_AUTH_ENDPOINT,
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
async def close_auth_client() -> None:
|
|
107
|
+
"""Close the auth http client."""
|
|
108
|
+
global _client
|
|
109
|
+
try:
|
|
110
|
+
await _client.client.aclose()
|
|
111
|
+
except NameError:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
async def initialize_auth_client() -> None:
|
|
116
|
+
"""Initialize the auth http client."""
|
|
117
|
+
await close_auth_client()
|
|
118
|
+
global _client
|
|
119
|
+
_client = create_client()
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@asynccontextmanager
|
|
123
|
+
async def auth_client() -> AsyncGenerator[JsonHttpClient, None]:
|
|
124
|
+
"""Get the auth http client."""
|
|
125
|
+
# pytest does something funny with event loops,
|
|
126
|
+
# so we can't use a global pool for tests
|
|
127
|
+
if LANGSMITH_AUTH_ENDPOINT.startswith("http://localhost"):
|
|
128
|
+
client = create_client()
|
|
129
|
+
try:
|
|
130
|
+
yield client
|
|
131
|
+
finally:
|
|
132
|
+
await client.client.aclose()
|
|
133
|
+
else:
|
|
134
|
+
try:
|
|
135
|
+
if not _client.client.is_closed:
|
|
136
|
+
found = True
|
|
137
|
+
else:
|
|
138
|
+
found = False
|
|
139
|
+
except NameError:
|
|
140
|
+
found = False
|
|
141
|
+
if found:
|
|
142
|
+
yield _client
|
|
143
|
+
else:
|
|
144
|
+
await initialize_auth_client()
|
|
145
|
+
yield _client
|