langgraph-api 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- LICENSE +93 -0
- langgraph_api/__init__.py +0 -0
- langgraph_api/api/__init__.py +63 -0
- langgraph_api/api/assistants.py +326 -0
- langgraph_api/api/meta.py +71 -0
- langgraph_api/api/openapi.py +32 -0
- langgraph_api/api/runs.py +463 -0
- langgraph_api/api/store.py +116 -0
- langgraph_api/api/threads.py +263 -0
- langgraph_api/asyncio.py +201 -0
- langgraph_api/auth/__init__.py +0 -0
- langgraph_api/auth/langsmith/__init__.py +0 -0
- langgraph_api/auth/langsmith/backend.py +67 -0
- langgraph_api/auth/langsmith/client.py +145 -0
- langgraph_api/auth/middleware.py +41 -0
- langgraph_api/auth/noop.py +14 -0
- langgraph_api/cli.py +209 -0
- langgraph_api/config.py +70 -0
- langgraph_api/cron_scheduler.py +60 -0
- langgraph_api/errors.py +52 -0
- langgraph_api/graph.py +314 -0
- langgraph_api/http.py +168 -0
- langgraph_api/http_logger.py +89 -0
- langgraph_api/js/.gitignore +2 -0
- langgraph_api/js/build.mts +49 -0
- langgraph_api/js/client.mts +849 -0
- langgraph_api/js/global.d.ts +6 -0
- langgraph_api/js/package.json +33 -0
- langgraph_api/js/remote.py +673 -0
- langgraph_api/js/server_sent_events.py +126 -0
- langgraph_api/js/src/graph.mts +88 -0
- langgraph_api/js/src/hooks.mjs +12 -0
- langgraph_api/js/src/parser/parser.mts +443 -0
- langgraph_api/js/src/parser/parser.worker.mjs +12 -0
- langgraph_api/js/src/schema/types.mts +2136 -0
- langgraph_api/js/src/schema/types.template.mts +74 -0
- langgraph_api/js/src/utils/importMap.mts +85 -0
- langgraph_api/js/src/utils/pythonSchemas.mts +28 -0
- langgraph_api/js/src/utils/serde.mts +21 -0
- langgraph_api/js/tests/api.test.mts +1566 -0
- langgraph_api/js/tests/compose-postgres.yml +56 -0
- langgraph_api/js/tests/graphs/.gitignore +1 -0
- langgraph_api/js/tests/graphs/agent.mts +127 -0
- langgraph_api/js/tests/graphs/error.mts +17 -0
- langgraph_api/js/tests/graphs/langgraph.json +8 -0
- langgraph_api/js/tests/graphs/nested.mts +44 -0
- langgraph_api/js/tests/graphs/package.json +7 -0
- langgraph_api/js/tests/graphs/weather.mts +57 -0
- langgraph_api/js/tests/graphs/yarn.lock +159 -0
- langgraph_api/js/tests/parser.test.mts +870 -0
- langgraph_api/js/tests/utils.mts +17 -0
- langgraph_api/js/yarn.lock +1340 -0
- langgraph_api/lifespan.py +41 -0
- langgraph_api/logging.py +121 -0
- langgraph_api/metadata.py +101 -0
- langgraph_api/models/__init__.py +0 -0
- langgraph_api/models/run.py +229 -0
- langgraph_api/patch.py +42 -0
- langgraph_api/queue.py +245 -0
- langgraph_api/route.py +118 -0
- langgraph_api/schema.py +190 -0
- langgraph_api/serde.py +124 -0
- langgraph_api/server.py +48 -0
- langgraph_api/sse.py +118 -0
- langgraph_api/state.py +67 -0
- langgraph_api/stream.py +289 -0
- langgraph_api/utils.py +60 -0
- langgraph_api/validation.py +141 -0
- langgraph_api-0.0.1.dist-info/LICENSE +93 -0
- langgraph_api-0.0.1.dist-info/METADATA +26 -0
- langgraph_api-0.0.1.dist-info/RECORD +86 -0
- langgraph_api-0.0.1.dist-info/WHEEL +4 -0
- langgraph_api-0.0.1.dist-info/entry_points.txt +3 -0
- langgraph_license/__init__.py +0 -0
- langgraph_license/middleware.py +21 -0
- langgraph_license/validation.py +11 -0
- langgraph_storage/__init__.py +0 -0
- langgraph_storage/checkpoint.py +94 -0
- langgraph_storage/database.py +190 -0
- langgraph_storage/ops.py +1523 -0
- langgraph_storage/queue.py +108 -0
- langgraph_storage/retry.py +27 -0
- langgraph_storage/store.py +28 -0
- langgraph_storage/ttl_dict.py +54 -0
- logging.json +22 -0
- openapi.json +4304 -0
langgraph_api/server.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import jsonschema_rs
|
|
4
|
+
import structlog
|
|
5
|
+
from langgraph.errors import EmptyInputError, InvalidUpdateError
|
|
6
|
+
from starlette.applications import Starlette
|
|
7
|
+
from starlette.middleware import Middleware
|
|
8
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
9
|
+
|
|
10
|
+
import langgraph_api.config as config
|
|
11
|
+
import langgraph_api.patch # noqa: F401
|
|
12
|
+
from langgraph_api.api import routes
|
|
13
|
+
from langgraph_api.errors import (
|
|
14
|
+
overloaded_error_handler,
|
|
15
|
+
validation_error_handler,
|
|
16
|
+
value_error_handler,
|
|
17
|
+
)
|
|
18
|
+
from langgraph_api.http_logger import AccessLoggerMiddleware
|
|
19
|
+
from langgraph_api.lifespan import lifespan
|
|
20
|
+
from langgraph_license.middleware import LicenseValidationMiddleware
|
|
21
|
+
from langgraph_storage.retry import OVERLOADED_EXCEPTIONS
|
|
22
|
+
|
|
23
|
+
logging.captureWarnings(True)
|
|
24
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
app = Starlette(
|
|
28
|
+
routes=routes,
|
|
29
|
+
lifespan=lifespan,
|
|
30
|
+
middleware=[
|
|
31
|
+
Middleware(
|
|
32
|
+
CORSMiddleware,
|
|
33
|
+
allow_origins=config.CORS_ALLOW_ORIGINS,
|
|
34
|
+
allow_credentials=True,
|
|
35
|
+
allow_methods=["*"],
|
|
36
|
+
allow_headers=["*"],
|
|
37
|
+
),
|
|
38
|
+
Middleware(LicenseValidationMiddleware),
|
|
39
|
+
Middleware(AccessLoggerMiddleware, logger=logger),
|
|
40
|
+
],
|
|
41
|
+
exception_handlers={
|
|
42
|
+
ValueError: value_error_handler,
|
|
43
|
+
InvalidUpdateError: value_error_handler,
|
|
44
|
+
EmptyInputError: value_error_handler,
|
|
45
|
+
jsonschema_rs.ValidationError: validation_error_handler,
|
|
46
|
+
}
|
|
47
|
+
| {exc: overloaded_error_handler for exc in OVERLOADED_EXCEPTIONS},
|
|
48
|
+
)
|
langgraph_api/sse.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Mapping
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import anyio
|
|
7
|
+
import sse_starlette
|
|
8
|
+
import sse_starlette.sse
|
|
9
|
+
import structlog.stdlib
|
|
10
|
+
from starlette.types import Receive, Scope, Send
|
|
11
|
+
|
|
12
|
+
from langgraph_api.asyncio import SimpleTaskGroup, aclosing
|
|
13
|
+
from langgraph_api.serde import json_dumpb
|
|
14
|
+
|
|
15
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EventSourceResponse(sse_starlette.EventSourceResponse):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
content: AsyncIterator[bytes | tuple[bytes, Any | bytes]],
|
|
22
|
+
status_code: int = 200,
|
|
23
|
+
headers: Mapping[str, str] | None = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
super().__init__(content=content, status_code=status_code, headers=headers)
|
|
26
|
+
|
|
27
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
28
|
+
async with anyio.create_task_group() as task_group:
|
|
29
|
+
# https://trio.readthedocs.io/en/latest/reference-core.html#custom-supervisors
|
|
30
|
+
async def wrap(func: Callable[[], Awaitable[None]]) -> None:
|
|
31
|
+
await func()
|
|
32
|
+
# noinspection PyAsyncCall
|
|
33
|
+
task_group.cancel_scope.cancel()
|
|
34
|
+
|
|
35
|
+
task_group.start_soon(wrap, partial(self.stream_response, send))
|
|
36
|
+
task_group.start_soon(wrap, self.listen_for_exit_signal)
|
|
37
|
+
|
|
38
|
+
if self.data_sender_callable:
|
|
39
|
+
task_group.start_soon(self.data_sender_callable)
|
|
40
|
+
|
|
41
|
+
await wrap(partial(self.listen_for_disconnect, receive))
|
|
42
|
+
|
|
43
|
+
if self.background is not None: # pragma: no cover, tested in StreamResponse
|
|
44
|
+
await self.background()
|
|
45
|
+
|
|
46
|
+
async def stream_response(self, send: Send) -> None:
|
|
47
|
+
await send(
|
|
48
|
+
{
|
|
49
|
+
"type": "http.response.start",
|
|
50
|
+
"status": self.status_code,
|
|
51
|
+
"headers": self.raw_headers,
|
|
52
|
+
}
|
|
53
|
+
)
|
|
54
|
+
async with (
|
|
55
|
+
SimpleTaskGroup(sse_heartbeat(send), cancel=True, wait=False),
|
|
56
|
+
aclosing(self.body_iterator) as body,
|
|
57
|
+
):
|
|
58
|
+
try:
|
|
59
|
+
async for data in body:
|
|
60
|
+
with anyio.move_on_after(self.send_timeout) as timeout:
|
|
61
|
+
await send(
|
|
62
|
+
{
|
|
63
|
+
"type": "http.response.body",
|
|
64
|
+
"body": json_to_sse(*data)
|
|
65
|
+
if isinstance(data, tuple)
|
|
66
|
+
else data,
|
|
67
|
+
"more_body": True,
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
if timeout.cancel_called:
|
|
71
|
+
raise sse_starlette.sse.SendTimeoutError()
|
|
72
|
+
except sse_starlette.sse.SendTimeoutError:
|
|
73
|
+
raise
|
|
74
|
+
except Exception as exc:
|
|
75
|
+
await logger.aexception("Error streaming response", exc_info=exc)
|
|
76
|
+
await send(
|
|
77
|
+
{
|
|
78
|
+
"type": "http.response.body",
|
|
79
|
+
"body": json_to_sse(b"error", exc),
|
|
80
|
+
"more_body": False,
|
|
81
|
+
}
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
async with self._send_lock:
|
|
85
|
+
self.active = False
|
|
86
|
+
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
async def sse_heartbeat(send: Send) -> None:
|
|
90
|
+
payload = sse_starlette.ServerSentEvent(comment="heartbeat").encode()
|
|
91
|
+
try:
|
|
92
|
+
while True:
|
|
93
|
+
await asyncio.sleep(5)
|
|
94
|
+
await send(
|
|
95
|
+
{"type": "http.response.body", "body": payload, "more_body": True}
|
|
96
|
+
)
|
|
97
|
+
except asyncio.CancelledError:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
SEP = b"\r\n"
|
|
102
|
+
EVENT = b"event: "
|
|
103
|
+
DATA = b"data: "
|
|
104
|
+
BYTES_LIKE = (bytes, bytearray, memoryview)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def json_to_sse(event: bytes, data: Any | bytes) -> bytes:
|
|
108
|
+
return b"".join(
|
|
109
|
+
(
|
|
110
|
+
EVENT,
|
|
111
|
+
event,
|
|
112
|
+
SEP,
|
|
113
|
+
DATA,
|
|
114
|
+
data if isinstance(data, BYTES_LIKE) else json_dumpb(data),
|
|
115
|
+
SEP,
|
|
116
|
+
SEP,
|
|
117
|
+
)
|
|
118
|
+
)
|
langgraph_api/state.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from langchain_core.runnables.config import RunnableConfig
|
|
2
|
+
from langgraph.types import StateSnapshot
|
|
3
|
+
|
|
4
|
+
from langgraph_api.schema import Checkpoint, ThreadState
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def runnable_config_to_checkpoint(
|
|
8
|
+
config: RunnableConfig | None,
|
|
9
|
+
) -> Checkpoint | None:
|
|
10
|
+
if (
|
|
11
|
+
not config
|
|
12
|
+
or not config["configurable"]
|
|
13
|
+
or "thread_id" not in config["configurable"]
|
|
14
|
+
or not config["configurable"]["thread_id"]
|
|
15
|
+
or "checkpoint_id" not in config["configurable"]
|
|
16
|
+
or not config["configurable"]["checkpoint_id"]
|
|
17
|
+
):
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
configurable = config["configurable"]
|
|
21
|
+
checkpoint: Checkpoint = {
|
|
22
|
+
"checkpoint_id": configurable["checkpoint_id"],
|
|
23
|
+
"thread_id": configurable["thread_id"],
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
if "checkpoint_ns" in configurable:
|
|
27
|
+
checkpoint["checkpoint_ns"] = configurable["checkpoint_ns"] or ""
|
|
28
|
+
|
|
29
|
+
if "checkpoint_map" in configurable:
|
|
30
|
+
checkpoint["checkpoint_map"] = configurable["checkpoint_map"]
|
|
31
|
+
|
|
32
|
+
return checkpoint
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def state_snapshot_to_thread_state(state: StateSnapshot) -> ThreadState:
|
|
36
|
+
return {
|
|
37
|
+
"values": state.values,
|
|
38
|
+
"next": state.next,
|
|
39
|
+
"tasks": [
|
|
40
|
+
{
|
|
41
|
+
"id": t.id,
|
|
42
|
+
"name": t.name,
|
|
43
|
+
"path": t.path,
|
|
44
|
+
"error": t.error,
|
|
45
|
+
"interrupts": t.interrupts,
|
|
46
|
+
"checkpoint": t.state["configurable"]
|
|
47
|
+
if t.state is not None and not isinstance(t.state, StateSnapshot)
|
|
48
|
+
else None,
|
|
49
|
+
"state": state_snapshot_to_thread_state(t.state)
|
|
50
|
+
if isinstance(t.state, StateSnapshot)
|
|
51
|
+
else None,
|
|
52
|
+
"result": getattr(t, "result", None),
|
|
53
|
+
}
|
|
54
|
+
for t in state.tasks
|
|
55
|
+
],
|
|
56
|
+
"metadata": state.metadata,
|
|
57
|
+
"created_at": state.created_at,
|
|
58
|
+
"checkpoint": runnable_config_to_checkpoint(state.config),
|
|
59
|
+
"parent_checkpoint": runnable_config_to_checkpoint(state.parent_config),
|
|
60
|
+
# below are deprecated
|
|
61
|
+
"checkpoint_id": state.config["configurable"].get("checkpoint_id")
|
|
62
|
+
if state.config
|
|
63
|
+
else None,
|
|
64
|
+
"parent_checkpoint_id": state.parent_config["configurable"]["checkpoint_id"]
|
|
65
|
+
if state.parent_config
|
|
66
|
+
else None,
|
|
67
|
+
}
|
langgraph_api/stream.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
from collections.abc import AsyncIterator, Callable
|
|
2
|
+
from contextlib import AsyncExitStack, aclosing
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
import langgraph.version
|
|
7
|
+
import langsmith
|
|
8
|
+
import structlog
|
|
9
|
+
from langchain_core.messages import (
|
|
10
|
+
BaseMessage,
|
|
11
|
+
BaseMessageChunk,
|
|
12
|
+
message_chunk_to_message,
|
|
13
|
+
)
|
|
14
|
+
from langchain_core.runnables.config import run_in_executor
|
|
15
|
+
from langgraph.errors import (
|
|
16
|
+
EmptyChannelError,
|
|
17
|
+
EmptyInputError,
|
|
18
|
+
GraphRecursionError,
|
|
19
|
+
InvalidUpdateError,
|
|
20
|
+
)
|
|
21
|
+
from langgraph.pregel.debug import CheckpointPayload
|
|
22
|
+
from langgraph.types import Command, Send
|
|
23
|
+
from pydantic import ValidationError
|
|
24
|
+
from pydantic.v1 import ValidationError as ValidationErrorLegacy
|
|
25
|
+
|
|
26
|
+
from langgraph_api.asyncio import ValueEvent, wait_if_not_done
|
|
27
|
+
from langgraph_api.graph import get_graph
|
|
28
|
+
from langgraph_api.js.remote import RemotePregel
|
|
29
|
+
from langgraph_api.metadata import HOST, PLAN, incr_nodes
|
|
30
|
+
from langgraph_api.schema import Run, RunCommand, StreamMode
|
|
31
|
+
from langgraph_api.serde import json_dumpb
|
|
32
|
+
from langgraph_api.utils import AsyncConnectionProto
|
|
33
|
+
from langgraph_storage.checkpoint import Checkpointer
|
|
34
|
+
from langgraph_storage.ops import Runs
|
|
35
|
+
from langgraph_storage.store import Store
|
|
36
|
+
|
|
37
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
38
|
+
|
|
39
|
+
AnyStream = AsyncIterator[tuple[str, Any]]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _preproces_debug_checkpoint_task(task: dict[str, Any]) -> dict[str, Any]:
|
|
43
|
+
if (
|
|
44
|
+
"state" not in task
|
|
45
|
+
or not task["state"]
|
|
46
|
+
or "configurable" not in task["state"]
|
|
47
|
+
or not task["state"]["configurable"]
|
|
48
|
+
):
|
|
49
|
+
return task
|
|
50
|
+
|
|
51
|
+
task["checkpoint"] = task["state"]["configurable"]
|
|
52
|
+
del task["state"]
|
|
53
|
+
return task
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _preprocess_debug_checkpoint(payload: CheckpointPayload | None) -> dict[str, Any]:
|
|
57
|
+
from langgraph_api.state import runnable_config_to_checkpoint
|
|
58
|
+
|
|
59
|
+
if not payload:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
payload["checkpoint"] = runnable_config_to_checkpoint(payload["config"])
|
|
63
|
+
payload["parent_checkpoint"] = runnable_config_to_checkpoint(
|
|
64
|
+
payload["parent_config"] if "parent_config" in payload else None
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
payload["tasks"] = [_preproces_debug_checkpoint_task(t) for t in payload["tasks"]]
|
|
68
|
+
|
|
69
|
+
# TODO: deprecate the `config`` and `parent_config`` fields
|
|
70
|
+
return payload
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _map_cmd(cmd: RunCommand) -> Command:
|
|
74
|
+
send = cmd.get("send")
|
|
75
|
+
if send is not None and not isinstance(send, list):
|
|
76
|
+
send = [cmd.get("send")]
|
|
77
|
+
|
|
78
|
+
return Command(
|
|
79
|
+
update=cmd.get("update"),
|
|
80
|
+
send=[Send(send["node"], send["input"]) for send in send] if send else None,
|
|
81
|
+
resume=cmd.get("resume"),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def astream_state(
|
|
86
|
+
stack: AsyncExitStack,
|
|
87
|
+
conn: AsyncConnectionProto,
|
|
88
|
+
run: Run,
|
|
89
|
+
attempt: int,
|
|
90
|
+
done: ValueEvent,
|
|
91
|
+
*,
|
|
92
|
+
on_checkpoint: Callable[[CheckpointPayload], None] = lambda _: None,
|
|
93
|
+
) -> AnyStream:
|
|
94
|
+
"""Stream messages from the runnable."""
|
|
95
|
+
run_id = str(run["run_id"])
|
|
96
|
+
await stack.enter_async_context(conn.pipeline())
|
|
97
|
+
# extract args from run
|
|
98
|
+
kwargs = run["kwargs"].copy()
|
|
99
|
+
subgraphs = kwargs.get("subgraphs", False)
|
|
100
|
+
temporary = kwargs.pop("temporary", False)
|
|
101
|
+
config = kwargs.pop("config")
|
|
102
|
+
graph = get_graph(
|
|
103
|
+
config["configurable"]["graph_id"],
|
|
104
|
+
config,
|
|
105
|
+
store=None if not conn else Store(conn),
|
|
106
|
+
checkpointer=None if temporary else Checkpointer(conn),
|
|
107
|
+
)
|
|
108
|
+
input = kwargs.pop("input")
|
|
109
|
+
if cmd := kwargs.pop("command"):
|
|
110
|
+
input = _map_cmd(cmd)
|
|
111
|
+
stream_mode: list[StreamMode] = kwargs.pop("stream_mode")
|
|
112
|
+
feedback_keys = kwargs.pop("feedback_keys", None)
|
|
113
|
+
stream_modes_set: set[StreamMode] = set(stream_mode) - {"events"}
|
|
114
|
+
if "debug" not in stream_modes_set:
|
|
115
|
+
stream_modes_set.add("debug")
|
|
116
|
+
if "messages-tuple" in stream_modes_set:
|
|
117
|
+
stream_modes_set.remove("messages-tuple")
|
|
118
|
+
stream_modes_set.add("messages")
|
|
119
|
+
# attach attempt metadata
|
|
120
|
+
config["metadata"]["run_attempt"] = attempt
|
|
121
|
+
# attach langgraph metadata
|
|
122
|
+
config["metadata"]["langgraph_version"] = langgraph.version.__version__
|
|
123
|
+
config["metadata"]["langgraph_plan"] = PLAN
|
|
124
|
+
config["metadata"]["langgraph_host"] = HOST
|
|
125
|
+
# attach node counter
|
|
126
|
+
if not isinstance(graph, RemotePregel):
|
|
127
|
+
config["configurable"]["__pregel_node_finished"] = incr_nodes
|
|
128
|
+
# TODO add node tracking for JS graphs
|
|
129
|
+
# attach run_id to config
|
|
130
|
+
# for attempts beyond the first, use a fresh, unique run_id
|
|
131
|
+
config = {**config, "run_id": run["run_id"]} if attempt == 1 else config
|
|
132
|
+
# set up state
|
|
133
|
+
checkpoint: CheckpointPayload | None = None
|
|
134
|
+
messages: dict[str, BaseMessageChunk] = {}
|
|
135
|
+
use_astream_events = "events" in stream_mode or isinstance(graph, RemotePregel)
|
|
136
|
+
# yield metadata chunk
|
|
137
|
+
yield "metadata", {"run_id": run_id, "attempt": attempt}
|
|
138
|
+
# stream run
|
|
139
|
+
if use_astream_events:
|
|
140
|
+
async with (
|
|
141
|
+
stack,
|
|
142
|
+
aclosing(
|
|
143
|
+
graph.astream_events(
|
|
144
|
+
input,
|
|
145
|
+
config,
|
|
146
|
+
version="v2",
|
|
147
|
+
stream_mode=list(stream_modes_set),
|
|
148
|
+
**kwargs,
|
|
149
|
+
)
|
|
150
|
+
) as stream,
|
|
151
|
+
):
|
|
152
|
+
sentinel = object()
|
|
153
|
+
while True:
|
|
154
|
+
event = await wait_if_not_done(anext(stream, sentinel), done)
|
|
155
|
+
if event is sentinel:
|
|
156
|
+
break
|
|
157
|
+
if event.get("tags") and "langsmith:hidden" in event["tags"]:
|
|
158
|
+
continue
|
|
159
|
+
if "messages" in stream_mode and isinstance(graph, RemotePregel):
|
|
160
|
+
if event["event"] == "on_custom_event" and event["name"] in (
|
|
161
|
+
"messages/complete",
|
|
162
|
+
"messages/partial",
|
|
163
|
+
"messages/metadata",
|
|
164
|
+
):
|
|
165
|
+
yield event["name"], event["data"]
|
|
166
|
+
# TODO support messages-tuple for js graphs
|
|
167
|
+
if event["event"] == "on_chain_stream" and event["run_id"] == run_id:
|
|
168
|
+
if subgraphs:
|
|
169
|
+
ns, mode, chunk = event["data"]["chunk"]
|
|
170
|
+
else:
|
|
171
|
+
mode, chunk = event["data"]["chunk"]
|
|
172
|
+
# --- begin shared logic with astream ---
|
|
173
|
+
if mode == "debug":
|
|
174
|
+
if chunk["type"] == "checkpoint":
|
|
175
|
+
checkpoint = _preprocess_debug_checkpoint(chunk["payload"])
|
|
176
|
+
on_checkpoint(checkpoint)
|
|
177
|
+
if mode == "messages":
|
|
178
|
+
if "messages-tuple" in stream_mode:
|
|
179
|
+
yield "messages", chunk
|
|
180
|
+
else:
|
|
181
|
+
msg, meta = cast(tuple[BaseMessage, dict[str, Any]], chunk)
|
|
182
|
+
if msg.id in messages:
|
|
183
|
+
messages[msg.id] += msg
|
|
184
|
+
else:
|
|
185
|
+
messages[msg.id] = msg
|
|
186
|
+
yield "messages/metadata", {msg.id: {"metadata": meta}}
|
|
187
|
+
yield (
|
|
188
|
+
"messages/partial"
|
|
189
|
+
if isinstance(msg, BaseMessageChunk)
|
|
190
|
+
else "messages/complete",
|
|
191
|
+
[message_chunk_to_message(messages[msg.id])],
|
|
192
|
+
)
|
|
193
|
+
elif mode in stream_mode:
|
|
194
|
+
if subgraphs and ns:
|
|
195
|
+
yield f"{mode}|{'|'.join(ns)}", chunk
|
|
196
|
+
else:
|
|
197
|
+
yield mode, chunk
|
|
198
|
+
# --- end shared logic with astream ---
|
|
199
|
+
elif "events" in stream_mode:
|
|
200
|
+
yield "events", event
|
|
201
|
+
else:
|
|
202
|
+
async with (
|
|
203
|
+
stack,
|
|
204
|
+
aclosing(
|
|
205
|
+
graph.astream(
|
|
206
|
+
input, config, stream_mode=list(stream_modes_set), **kwargs
|
|
207
|
+
)
|
|
208
|
+
) as stream,
|
|
209
|
+
):
|
|
210
|
+
sentinel = object()
|
|
211
|
+
while True:
|
|
212
|
+
event = await wait_if_not_done(anext(stream, sentinel), done)
|
|
213
|
+
if event is sentinel:
|
|
214
|
+
break
|
|
215
|
+
if subgraphs:
|
|
216
|
+
ns, mode, chunk = event
|
|
217
|
+
else:
|
|
218
|
+
mode, chunk = event
|
|
219
|
+
# --- begin shared logic with astream_events ---
|
|
220
|
+
if mode == "debug":
|
|
221
|
+
if chunk["type"] == "checkpoint":
|
|
222
|
+
checkpoint = _preprocess_debug_checkpoint(chunk["payload"])
|
|
223
|
+
on_checkpoint(checkpoint)
|
|
224
|
+
if mode == "messages":
|
|
225
|
+
if "messages-tuple" in stream_mode:
|
|
226
|
+
yield "messages", chunk
|
|
227
|
+
else:
|
|
228
|
+
msg, meta = cast(tuple[BaseMessage, dict[str, Any]], chunk)
|
|
229
|
+
if msg.id in messages:
|
|
230
|
+
messages[msg.id] += msg
|
|
231
|
+
else:
|
|
232
|
+
messages[msg.id] = msg
|
|
233
|
+
yield "messages/metadata", {msg.id: {"metadata": meta}}
|
|
234
|
+
yield (
|
|
235
|
+
"messages/partial"
|
|
236
|
+
if isinstance(msg, BaseMessageChunk)
|
|
237
|
+
else "messages/complete",
|
|
238
|
+
[message_chunk_to_message(messages[msg.id])],
|
|
239
|
+
)
|
|
240
|
+
elif mode in stream_mode:
|
|
241
|
+
if subgraphs and ns:
|
|
242
|
+
yield f"{mode}|{'|'.join(ns)}", chunk
|
|
243
|
+
else:
|
|
244
|
+
yield mode, chunk
|
|
245
|
+
# --- end shared logic with astream_events ---
|
|
246
|
+
# Get feedback URLs
|
|
247
|
+
if feedback_keys:
|
|
248
|
+
feedback_urls = await run_in_executor(
|
|
249
|
+
None, get_feedback_urls, run_id, feedback_keys
|
|
250
|
+
)
|
|
251
|
+
yield "feedback", feedback_urls
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
async def consume(stream: AnyStream, run_id: str) -> None:
|
|
255
|
+
async with aclosing(stream):
|
|
256
|
+
try:
|
|
257
|
+
async for mode, payload in stream:
|
|
258
|
+
await Runs.Stream.publish(
|
|
259
|
+
run_id, mode, await run_in_executor(None, json_dumpb, payload)
|
|
260
|
+
)
|
|
261
|
+
except Exception as e:
|
|
262
|
+
if isinstance(e, ExceptionGroup):
|
|
263
|
+
e = e.exceptions[0]
|
|
264
|
+
await Runs.Stream.publish(
|
|
265
|
+
run_id, "error", await run_in_executor(None, json_dumpb, e)
|
|
266
|
+
)
|
|
267
|
+
raise e from None
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def get_feedback_urls(run_id: str, feedback_keys: list[str]) -> dict[str, str]:
|
|
271
|
+
client = get_langsmith_client()
|
|
272
|
+
tokens = client.create_presigned_feedback_tokens(run_id, feedback_keys)
|
|
273
|
+
return {key: token.url for key, token in zip(feedback_keys, tokens, strict=False)}
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@lru_cache(maxsize=1)
|
|
277
|
+
def get_langsmith_client() -> langsmith.Client:
|
|
278
|
+
return langsmith.Client()
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
EXPECTED_ERRORS = (
|
|
282
|
+
ValueError,
|
|
283
|
+
InvalidUpdateError,
|
|
284
|
+
GraphRecursionError,
|
|
285
|
+
EmptyInputError,
|
|
286
|
+
EmptyChannelError,
|
|
287
|
+
ValidationError,
|
|
288
|
+
ValidationErrorLegacy,
|
|
289
|
+
)
|
langgraph_api/utils.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
from collections.abc import AsyncIterator
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any, Protocol, TypeAlias, TypeVar
|
|
6
|
+
|
|
7
|
+
from starlette.exceptions import HTTPException
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
Row: TypeAlias = dict[str, Any]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AsyncCursorProto(Protocol):
|
|
14
|
+
async def fetchone(self) -> Row: ...
|
|
15
|
+
|
|
16
|
+
async def fetchall(self) -> list[Row]: ...
|
|
17
|
+
|
|
18
|
+
async def __aiter__(self) -> AsyncIterator[Row]:
|
|
19
|
+
yield ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AsyncPipelineProto(Protocol):
|
|
23
|
+
async def sync(self) -> None: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AsyncConnectionProto(Protocol):
|
|
27
|
+
@asynccontextmanager
|
|
28
|
+
async def pipeline(self) -> AsyncIterator[AsyncPipelineProto]:
|
|
29
|
+
yield ...
|
|
30
|
+
|
|
31
|
+
async def execute(self, query: str, *args, **kwargs) -> AsyncCursorProto: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def fetchone(
|
|
35
|
+
it: AsyncIterator[T],
|
|
36
|
+
*,
|
|
37
|
+
not_found_code: int = 404,
|
|
38
|
+
not_found_detail: str | None = None,
|
|
39
|
+
) -> T:
|
|
40
|
+
"""Fetch the first row from an async iterator."""
|
|
41
|
+
try:
|
|
42
|
+
return await anext(it)
|
|
43
|
+
except StopAsyncIteration:
|
|
44
|
+
raise HTTPException(
|
|
45
|
+
status_code=not_found_code, detail=not_found_detail
|
|
46
|
+
) from None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def validate_uuid(uuid_str: str, invalid_uuid_detail: str | None) -> uuid.UUID:
|
|
50
|
+
try:
|
|
51
|
+
return uuid.UUID(uuid_str)
|
|
52
|
+
except ValueError:
|
|
53
|
+
raise HTTPException(status_code=422, detail=invalid_uuid_detail) from None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def next_cron_date(schedule: str, base_time: datetime) -> datetime:
|
|
57
|
+
import croniter
|
|
58
|
+
|
|
59
|
+
cron_iter = croniter.croniter(schedule, base_time)
|
|
60
|
+
return cron_iter.get_next(datetime)
|