langgraph-api 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

Files changed (86) hide show
  1. LICENSE +93 -0
  2. langgraph_api/__init__.py +0 -0
  3. langgraph_api/api/__init__.py +63 -0
  4. langgraph_api/api/assistants.py +326 -0
  5. langgraph_api/api/meta.py +71 -0
  6. langgraph_api/api/openapi.py +32 -0
  7. langgraph_api/api/runs.py +463 -0
  8. langgraph_api/api/store.py +116 -0
  9. langgraph_api/api/threads.py +263 -0
  10. langgraph_api/asyncio.py +201 -0
  11. langgraph_api/auth/__init__.py +0 -0
  12. langgraph_api/auth/langsmith/__init__.py +0 -0
  13. langgraph_api/auth/langsmith/backend.py +67 -0
  14. langgraph_api/auth/langsmith/client.py +145 -0
  15. langgraph_api/auth/middleware.py +41 -0
  16. langgraph_api/auth/noop.py +14 -0
  17. langgraph_api/cli.py +209 -0
  18. langgraph_api/config.py +70 -0
  19. langgraph_api/cron_scheduler.py +60 -0
  20. langgraph_api/errors.py +52 -0
  21. langgraph_api/graph.py +314 -0
  22. langgraph_api/http.py +168 -0
  23. langgraph_api/http_logger.py +89 -0
  24. langgraph_api/js/.gitignore +2 -0
  25. langgraph_api/js/build.mts +49 -0
  26. langgraph_api/js/client.mts +849 -0
  27. langgraph_api/js/global.d.ts +6 -0
  28. langgraph_api/js/package.json +33 -0
  29. langgraph_api/js/remote.py +673 -0
  30. langgraph_api/js/server_sent_events.py +126 -0
  31. langgraph_api/js/src/graph.mts +88 -0
  32. langgraph_api/js/src/hooks.mjs +12 -0
  33. langgraph_api/js/src/parser/parser.mts +443 -0
  34. langgraph_api/js/src/parser/parser.worker.mjs +12 -0
  35. langgraph_api/js/src/schema/types.mts +2136 -0
  36. langgraph_api/js/src/schema/types.template.mts +74 -0
  37. langgraph_api/js/src/utils/importMap.mts +85 -0
  38. langgraph_api/js/src/utils/pythonSchemas.mts +28 -0
  39. langgraph_api/js/src/utils/serde.mts +21 -0
  40. langgraph_api/js/tests/api.test.mts +1566 -0
  41. langgraph_api/js/tests/compose-postgres.yml +56 -0
  42. langgraph_api/js/tests/graphs/.gitignore +1 -0
  43. langgraph_api/js/tests/graphs/agent.mts +127 -0
  44. langgraph_api/js/tests/graphs/error.mts +17 -0
  45. langgraph_api/js/tests/graphs/langgraph.json +8 -0
  46. langgraph_api/js/tests/graphs/nested.mts +44 -0
  47. langgraph_api/js/tests/graphs/package.json +7 -0
  48. langgraph_api/js/tests/graphs/weather.mts +57 -0
  49. langgraph_api/js/tests/graphs/yarn.lock +159 -0
  50. langgraph_api/js/tests/parser.test.mts +870 -0
  51. langgraph_api/js/tests/utils.mts +17 -0
  52. langgraph_api/js/yarn.lock +1340 -0
  53. langgraph_api/lifespan.py +41 -0
  54. langgraph_api/logging.py +121 -0
  55. langgraph_api/metadata.py +101 -0
  56. langgraph_api/models/__init__.py +0 -0
  57. langgraph_api/models/run.py +229 -0
  58. langgraph_api/patch.py +42 -0
  59. langgraph_api/queue.py +245 -0
  60. langgraph_api/route.py +118 -0
  61. langgraph_api/schema.py +190 -0
  62. langgraph_api/serde.py +124 -0
  63. langgraph_api/server.py +48 -0
  64. langgraph_api/sse.py +118 -0
  65. langgraph_api/state.py +67 -0
  66. langgraph_api/stream.py +289 -0
  67. langgraph_api/utils.py +60 -0
  68. langgraph_api/validation.py +141 -0
  69. langgraph_api-0.0.1.dist-info/LICENSE +93 -0
  70. langgraph_api-0.0.1.dist-info/METADATA +26 -0
  71. langgraph_api-0.0.1.dist-info/RECORD +86 -0
  72. langgraph_api-0.0.1.dist-info/WHEEL +4 -0
  73. langgraph_api-0.0.1.dist-info/entry_points.txt +3 -0
  74. langgraph_license/__init__.py +0 -0
  75. langgraph_license/middleware.py +21 -0
  76. langgraph_license/validation.py +11 -0
  77. langgraph_storage/__init__.py +0 -0
  78. langgraph_storage/checkpoint.py +94 -0
  79. langgraph_storage/database.py +190 -0
  80. langgraph_storage/ops.py +1523 -0
  81. langgraph_storage/queue.py +108 -0
  82. langgraph_storage/retry.py +27 -0
  83. langgraph_storage/store.py +28 -0
  84. langgraph_storage/ttl_dict.py +54 -0
  85. logging.json +22 -0
  86. openapi.json +4304 -0
