langgraph-api 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

Files changed (86) hide show
  1. LICENSE +93 -0
  2. langgraph_api/__init__.py +0 -0
  3. langgraph_api/api/__init__.py +63 -0
  4. langgraph_api/api/assistants.py +326 -0
  5. langgraph_api/api/meta.py +71 -0
  6. langgraph_api/api/openapi.py +32 -0
  7. langgraph_api/api/runs.py +463 -0
  8. langgraph_api/api/store.py +116 -0
  9. langgraph_api/api/threads.py +263 -0
  10. langgraph_api/asyncio.py +201 -0
  11. langgraph_api/auth/__init__.py +0 -0
  12. langgraph_api/auth/langsmith/__init__.py +0 -0
  13. langgraph_api/auth/langsmith/backend.py +67 -0
  14. langgraph_api/auth/langsmith/client.py +145 -0
  15. langgraph_api/auth/middleware.py +41 -0
  16. langgraph_api/auth/noop.py +14 -0
  17. langgraph_api/cli.py +209 -0
  18. langgraph_api/config.py +70 -0
  19. langgraph_api/cron_scheduler.py +60 -0
  20. langgraph_api/errors.py +52 -0
  21. langgraph_api/graph.py +314 -0
  22. langgraph_api/http.py +168 -0
  23. langgraph_api/http_logger.py +89 -0
  24. langgraph_api/js/.gitignore +2 -0
  25. langgraph_api/js/build.mts +49 -0
  26. langgraph_api/js/client.mts +849 -0
  27. langgraph_api/js/global.d.ts +6 -0
  28. langgraph_api/js/package.json +33 -0
  29. langgraph_api/js/remote.py +673 -0
  30. langgraph_api/js/server_sent_events.py +126 -0
  31. langgraph_api/js/src/graph.mts +88 -0
  32. langgraph_api/js/src/hooks.mjs +12 -0
  33. langgraph_api/js/src/parser/parser.mts +443 -0
  34. langgraph_api/js/src/parser/parser.worker.mjs +12 -0
  35. langgraph_api/js/src/schema/types.mts +2136 -0
  36. langgraph_api/js/src/schema/types.template.mts +74 -0
  37. langgraph_api/js/src/utils/importMap.mts +85 -0
  38. langgraph_api/js/src/utils/pythonSchemas.mts +28 -0
  39. langgraph_api/js/src/utils/serde.mts +21 -0
  40. langgraph_api/js/tests/api.test.mts +1566 -0
  41. langgraph_api/js/tests/compose-postgres.yml +56 -0
  42. langgraph_api/js/tests/graphs/.gitignore +1 -0
  43. langgraph_api/js/tests/graphs/agent.mts +127 -0
  44. langgraph_api/js/tests/graphs/error.mts +17 -0
  45. langgraph_api/js/tests/graphs/langgraph.json +8 -0
  46. langgraph_api/js/tests/graphs/nested.mts +44 -0
  47. langgraph_api/js/tests/graphs/package.json +7 -0
  48. langgraph_api/js/tests/graphs/weather.mts +57 -0
  49. langgraph_api/js/tests/graphs/yarn.lock +159 -0
  50. langgraph_api/js/tests/parser.test.mts +870 -0
  51. langgraph_api/js/tests/utils.mts +17 -0
  52. langgraph_api/js/yarn.lock +1340 -0
  53. langgraph_api/lifespan.py +41 -0
  54. langgraph_api/logging.py +121 -0
  55. langgraph_api/metadata.py +101 -0
  56. langgraph_api/models/__init__.py +0 -0
  57. langgraph_api/models/run.py +229 -0
  58. langgraph_api/patch.py +42 -0
  59. langgraph_api/queue.py +245 -0
  60. langgraph_api/route.py +118 -0
  61. langgraph_api/schema.py +190 -0
  62. langgraph_api/serde.py +124 -0
  63. langgraph_api/server.py +48 -0
  64. langgraph_api/sse.py +118 -0
  65. langgraph_api/state.py +67 -0
  66. langgraph_api/stream.py +289 -0
  67. langgraph_api/utils.py +60 -0
  68. langgraph_api/validation.py +141 -0
  69. langgraph_api-0.0.1.dist-info/LICENSE +93 -0
  70. langgraph_api-0.0.1.dist-info/METADATA +26 -0
  71. langgraph_api-0.0.1.dist-info/RECORD +86 -0
  72. langgraph_api-0.0.1.dist-info/WHEEL +4 -0
  73. langgraph_api-0.0.1.dist-info/entry_points.txt +3 -0
  74. langgraph_license/__init__.py +0 -0
  75. langgraph_license/middleware.py +21 -0
  76. langgraph_license/validation.py +11 -0
  77. langgraph_storage/__init__.py +0 -0
  78. langgraph_storage/checkpoint.py +94 -0
  79. langgraph_storage/database.py +190 -0
  80. langgraph_storage/ops.py +1523 -0
  81. langgraph_storage/queue.py +108 -0
  82. langgraph_storage/retry.py +27 -0
  83. langgraph_storage/store.py +28 -0
  84. langgraph_storage/ttl_dict.py +54 -0
  85. logging.json +22 -0
  86. openapi.json +4304 -0
