dynamic-subgraphs 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +1 -0
- app/api/__init__.py +6 -0
- app/api/__main__.py +18 -0
- app/api/app.py +32 -0
- app/api/deps.py +146 -0
- app/api/errors.py +67 -0
- app/api/jobs.py +136 -0
- app/api/routers/__init__.py +1 -0
- app/api/routers/chains.py +170 -0
- app/api/routers/health.py +11 -0
- app/api/routers/registry.py +38 -0
- app/api/routers/runs.py +300 -0
- app/api/run_config_store.py +51 -0
- app/api/schemas.py +86 -0
- app/api/serialize.py +76 -0
- app/api/settings.py +53 -0
- app/assembly.py +256 -0
- app/compiler/__init__.py +6 -0
- app/compiler/build.py +168 -0
- app/compiler/errors.py +5 -0
- app/main.py +202 -0
- app/models/__init__.py +29 -0
- app/models/graph_spec.py +51 -0
- app/models/node_kinds.py +13 -0
- app/models/run_state.py +44 -0
- app/models/trace.py +31 -0
- app/py.typed +1 -0
- app/recording/__init__.py +27 -0
- app/recording/mermaid.py +27 -0
- app/recording/recorder.py +646 -0
- app/registry/__init__.py +22 -0
- app/registry/allowlists.py +30 -0
- app/registry/definitions.py +91 -0
- app/registry/errors.py +20 -0
- app/registry/params.py +88 -0
- app/registry/registry.py +214 -0
- app/registry/validator.py +348 -0
- app/runtime/__init__.py +133 -0
- app/runtime/artifacts.py +176 -0
- app/runtime/branch.py +103 -0
- app/runtime/chat_models.py +39 -0
- app/runtime/executor.py +304 -0
- app/runtime/llm_runner.py +152 -0
- app/runtime/model_providers.py +307 -0
- app/runtime/parallel_map.py +342 -0
- app/runtime/runners.py +218 -0
- app/runtime/state.py +40 -0
- app/runtime/subagents.py +172 -0
- app/runtime/subgraph.py +238 -0
- app/runtime/tools.py +583 -0
- app/runtime/wait_for_event.py +88 -0
- app/runtime/wrappers.py +162 -0
- app/supervisor/__init__.py +51 -0
- app/supervisor/graph.py +235 -0
- app/supervisor/iteration.py +525 -0
- app/supervisor/llm_planner.py +340 -0
- app/supervisor/planner.py +26 -0
- app/supervisor/state.py +45 -0
- app/supervisor/supervisor.py +510 -0
- dynamic_subgraphs/__init__.py +65 -0
- dynamic_subgraphs/engine.py +525 -0
- dynamic_subgraphs/py.typed +1 -0
- dynamic_subgraphs/recording.py +169 -0
- dynamic_subgraphs/types.py +63 -0
- dynamic_subgraphs-0.1.0.dist-info/METADATA +335 -0
- dynamic_subgraphs-0.1.0.dist-info/RECORD +69 -0
- dynamic_subgraphs-0.1.0.dist-info/WHEEL +4 -0
- dynamic_subgraphs-0.1.0.dist-info/licenses/LICENSE +201 -0
- dynamic_subgraphs-0.1.0.dist-info/licenses/NOTICE +24 -0
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Dynamic Subgraphs — governed runtime for transient workflow graphs."""
|
app/api/__init__.py
ADDED
app/api/__main__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# app/api/__main__.py
|
|
2
|
+
"""`python -m app.api` -> run uvicorn."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
import uvicorn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main() -> None:
|
|
12
|
+
host = os.environ.get("DS_HOST", "127.0.0.1")
|
|
13
|
+
port = int(os.environ.get("DS_PORT", "8000"))
|
|
14
|
+
uvicorn.run("app.api:create_app", host=host, port=port, factory=True)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
if __name__ == "__main__":
|
|
18
|
+
main()
|
app/api/app.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# app/api/app.py
|
|
2
|
+
"""create_app() — wire settings, context, routers, error handlers."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
from fastapi import FastAPI
|
|
8
|
+
|
|
9
|
+
from app.api.deps import AppContext
|
|
10
|
+
from app.api.errors import install_error_handlers
|
|
11
|
+
from app.api.routers import chains, health, registry, runs
|
|
12
|
+
from app.api.settings import ApiSettings
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_app(settings: ApiSettings | None = None) -> FastAPI:
|
|
16
|
+
load_dotenv()
|
|
17
|
+
settings = settings or ApiSettings.from_env()
|
|
18
|
+
|
|
19
|
+
app = FastAPI(title="Dynamic Subgraphs API", version="1.0.0")
|
|
20
|
+
app.state.context = AppContext.build(settings)
|
|
21
|
+
|
|
22
|
+
install_error_handlers(app)
|
|
23
|
+
app.include_router(health.router)
|
|
24
|
+
app.include_router(registry.router)
|
|
25
|
+
app.include_router(runs.router)
|
|
26
|
+
app.include_router(chains.router)
|
|
27
|
+
|
|
28
|
+
@app.on_event("shutdown")
|
|
29
|
+
def _shutdown() -> None:
|
|
30
|
+
app.state.context.jobs.shutdown()
|
|
31
|
+
|
|
32
|
+
return app
|
app/api/deps.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# app/api/deps.py
|
|
2
|
+
"""App-wide context + per-request config resolution + auth."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
from fastapi import Request
|
|
9
|
+
|
|
10
|
+
from app.api.errors import BadRequest, ServiceUnavailable, Unauthorized
|
|
11
|
+
from app.api.jobs import JobStore
|
|
12
|
+
from app.api.settings import ApiSettings
|
|
13
|
+
from app.assembly import RunConfig, build_supervisor
|
|
14
|
+
from app.recording import FileRecorder
|
|
15
|
+
from app.runtime import (
|
|
16
|
+
MissingModelProviderCredential,
|
|
17
|
+
ProviderRegistry,
|
|
18
|
+
default_model_providers,
|
|
19
|
+
)
|
|
20
|
+
from app.supervisor import (
|
|
21
|
+
IterationDecider,
|
|
22
|
+
StatusIterationDecider,
|
|
23
|
+
Supervisor,
|
|
24
|
+
build_provider_iteration_decider,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class AppContext:
|
|
30
|
+
settings: ApiSettings
|
|
31
|
+
recorder: FileRecorder
|
|
32
|
+
jobs: JobStore
|
|
33
|
+
checkpointer: object
|
|
34
|
+
model_providers: ProviderRegistry
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def build(cls, settings: ApiSettings) -> AppContext:
|
|
38
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
39
|
+
|
|
40
|
+
return cls(
|
|
41
|
+
settings=settings,
|
|
42
|
+
recorder=FileRecorder(root_dir=settings.runs_dir, overwrite=True),
|
|
43
|
+
jobs=JobStore(max_workers=4),
|
|
44
|
+
checkpointer=MemorySaver(),
|
|
45
|
+
model_providers=default_model_providers(),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def supervisor_for(self, config: RunConfig) -> Supervisor:
|
|
49
|
+
return build_supervisor(
|
|
50
|
+
config,
|
|
51
|
+
recorder=self.recorder,
|
|
52
|
+
checkpointer=self.checkpointer,
|
|
53
|
+
model_providers=self.model_providers,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_context(request: Request) -> AppContext:
|
|
58
|
+
return request.app.state.context
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def resolve_run_config(
|
|
62
|
+
ctx: AppContext,
|
|
63
|
+
*,
|
|
64
|
+
planner: str | None,
|
|
65
|
+
provider: str | None = None,
|
|
66
|
+
model: str | None,
|
|
67
|
+
) -> RunConfig:
|
|
68
|
+
chosen_planner = planner or ctx.settings.planner
|
|
69
|
+
chosen_provider = (provider or ctx.settings.provider).strip().lower()
|
|
70
|
+
if chosen_planner == "openai":
|
|
71
|
+
chosen_planner = "llm"
|
|
72
|
+
chosen_provider = "openai"
|
|
73
|
+
chosen_model = model or ctx.settings.model
|
|
74
|
+
|
|
75
|
+
if not ctx.settings.is_model_allowed(chosen_model, provider=chosen_provider):
|
|
76
|
+
raise BadRequest(
|
|
77
|
+
f"Model {chosen_provider}:{chosen_model} is not in the allowlist "
|
|
78
|
+
f"{list(ctx.settings.model_allowlist)}"
|
|
79
|
+
)
|
|
80
|
+
config = RunConfig(
|
|
81
|
+
planner=chosen_planner, # type: ignore[arg-type]
|
|
82
|
+
provider=chosen_provider,
|
|
83
|
+
model=chosen_model,
|
|
84
|
+
strict_runners=chosen_planner == "llm",
|
|
85
|
+
)
|
|
86
|
+
if config.planner == "llm":
|
|
87
|
+
try:
|
|
88
|
+
for provider_name in config.providers_in_use():
|
|
89
|
+
ctx.model_providers.require_credentials(provider_name)
|
|
90
|
+
except KeyError as exc:
|
|
91
|
+
raise BadRequest(str(exc)) from exc
|
|
92
|
+
except MissingModelProviderCredential as exc:
|
|
93
|
+
raise ServiceUnavailable(str(exc)) from exc
|
|
94
|
+
return config
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def resolve_chain_decider(
|
|
98
|
+
ctx: AppContext,
|
|
99
|
+
*,
|
|
100
|
+
config: RunConfig,
|
|
101
|
+
decider: str,
|
|
102
|
+
success_criteria: str | None,
|
|
103
|
+
judge_failed_runs: bool,
|
|
104
|
+
) -> IterationDecider:
|
|
105
|
+
"""Resolve the chain-level orchestration judge for `/chains`.
|
|
106
|
+
|
|
107
|
+
The status decider is token-free and remains the default. The LLM decider is
|
|
108
|
+
explicit because it makes an additional model call after each successful
|
|
109
|
+
iteration to decide whether to stop, replan, ask the user, or fail.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
if decider == "status":
|
|
113
|
+
return StatusIterationDecider()
|
|
114
|
+
|
|
115
|
+
if decider == "llm":
|
|
116
|
+
if not ctx.settings.is_model_allowed(config.model, provider=config.provider):
|
|
117
|
+
raise BadRequest(
|
|
118
|
+
f"Model {config.provider}:{config.model} is not in the allowlist "
|
|
119
|
+
f"{list(ctx.settings.model_allowlist)}"
|
|
120
|
+
)
|
|
121
|
+
try:
|
|
122
|
+
ctx.model_providers.require_credentials(config.provider)
|
|
123
|
+
model_provider = ctx.model_providers.get(config.provider)
|
|
124
|
+
except KeyError as exc:
|
|
125
|
+
raise BadRequest(str(exc)) from exc
|
|
126
|
+
except MissingModelProviderCredential as exc:
|
|
127
|
+
raise ServiceUnavailable(str(exc)) from exc
|
|
128
|
+
return build_provider_iteration_decider(
|
|
129
|
+
model_provider,
|
|
130
|
+
config.model_ref,
|
|
131
|
+
success_criteria=success_criteria,
|
|
132
|
+
judge_failed_runs=judge_failed_runs,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
raise BadRequest(f"Unknown chain decider {decider!r}")
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def require_auth(request: Request) -> None:
|
|
139
|
+
"""Guard for POST endpoints. No-op when DS_API_KEY is unset."""
|
|
140
|
+
ctx: AppContext = request.app.state.context
|
|
141
|
+
expected = ctx.settings.api_key
|
|
142
|
+
if not expected:
|
|
143
|
+
return
|
|
144
|
+
header = request.headers.get("authorization", "")
|
|
145
|
+
if header != f"Bearer {expected}":
|
|
146
|
+
raise Unauthorized("Missing or invalid bearer token")
|
app/api/errors.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# app/api/errors.py
|
|
2
|
+
"""API exceptions and JSON error-envelope handlers."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from fastapi import FastAPI, Request
|
|
7
|
+
from fastapi.exceptions import RequestValidationError
|
|
8
|
+
from fastapi.responses import JSONResponse
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ApiError(Exception):
|
|
12
|
+
status_code = 500
|
|
13
|
+
error_type = "ApiError"
|
|
14
|
+
|
|
15
|
+
def __init__(self, message: str, *, detail: object = None) -> None:
|
|
16
|
+
super().__init__(message)
|
|
17
|
+
self.message = message
|
|
18
|
+
self.detail = detail
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class NotFound(ApiError):
|
|
22
|
+
status_code = 404
|
|
23
|
+
error_type = "NotFound"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Conflict(ApiError):
|
|
27
|
+
status_code = 409
|
|
28
|
+
error_type = "Conflict"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Unauthorized(ApiError):
|
|
32
|
+
status_code = 401
|
|
33
|
+
error_type = "Unauthorized"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BadRequest(ApiError):
|
|
37
|
+
status_code = 400
|
|
38
|
+
error_type = "BadRequest"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ServiceUnavailable(ApiError):
|
|
42
|
+
status_code = 503
|
|
43
|
+
error_type = "ServiceUnavailable"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _envelope(error_type: str, message: str, detail: object = None) -> dict:
|
|
47
|
+
return {"error": {"type": error_type, "message": message, "detail": detail}}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def install_error_handlers(app: FastAPI) -> None:
|
|
51
|
+
@app.exception_handler(ApiError)
|
|
52
|
+
async def _handle_api_error(_: Request, exc: ApiError) -> JSONResponse:
|
|
53
|
+
return JSONResponse(
|
|
54
|
+
status_code=exc.status_code,
|
|
55
|
+
content=_envelope(exc.error_type, exc.message, exc.detail),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@app.exception_handler(RequestValidationError)
|
|
59
|
+
async def _handle_validation(
|
|
60
|
+
_: Request, exc: RequestValidationError
|
|
61
|
+
) -> JSONResponse:
|
|
62
|
+
return JSONResponse(
|
|
63
|
+
status_code=422,
|
|
64
|
+
content=_envelope(
|
|
65
|
+
"ValidationError", "Request validation failed", exc.errors()
|
|
66
|
+
),
|
|
67
|
+
)
|
app/api/jobs.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# app/api/jobs.py
|
|
2
|
+
"""In-process job store: background execution + subscribe bus.
|
|
3
|
+
|
|
4
|
+
Every run/chain becomes a Job executed on a thread pool. The request handler
|
|
5
|
+
decides how long to wait (sync/async/auto). The subscribe bus feeds SSE today;
|
|
6
|
+
it is the seam where per-node events will publish later.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import threading
|
|
12
|
+
from collections.abc import Callable
|
|
13
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
14
|
+
from datetime import UTC, datetime
|
|
15
|
+
from enum import StrEnum
|
|
16
|
+
from queue import Queue
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
_TERMINAL: frozenset[str] = frozenset({"ok", "failed", "paused"})
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class JobState(StrEnum):
|
|
23
|
+
QUEUED = "queued"
|
|
24
|
+
RUNNING = "running"
|
|
25
|
+
OK = "ok"
|
|
26
|
+
FAILED = "failed"
|
|
27
|
+
PAUSED = "paused"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class JobExistsError(Exception):
|
|
31
|
+
def __init__(self, run_id: str) -> None:
|
|
32
|
+
super().__init__(f"Job already exists: {run_id!r}")
|
|
33
|
+
self.run_id = run_id
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _now() -> datetime:
|
|
37
|
+
return datetime.now(UTC)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Job:
|
|
41
|
+
"""Mutable, thread-safe handle for one background run/chain."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, run_id: str, kind: str) -> None:
|
|
44
|
+
self.run_id = run_id
|
|
45
|
+
self.kind = kind
|
|
46
|
+
self.state: JobState = JobState.QUEUED
|
|
47
|
+
self.submitted_at: datetime = _now()
|
|
48
|
+
self.started_at: datetime | None = None
|
|
49
|
+
self.finished_at: datetime | None = None
|
|
50
|
+
self.result: Any = None
|
|
51
|
+
self.error: str | None = None
|
|
52
|
+
self.budget_wall_seconds: int | None = None
|
|
53
|
+
self._lock = threading.Lock()
|
|
54
|
+
self._done = threading.Event()
|
|
55
|
+
self._subscribers: list[Queue[dict[str, Any]]] = []
|
|
56
|
+
|
|
57
|
+
def _publish(self, msg: dict[str, Any]) -> None:
|
|
58
|
+
for q in self._subscribers:
|
|
59
|
+
q.put(msg)
|
|
60
|
+
|
|
61
|
+
def set_state(self, state: JobState) -> None:
|
|
62
|
+
with self._lock:
|
|
63
|
+
self.state = state
|
|
64
|
+
if state == JobState.RUNNING and self.started_at is None:
|
|
65
|
+
self.started_at = _now()
|
|
66
|
+
self._publish({"type": "status", "state": state.value})
|
|
67
|
+
|
|
68
|
+
def complete(self, *, result: Any, state: JobState) -> None:
|
|
69
|
+
with self._lock:
|
|
70
|
+
self.result = result
|
|
71
|
+
self.state = state
|
|
72
|
+
self.finished_at = _now()
|
|
73
|
+
self._publish({"type": "status", "state": state.value})
|
|
74
|
+
self._publish({"type": "__end__"})
|
|
75
|
+
self._subscribers.clear()
|
|
76
|
+
self._done.set()
|
|
77
|
+
|
|
78
|
+
def fail(self, message: str) -> None:
|
|
79
|
+
with self._lock:
|
|
80
|
+
self.error = message
|
|
81
|
+
self.state = JobState.FAILED
|
|
82
|
+
self.finished_at = _now()
|
|
83
|
+
self._publish({"type": "status", "state": JobState.FAILED.value})
|
|
84
|
+
self._publish({"type": "__end__"})
|
|
85
|
+
self._subscribers.clear()
|
|
86
|
+
self._done.set()
|
|
87
|
+
|
|
88
|
+
def wait(self, timeout: float | None = None) -> bool:
|
|
89
|
+
return self._done.wait(timeout=timeout)
|
|
90
|
+
|
|
91
|
+
def is_terminal(self) -> bool:
|
|
92
|
+
return self.state.value in _TERMINAL
|
|
93
|
+
|
|
94
|
+
def subscribe(self) -> Queue[dict[str, Any]]:
|
|
95
|
+
q: Queue[dict[str, Any]] = Queue()
|
|
96
|
+
with self._lock:
|
|
97
|
+
q.put({"type": "status", "state": self.state.value})
|
|
98
|
+
if self.is_terminal():
|
|
99
|
+
q.put({"type": "__end__"})
|
|
100
|
+
else:
|
|
101
|
+
self._subscribers.append(q)
|
|
102
|
+
return q
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class JobStore:
|
|
106
|
+
def __init__(self, max_workers: int = 4) -> None:
|
|
107
|
+
self._jobs: dict[str, Job] = {}
|
|
108
|
+
self._lock = threading.Lock()
|
|
109
|
+
self._executor = ThreadPoolExecutor(
|
|
110
|
+
max_workers=max_workers, thread_name_prefix="ds-job"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def create(self, run_id: str, *, kind: str) -> Job:
|
|
114
|
+
with self._lock:
|
|
115
|
+
if run_id in self._jobs:
|
|
116
|
+
raise JobExistsError(run_id)
|
|
117
|
+
job = Job(run_id=run_id, kind=kind)
|
|
118
|
+
self._jobs[run_id] = job
|
|
119
|
+
return job
|
|
120
|
+
|
|
121
|
+
def submit(self, job: Job, fn: Callable[[Job], None]) -> None:
|
|
122
|
+
self._executor.submit(self._wrap, job, fn)
|
|
123
|
+
|
|
124
|
+
@staticmethod
|
|
125
|
+
def _wrap(job: Job, fn: Callable[[Job], None]) -> None:
|
|
126
|
+
try:
|
|
127
|
+
fn(job)
|
|
128
|
+
except Exception as exc:
|
|
129
|
+
job.fail(f"{type(exc).__name__}: {exc}")
|
|
130
|
+
|
|
131
|
+
def get(self, run_id: str) -> Job | None:
|
|
132
|
+
with self._lock:
|
|
133
|
+
return self._jobs.get(run_id)
|
|
134
|
+
|
|
135
|
+
def shutdown(self) -> None:
|
|
136
|
+
self._executor.shutdown(wait=False)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# app/api/routers/__init__.py
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# app/api/routers/chains.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, Depends, Request, Response
|
|
7
|
+
from fastapi.responses import JSONResponse
|
|
8
|
+
|
|
9
|
+
from app.api.deps import (
|
|
10
|
+
AppContext,
|
|
11
|
+
get_context,
|
|
12
|
+
require_auth,
|
|
13
|
+
resolve_chain_decider,
|
|
14
|
+
resolve_run_config,
|
|
15
|
+
)
|
|
16
|
+
from app.api.errors import Conflict, NotFound
|
|
17
|
+
from app.api.jobs import Job, JobState
|
|
18
|
+
from app.api.schemas import ChainRequest
|
|
19
|
+
from app.recording.recorder import _validate_run_id
|
|
20
|
+
|
|
21
|
+
router = APIRouter(tags=["chains"])
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _job_state_for_chain(status: str) -> JobState:
|
|
25
|
+
"""Map an iterative-chain terminal status onto a job state.
|
|
26
|
+
|
|
27
|
+
`ask_user` is a *paused* outcome — the LLM judge needs the human before the
|
|
28
|
+
chain can continue — not a failure. Mirror how /runs maps a paused run, so a
|
|
29
|
+
legitimate clarification request isn't recorded as FAILED.
|
|
30
|
+
"""
|
|
31
|
+
if status in {"ok", "stopped", "max_iterations"}:
|
|
32
|
+
return JobState.OK
|
|
33
|
+
if status == "ask_user":
|
|
34
|
+
return JobState.PAUSED
|
|
35
|
+
return JobState.FAILED
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _decision_payload(decision: Any | None) -> dict[str, Any] | None:
|
|
39
|
+
if decision is None:
|
|
40
|
+
return None
|
|
41
|
+
return {
|
|
42
|
+
"action": decision.action,
|
|
43
|
+
"reason": decision.reason,
|
|
44
|
+
"success_criteria_met": decision.success_criteria_met,
|
|
45
|
+
"gaps": list(decision.gaps),
|
|
46
|
+
"next_prompt": decision.next_prompt,
|
|
47
|
+
"question_to_user": decision.question_to_user,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _chain_payload(result: Any) -> dict[str, Any]:
|
|
52
|
+
return {
|
|
53
|
+
"chain_id": result.chain_id,
|
|
54
|
+
"status": result.status,
|
|
55
|
+
"response": result.response,
|
|
56
|
+
"steps": [
|
|
57
|
+
{
|
|
58
|
+
"iteration": s.iteration,
|
|
59
|
+
"run_id": s.run_id,
|
|
60
|
+
"status": s.result.status,
|
|
61
|
+
"decision": s.decision.action,
|
|
62
|
+
"decision_detail": _decision_payload(s.decision),
|
|
63
|
+
"reason": s.decision.reason,
|
|
64
|
+
}
|
|
65
|
+
for s in result.steps
|
|
66
|
+
],
|
|
67
|
+
"final_decision": _decision_payload(result.final_decision),
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _make_chain_worker(
|
|
72
|
+
ctx: AppContext,
|
|
73
|
+
config,
|
|
74
|
+
prompt: str,
|
|
75
|
+
run_id: str,
|
|
76
|
+
max_iter: int,
|
|
77
|
+
decider,
|
|
78
|
+
):
|
|
79
|
+
supervisor = ctx.supervisor_for(config)
|
|
80
|
+
|
|
81
|
+
def work(job: Job) -> None:
|
|
82
|
+
job.set_state(JobState.RUNNING)
|
|
83
|
+
result = supervisor.run_iteratively(
|
|
84
|
+
prompt,
|
|
85
|
+
run_id=run_id,
|
|
86
|
+
max_iterations=max_iter,
|
|
87
|
+
decider=decider,
|
|
88
|
+
)
|
|
89
|
+
job.complete(result=result, state=_job_state_for_chain(result.status))
|
|
90
|
+
|
|
91
|
+
return work
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@router.post("/chains")
|
|
95
|
+
def create_chain(
|
|
96
|
+
request: Request,
|
|
97
|
+
body: ChainRequest,
|
|
98
|
+
_: None = Depends(require_auth),
|
|
99
|
+
) -> Response:
|
|
100
|
+
ctx = get_context(request)
|
|
101
|
+
config = resolve_run_config(
|
|
102
|
+
ctx,
|
|
103
|
+
planner=body.planner,
|
|
104
|
+
provider=body.provider,
|
|
105
|
+
model=body.model,
|
|
106
|
+
)
|
|
107
|
+
decider = resolve_chain_decider(
|
|
108
|
+
ctx,
|
|
109
|
+
config=config,
|
|
110
|
+
decider=body.decider,
|
|
111
|
+
success_criteria=body.success_criteria,
|
|
112
|
+
judge_failed_runs=body.judge_failed_runs,
|
|
113
|
+
)
|
|
114
|
+
run_id = body.run_id or f"chain-{__import__('uuid').uuid4().hex[:12]}"
|
|
115
|
+
_validate_run_id(run_id)
|
|
116
|
+
if ctx.jobs.get(run_id) is not None or ctx.recorder.exists(run_id):
|
|
117
|
+
raise Conflict(f"chain id already exists: {run_id!r}")
|
|
118
|
+
|
|
119
|
+
job = ctx.jobs.create(run_id, kind="chain")
|
|
120
|
+
ctx.jobs.submit(
|
|
121
|
+
job,
|
|
122
|
+
_make_chain_worker(
|
|
123
|
+
ctx,
|
|
124
|
+
config,
|
|
125
|
+
body.prompt,
|
|
126
|
+
run_id,
|
|
127
|
+
body.max_iterations,
|
|
128
|
+
decider,
|
|
129
|
+
),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if body.mode == "async":
|
|
133
|
+
return JSONResponse(
|
|
134
|
+
status_code=202,
|
|
135
|
+
content={
|
|
136
|
+
"chain_id": run_id,
|
|
137
|
+
"status": job.state.value,
|
|
138
|
+
"links": {"self": f"/chains/{run_id}"},
|
|
139
|
+
},
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
timeout = (
|
|
143
|
+
ctx.settings.max_sync_seconds
|
|
144
|
+
if body.mode == "sync"
|
|
145
|
+
else ctx.settings.auto_sync_seconds
|
|
146
|
+
)
|
|
147
|
+
finished = job.wait(timeout=timeout)
|
|
148
|
+
if finished and job.result is not None:
|
|
149
|
+
return JSONResponse(status_code=200, content=_chain_payload(job.result))
|
|
150
|
+
return JSONResponse(
|
|
151
|
+
status_code=202,
|
|
152
|
+
content={
|
|
153
|
+
"chain_id": run_id,
|
|
154
|
+
"status": job.state.value,
|
|
155
|
+
"links": {"self": f"/chains/{run_id}"},
|
|
156
|
+
},
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@router.get("/chains/{chain_id}")
|
|
161
|
+
def get_chain(request: Request, chain_id: str) -> dict[str, Any]:
|
|
162
|
+
ctx = get_context(request)
|
|
163
|
+
_validate_run_id(chain_id)
|
|
164
|
+
job = ctx.jobs.get(chain_id)
|
|
165
|
+
if job is not None and job.result is not None:
|
|
166
|
+
return _chain_payload(job.result)
|
|
167
|
+
try:
|
|
168
|
+
return ctx.recorder.load_chain(chain_id)
|
|
169
|
+
except FileNotFoundError as exc:
|
|
170
|
+
raise NotFound(f"No chain {chain_id!r}") from exc
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# app/api/routers/registry.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter
|
|
7
|
+
|
|
8
|
+
from app.registry import (
|
|
9
|
+
DEFAULT_SUBAGENTS,
|
|
10
|
+
DEFAULT_TOOLS,
|
|
11
|
+
FORBIDDEN_KINDS,
|
|
12
|
+
default_kind_definitions,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
router = APIRouter(tags=["registry"])
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@router.get("/registry")
|
|
19
|
+
def get_registry() -> dict[str, Any]:
|
|
20
|
+
kinds: list[dict[str, Any]] = []
|
|
21
|
+
for kind, definition in default_kind_definitions().items():
|
|
22
|
+
kinds.append(
|
|
23
|
+
{
|
|
24
|
+
"kind": kind.value,
|
|
25
|
+
"description": definition.description,
|
|
26
|
+
"counts_as_llm_call": definition.counts_as_llm_call,
|
|
27
|
+
"has_side_effects": definition.has_side_effects,
|
|
28
|
+
"requires_tool_allowlist": definition.requires_tool_allowlist,
|
|
29
|
+
"requires_subagent_allowlist": definition.requires_subagent_allowlist,
|
|
30
|
+
"param_schema": definition.param_model.model_json_schema(),
|
|
31
|
+
}
|
|
32
|
+
)
|
|
33
|
+
return {
|
|
34
|
+
"node_kinds": kinds,
|
|
35
|
+
"tools": sorted(DEFAULT_TOOLS),
|
|
36
|
+
"subagents": sorted(DEFAULT_SUBAGENTS),
|
|
37
|
+
"forbidden_kinds": sorted(FORBIDDEN_KINDS),
|
|
38
|
+
}
|