@@ -0,0 +1,263 @@
1
+ from uuid import uuid4
2
+
3
+ from starlette.responses import Response
4
+ from starlette.routing import BaseRoute
5
+
6
+ from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
7
+ from langgraph_api.state import state_snapshot_to_thread_state
8
+ from langgraph_api.utils import fetchone, validate_uuid
9
+ from langgraph_api.validation import (
10
+ ThreadCreate,
11
+ ThreadPatch,
12
+ ThreadSearchRequest,
13
+ ThreadStateCheckpointRequest,
14
+ ThreadStateSearch,
15
+ ThreadStateUpdate,
16
+ )
17
+ from langgraph_storage.database import connect
18
+ from langgraph_storage.ops import Threads
19
+ from langgraph_storage.retry import retry_db
20
+
21
+
22
+ @retry_db
23
+ async def create_thread(
24
+ request: ApiRequest,
25
+ ):
26
+ """Create a thread."""
27
+ payload = await request.json(ThreadCreate)
28
+ if thread_id := payload.get("thread_id"):
29
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
30
+ async with connect() as conn:
31
+ iter = await Threads.put(
32
+ conn,
33
+ thread_id or uuid4(),
34
+ metadata=payload.get("metadata"),
35
+ if_exists=payload.get("if_exists") or "raise",
36
+ )
37
+ return ApiResponse(await fetchone(iter, not_found_code=409))
38
+
39
+
40
+ @retry_db
41
+ async def search_threads(
42
+ request: ApiRequest,
43
+ ):
44
+ """List threads."""
45
+ payload = await request.json(ThreadSearchRequest)
46
+ async with connect() as conn:
47
+ iter = await Threads.search(
48
+ conn,
49
+ status=payload.get("status"),
50
+ values=payload.get("values"),
51
+ metadata=payload.get("metadata"),
52
+ limit=payload.get("limit") or 10,
53
+ offset=payload.get("offset") or 0,
54
+ )
55
+ return ApiResponse([thread async for thread in iter])
56
+
57
+
58
+ @retry_db
59
+ async def get_thread_state(
60
+ request: ApiRequest,
61
+ ):
62
+ """Get state for a thread."""
63
+ thread_id = request.path_params["thread_id"]
64
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
65
+ subgraphs = request.query_params.get("subgraphs") in ("true", "True")
66
+ async with connect() as conn:
67
+ state = state_snapshot_to_thread_state(
68
+ await Threads.State.get(
69
+ conn, {"configurable": {"thread_id": thread_id}}, subgraphs=subgraphs
70
+ )
71
+ )
72
+ return ApiResponse(state)
73
+
74
+
75
+ @retry_db
76
+ async def get_thread_state_at_checkpoint(
77
+ request: ApiRequest,
78
+ ):
79
+ """Get state for a thread."""
80
+ thread_id = request.path_params["thread_id"]
81
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
82
+ checkpoint_id = request.path_params["checkpoint_id"]
83
+ async with connect() as conn:
84
+ state = state_snapshot_to_thread_state(
85
+ await Threads.State.get(
86
+ conn,
87
+ {
88
+ "configurable": {
89
+ "thread_id": thread_id,
90
+ "checkpoint_id": checkpoint_id,
91
+ }
92
+ },
93
+ subgraphs=request.query_params.get("subgraphs") in ("true", "True"),
94
+ )
95
+ )
96
+ return ApiResponse(state)
97
+
98
+
99
+ @retry_db
100
+ async def get_thread_state_at_checkpoint_post(
101
+ request: ApiRequest,
102
+ ):
103
+ """Get state for a thread."""
104
+ thread_id = request.path_params["thread_id"]
105
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
106
+ payload = await request.json(ThreadStateCheckpointRequest)
107
+ async with connect() as conn:
108
+ state = state_snapshot_to_thread_state(
109
+ await Threads.State.get(
110
+ conn,
111
+ {"configurable": {"thread_id": thread_id, **payload["checkpoint"]}},
112
+ subgraphs=payload.get("subgraphs", False),
113
+ )
114
+ )
115
+ return ApiResponse(state)
116
+
117
+
118
+ @retry_db
119
+ async def update_thread_state(
120
+ request: ApiRequest,
121
+ ):
122
+ """Add state to a thread."""
123
+ thread_id = request.path_params["thread_id"]
124
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
125
+ payload = await request.json(ThreadStateUpdate)
126
+ config = {"configurable": {"thread_id": thread_id}}
127
+ if payload.get("checkpoint_id"):
128
+ config["configurable"]["checkpoint_id"] = payload["checkpoint_id"]
129
+ if payload.get("checkpoint"):
130
+ config["configurable"].update(payload["checkpoint"])
131
+ try:
132
+ if user_id := request.user.display_name:
133
+ config["configurable"]["user_id"] = user_id
134
+ except AssertionError:
135
+ pass
136
+ async with connect() as conn:
137
+ inserted = await Threads.State.post(
138
+ conn,
139
+ config,
140
+ payload.get("values"),
141
+ payload.get("as_node"),
142
+ )
143
+ return ApiResponse(inserted)
144
+
145
+
146
+ @retry_db
147
+ async def get_thread_history(
148
+ request: ApiRequest,
149
+ ):
150
+ """Get all past states for a thread."""
151
+ thread_id = request.path_params["thread_id"]
152
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
153
+ limit = request.query_params.get("limit", 10)
154
+ before = request.query_params.get("before")
155
+ config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
156
+ async with connect() as conn:
157
+ states = [
158
+ state_snapshot_to_thread_state(c)
159
+ for c in await Threads.State.list(
160
+ conn, config=config, limit=limit, before=before
161
+ )
162
+ ]
163
+ return ApiResponse(states)
164
+
165
+
166
+ @retry_db
167
+ async def get_thread_history_post(
168
+ request: ApiRequest,
169
+ ):
170
+ """Get all past states for a thread."""
171
+ thread_id = request.path_params["thread_id"]
172
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
173
+ payload = await request.json(ThreadStateSearch)
174
+ config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
175
+ config["configurable"].update(payload.get("checkpoint", {}))
176
+ async with connect() as conn:
177
+ states = [
178
+ state_snapshot_to_thread_state(c)
179
+ for c in await Threads.State.list(
180
+ conn,
181
+ config=config,
182
+ limit=int(payload.get("limit") or 10),
183
+ before=payload.get("before"),
184
+ metadata=payload.get("metadata"),
185
+ )
186
+ ]
187
+ return ApiResponse(states)
188
+
189
+
190
+ @retry_db
191
+ async def get_thread(
192
+ request: ApiRequest,
193
+ ):
194
+ """Get a thread by ID."""
195
+ thread_id = request.path_params["thread_id"]
196
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
197
+ async with connect() as conn:
198
+ thread = await Threads.get(conn, thread_id)
199
+ return ApiResponse(await fetchone(thread))
200
+
201
+
202
+ @retry_db
203
+ async def patch_thread(
204
+ request: ApiRequest,
205
+ ):
206
+ """Update a thread."""
207
+ thread_id = request.path_params["thread_id"]
208
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
209
+ payload = await request.json(ThreadPatch)
210
+ async with connect() as conn:
211
+ thread = await Threads.patch(conn, thread_id, metadata=payload["metadata"])
212
+ return ApiResponse(await fetchone(thread))
213
+
214
+
215
+ @retry_db
216
+ async def delete_thread(request: ApiRequest):
217
+ """Delete a thread by ID."""
218
+ thread_id = request.path_params["thread_id"]
219
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
220
+ async with connect() as conn:
221
+ tid = await Threads.delete(conn, thread_id)
222
+ await fetchone(tid)
223
+ return Response(status_code=204)
224
+
225
+
226
+ @retry_db
227
+ async def copy_thread(request: ApiRequest):
228
+ thread_id = request.path_params["thread_id"]
229
+ async with connect() as conn:
230
+ iter = await Threads.copy(conn, thread_id)
231
+ return ApiResponse(await fetchone(iter, not_found_code=409))
232
+
233
+
234
+ threads_routes: list[BaseRoute] = [
235
+ ApiRoute("/threads", endpoint=create_thread, methods=["POST"]),
236
+ ApiRoute("/threads/search", endpoint=search_threads, methods=["POST"]),
237
+ ApiRoute("/threads/{thread_id}", endpoint=get_thread, methods=["GET"]),
238
+ ApiRoute("/threads/{thread_id}", endpoint=patch_thread, methods=["PATCH"]),
239
+ ApiRoute("/threads/{thread_id}", endpoint=delete_thread, methods=["DELETE"]),
240
+ ApiRoute("/threads/{thread_id}/state", endpoint=get_thread_state, methods=["GET"]),
241
+ ApiRoute(
242
+ "/threads/{thread_id}/state", endpoint=update_thread_state, methods=["POST"]
243
+ ),
244
+ ApiRoute(
245
+ "/threads/{thread_id}/history", endpoint=get_thread_history, methods=["GET"]
246
+ ),
247
+ ApiRoute("/threads/{thread_id}/copy", endpoint=copy_thread, methods=["POST"]),
248
+ ApiRoute(
249
+ "/threads/{thread_id}/history",
250
+ endpoint=get_thread_history_post,
251
+ methods=["POST"],
252
+ ),
253
+ ApiRoute(
254
+ "/threads/{thread_id}/state/{checkpoint_id}",
255
+ endpoint=get_thread_state_at_checkpoint,
256
+ methods=["GET"],
257
+ ),
258
+ ApiRoute(
259
+ "/threads/{thread_id}/state/checkpoint",
260
+ endpoint=get_thread_state_at_checkpoint_post,
261
+ methods=["POST"],
262
+ ),
263
+ ]
@@ -0,0 +1,201 @@
1
+ import asyncio
2
+ from collections.abc import AsyncIterator, Coroutine
3
+ from contextlib import AbstractAsyncContextManager
4
+ from functools import partial
5
+ from typing import Any, Generic, TypeVar
6
+
7
+ import structlog
8
+
9
+ T = TypeVar("T")
10
+
11
+ logger = structlog.stdlib.get_logger(__name__)
12
+
13
+
14
+ async def sleep_if_not_done(delay: float, done: asyncio.Event) -> None:
15
+ try:
16
+ await asyncio.wait_for(done.wait(), delay)
17
+ except TimeoutError:
18
+ pass
19
+
20
+
21
+ class ValueEvent(asyncio.Event):
22
+ def set(self, value: Any = True) -> None:
23
+ """Set the internal flag to true. All coroutines waiting for it to
24
+ become set are awakened. Coroutine that call wait() once the flag is
25
+ true will not block at all.
26
+ """
27
+ if not self._value:
28
+ self._value = value
29
+
30
+ for fut in self._waiters:
31
+ if not fut.done():
32
+ fut.set_result(value)
33
+
34
+ async def wait(self):
35
+ """Block until the internal flag is set.
36
+
37
+ If the internal flag is set on entry, return value
38
+ immediately. Otherwise, block until another coroutine calls
39
+ set() to set the flag, then return the value.
40
+ """
41
+ if self._value:
42
+ return self._value
43
+
44
+ fut = self._get_loop().create_future()
45
+ self._waiters.append(fut)
46
+ try:
47
+ return await fut
48
+ finally:
49
+ self._waiters.remove(fut)
50
+
51
+
52
+ async def wait_if_not_done(coro: Coroutine[Any, Any, T], done: ValueEvent) -> T:
53
+ """Wait for the coroutine to finish or the event to be set."""
54
+ try:
55
+ async with asyncio.TaskGroup() as tg:
56
+ coro_task = tg.create_task(coro)
57
+ done_task = tg.create_task(done.wait())
58
+ coro_task.add_done_callback(lambda _: done_task.cancel())
59
+ done_task.add_done_callback(lambda _: coro_task.cancel(done._value))
60
+ try:
61
+ return await coro_task
62
+ except asyncio.CancelledError as e:
63
+ if e.args and isinstance(e.args[0], Exception):
64
+ raise e.args[0] from None
65
+ raise
66
+ except ExceptionGroup as e:
67
+ raise e.exceptions[0] from None
68
+
69
+
70
+ PENDING_TASKS = set()
71
+
72
+
73
+ def _create_task_done_callback(
74
+ ignore_exceptions: tuple[Exception, ...], task: asyncio.Task
75
+ ) -> None:
76
+ PENDING_TASKS.remove(task)
77
+ try:
78
+ if exc := task.exception():
79
+ if not isinstance(exc, ignore_exceptions):
80
+ logger.exception("Background task failed", exc_info=exc)
81
+ except asyncio.CancelledError:
82
+ pass
83
+
84
+
85
+ def create_task(
86
+ coro: Coroutine[Any, Any, T], ignore_exceptions: tuple[Exception, ...] = ()
87
+ ) -> asyncio.Task[T]:
88
+ """Create a new task in the current task group and return it."""
89
+ task = asyncio.create_task(coro)
90
+ PENDING_TASKS.add(task)
91
+ task.add_done_callback(partial(_create_task_done_callback, ignore_exceptions))
92
+ return task
93
+
94
+
95
+ class SimpleTaskGroup(AbstractAsyncContextManager["SimpleTaskGroup"]):
96
+ """An async task group that can be configured to wait and/or cancel tasks on exit.
97
+
98
+ asyncio.TaskGroup and anyio.TaskGroup both expect enter and exit to be called
99
+ in the same asyncio task, which is not true for our use case, where exit is
100
+ shielded from cancellation."""
101
+
102
+ tasks: set[asyncio.Task]
103
+
104
+ def __init__(
105
+ self, *coros: Coroutine[Any, Any, T], cancel: bool = False, wait: bool = True
106
+ ) -> None:
107
+ self.tasks = set()
108
+ self.cancel = cancel
109
+ self.wait = wait
110
+ for coro in coros:
111
+ self.create_task(coro)
112
+
113
+ def _create_task_done_callback(
114
+ self, ignore_exceptions: tuple[Exception, ...], task: asyncio.Task
115
+ ) -> None:
116
+ try:
117
+ self.tasks.remove(task)
118
+ except AttributeError:
119
+ pass
120
+ try:
121
+ if exc := task.exception():
122
+ if not isinstance(exc, ignore_exceptions):
123
+ logger.exception("Background task failed", exc_info=exc)
124
+ except asyncio.CancelledError:
125
+ pass
126
+
127
+ def create_task(
128
+ self,
129
+ coro: Coroutine[Any, Any, T],
130
+ ignore_exceptions: tuple[Exception, ...] = (),
131
+ ) -> asyncio.Task[T]:
132
+ """Create a new task in the current task group and return it."""
133
+ task = asyncio.create_task(coro)
134
+ self.tasks.add(task)
135
+ task.add_done_callback(
136
+ partial(self._create_task_done_callback, ignore_exceptions)
137
+ )
138
+ return task
139
+
140
+ async def __aexit__(self, exc_type, exc_value, traceback) -> None:
141
+ tasks = self.tasks
142
+ # break reference cycles between tasks and task group
143
+ del self.tasks
144
+ # cancel all tasks
145
+ if self.cancel:
146
+ for task in tasks:
147
+ task.cancel()
148
+ # wait for all tasks
149
+ if self.wait:
150
+ await asyncio.gather(*tasks, return_exceptions=True)
151
+
152
+
153
+ def to_aiter(*args: T) -> AsyncIterator[T]:
154
+ async def agen():
155
+ for arg in args:
156
+ yield arg
157
+
158
+ return agen()
159
+
160
+
161
+ V = TypeVar("V")
162
+
163
+
164
+ class aclosing(Generic[V], AbstractAsyncContextManager):
165
+ """Async context manager for safely finalizing an asynchronously cleaned-up
166
+ resource such as an async generator, calling its ``aclose()`` method.
167
+
168
+ Code like this:
169
+
170
+ async with aclosing(<module>.fetch(<arguments>)) as agen:
171
+ <block>
172
+
173
+ is equivalent to this:
174
+
175
+ agen = <module>.fetch(<arguments>)
176
+ try:
177
+ <block>
178
+ finally:
179
+ await agen.aclose()
180
+
181
+ """
182
+
183
+ def __init__(self, thing: V):
184
+ self.thing = thing
185
+
186
+ async def __aenter__(self) -> V:
187
+ return self.thing
188
+
189
+ async def __aexit__(self, *exc_info):
190
+ await self.thing.aclose()
191
+
192
+
193
+ async def aclosing_aiter(aiter: AsyncIterator[T]) -> AsyncIterator[T]:
194
+ if hasattr(aiter, "__aenter__"):
195
+ async with aiter:
196
+ async for item in aiter:
197
+ yield item
198
+ else:
199
+ async with aclosing(aiter):
200
+ async for item in aiter:
201
+ yield item
File without changes
File without changes
@@ -0,0 +1,67 @@
1
+ from typing import NotRequired, TypedDict
2
+
3
+ from starlette.authentication import (
4
+ AuthCredentials,
5
+ AuthenticationBackend,
6
+ AuthenticationError,
7
+ BaseUser,
8
+ SimpleUser,
9
+ )
10
+ from starlette.requests import HTTPConnection
11
+
12
+ from langgraph_api.auth.langsmith.client import auth_client
13
+ from langgraph_api.config import (
14
+ LANGSMITH_AUTH_VERIFY_TENANT_ID,
15
+ LANGSMITH_TENANT_ID,
16
+ )
17
+
18
+
19
+ class AuthDict(TypedDict):
20
+ organization_id: str
21
+ tenant_id: str
22
+ user_id: NotRequired[str]
23
+ user_email: NotRequired[str]
24
+
25
+
26
+ class LangsmithAuthBackend(AuthenticationBackend):
27
+ async def authenticate(
28
+ self, conn: HTTPConnection
29
+ ) -> tuple[AuthCredentials, BaseUser] | None:
30
+ headers = [
31
+ ("Authorization", conn.headers.get("Authorization")),
32
+ ("X-Tenant-Id", conn.headers.get("x-tenant-id")),
33
+ ("X-Api-Key", conn.headers.get("x-api-key")),
34
+ ("X-Service-Key", conn.headers.get("x-service-key")),
35
+ ]
36
+ if not any(h[1] for h in headers):
37
+ raise AuthenticationError("Missing authentication headers")
38
+ async with auth_client() as auth:
39
+ if not LANGSMITH_AUTH_VERIFY_TENANT_ID and not conn.headers.get(
40
+ "x-api-key"
41
+ ):
42
+ # when LANGSMITH_AUTH_VERIFY_TENANT_ID is false, we allow
43
+ # any valid bearer token to pass through
44
+ # api key auth is always required to match the tenant id
45
+ res = await auth.get(
46
+ "/auth/verify", headers=[h for h in headers if h[1] is not None]
47
+ )
48
+ else:
49
+ res = await auth.get(
50
+ "/auth/public", headers=[h for h in headers if h[1] is not None]
51
+ )
52
+ if res.status_code == 401:
53
+ raise AuthenticationError("Invalid token")
54
+ elif res.status_code == 403:
55
+ raise AuthenticationError("Forbidden")
56
+ else:
57
+ res.raise_for_status()
58
+ auth_dict: AuthDict = res.json()
59
+
60
+ # If tenant id verification is disabled, the bearer token requests
61
+ # are not required to match the tenant id. Api key requests are
62
+ # always required to match the tenant id.
63
+ if LANGSMITH_AUTH_VERIFY_TENANT_ID or conn.headers.get("x-api-key"):
64
+ if auth_dict["tenant_id"] != LANGSMITH_TENANT_ID:
65
+ raise AuthenticationError("Invalid tenant ID")
66
+
67
+ return AuthCredentials(["authenticated"]), SimpleUser(auth_dict.get("user_id"))
@@ -0,0 +1,145 @@
1
+ from collections.abc import AsyncGenerator
2
+ from contextlib import asynccontextmanager
3
+ from typing import Any
4
+
5
+ import httpx
6
+ from httpx._types import HeaderTypes, QueryParamTypes, RequestData
7
+ from tenacity import retry
8
+ from tenacity.retry import retry_if_exception
9
+ from tenacity.stop import stop_after_attempt
10
+ from tenacity.wait import wait_exponential_jitter
11
+
12
+ from langgraph_api.config import LANGSMITH_AUTH_ENDPOINT
13
+
14
+ _client: "JsonHttpClient"
15
+
16
+
17
+ def is_retriable_error(exception: Exception) -> bool:
18
+ if isinstance(exception, httpx.TransportError):
19
+ return True
20
+ if isinstance(exception, httpx.HTTPStatusError):
21
+ if exception.response.status_code > 499:
22
+ return True
23
+
24
+ return False
25
+
26
+
27
+ retry_httpx = retry(
28
+ reraise=True,
29
+ retry=retry_if_exception(is_retriable_error),
30
+ wait=wait_exponential_jitter(),
31
+ stop=stop_after_attempt(3),
32
+ )
33
+
34
+
35
+ class JsonHttpClient:
36
+ """HTTPX client for JSON requests, with retries."""
37
+
38
+ def __init__(self, client: httpx.AsyncClient) -> None:
39
+ """Initialize the auth client."""
40
+ self.client = client
41
+
42
+ async def _get(
43
+ self,
44
+ path: str,
45
+ *,
46
+ params: QueryParamTypes | None = None,
47
+ headers: HeaderTypes | None = None,
48
+ ) -> httpx.Response:
49
+ return await self.client.get(path, params=params, headers=headers)
50
+
51
+ @retry_httpx
52
+ async def get(
53
+ self,
54
+ path: str,
55
+ *,
56
+ params: QueryParamTypes | None = None,
57
+ headers: HeaderTypes | None = None,
58
+ ) -> httpx.Response:
59
+ return await self.client.get(path, params=params, headers=headers)
60
+
61
+ async def _post(
62
+ self,
63
+ path: str,
64
+ *,
65
+ data: RequestData | None = None,
66
+ json: Any | None = None,
67
+ params: QueryParamTypes | None = None,
68
+ headers: HeaderTypes | None = None,
69
+ ) -> httpx.Response:
70
+ return await self.client.post(
71
+ path, data=data, json=json, params=params, headers=headers
72
+ )
73
+
74
+ @retry_httpx
75
+ async def post(
76
+ self,
77
+ path: str,
78
+ *,
79
+ data: RequestData | None = None,
80
+ json: Any | None = None,
81
+ params: QueryParamTypes | None = None,
82
+ headers: HeaderTypes | None = None,
83
+ ) -> httpx.Response:
84
+ return await self.client.post(
85
+ path, data=data, json=json, params=params, headers=headers
86
+ )
87
+
88
+
89
+ def create_client() -> JsonHttpClient:
90
+ """Create the auth http client."""
91
+ return JsonHttpClient(
92
+ httpx.AsyncClient(
93
+ transport=httpx.AsyncHTTPTransport(
94
+ retries=5, # this applies only to ConnectError, ConnectTimeout
95
+ limits=httpx.Limits(
96
+ max_keepalive_connections=40,
97
+ keepalive_expiry=240.0,
98
+ ),
99
+ ),
100
+ timeout=httpx.Timeout(2.0),
101
+ base_url=LANGSMITH_AUTH_ENDPOINT,
102
+ )
103
+ )
104
+
105
+
106
+ async def close_auth_client() -> None:
107
+ """Close the auth http client."""
108
+ global _client
109
+ try:
110
+ await _client.client.aclose()
111
+ except NameError:
112
+ pass
113
+
114
+
115
+ async def initialize_auth_client() -> None:
116
+ """Initialize the auth http client."""
117
+ await close_auth_client()
118
+ global _client
119
+ _client = create_client()
120
+
121
+
122
+ @asynccontextmanager
123
+ async def auth_client() -> AsyncGenerator[JsonHttpClient, None]:
124
+ """Get the auth http client."""
125
+ # pytest does something funny with event loops,
126
+ # so we can't use a global pool for tests
127
+ if LANGSMITH_AUTH_ENDPOINT.startswith("http://localhost"):
128
+ client = create_client()
129
+ try:
130
+ yield client
131
+ finally:
132
+ await client.client.aclose()
133
+ else:
134
+ try:
135
+ if not _client.client.is_closed:
136
+ found = True
137
+ else:
138
+ found = False
139
+ except NameError:
140
+ found = False
141
+ if found:
142
+ yield _client
143
+ else:
144
+ await initialize_auth_client()
145
+ yield _client