langgraph-api 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- LICENSE +93 -0
- langgraph_api/__init__.py +0 -0
- langgraph_api/api/__init__.py +63 -0
- langgraph_api/api/assistants.py +326 -0
- langgraph_api/api/meta.py +71 -0
- langgraph_api/api/openapi.py +32 -0
- langgraph_api/api/runs.py +463 -0
- langgraph_api/api/store.py +116 -0
- langgraph_api/api/threads.py +263 -0
- langgraph_api/asyncio.py +201 -0
- langgraph_api/auth/__init__.py +0 -0
- langgraph_api/auth/langsmith/__init__.py +0 -0
- langgraph_api/auth/langsmith/backend.py +67 -0
- langgraph_api/auth/langsmith/client.py +145 -0
- langgraph_api/auth/middleware.py +41 -0
- langgraph_api/auth/noop.py +14 -0
- langgraph_api/cli.py +209 -0
- langgraph_api/config.py +70 -0
- langgraph_api/cron_scheduler.py +60 -0
- langgraph_api/errors.py +52 -0
- langgraph_api/graph.py +314 -0
- langgraph_api/http.py +168 -0
- langgraph_api/http_logger.py +89 -0
- langgraph_api/js/.gitignore +2 -0
- langgraph_api/js/build.mts +49 -0
- langgraph_api/js/client.mts +849 -0
- langgraph_api/js/global.d.ts +6 -0
- langgraph_api/js/package.json +33 -0
- langgraph_api/js/remote.py +673 -0
- langgraph_api/js/server_sent_events.py +126 -0
- langgraph_api/js/src/graph.mts +88 -0
- langgraph_api/js/src/hooks.mjs +12 -0
- langgraph_api/js/src/parser/parser.mts +443 -0
- langgraph_api/js/src/parser/parser.worker.mjs +12 -0
- langgraph_api/js/src/schema/types.mts +2136 -0
- langgraph_api/js/src/schema/types.template.mts +74 -0
- langgraph_api/js/src/utils/importMap.mts +85 -0
- langgraph_api/js/src/utils/pythonSchemas.mts +28 -0
- langgraph_api/js/src/utils/serde.mts +21 -0
- langgraph_api/js/tests/api.test.mts +1566 -0
- langgraph_api/js/tests/compose-postgres.yml +56 -0
- langgraph_api/js/tests/graphs/.gitignore +1 -0
- langgraph_api/js/tests/graphs/agent.mts +127 -0
- langgraph_api/js/tests/graphs/error.mts +17 -0
- langgraph_api/js/tests/graphs/langgraph.json +8 -0
- langgraph_api/js/tests/graphs/nested.mts +44 -0
- langgraph_api/js/tests/graphs/package.json +7 -0
- langgraph_api/js/tests/graphs/weather.mts +57 -0
- langgraph_api/js/tests/graphs/yarn.lock +159 -0
- langgraph_api/js/tests/parser.test.mts +870 -0
- langgraph_api/js/tests/utils.mts +17 -0
- langgraph_api/js/yarn.lock +1340 -0
- langgraph_api/lifespan.py +41 -0
- langgraph_api/logging.py +121 -0
- langgraph_api/metadata.py +101 -0
- langgraph_api/models/__init__.py +0 -0
- langgraph_api/models/run.py +229 -0
- langgraph_api/patch.py +42 -0
- langgraph_api/queue.py +245 -0
- langgraph_api/route.py +118 -0
- langgraph_api/schema.py +190 -0
- langgraph_api/serde.py +124 -0
- langgraph_api/server.py +48 -0
- langgraph_api/sse.py +118 -0
- langgraph_api/state.py +67 -0
- langgraph_api/stream.py +289 -0
- langgraph_api/utils.py +60 -0
- langgraph_api/validation.py +141 -0
- langgraph_api-0.0.1.dist-info/LICENSE +93 -0
- langgraph_api-0.0.1.dist-info/METADATA +26 -0
- langgraph_api-0.0.1.dist-info/RECORD +86 -0
- langgraph_api-0.0.1.dist-info/WHEEL +4 -0
- langgraph_api-0.0.1.dist-info/entry_points.txt +3 -0
- langgraph_license/__init__.py +0 -0
- langgraph_license/middleware.py +21 -0
- langgraph_license/validation.py +11 -0
- langgraph_storage/__init__.py +0 -0
- langgraph_storage/checkpoint.py +94 -0
- langgraph_storage/database.py +190 -0
- langgraph_storage/ops.py +1523 -0
- langgraph_storage/queue.py +108 -0
- langgraph_storage/retry.py +27 -0
- langgraph_storage/store.py +28 -0
- langgraph_storage/ttl_dict.py +54 -0
- logging.json +22 -0
- openapi.json +4304 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from contextlib import asynccontextmanager
|
|
2
|
+
|
|
3
|
+
from starlette.applications import Starlette
|
|
4
|
+
|
|
5
|
+
import langgraph_api.config as config
|
|
6
|
+
from langgraph_api.asyncio import SimpleTaskGroup
|
|
7
|
+
from langgraph_api.cron_scheduler import cron_scheduler
|
|
8
|
+
from langgraph_api.graph import collect_graphs_from_env, stop_remote_graphs
|
|
9
|
+
from langgraph_api.http import start_http_client, stop_http_client
|
|
10
|
+
from langgraph_api.metadata import metadata_loop
|
|
11
|
+
from langgraph_api.queue import queue
|
|
12
|
+
from langgraph_license.validation import get_license_status, plus_features_enabled
|
|
13
|
+
from langgraph_storage.database import start_pool, stop_pool
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@asynccontextmanager
|
|
17
|
+
async def lifespan(app: Starlette):
|
|
18
|
+
if not await get_license_status():
|
|
19
|
+
raise ValueError(
|
|
20
|
+
"License verification failed. Please ensure proper configuration:\n"
|
|
21
|
+
"- For local development, set a valid LANGSMITH_API_KEY for an account with LangGraph Cloud access "
|
|
22
|
+
"in the environment defined in your langgraph.json file.\n"
|
|
23
|
+
"- For production, configure the LANGGRAPH_CLOUD_LICENSE_KEY environment variable "
|
|
24
|
+
"with your LangGraph Cloud license key.\n"
|
|
25
|
+
"Review your configuration settings and try again. If issues persist, "
|
|
26
|
+
"contact support for assistance."
|
|
27
|
+
)
|
|
28
|
+
await start_http_client()
|
|
29
|
+
await start_pool()
|
|
30
|
+
await collect_graphs_from_env(True)
|
|
31
|
+
try:
|
|
32
|
+
async with SimpleTaskGroup(cancel=True) as tg:
|
|
33
|
+
tg.create_task(metadata_loop())
|
|
34
|
+
tg.create_task(queue(config.N_JOBS_PER_WORKER, config.BG_JOB_TIMEOUT_SECS))
|
|
35
|
+
if config.FF_CRONS_ENABLED and plus_features_enabled():
|
|
36
|
+
tg.create_task(cron_scheduler())
|
|
37
|
+
yield
|
|
38
|
+
finally:
|
|
39
|
+
await stop_remote_graphs()
|
|
40
|
+
await stop_http_client()
|
|
41
|
+
await stop_pool()
|
langgraph_api/logging.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
from starlette.config import Config
|
|
6
|
+
from structlog.typing import EventDict
|
|
7
|
+
|
|
8
|
+
from langgraph_api.metadata import append_log
|
|
9
|
+
from langgraph_api.serde import json_dumpb
|
|
10
|
+
|
|
11
|
+
# env
|
|
12
|
+
|
|
13
|
+
log_env = Config()
|
|
14
|
+
|
|
15
|
+
LOG_JSON = log_env("LOG_JSON", cast=bool, default=False)
|
|
16
|
+
LOG_COLOR = log_env("LOG_COLOR", cast=bool, default=True)
|
|
17
|
+
LOG_LEVEL = log_env("LOG_LEVEL", cast=str, default="INFO")
|
|
18
|
+
|
|
19
|
+
logging.getLogger().setLevel(LOG_LEVEL.upper())
|
|
20
|
+
logging.getLogger("psycopg").setLevel(logging.WARNING)
|
|
21
|
+
|
|
22
|
+
# custom processors
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AddPrefixedEnvVars:
|
|
26
|
+
def __init__(self, prefix: str) -> None:
|
|
27
|
+
self.kv = {
|
|
28
|
+
key.removeprefix(prefix).lower(): value
|
|
29
|
+
for key, value in os.environ.items()
|
|
30
|
+
if key.startswith(prefix)
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
def __call__(
|
|
34
|
+
self, logger: logging.Logger, method_name: str, event_dict: EventDict
|
|
35
|
+
) -> EventDict:
|
|
36
|
+
event_dict.update(self.kv)
|
|
37
|
+
return event_dict
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class JSONRenderer:
|
|
41
|
+
def __call__(
|
|
42
|
+
self, logger: logging.Logger, method_name: str, event_dict: EventDict
|
|
43
|
+
) -> str:
|
|
44
|
+
"""
|
|
45
|
+
The return type of this depends on the return type of self._dumps.
|
|
46
|
+
"""
|
|
47
|
+
return json_dumpb(event_dict).decode()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
LEVELS = logging.getLevelNamesMapping()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TapForMetadata:
|
|
54
|
+
def __call__(
|
|
55
|
+
self, logger: logging.Logger, method_name: str, event_dict: EventDict
|
|
56
|
+
) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Tap WARN and above logs for metadata. Exclude user loggers.
|
|
59
|
+
"""
|
|
60
|
+
if (
|
|
61
|
+
event_dict["logger"].startswith("langgraph")
|
|
62
|
+
and LEVELS[event_dict["level"].upper()] > LEVELS["INFO"]
|
|
63
|
+
):
|
|
64
|
+
append_log(event_dict.copy())
|
|
65
|
+
return event_dict
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# shared config, for both logging and structlog
|
|
69
|
+
|
|
70
|
+
shared_processors = [
|
|
71
|
+
structlog.stdlib.add_logger_name,
|
|
72
|
+
structlog.stdlib.add_log_level,
|
|
73
|
+
structlog.stdlib.PositionalArgumentsFormatter(),
|
|
74
|
+
structlog.stdlib.ExtraAdder(),
|
|
75
|
+
AddPrefixedEnvVars("LANGSMITH_LANGGRAPH_"), # injected by docker build
|
|
76
|
+
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
|
77
|
+
structlog.processors.StackInfoRenderer(),
|
|
78
|
+
structlog.processors.format_exc_info,
|
|
79
|
+
structlog.processors.UnicodeDecoder(),
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# configure logging, used by logging.json, applied by uvicorn
|
|
84
|
+
|
|
85
|
+
renderer = (
|
|
86
|
+
JSONRenderer() if LOG_JSON else structlog.dev.ConsoleRenderer(colors=LOG_COLOR)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class Formatter(structlog.stdlib.ProcessorFormatter):
|
|
91
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
92
|
+
if len(args) == 3:
|
|
93
|
+
fmt, datefmt, style = args
|
|
94
|
+
kwargs["fmt"] = fmt
|
|
95
|
+
kwargs["datefmt"] = datefmt
|
|
96
|
+
kwargs["style"] = style
|
|
97
|
+
else:
|
|
98
|
+
raise RuntimeError("Invalid number of arguments")
|
|
99
|
+
super().__init__(
|
|
100
|
+
processors=[
|
|
101
|
+
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
|
102
|
+
TapForMetadata(),
|
|
103
|
+
renderer,
|
|
104
|
+
],
|
|
105
|
+
foreign_pre_chain=shared_processors,
|
|
106
|
+
**kwargs,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# configure structlog
|
|
111
|
+
|
|
112
|
+
structlog.configure(
|
|
113
|
+
processors=[
|
|
114
|
+
structlog.stdlib.filter_by_level,
|
|
115
|
+
*shared_processors,
|
|
116
|
+
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
|
117
|
+
],
|
|
118
|
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
119
|
+
wrapper_class=structlog.stdlib.BoundLogger,
|
|
120
|
+
cache_logger_on_first_use=True,
|
|
121
|
+
)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
|
|
5
|
+
import langgraph.version
|
|
6
|
+
import orjson
|
|
7
|
+
import structlog
|
|
8
|
+
|
|
9
|
+
from langgraph_api.config import LANGGRAPH_CLOUD_LICENSE_KEY, LANGSMITH_API_KEY
|
|
10
|
+
from langgraph_api.http import http_request
|
|
11
|
+
from langgraph_license.validation import plus_features_enabled
|
|
12
|
+
|
|
13
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
INTERVAL = 300
|
|
16
|
+
REVISION = os.getenv("LANGSMITH_LANGGRAPH_API_REVISION")
|
|
17
|
+
VARIANT = os.getenv("LANGSMITH_LANGGRAPH_API_VARIANT")
|
|
18
|
+
if VARIANT == "cloud":
|
|
19
|
+
HOST = "saas"
|
|
20
|
+
elif os.getenv("LANGSMITH_HOST_PROJECT_ID"):
|
|
21
|
+
HOST = "byoc"
|
|
22
|
+
else:
|
|
23
|
+
HOST = "self-hosted"
|
|
24
|
+
PLAN = "enterprise" if plus_features_enabled() else "developer"
|
|
25
|
+
|
|
26
|
+
LOGS: list[dict] = []
|
|
27
|
+
RUN_COUNTER = 0
|
|
28
|
+
NODE_COUNTER = 0
|
|
29
|
+
FROM_TIMESTAMP = datetime.now(UTC).isoformat()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def incr_runs(*, incr: int = 1) -> None:
|
|
33
|
+
global RUN_COUNTER
|
|
34
|
+
RUN_COUNTER += incr
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def incr_nodes(_, *, incr: int = 1) -> None:
|
|
38
|
+
global NODE_COUNTER
|
|
39
|
+
NODE_COUNTER += incr
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def append_log(log: dict) -> None:
|
|
43
|
+
if not LANGGRAPH_CLOUD_LICENSE_KEY and not LANGSMITH_API_KEY:
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
global LOGS
|
|
47
|
+
LOGS.append(log)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def metadata_loop() -> None:
|
|
51
|
+
if not LANGGRAPH_CLOUD_LICENSE_KEY and not LANGSMITH_API_KEY:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
logger.info("Starting metadata loop")
|
|
55
|
+
|
|
56
|
+
global RUN_COUNTER, NODE_COUNTER, FROM_TIMESTAMP
|
|
57
|
+
while True:
|
|
58
|
+
# because we always read and write from coroutines in main thread
|
|
59
|
+
# we don't need a lock as long as there's no awaits in this block
|
|
60
|
+
from_timestamp = FROM_TIMESTAMP
|
|
61
|
+
to_timestamp = datetime.now(UTC).isoformat()
|
|
62
|
+
nodes = NODE_COUNTER
|
|
63
|
+
runs = RUN_COUNTER
|
|
64
|
+
logs = LOGS.copy()
|
|
65
|
+
LOGS.clear()
|
|
66
|
+
RUN_COUNTER = 0
|
|
67
|
+
NODE_COUNTER = 0
|
|
68
|
+
FROM_TIMESTAMP = to_timestamp
|
|
69
|
+
|
|
70
|
+
payload = {
|
|
71
|
+
"license_key": LANGGRAPH_CLOUD_LICENSE_KEY,
|
|
72
|
+
"api_key": LANGSMITH_API_KEY,
|
|
73
|
+
"from_timestamp": from_timestamp,
|
|
74
|
+
"to_timestamp": to_timestamp,
|
|
75
|
+
"tags": {
|
|
76
|
+
"langgraph.python.version": langgraph.version.__version__,
|
|
77
|
+
"langgraph.platform.revision": REVISION,
|
|
78
|
+
"langgraph.platform.variant": VARIANT,
|
|
79
|
+
"langgraph.platform.host": HOST,
|
|
80
|
+
"langgraph.platform.plan": PLAN,
|
|
81
|
+
},
|
|
82
|
+
"measures": {
|
|
83
|
+
"langgraph.platform.runs": runs,
|
|
84
|
+
"langgraph.platform.nodes": nodes,
|
|
85
|
+
},
|
|
86
|
+
"logs": logs,
|
|
87
|
+
}
|
|
88
|
+
try:
|
|
89
|
+
await http_request(
|
|
90
|
+
"POST",
|
|
91
|
+
"https://api.smith.langchain.com/v1/metadata/submit",
|
|
92
|
+
body=orjson.dumps(payload),
|
|
93
|
+
headers={"Content-Type": "application/json"},
|
|
94
|
+
)
|
|
95
|
+
except Exception as e:
|
|
96
|
+
# retry on next iteration
|
|
97
|
+
incr_runs(incr=runs)
|
|
98
|
+
incr_nodes("", incr=nodes)
|
|
99
|
+
FROM_TIMESTAMP = from_timestamp
|
|
100
|
+
logger.warning("Failed to submit metadata", exc_info=e)
|
|
101
|
+
await asyncio.sleep(INTERVAL)
|
|
File without changes
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import uuid
|
|
3
|
+
from collections.abc import Mapping, Sequence
|
|
4
|
+
from typing import Any, TypedDict
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
|
|
7
|
+
from langgraph.checkpoint.base.id import uuid6
|
|
8
|
+
from starlette.exceptions import HTTPException
|
|
9
|
+
|
|
10
|
+
from langgraph_api.graph import get_assistant_id
|
|
11
|
+
from langgraph_api.schema import (
|
|
12
|
+
All,
|
|
13
|
+
Config,
|
|
14
|
+
IfNotExists,
|
|
15
|
+
MetadataInput,
|
|
16
|
+
MultitaskStrategy,
|
|
17
|
+
OnCompletion,
|
|
18
|
+
Run,
|
|
19
|
+
RunCommand,
|
|
20
|
+
StreamMode,
|
|
21
|
+
)
|
|
22
|
+
from langgraph_api.utils import AsyncConnectionProto
|
|
23
|
+
from langgraph_storage.ops import Runs, logger
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RunCreateDict(TypedDict):
|
|
27
|
+
"""Payload for creating a run."""
|
|
28
|
+
|
|
29
|
+
assistant_id: str
|
|
30
|
+
"""Assistant ID to use for this run."""
|
|
31
|
+
checkpoint_id: str | None
|
|
32
|
+
"""Checkpoint ID to start from. Defaults to the latest checkpoint."""
|
|
33
|
+
input: Sequence[dict] | dict[str, Any] | None
|
|
34
|
+
"""Input to the run. Pass null to resume from the current state of the thread."""
|
|
35
|
+
command: RunCommand | None
|
|
36
|
+
"""One or more commands to update the graph's state and send messages to nodes."""
|
|
37
|
+
metadata: MetadataInput
|
|
38
|
+
"""Metadata for the run."""
|
|
39
|
+
config: Config | None
|
|
40
|
+
"""Additional configuration for the run."""
|
|
41
|
+
webhook: str | None
|
|
42
|
+
"""Webhook to call when the run is complete."""
|
|
43
|
+
|
|
44
|
+
interrupt_before: All | list[str] | None
|
|
45
|
+
"""Interrupt execution before entering these nodes."""
|
|
46
|
+
interrupt_after: All | list[str] | None
|
|
47
|
+
"""Interrupt execution after leaving these nodes."""
|
|
48
|
+
|
|
49
|
+
multitask_strategy: MultitaskStrategy
|
|
50
|
+
"""Strategy to handle concurrent runs on the same thread. Only relevant if
|
|
51
|
+
there is a pending/inflight run on the same thread. One of:
|
|
52
|
+
- "reject": Reject the new run.
|
|
53
|
+
- "interrupt": Interrupt the current run, keeping steps completed until now,
|
|
54
|
+
and start a new one.
|
|
55
|
+
- "rollback": Cancel and delete the existing run, rolling back the thread to
|
|
56
|
+
the state before it had started, then start the new run.
|
|
57
|
+
- "enqueue": Queue up the new run to start after the current run finishes.
|
|
58
|
+
"""
|
|
59
|
+
on_completion: OnCompletion
|
|
60
|
+
"""What to do when the run completes. One of:
|
|
61
|
+
- "keep": Keep the thread in the database.
|
|
62
|
+
- "delete": Delete the thread from the database.
|
|
63
|
+
"""
|
|
64
|
+
stream_mode: list[StreamMode] | StreamMode
|
|
65
|
+
"""One or more of "values", "messages", "updates" or "events".
|
|
66
|
+
- "values": Stream the thread state any time it changes.
|
|
67
|
+
- "messages": Stream chat messages from thread state and calls to chat models,
|
|
68
|
+
token-by-token where possible.
|
|
69
|
+
- "updates": Stream the state updates returned by each node.
|
|
70
|
+
- "events": Stream all events produced by sub-runs (eg. nodes, LLMs, etc.).
|
|
71
|
+
- "custom": Stream custom events produced by your nodes.
|
|
72
|
+
"""
|
|
73
|
+
stream_subgraphs: bool | None
|
|
74
|
+
"""Stream output from subgraphs. By default, streams only the top graph."""
|
|
75
|
+
feedback_keys: list[str] | None
|
|
76
|
+
"""Pass one or more feedback_keys if you want to request short-lived signed URLs
|
|
77
|
+
for submitting feedback to LangSmith with this key for this run."""
|
|
78
|
+
after_seconds: int | None
|
|
79
|
+
"""Start the run after this many seconds. Defaults to 0."""
|
|
80
|
+
if_not_exists: IfNotExists
|
|
81
|
+
"""Create the thread if it doesn't exist. If False, reply with 404."""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def ensure_ids(
|
|
85
|
+
assistant_id: str,
|
|
86
|
+
thread_id: str | None,
|
|
87
|
+
payload: RunCreateDict,
|
|
88
|
+
) -> tuple[uuid.UUID, uuid.UUID | None, uuid.UUID | None]:
|
|
89
|
+
try:
|
|
90
|
+
results = [UUID(assistant_id)]
|
|
91
|
+
except ValueError:
|
|
92
|
+
raise HTTPException(status_code=422, detail="Invalid assistant ID") from None
|
|
93
|
+
if thread_id:
|
|
94
|
+
try:
|
|
95
|
+
results.append(UUID(thread_id))
|
|
96
|
+
except ValueError:
|
|
97
|
+
raise HTTPException(status_code=422, detail="Invalid thread ID") from None
|
|
98
|
+
else:
|
|
99
|
+
results.append(None)
|
|
100
|
+
if payload.get("checkpoint_id"):
|
|
101
|
+
try:
|
|
102
|
+
results.append(UUID(payload.get("checkpoint_id")))
|
|
103
|
+
except ValueError:
|
|
104
|
+
raise HTTPException(
|
|
105
|
+
status_code=422, detail="Invalid checkpoint ID"
|
|
106
|
+
) from None
|
|
107
|
+
else:
|
|
108
|
+
results.append(None)
|
|
109
|
+
return tuple(results)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def assign_defaults(
|
|
113
|
+
payload: RunCreateDict,
|
|
114
|
+
):
|
|
115
|
+
if payload.get("stream_mode"):
|
|
116
|
+
stream_mode = (
|
|
117
|
+
payload["stream_mode"]
|
|
118
|
+
if isinstance(payload["stream_mode"], list)
|
|
119
|
+
else [payload["stream_mode"]]
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
stream_mode = ["values"]
|
|
123
|
+
multitask_strategy = payload.get("multitask_strategy") or "reject"
|
|
124
|
+
prevent_insert_if_inflight = multitask_strategy == "reject"
|
|
125
|
+
return stream_mode, multitask_strategy, prevent_insert_if_inflight
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def create_valid_run(
|
|
129
|
+
conn: AsyncConnectionProto,
|
|
130
|
+
thread_id: str | None,
|
|
131
|
+
payload: RunCreateDict,
|
|
132
|
+
user_id: str | None,
|
|
133
|
+
headers: Mapping[str, str],
|
|
134
|
+
barrier: asyncio.Barrier | None = None,
|
|
135
|
+
run_id: UUID | None = None,
|
|
136
|
+
) -> Run:
|
|
137
|
+
# get assistant_id
|
|
138
|
+
assistant_id = get_assistant_id(payload["assistant_id"])
|
|
139
|
+
|
|
140
|
+
# ensure UUID validity defaults
|
|
141
|
+
assistant_id, thread_id, checkpoint_id = ensure_ids(
|
|
142
|
+
assistant_id, thread_id, payload
|
|
143
|
+
)
|
|
144
|
+
stream_mode, multitask_strategy, prevent_insert_if_inflight = assign_defaults(
|
|
145
|
+
payload
|
|
146
|
+
)
|
|
147
|
+
temporary = thread_id is None and payload.get("on_completion", "delete") == "delete"
|
|
148
|
+
run_id = run_id or uuid6()
|
|
149
|
+
|
|
150
|
+
# assign custom headers and checkpoint to config
|
|
151
|
+
config = payload.get("config") or {}
|
|
152
|
+
config.setdefault("configurable", {})
|
|
153
|
+
if checkpoint_id:
|
|
154
|
+
config["configurable"]["checkpoint_id"] = str(checkpoint_id)
|
|
155
|
+
if checkpoint := payload.get("checkpoint"):
|
|
156
|
+
config["configurable"].update(checkpoint)
|
|
157
|
+
for key, value in headers.items():
|
|
158
|
+
if key.startswith("x-"):
|
|
159
|
+
if key in (
|
|
160
|
+
"x-api-key",
|
|
161
|
+
"x-tenant-id",
|
|
162
|
+
"x-service-key",
|
|
163
|
+
):
|
|
164
|
+
continue
|
|
165
|
+
config["configurable"][key] = value
|
|
166
|
+
|
|
167
|
+
# try to insert
|
|
168
|
+
run_coro = Runs.put(
|
|
169
|
+
conn,
|
|
170
|
+
assistant_id,
|
|
171
|
+
{
|
|
172
|
+
"input": payload.get("input"),
|
|
173
|
+
"command": payload.get("command"),
|
|
174
|
+
"config": config,
|
|
175
|
+
"stream_mode": stream_mode,
|
|
176
|
+
"interrupt_before": payload.get("interrupt_before"),
|
|
177
|
+
"interrupt_after": payload.get("interrupt_after"),
|
|
178
|
+
"webhook": payload.get("webhook"),
|
|
179
|
+
"feedback_keys": payload.get("feedback_keys"),
|
|
180
|
+
"temporary": temporary,
|
|
181
|
+
"subgraphs": payload.get("stream_subgraphs", False),
|
|
182
|
+
},
|
|
183
|
+
metadata=payload.get("metadata"),
|
|
184
|
+
status="pending",
|
|
185
|
+
user_id=user_id,
|
|
186
|
+
thread_id=thread_id,
|
|
187
|
+
run_id=run_id,
|
|
188
|
+
multitask_strategy=multitask_strategy,
|
|
189
|
+
prevent_insert_if_inflight=prevent_insert_if_inflight,
|
|
190
|
+
after_seconds=payload.get("after_seconds", 0),
|
|
191
|
+
if_not_exists=payload.get("if_not_exists", "reject"),
|
|
192
|
+
)
|
|
193
|
+
run_ = await run_coro
|
|
194
|
+
|
|
195
|
+
if barrier:
|
|
196
|
+
await barrier.wait()
|
|
197
|
+
|
|
198
|
+
# abort if thread, assistant, etc not found
|
|
199
|
+
try:
|
|
200
|
+
first = await anext(run_)
|
|
201
|
+
except StopAsyncIteration:
|
|
202
|
+
raise HTTPException(
|
|
203
|
+
status_code=404, detail="Thread or assistant not found."
|
|
204
|
+
) from None
|
|
205
|
+
|
|
206
|
+
# handle multitask strategy
|
|
207
|
+
inflight_runs = [run async for run in run_]
|
|
208
|
+
if first["run_id"] == run_id:
|
|
209
|
+
logger.info("Created run", run_id=str(run_id), thread_id=str(thread_id))
|
|
210
|
+
# inserted, proceed
|
|
211
|
+
if multitask_strategy in ("interrupt", "rollback") and inflight_runs:
|
|
212
|
+
try:
|
|
213
|
+
await Runs.cancel(
|
|
214
|
+
conn,
|
|
215
|
+
[run["run_id"] for run in inflight_runs],
|
|
216
|
+
thread_id=thread_id,
|
|
217
|
+
action=multitask_strategy,
|
|
218
|
+
)
|
|
219
|
+
except HTTPException:
|
|
220
|
+
# if we can't find the inflight runs again, we can proceeed
|
|
221
|
+
pass
|
|
222
|
+
return first
|
|
223
|
+
elif multitask_strategy == "reject":
|
|
224
|
+
raise HTTPException(
|
|
225
|
+
status_code=409,
|
|
226
|
+
detail="Thread is already running a task. Wait for it to finish or choose a different multitask strategy.",
|
|
227
|
+
)
|
|
228
|
+
else:
|
|
229
|
+
raise NotImplementedError
|
langgraph_api/patch.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from starlette.responses import Response, StreamingResponse
|
|
4
|
+
from starlette.types import Send
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Patch Response.render and StreamingResponse.stream_response
|
|
8
|
+
to recognize bytearrays and memoryviews as bytes-like objects.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def Response_render(self, content: Any) -> bytes:
|
|
13
|
+
if content is None:
|
|
14
|
+
return b""
|
|
15
|
+
if isinstance(content, (bytes, bytearray, memoryview)): # noqa: UP038
|
|
16
|
+
return content
|
|
17
|
+
return content.encode(self.charset) # type: ignore
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def StreamingResponse_stream_response(self, send: Send) -> None:
|
|
21
|
+
await send(
|
|
22
|
+
{
|
|
23
|
+
"type": "http.response.start",
|
|
24
|
+
"status": self.status_code,
|
|
25
|
+
"headers": self.raw_headers,
|
|
26
|
+
}
|
|
27
|
+
)
|
|
28
|
+
async for chunk in self.body_iterator:
|
|
29
|
+
if not isinstance(chunk, (bytes, bytearray, memoryview)): # noqa: UP038
|
|
30
|
+
chunk = chunk.encode(self.charset)
|
|
31
|
+
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
|
32
|
+
|
|
33
|
+
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# patch StreamingResponse.stream_response
|
|
37
|
+
|
|
38
|
+
StreamingResponse.stream_response = StreamingResponse_stream_response
|
|
39
|
+
|
|
40
|
+
# patch Response.render
|
|
41
|
+
|
|
42
|
+
Response.render = Response_render
|