langgraph_api/queue.py ADDED
@@ -0,0 +1,245 @@
1
+ import asyncio
2
+ from contextlib import AsyncExitStack
3
+ from datetime import UTC, datetime
4
+ from random import random
5
+ from typing import cast
6
+
7
+ import structlog
8
+ from langgraph.pregel.debug import CheckpointPayload
9
+
10
+ from langgraph_api.config import BG_JOB_NO_DELAY, STATS_INTERVAL_SECS
11
+ from langgraph_api.errors import (
12
+ UserInterrupt,
13
+ UserRollback,
14
+ )
15
+ from langgraph_api.http import get_http_client
16
+ from langgraph_api.js.remote import RemoteException
17
+ from langgraph_api.metadata import incr_runs
18
+ from langgraph_api.schema import Run
19
+ from langgraph_api.stream import (
20
+ astream_state,
21
+ consume,
22
+ )
23
+ from langgraph_api.utils import AsyncConnectionProto
24
+ from langgraph_storage.database import connect
25
+ from langgraph_storage.ops import Runs, Threads
26
+ from langgraph_storage.retry import RETRIABLE_EXCEPTIONS
27
+
28
+ logger = structlog.stdlib.get_logger(__name__)
29
+
30
+ WORKERS: set[asyncio.Task] = set()
31
+ MAX_RETRY_ATTEMPTS = 3
32
+ SHUTDOWN_GRACE_PERIOD_SECS = 5
33
+
34
+
35
+ def ms(after: datetime, before: datetime) -> int:
36
+ return int((after - before).total_seconds() * 1000)
37
+
38
+
39
+ async def queue(concurrency: int, timeout: float):
40
+ loop = asyncio.get_running_loop()
41
+ last_stats_secs: int | None = None
42
+ semaphore = asyncio.Semaphore(concurrency)
43
+
44
+ def cleanup(task: asyncio.Task):
45
+ WORKERS.remove(task)
46
+ semaphore.release()
47
+ try:
48
+ if exc := task.exception():
49
+ logger.exception("Background worker failed", exc_info=exc)
50
+ except asyncio.CancelledError:
51
+ pass
52
+
53
+ await logger.ainfo(f"Starting {concurrency} background workers")
54
+ try:
55
+ while True:
56
+ try:
57
+ if calc_stats := (
58
+ last_stats_secs is None
59
+ or loop.time() - last_stats_secs > STATS_INTERVAL_SECS
60
+ ):
61
+ last_stats_secs = loop.time()
62
+ active = len(WORKERS)
63
+ await logger.ainfo(
64
+ "Worker stats",
65
+ max=concurrency,
66
+ available=concurrency - active,
67
+ active=active,
68
+ )
69
+ await semaphore.acquire()
70
+ exit = AsyncExitStack()
71
+ conn = await exit.enter_async_context(connect())
72
+ if calc_stats:
73
+ stats = await Runs.stats(conn)
74
+ await logger.ainfo("Queue stats", **stats)
75
+ if tup := await exit.enter_async_context(Runs.next(conn)):
76
+ task = asyncio.create_task(
77
+ worker(timeout, exit, conn, *tup),
78
+ name=f"run-{tup[0]['run_id']}-attempt-{tup[1]}",
79
+ )
80
+ task.add_done_callback(cleanup)
81
+ WORKERS.add(task)
82
+ else:
83
+ semaphore.release()
84
+ await exit.aclose()
85
+ await asyncio.sleep(0 if BG_JOB_NO_DELAY else random())
86
+ except Exception as exc:
87
+ # keep trying to run the scheduler indefinitely
88
+ logger.exception("Background worker scheduler failed", exc_info=exc)
89
+ semaphore.release()
90
+ await exit.aclose()
91
+ await asyncio.sleep(0 if BG_JOB_NO_DELAY else random())
92
+ finally:
93
+ logger.info("Shutting down background workers")
94
+ for task in WORKERS:
95
+ task.cancel()
96
+ await asyncio.wait_for(
97
+ asyncio.gather(*WORKERS, return_exceptions=True), SHUTDOWN_GRACE_PERIOD_SECS
98
+ )
99
+
100
+
101
+ async def worker(
102
+ timeout: float,
103
+ exit: AsyncExitStack,
104
+ conn: AsyncConnectionProto,
105
+ run: Run,
106
+ attempt: int,
107
+ ):
108
+ run_id = run["run_id"]
109
+ if attempt == 1:
110
+ incr_runs()
111
+ async with Runs.enter(run_id) as done, exit:
112
+ temporary = run["kwargs"].get("temporary", False)
113
+ webhook = run["kwargs"].pop("webhook", None)
114
+ checkpoint: CheckpointPayload | None = None
115
+ exception: Exception | None = None
116
+ status: str | None = None
117
+ run_started_at = datetime.now(UTC)
118
+ run_created_at = run["created_at"].isoformat()
119
+ await logger.ainfo(
120
+ "Starting background run",
121
+ run_id=str(run_id),
122
+ run_attempt=attempt,
123
+ run_created_at=run_created_at,
124
+ run_started_at=run_started_at.isoformat(),
125
+ run_queue_ms=ms(run_started_at, run["created_at"]),
126
+ )
127
+
128
+ def on_checkpoint(checkpoint_arg: CheckpointPayload):
129
+ nonlocal checkpoint
130
+ checkpoint = checkpoint_arg
131
+
132
+ try:
133
+ if attempt > MAX_RETRY_ATTEMPTS:
134
+ raise RuntimeError(f"Run {run['run_id']} exceeded max attempts")
135
+ if temporary:
136
+ stream = astream_state(
137
+ AsyncExitStack(), conn, cast(Run, run), attempt, done
138
+ )
139
+ else:
140
+ stream = astream_state(
141
+ AsyncExitStack(),
142
+ conn,
143
+ cast(Run, run),
144
+ attempt,
145
+ done,
146
+ on_checkpoint=on_checkpoint,
147
+ )
148
+ await asyncio.wait_for(consume(stream, run_id), timeout)
149
+ await logger.ainfo(
150
+ "Background run succeeded",
151
+ run_id=str(run_id),
152
+ run_attempt=attempt,
153
+ run_created_at=run_created_at,
154
+ run_started_at=run_started_at.isoformat(),
155
+ run_ended_at=datetime.now().isoformat(),
156
+ run_exec_ms=ms(datetime.now(UTC), run_started_at),
157
+ )
158
+ status = "success"
159
+ await Runs.set_status(conn, run_id, "success")
160
+ except TimeoutError as e:
161
+ exception = e
162
+ status = "timeout"
163
+ await logger.awarning(
164
+ "Background run timed out",
165
+ run_id=str(run_id),
166
+ run_attempt=attempt,
167
+ run_created_at=run_created_at,
168
+ run_started_at=run_started_at.isoformat(),
169
+ run_ended_at=datetime.now().isoformat(),
170
+ run_exec_ms=ms(datetime.now(UTC), run_started_at),
171
+ )
172
+ await Runs.set_status(conn, run_id, "timeout")
173
+ except UserRollback as e:
174
+ exception = e
175
+ status = "rollback"
176
+ await logger.ainfo(
177
+ "Background run rolled back",
178
+ run_id=str(run_id),
179
+ run_attempt=attempt,
180
+ run_created_at=run_created_at,
181
+ run_started_at=run_started_at.isoformat(),
182
+ run_ended_at=datetime.now().isoformat(),
183
+ run_exec_ms=ms(datetime.now(UTC), run_started_at),
184
+ )
185
+ await Runs.delete(conn, run_id, thread_id=run["thread_id"])
186
+ except UserInterrupt as e:
187
+ exception = e
188
+ status = "interrupted"
189
+ await logger.ainfo(
190
+ "Background run interrupted",
191
+ run_id=str(run_id),
192
+ run_attempt=attempt,
193
+ run_created_at=run_created_at,
194
+ run_started_at=run_started_at.isoformat(),
195
+ run_ended_at=datetime.now().isoformat(),
196
+ run_exec_ms=ms(datetime.now(UTC), run_started_at),
197
+ )
198
+ await Runs.set_status(conn, run_id, "interrupted")
199
+ except RETRIABLE_EXCEPTIONS as e:
200
+ exception = e
201
+ status = "retry"
202
+ await logger.awarning(
203
+ "Background run failed, will retry",
204
+ exc_info=True,
205
+ run_id=str(run_id),
206
+ run_attempt=attempt,
207
+ run_created_at=run_created_at,
208
+ run_started_at=run_started_at.isoformat(),
209
+ run_ended_at=datetime.now().isoformat(),
210
+ run_exec_ms=ms(datetime.now(UTC), run_started_at),
211
+ )
212
+ raise
213
+ # Note we re-raise here, thus marking the run
214
+ # as available to be picked up by another worker
215
+ except Exception as exc:
216
+ exception = exc
217
+ status = "error"
218
+ await logger.aexception(
219
+ "Background run failed",
220
+ exc_info=not isinstance(exc, RemoteException),
221
+ run_id=str(run_id),
222
+ run_attempt=attempt,
223
+ run_created_at=run_created_at,
224
+ run_started_at=run_started_at.isoformat(),
225
+ run_ended_at=datetime.now().isoformat(),
226
+ run_exec_ms=ms(datetime.now(UTC), run_started_at),
227
+ )
228
+ await Runs.set_status(conn, run_id, "error")
229
+ # delete or set status of thread
230
+ if temporary:
231
+ await Threads.delete(conn, run["thread_id"])
232
+ else:
233
+ await Threads.set_status(conn, run["thread_id"], checkpoint, exception)
234
+ if webhook:
235
+ # TODO add error, values to webhook payload
236
+ # TODO add retries for webhook calls
237
+ try:
238
+ await get_http_client().post(
239
+ webhook, json={**run, "status": status}, total_timeout=5
240
+ )
241
+ except Exception as e:
242
+ logger.warning("Failed to send webhook", exc_info=e)
243
+ # Note we don't handle asyncio.CancelledError here, as we want to
244
+ # let it bubble up and rollback db transaction, thus marking the run
245
+ # as available to be picked up by another worker
langgraph_api/route.py ADDED
@@ -0,0 +1,118 @@
1
+ import functools
2
+ import inspect
3
+ import typing
4
+
5
+ import jsonschema_rs
6
+ import orjson
7
+ from starlette._exception_handler import wrap_app_handling_exceptions
8
+ from starlette._utils import is_async_callable
9
+ from starlette.concurrency import run_in_threadpool
10
+ from starlette.middleware import Middleware
11
+ from starlette.requests import Request
12
+ from starlette.responses import JSONResponse, Response
13
+ from starlette.routing import Route, compile_path, get_name
14
+ from starlette.types import ASGIApp, Receive, Scope, Send
15
+
16
+ from langgraph_api.serde import json_dumpb
17
+
18
+
19
+ def api_request_response(
20
+ func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
21
+ ) -> ASGIApp:
22
+ """
23
+ Takes a function or coroutine `func(request) -> response`,
24
+ and returns an ASGI application.
25
+ """
26
+
27
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
28
+ request = ApiRequest(scope, receive, send)
29
+
30
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
31
+ if is_async_callable(func):
32
+ response = await func(request)
33
+ else:
34
+ response = await run_in_threadpool(func, request)
35
+ await response(scope, receive, send)
36
+
37
+ await wrap_app_handling_exceptions(app, request)(scope, receive, send)
38
+
39
+ return app
40
+
41
+
42
+ class ApiResponse(JSONResponse):
43
+ def render(self, content: typing.Any) -> bytes:
44
+ return json_dumpb(content)
45
+
46
+
47
+ def _json_loads(
48
+ content: bytearray, schema: jsonschema_rs.Draft4Validator | None
49
+ ) -> typing.Any:
50
+ json = orjson.loads(content)
51
+ if schema is not None:
52
+ schema.validate(json)
53
+ return json
54
+
55
+
56
+ class ApiRequest(Request):
57
+ async def body(self) -> typing.Coroutine[typing.Any, typing.Any, bytearray]:
58
+ if not hasattr(self, "_body"):
59
+ chunks = bytearray()
60
+ async for chunk in self.stream():
61
+ chunks.extend(chunk)
62
+ self._body = chunks
63
+ return self._body
64
+
65
+ async def json(self, schema: jsonschema_rs.Draft4Validator | None) -> typing.Any:
66
+ if not hasattr(self, "_json"):
67
+ body = await self.body()
68
+ self._json = await run_in_threadpool(_json_loads, body, schema)
69
+ return self._json
70
+
71
+
72
+ # ApiRoute uses our custom ApiRequest class to handle requests.
73
+ class ApiRoute(Route):
74
+ def __init__(
75
+ self,
76
+ path: str,
77
+ endpoint: typing.Callable[..., typing.Any],
78
+ *,
79
+ methods: list[str] | None = None,
80
+ name: str | None = None,
81
+ include_in_schema: bool = True,
82
+ middleware: typing.Sequence[Middleware] | None = None,
83
+ ) -> None:
84
+ assert path.startswith("/"), "Routed paths must start with '/'"
85
+ self.path = path
86
+ self.endpoint = endpoint
87
+ self.name = get_name(endpoint) if name is None else name
88
+ self.include_in_schema = include_in_schema
89
+
90
+ endpoint_handler = endpoint
91
+ while isinstance(endpoint_handler, functools.partial):
92
+ endpoint_handler = endpoint_handler.func
93
+ if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
94
+ # Endpoint is function or method. Treat it as `func(request) -> response`.
95
+ self.app = api_request_response(endpoint)
96
+ if methods is None:
97
+ methods = ["GET"]
98
+ else:
99
+ # Endpoint is a class. Treat it as ASGI.
100
+ self.app = endpoint
101
+
102
+ if middleware is not None:
103
+ for cls, args, kwargs in reversed(middleware):
104
+ self.app = cls(app=self.app, *args, **kwargs) # noqa: B026
105
+
106
+ if methods is None:
107
+ self.methods = None
108
+ else:
109
+ self.methods = {method.upper() for method in methods}
110
+ if "GET" in self.methods:
111
+ self.methods.add("HEAD")
112
+
113
+ self.path_regex, self.path_format, self.param_convertors = compile_path(path)
114
+
115
+ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
116
+ # https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
117
+ scope["route"] = self.path
118
+ return await super().handle(scope, receive, send)
@@ -0,0 +1,190 @@
1
+ from collections.abc import Sequence
2
+ from datetime import datetime
3
+ from typing import Any, Literal, Optional, TypedDict
4
+ from uuid import UUID
5
+
6
+ from langgraph_api.serde import Fragment
7
+
8
+ MetadataInput = dict[str, Any] | None
9
+ MetadataValue = dict[str, Any]
10
+
11
+ RunStatus = Literal["pending", "error", "success", "timeout", "interrupted"]
12
+
13
+ ThreadStatus = Literal["idle", "busy", "interrupted", "error"]
14
+
15
+ StreamMode = Literal["values", "messages", "updates", "events", "debug", "custom"]
16
+
17
+ MultitaskStrategy = Literal["reject", "rollback", "interrupt", "enqueue"]
18
+
19
+ OnConflictBehavior = Literal["raise", "do_nothing"]
20
+
21
+ OnCompletion = Literal["delete", "keep"]
22
+
23
+ IfNotExists = Literal["create", "reject"]
24
+
25
+ All = Literal["*"]
26
+
27
+
28
+ class Config(TypedDict, total=False):
29
+ tags: list[str]
30
+ """
31
+ Tags for this call and any sub-calls (eg. a Chain calling an LLM).
32
+ You can use these to filter calls.
33
+ """
34
+
35
+ recursion_limit: int
36
+ """
37
+ Maximum number of times a call can recurse. If not provided, defaults to 25.
38
+ """
39
+
40
+ configurable: dict[str, Any]
41
+ """
42
+ Runtime values for attributes previously made configurable on this Runnable,
43
+ or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
44
+ Check .output_schema() for a description of the attributes that have been made
45
+ configurable.
46
+ """
47
+
48
+
49
+ class Checkpoint(TypedDict):
50
+ thread_id: str
51
+ checkpoint_ns: str
52
+ checkpoint_id: str | None
53
+ checkpoint_map: dict[str, Any] | None
54
+
55
+
56
+ class GraphSchema(TypedDict):
57
+ """Graph model."""
58
+
59
+ graph_id: str
60
+ """The ID of the graph."""
61
+ state_schema: dict
62
+ """The schema for the graph state."""
63
+ config_schema: dict
64
+ """The schema for the graph config."""
65
+
66
+
67
+ class Assistant(TypedDict):
68
+ """Assistant model."""
69
+
70
+ assistant_id: UUID
71
+ """The ID of the assistant."""
72
+ graph_id: str
73
+ """The ID of the graph."""
74
+ config: Config
75
+ """The assistant config."""
76
+ created_at: datetime
77
+ """The time the assistant was created."""
78
+ updated_at: datetime
79
+ """The last time the assistant was updated."""
80
+ metadata: Fragment
81
+ """The assistant metadata."""
82
+ version: int
83
+ """The assistant version."""
84
+
85
+
86
+ class Thread(TypedDict):
87
+ thread_id: UUID
88
+ """The ID of the thread."""
89
+ created_at: datetime
90
+ """The time the thread was created."""
91
+ updated_at: datetime
92
+ """The last time the thread was updated."""
93
+ metadata: Fragment
94
+ """The thread metadata."""
95
+ config: Fragment
96
+ """The thread config."""
97
+ status: ThreadStatus
98
+ """The status of the thread. One of 'idle', 'busy', 'interrupted', "error"."""
99
+ values: Fragment
100
+ """The current state of the thread."""
101
+
102
+
103
+ class ThreadTask(TypedDict):
104
+ id: str
105
+ name: str
106
+ error: str | None
107
+ interrupts: list[dict]
108
+ checkpoint: Checkpoint | None
109
+ state: Optional["ThreadState"]
110
+
111
+
112
+ class ThreadState(TypedDict):
113
+ values: dict[str, Any]
114
+ """The state values."""
115
+ next: Sequence[str]
116
+ """The name of the node to execute in each task for this step."""
117
+ checkpoint: Checkpoint
118
+ """The checkpoint keys. This object can be passed to the /threads and /runs
119
+ endpoints to resume execution or update state."""
120
+ metadata: Fragment
121
+ """Metadata for this state"""
122
+ created_at: str | None
123
+ """Timestamp of state creation"""
124
+ parent_checkpoint: Checkpoint | None
125
+ """The parent checkpoint. If missing, this is the root checkpoint."""
126
+ tasks: Sequence[ThreadTask]
127
+ """Tasks to execute in this step. If already attempted, may contain an error."""
128
+
129
+
130
+ class Run(TypedDict):
131
+ run_id: UUID
132
+ """The ID of the run."""
133
+ thread_id: UUID
134
+ """The ID of the thread."""
135
+ assistant_id: UUID
136
+ """The assistant that was used for this run."""
137
+ created_at: datetime
138
+ """The time the run was created."""
139
+ updated_at: datetime
140
+ """The last time the run was updated."""
141
+ status: RunStatus
142
+ """The status of the run. One of 'pending', 'error', 'success'."""
143
+ metadata: Fragment
144
+ """The run metadata."""
145
+ kwargs: Fragment
146
+ """The run kwargs."""
147
+ multitask_strategy: MultitaskStrategy
148
+ """Strategy to handle concurrent runs on the same thread."""
149
+
150
+
151
+ class RunSend(TypedDict):
152
+ node: str
153
+ input: dict[str, Any] | None
154
+
155
+
156
+ class RunCommand(TypedDict):
157
+ send: RunSend | Sequence[RunSend] | None
158
+ update: dict[str, Any] | None
159
+ resume: Any | None
160
+
161
+
162
+ class Cron(TypedDict):
163
+ """Cron model."""
164
+
165
+ cron_id: UUID
166
+ """The ID of the cron."""
167
+ thread_id: UUID | None
168
+ """The ID of the thread."""
169
+ end_time: datetime | None
170
+ """The end date to stop running the cron."""
171
+ schedule: str
172
+ """The schedule to run, cron format."""
173
+ created_at: datetime
174
+ """The time the cron was created."""
175
+ updated_at: datetime
176
+ """The last time the cron was updated."""
177
+ payload: Fragment
178
+ """The run payload to use for creating new run."""
179
+
180
+
181
+ class ThreadUpdateResponse(TypedDict):
182
+ """Response for updating a thread."""
183
+
184
+ checkpoint: Checkpoint
185
+
186
+
187
+ class QueueStats(TypedDict):
188
+ n_pending: int
189
+ max_age_secs: datetime | None
190
+ med_age_secs: datetime | None
langgraph_api/serde.py ADDED
@@ -0,0 +1,124 @@
1
+ import asyncio
2
+ import pickle
3
+ import uuid
4
+ from base64 import b64encode
5
+ from collections import deque
6
+ from datetime import timedelta, timezone
7
+ from decimal import Decimal
8
+ from ipaddress import (
9
+ IPv4Address,
10
+ IPv4Interface,
11
+ IPv4Network,
12
+ IPv6Address,
13
+ IPv6Interface,
14
+ IPv6Network,
15
+ )
16
+ from pathlib import Path
17
+ from re import Pattern
18
+ from typing import Any, NamedTuple
19
+ from zoneinfo import ZoneInfo
20
+
21
+ import orjson
22
+ from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
23
+
24
+
25
+ class Fragment(NamedTuple):
26
+ buf: bytes
27
+
28
+
29
+ def decimal_encoder(dec_value: Decimal) -> int | float:
30
+ """
31
+ Encodes a Decimal as int of there's no exponent, otherwise float
32
+
33
+ This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
34
+ where a integer (but not int typed) is used. Encoding this as a float
35
+ results in failed round-tripping between encode and parse.
36
+ Our Id type is a prime example of this.
37
+
38
+ >>> decimal_encoder(Decimal("1.0"))
39
+ 1.0
40
+
41
+ >>> decimal_encoder(Decimal("1"))
42
+ 1
43
+ """
44
+ if dec_value.as_tuple().exponent >= 0:
45
+ return int(dec_value)
46
+ else:
47
+ return float(dec_value)
48
+
49
+
50
+ def default(obj):
51
+ # Only need to handle types that orjson doesn't serialize by default
52
+ # https://github.com/ijl/orjson#serialize
53
+ if isinstance(obj, Fragment):
54
+ return orjson.Fragment(obj.buf)
55
+ if hasattr(obj, "model_dump") and callable(obj.model_dump):
56
+ return obj.model_dump()
57
+ elif hasattr(obj, "dict") and callable(obj.dict):
58
+ return obj.dict()
59
+ elif hasattr(obj, "_asdict") and callable(obj._asdict):
60
+ return obj._asdict()
61
+ elif isinstance(obj, BaseException):
62
+ return {"error": type(obj).__name__, "message": str(obj)}
63
+ elif isinstance(obj, (set, frozenset, deque)): # noqa: UP038
64
+ return list(obj)
65
+ elif isinstance(obj, (timezone, ZoneInfo)): # noqa: UP038
66
+ return obj.tzname(None)
67
+ elif isinstance(obj, timedelta):
68
+ return obj.total_seconds()
69
+ elif isinstance(obj, Decimal):
70
+ return decimal_encoder(obj)
71
+ elif isinstance(obj, uuid.UUID):
72
+ return str(obj)
73
+ elif isinstance( # noqa: UP038
74
+ obj,
75
+ (
76
+ IPv4Address,
77
+ IPv4Interface,
78
+ IPv4Network,
79
+ IPv6Address,
80
+ IPv6Interface,
81
+ IPv6Network,
82
+ Path,
83
+ ),
84
+ ):
85
+ return str(obj)
86
+ elif isinstance(obj, Pattern):
87
+ return obj.pattern
88
+ elif isinstance(obj, bytes | bytearray):
89
+ return b64encode(obj).decode()
90
+ return None
91
+
92
+
93
+ _option = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS
94
+
95
+
96
+ def json_dumpb(obj) -> bytes:
97
+ return orjson.dumps(obj, default=default, option=_option).replace(
98
+ b"\u0000", b""
99
+ ) # null unicode char not allowed in json
100
+
101
+
102
+ def json_loads(content: bytes | Fragment | dict) -> Any:
103
+ if isinstance(content, Fragment):
104
+ content = content.buf
105
+ if isinstance(content, dict):
106
+ return content
107
+ return orjson.loads(content)
108
+
109
+
110
+ async def ajson_loads(content: bytes | Fragment) -> Any:
111
+ return await asyncio.to_thread(json_loads, content)
112
+
113
+
114
+ class Serializer(JsonPlusSerializer):
115
+ def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
116
+ try:
117
+ return super().dumps_typed(obj)
118
+ except TypeError:
119
+ return "pickle", pickle.dumps(obj)
120
+
121
+ def loads_typed(self, data: tuple[str, bytes]) -> Any:
122
+ if data[0] == "pickle":
123
+ return pickle.loads(data[1])
124
+ return super().loads_typed(data)