langgraph-api 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- LICENSE +93 -0
- langgraph_api/__init__.py +0 -0
- langgraph_api/api/__init__.py +63 -0
- langgraph_api/api/assistants.py +326 -0
- langgraph_api/api/meta.py +71 -0
- langgraph_api/api/openapi.py +32 -0
- langgraph_api/api/runs.py +463 -0
- langgraph_api/api/store.py +116 -0
- langgraph_api/api/threads.py +263 -0
- langgraph_api/asyncio.py +201 -0
- langgraph_api/auth/__init__.py +0 -0
- langgraph_api/auth/langsmith/__init__.py +0 -0
- langgraph_api/auth/langsmith/backend.py +67 -0
- langgraph_api/auth/langsmith/client.py +145 -0
- langgraph_api/auth/middleware.py +41 -0
- langgraph_api/auth/noop.py +14 -0
- langgraph_api/cli.py +209 -0
- langgraph_api/config.py +70 -0
- langgraph_api/cron_scheduler.py +60 -0
- langgraph_api/errors.py +52 -0
- langgraph_api/graph.py +314 -0
- langgraph_api/http.py +168 -0
- langgraph_api/http_logger.py +89 -0
- langgraph_api/js/.gitignore +2 -0
- langgraph_api/js/build.mts +49 -0
- langgraph_api/js/client.mts +849 -0
- langgraph_api/js/global.d.ts +6 -0
- langgraph_api/js/package.json +33 -0
- langgraph_api/js/remote.py +673 -0
- langgraph_api/js/server_sent_events.py +126 -0
- langgraph_api/js/src/graph.mts +88 -0
- langgraph_api/js/src/hooks.mjs +12 -0
- langgraph_api/js/src/parser/parser.mts +443 -0
- langgraph_api/js/src/parser/parser.worker.mjs +12 -0
- langgraph_api/js/src/schema/types.mts +2136 -0
- langgraph_api/js/src/schema/types.template.mts +74 -0
- langgraph_api/js/src/utils/importMap.mts +85 -0
- langgraph_api/js/src/utils/pythonSchemas.mts +28 -0
- langgraph_api/js/src/utils/serde.mts +21 -0
- langgraph_api/js/tests/api.test.mts +1566 -0
- langgraph_api/js/tests/compose-postgres.yml +56 -0
- langgraph_api/js/tests/graphs/.gitignore +1 -0
- langgraph_api/js/tests/graphs/agent.mts +127 -0
- langgraph_api/js/tests/graphs/error.mts +17 -0
- langgraph_api/js/tests/graphs/langgraph.json +8 -0
- langgraph_api/js/tests/graphs/nested.mts +44 -0
- langgraph_api/js/tests/graphs/package.json +7 -0
- langgraph_api/js/tests/graphs/weather.mts +57 -0
- langgraph_api/js/tests/graphs/yarn.lock +159 -0
- langgraph_api/js/tests/parser.test.mts +870 -0
- langgraph_api/js/tests/utils.mts +17 -0
- langgraph_api/js/yarn.lock +1340 -0
- langgraph_api/lifespan.py +41 -0
- langgraph_api/logging.py +121 -0
- langgraph_api/metadata.py +101 -0
- langgraph_api/models/__init__.py +0 -0
- langgraph_api/models/run.py +229 -0
- langgraph_api/patch.py +42 -0
- langgraph_api/queue.py +245 -0
- langgraph_api/route.py +118 -0
- langgraph_api/schema.py +190 -0
- langgraph_api/serde.py +124 -0
- langgraph_api/server.py +48 -0
- langgraph_api/sse.py +118 -0
- langgraph_api/state.py +67 -0
- langgraph_api/stream.py +289 -0
- langgraph_api/utils.py +60 -0
- langgraph_api/validation.py +141 -0
- langgraph_api-0.0.1.dist-info/LICENSE +93 -0
- langgraph_api-0.0.1.dist-info/METADATA +26 -0
- langgraph_api-0.0.1.dist-info/RECORD +86 -0
- langgraph_api-0.0.1.dist-info/WHEEL +4 -0
- langgraph_api-0.0.1.dist-info/entry_points.txt +3 -0
- langgraph_license/__init__.py +0 -0
- langgraph_license/middleware.py +21 -0
- langgraph_license/validation.py +11 -0
- langgraph_storage/__init__.py +0 -0
- langgraph_storage/checkpoint.py +94 -0
- langgraph_storage/database.py +190 -0
- langgraph_storage/ops.py +1523 -0
- langgraph_storage/queue.py +108 -0
- langgraph_storage/retry.py +27 -0
- langgraph_storage/store.py +28 -0
- langgraph_storage/ttl_dict.py +54 -0
- logging.json +22 -0
- openapi.json +4304 -0
langgraph_api/queue.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from contextlib import AsyncExitStack
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
from random import random
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
from langgraph.pregel.debug import CheckpointPayload
|
|
9
|
+
|
|
10
|
+
from langgraph_api.config import BG_JOB_NO_DELAY, STATS_INTERVAL_SECS
|
|
11
|
+
from langgraph_api.errors import (
|
|
12
|
+
UserInterrupt,
|
|
13
|
+
UserRollback,
|
|
14
|
+
)
|
|
15
|
+
from langgraph_api.http import get_http_client
|
|
16
|
+
from langgraph_api.js.remote import RemoteException
|
|
17
|
+
from langgraph_api.metadata import incr_runs
|
|
18
|
+
from langgraph_api.schema import Run
|
|
19
|
+
from langgraph_api.stream import (
|
|
20
|
+
astream_state,
|
|
21
|
+
consume,
|
|
22
|
+
)
|
|
23
|
+
from langgraph_api.utils import AsyncConnectionProto
|
|
24
|
+
from langgraph_storage.database import connect
|
|
25
|
+
from langgraph_storage.ops import Runs, Threads
|
|
26
|
+
from langgraph_storage.retry import RETRIABLE_EXCEPTIONS
|
|
27
|
+
|
|
28
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
WORKERS: set[asyncio.Task] = set()
|
|
31
|
+
MAX_RETRY_ATTEMPTS = 3
|
|
32
|
+
SHUTDOWN_GRACE_PERIOD_SECS = 5
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def ms(after: datetime, before: datetime) -> int:
|
|
36
|
+
return int((after - before).total_seconds() * 1000)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def queue(concurrency: int, timeout: float):
|
|
40
|
+
loop = asyncio.get_running_loop()
|
|
41
|
+
last_stats_secs: int | None = None
|
|
42
|
+
semaphore = asyncio.Semaphore(concurrency)
|
|
43
|
+
|
|
44
|
+
def cleanup(task: asyncio.Task):
|
|
45
|
+
WORKERS.remove(task)
|
|
46
|
+
semaphore.release()
|
|
47
|
+
try:
|
|
48
|
+
if exc := task.exception():
|
|
49
|
+
logger.exception("Background worker failed", exc_info=exc)
|
|
50
|
+
except asyncio.CancelledError:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
await logger.ainfo(f"Starting {concurrency} background workers")
|
|
54
|
+
try:
|
|
55
|
+
while True:
|
|
56
|
+
try:
|
|
57
|
+
if calc_stats := (
|
|
58
|
+
last_stats_secs is None
|
|
59
|
+
or loop.time() - last_stats_secs > STATS_INTERVAL_SECS
|
|
60
|
+
):
|
|
61
|
+
last_stats_secs = loop.time()
|
|
62
|
+
active = len(WORKERS)
|
|
63
|
+
await logger.ainfo(
|
|
64
|
+
"Worker stats",
|
|
65
|
+
max=concurrency,
|
|
66
|
+
available=concurrency - active,
|
|
67
|
+
active=active,
|
|
68
|
+
)
|
|
69
|
+
await semaphore.acquire()
|
|
70
|
+
exit = AsyncExitStack()
|
|
71
|
+
conn = await exit.enter_async_context(connect())
|
|
72
|
+
if calc_stats:
|
|
73
|
+
stats = await Runs.stats(conn)
|
|
74
|
+
await logger.ainfo("Queue stats", **stats)
|
|
75
|
+
if tup := await exit.enter_async_context(Runs.next(conn)):
|
|
76
|
+
task = asyncio.create_task(
|
|
77
|
+
worker(timeout, exit, conn, *tup),
|
|
78
|
+
name=f"run-{tup[0]['run_id']}-attempt-{tup[1]}",
|
|
79
|
+
)
|
|
80
|
+
task.add_done_callback(cleanup)
|
|
81
|
+
WORKERS.add(task)
|
|
82
|
+
else:
|
|
83
|
+
semaphore.release()
|
|
84
|
+
await exit.aclose()
|
|
85
|
+
await asyncio.sleep(0 if BG_JOB_NO_DELAY else random())
|
|
86
|
+
except Exception as exc:
|
|
87
|
+
# keep trying to run the scheduler indefinitely
|
|
88
|
+
logger.exception("Background worker scheduler failed", exc_info=exc)
|
|
89
|
+
semaphore.release()
|
|
90
|
+
await exit.aclose()
|
|
91
|
+
await asyncio.sleep(0 if BG_JOB_NO_DELAY else random())
|
|
92
|
+
finally:
|
|
93
|
+
logger.info("Shutting down background workers")
|
|
94
|
+
for task in WORKERS:
|
|
95
|
+
task.cancel()
|
|
96
|
+
await asyncio.wait_for(
|
|
97
|
+
asyncio.gather(*WORKERS, return_exceptions=True), SHUTDOWN_GRACE_PERIOD_SECS
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
async def worker(
|
|
102
|
+
timeout: float,
|
|
103
|
+
exit: AsyncExitStack,
|
|
104
|
+
conn: AsyncConnectionProto,
|
|
105
|
+
run: Run,
|
|
106
|
+
attempt: int,
|
|
107
|
+
):
|
|
108
|
+
run_id = run["run_id"]
|
|
109
|
+
if attempt == 1:
|
|
110
|
+
incr_runs()
|
|
111
|
+
async with Runs.enter(run_id) as done, exit:
|
|
112
|
+
temporary = run["kwargs"].get("temporary", False)
|
|
113
|
+
webhook = run["kwargs"].pop("webhook", None)
|
|
114
|
+
checkpoint: CheckpointPayload | None = None
|
|
115
|
+
exception: Exception | None = None
|
|
116
|
+
status: str | None = None
|
|
117
|
+
run_started_at = datetime.now(UTC)
|
|
118
|
+
run_created_at = run["created_at"].isoformat()
|
|
119
|
+
await logger.ainfo(
|
|
120
|
+
"Starting background run",
|
|
121
|
+
run_id=str(run_id),
|
|
122
|
+
run_attempt=attempt,
|
|
123
|
+
run_created_at=run_created_at,
|
|
124
|
+
run_started_at=run_started_at.isoformat(),
|
|
125
|
+
run_queue_ms=ms(run_started_at, run["created_at"]),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def on_checkpoint(checkpoint_arg: CheckpointPayload):
|
|
129
|
+
nonlocal checkpoint
|
|
130
|
+
checkpoint = checkpoint_arg
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
if attempt > MAX_RETRY_ATTEMPTS:
|
|
134
|
+
raise RuntimeError(f"Run {run['run_id']} exceeded max attempts")
|
|
135
|
+
if temporary:
|
|
136
|
+
stream = astream_state(
|
|
137
|
+
AsyncExitStack(), conn, cast(Run, run), attempt, done
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
stream = astream_state(
|
|
141
|
+
AsyncExitStack(),
|
|
142
|
+
conn,
|
|
143
|
+
cast(Run, run),
|
|
144
|
+
attempt,
|
|
145
|
+
done,
|
|
146
|
+
on_checkpoint=on_checkpoint,
|
|
147
|
+
)
|
|
148
|
+
await asyncio.wait_for(consume(stream, run_id), timeout)
|
|
149
|
+
await logger.ainfo(
|
|
150
|
+
"Background run succeeded",
|
|
151
|
+
run_id=str(run_id),
|
|
152
|
+
run_attempt=attempt,
|
|
153
|
+
run_created_at=run_created_at,
|
|
154
|
+
run_started_at=run_started_at.isoformat(),
|
|
155
|
+
run_ended_at=datetime.now().isoformat(),
|
|
156
|
+
run_exec_ms=ms(datetime.now(UTC), run_started_at),
|
|
157
|
+
)
|
|
158
|
+
status = "success"
|
|
159
|
+
await Runs.set_status(conn, run_id, "success")
|
|
160
|
+
except TimeoutError as e:
|
|
161
|
+
exception = e
|
|
162
|
+
status = "timeout"
|
|
163
|
+
await logger.awarning(
|
|
164
|
+
"Background run timed out",
|
|
165
|
+
run_id=str(run_id),
|
|
166
|
+
run_attempt=attempt,
|
|
167
|
+
run_created_at=run_created_at,
|
|
168
|
+
run_started_at=run_started_at.isoformat(),
|
|
169
|
+
run_ended_at=datetime.now().isoformat(),
|
|
170
|
+
run_exec_ms=ms(datetime.now(UTC), run_started_at),
|
|
171
|
+
)
|
|
172
|
+
await Runs.set_status(conn, run_id, "timeout")
|
|
173
|
+
except UserRollback as e:
|
|
174
|
+
exception = e
|
|
175
|
+
status = "rollback"
|
|
176
|
+
await logger.ainfo(
|
|
177
|
+
"Background run rolled back",
|
|
178
|
+
run_id=str(run_id),
|
|
179
|
+
run_attempt=attempt,
|
|
180
|
+
run_created_at=run_created_at,
|
|
181
|
+
run_started_at=run_started_at.isoformat(),
|
|
182
|
+
run_ended_at=datetime.now().isoformat(),
|
|
183
|
+
run_exec_ms=ms(datetime.now(UTC), run_started_at),
|
|
184
|
+
)
|
|
185
|
+
await Runs.delete(conn, run_id, thread_id=run["thread_id"])
|
|
186
|
+
except UserInterrupt as e:
|
|
187
|
+
exception = e
|
|
188
|
+
status = "interrupted"
|
|
189
|
+
await logger.ainfo(
|
|
190
|
+
"Background run interrupted",
|
|
191
|
+
run_id=str(run_id),
|
|
192
|
+
run_attempt=attempt,
|
|
193
|
+
run_created_at=run_created_at,
|
|
194
|
+
run_started_at=run_started_at.isoformat(),
|
|
195
|
+
run_ended_at=datetime.now().isoformat(),
|
|
196
|
+
run_exec_ms=ms(datetime.now(UTC), run_started_at),
|
|
197
|
+
)
|
|
198
|
+
await Runs.set_status(conn, run_id, "interrupted")
|
|
199
|
+
except RETRIABLE_EXCEPTIONS as e:
|
|
200
|
+
exception = e
|
|
201
|
+
status = "retry"
|
|
202
|
+
await logger.awarning(
|
|
203
|
+
"Background run failed, will retry",
|
|
204
|
+
exc_info=True,
|
|
205
|
+
run_id=str(run_id),
|
|
206
|
+
run_attempt=attempt,
|
|
207
|
+
run_created_at=run_created_at,
|
|
208
|
+
run_started_at=run_started_at.isoformat(),
|
|
209
|
+
run_ended_at=datetime.now().isoformat(),
|
|
210
|
+
run_exec_ms=ms(datetime.now(UTC), run_started_at),
|
|
211
|
+
)
|
|
212
|
+
raise
|
|
213
|
+
# Note we re-raise here, thus marking the run
|
|
214
|
+
# as available to be picked up by another worker
|
|
215
|
+
except Exception as exc:
|
|
216
|
+
exception = exc
|
|
217
|
+
status = "error"
|
|
218
|
+
await logger.aexception(
|
|
219
|
+
"Background run failed",
|
|
220
|
+
exc_info=not isinstance(exc, RemoteException),
|
|
221
|
+
run_id=str(run_id),
|
|
222
|
+
run_attempt=attempt,
|
|
223
|
+
run_created_at=run_created_at,
|
|
224
|
+
run_started_at=run_started_at.isoformat(),
|
|
225
|
+
run_ended_at=datetime.now().isoformat(),
|
|
226
|
+
run_exec_ms=ms(datetime.now(UTC), run_started_at),
|
|
227
|
+
)
|
|
228
|
+
await Runs.set_status(conn, run_id, "error")
|
|
229
|
+
# delete or set status of thread
|
|
230
|
+
if temporary:
|
|
231
|
+
await Threads.delete(conn, run["thread_id"])
|
|
232
|
+
else:
|
|
233
|
+
await Threads.set_status(conn, run["thread_id"], checkpoint, exception)
|
|
234
|
+
if webhook:
|
|
235
|
+
# TODO add error, values to webhook payload
|
|
236
|
+
# TODO add retries for webhook calls
|
|
237
|
+
try:
|
|
238
|
+
await get_http_client().post(
|
|
239
|
+
webhook, json={**run, "status": status}, total_timeout=5
|
|
240
|
+
)
|
|
241
|
+
except Exception as e:
|
|
242
|
+
logger.warning("Failed to send webhook", exc_info=e)
|
|
243
|
+
# Note we don't handle asyncio.CancelledError here, as we want to
|
|
244
|
+
# let it bubble up and rollback db transaction, thus marking the run
|
|
245
|
+
# as available to be picked up by another worker
|
langgraph_api/route.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
import jsonschema_rs
|
|
6
|
+
import orjson
|
|
7
|
+
from starlette._exception_handler import wrap_app_handling_exceptions
|
|
8
|
+
from starlette._utils import is_async_callable
|
|
9
|
+
from starlette.concurrency import run_in_threadpool
|
|
10
|
+
from starlette.middleware import Middleware
|
|
11
|
+
from starlette.requests import Request
|
|
12
|
+
from starlette.responses import JSONResponse, Response
|
|
13
|
+
from starlette.routing import Route, compile_path, get_name
|
|
14
|
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
15
|
+
|
|
16
|
+
from langgraph_api.serde import json_dumpb
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def api_request_response(
|
|
20
|
+
func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
|
|
21
|
+
) -> ASGIApp:
|
|
22
|
+
"""
|
|
23
|
+
Takes a function or coroutine `func(request) -> response`,
|
|
24
|
+
and returns an ASGI application.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
28
|
+
request = ApiRequest(scope, receive, send)
|
|
29
|
+
|
|
30
|
+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
31
|
+
if is_async_callable(func):
|
|
32
|
+
response = await func(request)
|
|
33
|
+
else:
|
|
34
|
+
response = await run_in_threadpool(func, request)
|
|
35
|
+
await response(scope, receive, send)
|
|
36
|
+
|
|
37
|
+
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
|
|
38
|
+
|
|
39
|
+
return app
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ApiResponse(JSONResponse):
|
|
43
|
+
def render(self, content: typing.Any) -> bytes:
|
|
44
|
+
return json_dumpb(content)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _json_loads(
|
|
48
|
+
content: bytearray, schema: jsonschema_rs.Draft4Validator | None
|
|
49
|
+
) -> typing.Any:
|
|
50
|
+
json = orjson.loads(content)
|
|
51
|
+
if schema is not None:
|
|
52
|
+
schema.validate(json)
|
|
53
|
+
return json
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ApiRequest(Request):
|
|
57
|
+
async def body(self) -> typing.Coroutine[typing.Any, typing.Any, bytearray]:
|
|
58
|
+
if not hasattr(self, "_body"):
|
|
59
|
+
chunks = bytearray()
|
|
60
|
+
async for chunk in self.stream():
|
|
61
|
+
chunks.extend(chunk)
|
|
62
|
+
self._body = chunks
|
|
63
|
+
return self._body
|
|
64
|
+
|
|
65
|
+
async def json(self, schema: jsonschema_rs.Draft4Validator | None) -> typing.Any:
|
|
66
|
+
if not hasattr(self, "_json"):
|
|
67
|
+
body = await self.body()
|
|
68
|
+
self._json = await run_in_threadpool(_json_loads, body, schema)
|
|
69
|
+
return self._json
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# ApiRoute uses our custom ApiRequest class to handle requests.
|
|
73
|
+
class ApiRoute(Route):
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
path: str,
|
|
77
|
+
endpoint: typing.Callable[..., typing.Any],
|
|
78
|
+
*,
|
|
79
|
+
methods: list[str] | None = None,
|
|
80
|
+
name: str | None = None,
|
|
81
|
+
include_in_schema: bool = True,
|
|
82
|
+
middleware: typing.Sequence[Middleware] | None = None,
|
|
83
|
+
) -> None:
|
|
84
|
+
assert path.startswith("/"), "Routed paths must start with '/'"
|
|
85
|
+
self.path = path
|
|
86
|
+
self.endpoint = endpoint
|
|
87
|
+
self.name = get_name(endpoint) if name is None else name
|
|
88
|
+
self.include_in_schema = include_in_schema
|
|
89
|
+
|
|
90
|
+
endpoint_handler = endpoint
|
|
91
|
+
while isinstance(endpoint_handler, functools.partial):
|
|
92
|
+
endpoint_handler = endpoint_handler.func
|
|
93
|
+
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
|
|
94
|
+
# Endpoint is function or method. Treat it as `func(request) -> response`.
|
|
95
|
+
self.app = api_request_response(endpoint)
|
|
96
|
+
if methods is None:
|
|
97
|
+
methods = ["GET"]
|
|
98
|
+
else:
|
|
99
|
+
# Endpoint is a class. Treat it as ASGI.
|
|
100
|
+
self.app = endpoint
|
|
101
|
+
|
|
102
|
+
if middleware is not None:
|
|
103
|
+
for cls, args, kwargs in reversed(middleware):
|
|
104
|
+
self.app = cls(app=self.app, *args, **kwargs) # noqa: B026
|
|
105
|
+
|
|
106
|
+
if methods is None:
|
|
107
|
+
self.methods = None
|
|
108
|
+
else:
|
|
109
|
+
self.methods = {method.upper() for method in methods}
|
|
110
|
+
if "GET" in self.methods:
|
|
111
|
+
self.methods.add("HEAD")
|
|
112
|
+
|
|
113
|
+
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
|
114
|
+
|
|
115
|
+
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
116
|
+
# https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
|
|
117
|
+
scope["route"] = self.path
|
|
118
|
+
return await super().handle(scope, receive, send)
|
langgraph_api/schema.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any, Literal, Optional, TypedDict
|
|
4
|
+
from uuid import UUID
|
|
5
|
+
|
|
6
|
+
from langgraph_api.serde import Fragment
|
|
7
|
+
|
|
8
|
+
MetadataInput = dict[str, Any] | None
|
|
9
|
+
MetadataValue = dict[str, Any]
|
|
10
|
+
|
|
11
|
+
RunStatus = Literal["pending", "error", "success", "timeout", "interrupted"]
|
|
12
|
+
|
|
13
|
+
ThreadStatus = Literal["idle", "busy", "interrupted", "error"]
|
|
14
|
+
|
|
15
|
+
StreamMode = Literal["values", "messages", "updates", "events", "debug", "custom"]
|
|
16
|
+
|
|
17
|
+
MultitaskStrategy = Literal["reject", "rollback", "interrupt", "enqueue"]
|
|
18
|
+
|
|
19
|
+
OnConflictBehavior = Literal["raise", "do_nothing"]
|
|
20
|
+
|
|
21
|
+
OnCompletion = Literal["delete", "keep"]
|
|
22
|
+
|
|
23
|
+
IfNotExists = Literal["create", "reject"]
|
|
24
|
+
|
|
25
|
+
All = Literal["*"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Config(TypedDict, total=False):
|
|
29
|
+
tags: list[str]
|
|
30
|
+
"""
|
|
31
|
+
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
|
32
|
+
You can use these to filter calls.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
recursion_limit: int
|
|
36
|
+
"""
|
|
37
|
+
Maximum number of times a call can recurse. If not provided, defaults to 25.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
configurable: dict[str, Any]
|
|
41
|
+
"""
|
|
42
|
+
Runtime values for attributes previously made configurable on this Runnable,
|
|
43
|
+
or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
|
|
44
|
+
Check .output_schema() for a description of the attributes that have been made
|
|
45
|
+
configurable.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Checkpoint(TypedDict):
|
|
50
|
+
thread_id: str
|
|
51
|
+
checkpoint_ns: str
|
|
52
|
+
checkpoint_id: str | None
|
|
53
|
+
checkpoint_map: dict[str, Any] | None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class GraphSchema(TypedDict):
|
|
57
|
+
"""Graph model."""
|
|
58
|
+
|
|
59
|
+
graph_id: str
|
|
60
|
+
"""The ID of the graph."""
|
|
61
|
+
state_schema: dict
|
|
62
|
+
"""The schema for the graph state."""
|
|
63
|
+
config_schema: dict
|
|
64
|
+
"""The schema for the graph config."""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Assistant(TypedDict):
|
|
68
|
+
"""Assistant model."""
|
|
69
|
+
|
|
70
|
+
assistant_id: UUID
|
|
71
|
+
"""The ID of the assistant."""
|
|
72
|
+
graph_id: str
|
|
73
|
+
"""The ID of the graph."""
|
|
74
|
+
config: Config
|
|
75
|
+
"""The assistant config."""
|
|
76
|
+
created_at: datetime
|
|
77
|
+
"""The time the assistant was created."""
|
|
78
|
+
updated_at: datetime
|
|
79
|
+
"""The last time the assistant was updated."""
|
|
80
|
+
metadata: Fragment
|
|
81
|
+
"""The assistant metadata."""
|
|
82
|
+
version: int
|
|
83
|
+
"""The assistant version."""
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class Thread(TypedDict):
|
|
87
|
+
thread_id: UUID
|
|
88
|
+
"""The ID of the thread."""
|
|
89
|
+
created_at: datetime
|
|
90
|
+
"""The time the thread was created."""
|
|
91
|
+
updated_at: datetime
|
|
92
|
+
"""The last time the thread was updated."""
|
|
93
|
+
metadata: Fragment
|
|
94
|
+
"""The thread metadata."""
|
|
95
|
+
config: Fragment
|
|
96
|
+
"""The thread config."""
|
|
97
|
+
status: ThreadStatus
|
|
98
|
+
"""The status of the thread. One of 'idle', 'busy', 'interrupted', "error"."""
|
|
99
|
+
values: Fragment
|
|
100
|
+
"""The current state of the thread."""
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ThreadTask(TypedDict):
|
|
104
|
+
id: str
|
|
105
|
+
name: str
|
|
106
|
+
error: str | None
|
|
107
|
+
interrupts: list[dict]
|
|
108
|
+
checkpoint: Checkpoint | None
|
|
109
|
+
state: Optional["ThreadState"]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class ThreadState(TypedDict):
|
|
113
|
+
values: dict[str, Any]
|
|
114
|
+
"""The state values."""
|
|
115
|
+
next: Sequence[str]
|
|
116
|
+
"""The name of the node to execute in each task for this step."""
|
|
117
|
+
checkpoint: Checkpoint
|
|
118
|
+
"""The checkpoint keys. This object can be passed to the /threads and /runs
|
|
119
|
+
endpoints to resume execution or update state."""
|
|
120
|
+
metadata: Fragment
|
|
121
|
+
"""Metadata for this state"""
|
|
122
|
+
created_at: str | None
|
|
123
|
+
"""Timestamp of state creation"""
|
|
124
|
+
parent_checkpoint: Checkpoint | None
|
|
125
|
+
"""The parent checkpoint. If missing, this is the root checkpoint."""
|
|
126
|
+
tasks: Sequence[ThreadTask]
|
|
127
|
+
"""Tasks to execute in this step. If already attempted, may contain an error."""
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class Run(TypedDict):
|
|
131
|
+
run_id: UUID
|
|
132
|
+
"""The ID of the run."""
|
|
133
|
+
thread_id: UUID
|
|
134
|
+
"""The ID of the thread."""
|
|
135
|
+
assistant_id: UUID
|
|
136
|
+
"""The assistant that was used for this run."""
|
|
137
|
+
created_at: datetime
|
|
138
|
+
"""The time the run was created."""
|
|
139
|
+
updated_at: datetime
|
|
140
|
+
"""The last time the run was updated."""
|
|
141
|
+
status: RunStatus
|
|
142
|
+
"""The status of the run. One of 'pending', 'error', 'success'."""
|
|
143
|
+
metadata: Fragment
|
|
144
|
+
"""The run metadata."""
|
|
145
|
+
kwargs: Fragment
|
|
146
|
+
"""The run kwargs."""
|
|
147
|
+
multitask_strategy: MultitaskStrategy
|
|
148
|
+
"""Strategy to handle concurrent runs on the same thread."""
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class RunSend(TypedDict):
|
|
152
|
+
node: str
|
|
153
|
+
input: dict[str, Any] | None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class RunCommand(TypedDict):
|
|
157
|
+
send: RunSend | Sequence[RunSend] | None
|
|
158
|
+
update: dict[str, Any] | None
|
|
159
|
+
resume: Any | None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class Cron(TypedDict):
|
|
163
|
+
"""Cron model."""
|
|
164
|
+
|
|
165
|
+
cron_id: UUID
|
|
166
|
+
"""The ID of the cron."""
|
|
167
|
+
thread_id: UUID | None
|
|
168
|
+
"""The ID of the thread."""
|
|
169
|
+
end_time: datetime | None
|
|
170
|
+
"""The end date to stop running the cron."""
|
|
171
|
+
schedule: str
|
|
172
|
+
"""The schedule to run, cron format."""
|
|
173
|
+
created_at: datetime
|
|
174
|
+
"""The time the cron was created."""
|
|
175
|
+
updated_at: datetime
|
|
176
|
+
"""The last time the cron was updated."""
|
|
177
|
+
payload: Fragment
|
|
178
|
+
"""The run payload to use for creating new run."""
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class ThreadUpdateResponse(TypedDict):
|
|
182
|
+
"""Response for updating a thread."""
|
|
183
|
+
|
|
184
|
+
checkpoint: Checkpoint
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class QueueStats(TypedDict):
|
|
188
|
+
n_pending: int
|
|
189
|
+
max_age_secs: datetime | None
|
|
190
|
+
med_age_secs: datetime | None
|
langgraph_api/serde.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import pickle
|
|
3
|
+
import uuid
|
|
4
|
+
from base64 import b64encode
|
|
5
|
+
from collections import deque
|
|
6
|
+
from datetime import timedelta, timezone
|
|
7
|
+
from decimal import Decimal
|
|
8
|
+
from ipaddress import (
|
|
9
|
+
IPv4Address,
|
|
10
|
+
IPv4Interface,
|
|
11
|
+
IPv4Network,
|
|
12
|
+
IPv6Address,
|
|
13
|
+
IPv6Interface,
|
|
14
|
+
IPv6Network,
|
|
15
|
+
)
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from re import Pattern
|
|
18
|
+
from typing import Any, NamedTuple
|
|
19
|
+
from zoneinfo import ZoneInfo
|
|
20
|
+
|
|
21
|
+
import orjson
|
|
22
|
+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Fragment(NamedTuple):
|
|
26
|
+
buf: bytes
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def decimal_encoder(dec_value: Decimal) -> int | float:
|
|
30
|
+
"""
|
|
31
|
+
Encodes a Decimal as int of there's no exponent, otherwise float
|
|
32
|
+
|
|
33
|
+
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
|
34
|
+
where a integer (but not int typed) is used. Encoding this as a float
|
|
35
|
+
results in failed round-tripping between encode and parse.
|
|
36
|
+
Our Id type is a prime example of this.
|
|
37
|
+
|
|
38
|
+
>>> decimal_encoder(Decimal("1.0"))
|
|
39
|
+
1.0
|
|
40
|
+
|
|
41
|
+
>>> decimal_encoder(Decimal("1"))
|
|
42
|
+
1
|
|
43
|
+
"""
|
|
44
|
+
if dec_value.as_tuple().exponent >= 0:
|
|
45
|
+
return int(dec_value)
|
|
46
|
+
else:
|
|
47
|
+
return float(dec_value)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def default(obj):
|
|
51
|
+
# Only need to handle types that orjson doesn't serialize by default
|
|
52
|
+
# https://github.com/ijl/orjson#serialize
|
|
53
|
+
if isinstance(obj, Fragment):
|
|
54
|
+
return orjson.Fragment(obj.buf)
|
|
55
|
+
if hasattr(obj, "model_dump") and callable(obj.model_dump):
|
|
56
|
+
return obj.model_dump()
|
|
57
|
+
elif hasattr(obj, "dict") and callable(obj.dict):
|
|
58
|
+
return obj.dict()
|
|
59
|
+
elif hasattr(obj, "_asdict") and callable(obj._asdict):
|
|
60
|
+
return obj._asdict()
|
|
61
|
+
elif isinstance(obj, BaseException):
|
|
62
|
+
return {"error": type(obj).__name__, "message": str(obj)}
|
|
63
|
+
elif isinstance(obj, (set, frozenset, deque)): # noqa: UP038
|
|
64
|
+
return list(obj)
|
|
65
|
+
elif isinstance(obj, (timezone, ZoneInfo)): # noqa: UP038
|
|
66
|
+
return obj.tzname(None)
|
|
67
|
+
elif isinstance(obj, timedelta):
|
|
68
|
+
return obj.total_seconds()
|
|
69
|
+
elif isinstance(obj, Decimal):
|
|
70
|
+
return decimal_encoder(obj)
|
|
71
|
+
elif isinstance(obj, uuid.UUID):
|
|
72
|
+
return str(obj)
|
|
73
|
+
elif isinstance( # noqa: UP038
|
|
74
|
+
obj,
|
|
75
|
+
(
|
|
76
|
+
IPv4Address,
|
|
77
|
+
IPv4Interface,
|
|
78
|
+
IPv4Network,
|
|
79
|
+
IPv6Address,
|
|
80
|
+
IPv6Interface,
|
|
81
|
+
IPv6Network,
|
|
82
|
+
Path,
|
|
83
|
+
),
|
|
84
|
+
):
|
|
85
|
+
return str(obj)
|
|
86
|
+
elif isinstance(obj, Pattern):
|
|
87
|
+
return obj.pattern
|
|
88
|
+
elif isinstance(obj, bytes | bytearray):
|
|
89
|
+
return b64encode(obj).decode()
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
_option = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def json_dumpb(obj) -> bytes:
|
|
97
|
+
return orjson.dumps(obj, default=default, option=_option).replace(
|
|
98
|
+
b"\u0000", b""
|
|
99
|
+
) # null unicode char not allowed in json
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def json_loads(content: bytes | Fragment | dict) -> Any:
|
|
103
|
+
if isinstance(content, Fragment):
|
|
104
|
+
content = content.buf
|
|
105
|
+
if isinstance(content, dict):
|
|
106
|
+
return content
|
|
107
|
+
return orjson.loads(content)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
async def ajson_loads(content: bytes | Fragment) -> Any:
|
|
111
|
+
return await asyncio.to_thread(json_loads, content)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Serializer(JsonPlusSerializer):
|
|
115
|
+
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
|
116
|
+
try:
|
|
117
|
+
return super().dumps_typed(obj)
|
|
118
|
+
except TypeError:
|
|
119
|
+
return "pickle", pickle.dumps(obj)
|
|
120
|
+
|
|
121
|
+
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
|
122
|
+
if data[0] == "pickle":
|
|
123
|
+
return pickle.loads(data[1])
|
|
124
|
+
return super().loads_typed(data)
|