promptuna-server 1.24.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- promptuna_server-1.24.0/PKG-INFO +8 -0
- promptuna_server-1.24.0/pyproject.toml +17 -0
- promptuna_server-1.24.0/src/promptuna_server/__init__.py +1 -0
- promptuna_server-1.24.0/src/promptuna_server/jobs.py +182 -0
- promptuna_server-1.24.0/src/promptuna_server/main.py +206 -0
- promptuna_server-1.24.0/src/promptuna_server/schemas.py +50 -0
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: promptuna-server
|
|
3
|
+
Version: 1.24.0
|
|
4
|
+
Summary: FastAPI server for streaming promptuna run / evaluate / optimize jobs
|
|
5
|
+
Requires-Dist: fastapi>=0.115.0
|
|
6
|
+
Requires-Dist: promptuna==1.24.0
|
|
7
|
+
Requires-Dist: uvicorn[standard]>=0.32.0
|
|
8
|
+
Requires-Python: >=3.13
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "promptuna-server"
|
|
3
|
+
version = "1.24.0"
|
|
4
|
+
description = "FastAPI server for streaming promptuna run / evaluate / optimize jobs"
|
|
5
|
+
requires-python = ">=3.13"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"fastapi>=0.115.0",
|
|
8
|
+
"promptuna==1.24.0",
|
|
9
|
+
"uvicorn[standard]>=0.32.0",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
[tool.uv.sources]
|
|
13
|
+
promptuna = { workspace = true }
|
|
14
|
+
|
|
15
|
+
[build-system]
|
|
16
|
+
requires = ["uv_build>=0.11.1,<0.12.0"]
|
|
17
|
+
build-backend = "uv_build"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""FastAPI transport layer for streaming promptuna jobs."""
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Job lifecycle: background threads, queues, SSE event bridging, and on-disk jobs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import threading
|
|
8
|
+
import uuid
|
|
9
|
+
from collections.abc import Callable, Iterator
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from promptuna.evaluate import Metric, Scoring, stream_evaluate
|
|
14
|
+
from promptuna.jobs import JobArchive, JobConfig, JobKind, JobStatus, get_jobs_root, stream_job
|
|
15
|
+
from promptuna.optimize import Step, stream_optimize
|
|
16
|
+
from promptuna.program import Example, Experiment
|
|
17
|
+
from promptuna.run import Trial, stream_run
|
|
18
|
+
|
|
19
|
+
_WAIT_TIMEOUT_SECONDS = 1.0
|
|
20
|
+
|
|
21
|
+
StreamSource = Callable[[], Iterator[Trial | Scoring | Step]]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class JobState:
|
|
26
|
+
"""In-memory job record for one streaming job."""
|
|
27
|
+
|
|
28
|
+
job_id: str
|
|
29
|
+
kind: JobKind
|
|
30
|
+
archive: JobArchive
|
|
31
|
+
status: JobStatus = "running"
|
|
32
|
+
events: list[dict[str, Any]] = field(default_factory=list)
|
|
33
|
+
error: str | None = None
|
|
34
|
+
_cond: threading.Condition = field(default_factory=threading.Condition)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_jobs: dict[str, JobState] = {}
|
|
38
|
+
_jobs_lock = threading.Lock()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _has_running_job() -> bool:
|
|
42
|
+
return any(job.status == "running" for job in _jobs.values())
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _append_event(job: JobState, envelope: dict[str, Any]) -> None:
|
|
46
|
+
with job._cond:
|
|
47
|
+
job.events.append(envelope)
|
|
48
|
+
job._cond.notify_all()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _finish_job(job: JobState) -> None:
|
|
52
|
+
with job._cond:
|
|
53
|
+
job._cond.notify_all()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def start_run_job(
|
|
57
|
+
*,
|
|
58
|
+
config: JobConfig,
|
|
59
|
+
experiment: Experiment,
|
|
60
|
+
examples: list[Example],
|
|
61
|
+
workers: int,
|
|
62
|
+
) -> str:
|
|
63
|
+
"""Start a run job and return its ``job_id``."""
|
|
64
|
+
return _start_job(
|
|
65
|
+
"run",
|
|
66
|
+
lambda: stream_run(experiment, examples, workers=workers),
|
|
67
|
+
config=config,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def start_evaluate_job(
|
|
72
|
+
*,
|
|
73
|
+
config: JobConfig,
|
|
74
|
+
experiment: Experiment,
|
|
75
|
+
examples: list[Example],
|
|
76
|
+
metrics: list[Metric],
|
|
77
|
+
workers: int,
|
|
78
|
+
) -> str:
|
|
79
|
+
"""Start an evaluate job and return its ``job_id``."""
|
|
80
|
+
return _start_job(
|
|
81
|
+
"evaluate",
|
|
82
|
+
lambda: stream_evaluate(experiment, examples, metrics, workers=workers),
|
|
83
|
+
config=config,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def start_optimize_job(
|
|
88
|
+
*,
|
|
89
|
+
config: JobConfig,
|
|
90
|
+
experiment: Experiment,
|
|
91
|
+
examples: list[Example],
|
|
92
|
+
metrics: list[Metric],
|
|
93
|
+
workers: int,
|
|
94
|
+
steps: int,
|
|
95
|
+
proposer_model: str,
|
|
96
|
+
) -> str:
|
|
97
|
+
"""Start an optimize job and return its ``job_id``."""
|
|
98
|
+
return _start_job(
|
|
99
|
+
"optimize",
|
|
100
|
+
lambda: stream_optimize(
|
|
101
|
+
experiment,
|
|
102
|
+
examples,
|
|
103
|
+
metrics,
|
|
104
|
+
proposer_model=proposer_model,
|
|
105
|
+
steps=steps,
|
|
106
|
+
workers=workers,
|
|
107
|
+
),
|
|
108
|
+
config=config,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _start_job(kind: JobKind, build_source: StreamSource, *, config: JobConfig) -> str:
|
|
113
|
+
with _jobs_lock:
|
|
114
|
+
if _has_running_job():
|
|
115
|
+
raise ConflictError("another job is already running")
|
|
116
|
+
job_id = str(uuid.uuid4())
|
|
117
|
+
archive = JobArchive.open(get_jobs_root(), job_id, config)
|
|
118
|
+
job = JobState(job_id=job_id, kind=kind, archive=archive)
|
|
119
|
+
_jobs[job_id] = job
|
|
120
|
+
|
|
121
|
+
thread = threading.Thread(
|
|
122
|
+
target=_job_thread,
|
|
123
|
+
args=(job, build_source),
|
|
124
|
+
daemon=True,
|
|
125
|
+
name=f"promptuna-{kind}-{job_id[:8]}",
|
|
126
|
+
)
|
|
127
|
+
thread.start()
|
|
128
|
+
return job_id
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ConflictError(Exception):
|
|
132
|
+
"""Raised when a second job is started while one is running."""
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class JobNotFoundError(Exception):
|
|
136
|
+
"""Raised when ``job_id`` is unknown."""
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def get_job(job_id: str) -> JobState:
|
|
140
|
+
"""Return job state or raise :class:`JobNotFoundError`."""
|
|
141
|
+
try:
|
|
142
|
+
return _jobs[job_id]
|
|
143
|
+
except KeyError as exc:
|
|
144
|
+
raise JobNotFoundError(job_id) from exc
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _job_thread(job: JobState, build_source: StreamSource) -> None:
|
|
148
|
+
try:
|
|
149
|
+
for envelope in stream_job(job.archive, build_source()):
|
|
150
|
+
_append_event(job, envelope)
|
|
151
|
+
job.status = "done"
|
|
152
|
+
except Exception as exc:
|
|
153
|
+
job.status = "error"
|
|
154
|
+
job.error = str(exc)
|
|
155
|
+
finally:
|
|
156
|
+
_finish_job(job)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
async def stream_job_events(job_id: str):
|
|
160
|
+
"""Async generator yielding SSE ``data:`` lines for ``job_id``."""
|
|
161
|
+
job = get_job(job_id)
|
|
162
|
+
offset = 0
|
|
163
|
+
while True:
|
|
164
|
+
batch, done = await asyncio.to_thread(_wait_for_events, job, offset)
|
|
165
|
+
for envelope in batch:
|
|
166
|
+
yield f"data: {json.dumps(envelope)}\n\n"
|
|
167
|
+
offset += 1
|
|
168
|
+
if done and offset >= len(job.events):
|
|
169
|
+
break
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _wait_for_events(job: JobState, offset: int) -> tuple[list[dict[str, Any]], bool]:
|
|
173
|
+
with job._cond:
|
|
174
|
+
while offset >= len(job.events) and job.status == "running":
|
|
175
|
+
job._cond.wait(timeout=_WAIT_TIMEOUT_SECONDS)
|
|
176
|
+
return job.events[offset:], job.status != "running"
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def reset_jobs() -> None:
|
|
180
|
+
"""Clear all jobs (tests only)."""
|
|
181
|
+
with _jobs_lock:
|
|
182
|
+
_jobs.clear()
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""FastAPI application: routes, CORS, and job orchestration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from fastapi import FastAPI, HTTPException, status
|
|
6
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
7
|
+
from fastapi.responses import StreamingResponse
|
|
8
|
+
|
|
9
|
+
from promptuna.jobs import JobConfig, JobKind
|
|
10
|
+
from promptuna.projects import (
|
|
11
|
+
ProjectValidationError,
|
|
12
|
+
build_catalog,
|
|
13
|
+
build_experiment,
|
|
14
|
+
get_projects_root,
|
|
15
|
+
resolve_dataset_path,
|
|
16
|
+
resolve_project_dir,
|
|
17
|
+
)
|
|
18
|
+
from promptuna_server import jobs
|
|
19
|
+
from promptuna_server.schemas import (
|
|
20
|
+
CatalogResponse,
|
|
21
|
+
EvaluateRequest,
|
|
22
|
+
JobStartResponse,
|
|
23
|
+
OptimizeRequest,
|
|
24
|
+
ProjectCatalogResponse,
|
|
25
|
+
RunRequest,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
app = FastAPI(title="promptuna-server")
|
|
29
|
+
|
|
30
|
+
app.add_middleware(
|
|
31
|
+
CORSMiddleware,
|
|
32
|
+
allow_origins=[
|
|
33
|
+
"http://localhost:5173",
|
|
34
|
+
"http://127.0.0.1:5173",
|
|
35
|
+
],
|
|
36
|
+
allow_credentials=True,
|
|
37
|
+
allow_methods=["*"],
|
|
38
|
+
allow_headers=["*"],
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _validation_error(exc: ProjectValidationError) -> HTTPException:
|
|
43
|
+
return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _job_config(
|
|
47
|
+
request: RunRequest,
|
|
48
|
+
*,
|
|
49
|
+
kind: JobKind,
|
|
50
|
+
metrics: list[str] | None = None,
|
|
51
|
+
steps: int | None = None,
|
|
52
|
+
proposer_model: str | None = None,
|
|
53
|
+
) -> JobConfig:
|
|
54
|
+
project_dir = resolve_project_dir(request.project)
|
|
55
|
+
dataset_path = resolve_dataset_path(project_dir, request.examples)
|
|
56
|
+
return JobConfig(
|
|
57
|
+
kind=kind,
|
|
58
|
+
projects_root=get_projects_root(),
|
|
59
|
+
project=request.project,
|
|
60
|
+
program=request.program,
|
|
61
|
+
prompt=request.prompt,
|
|
62
|
+
examples=request.examples,
|
|
63
|
+
dataset_path=dataset_path,
|
|
64
|
+
model=request.model,
|
|
65
|
+
workers=request.workers,
|
|
66
|
+
metrics=tuple(metrics) if metrics is not None else None,
|
|
67
|
+
steps=steps,
|
|
68
|
+
proposer_model=proposer_model,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@app.get("/health")
|
|
73
|
+
def health() -> dict[str, str]:
|
|
74
|
+
"""Liveness check."""
|
|
75
|
+
return {"status": "ok"}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@app.get("/catalog", response_model=CatalogResponse)
|
|
79
|
+
def catalog() -> CatalogResponse:
|
|
80
|
+
"""List project and artifact names under the active projects root."""
|
|
81
|
+
workspace = build_catalog()
|
|
82
|
+
return CatalogResponse(
|
|
83
|
+
projects_root=str(workspace.projects_root),
|
|
84
|
+
projects=[
|
|
85
|
+
ProjectCatalogResponse(
|
|
86
|
+
name=entry.name,
|
|
87
|
+
programs=entry.programs,
|
|
88
|
+
metrics=entry.metrics,
|
|
89
|
+
prompts=entry.prompts,
|
|
90
|
+
datasets=entry.datasets,
|
|
91
|
+
)
|
|
92
|
+
for entry in workspace.projects
|
|
93
|
+
],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@app.post("/run", response_model=JobStartResponse)
|
|
98
|
+
def start_run(request: RunRequest) -> JobStartResponse:
|
|
99
|
+
"""Start a run job; stream trials via ``GET /jobs/{job_id}/events``."""
|
|
100
|
+
try:
|
|
101
|
+
experiment, examples, _ = build_experiment(
|
|
102
|
+
project=request.project,
|
|
103
|
+
program=request.program,
|
|
104
|
+
prompt=request.prompt,
|
|
105
|
+
model=request.model,
|
|
106
|
+
examples=request.examples,
|
|
107
|
+
)
|
|
108
|
+
config = _job_config(request, kind="run")
|
|
109
|
+
except ProjectValidationError as exc:
|
|
110
|
+
raise _validation_error(exc) from exc
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
job_id = jobs.start_run_job(
|
|
114
|
+
config=config,
|
|
115
|
+
experiment=experiment,
|
|
116
|
+
examples=examples,
|
|
117
|
+
workers=request.workers,
|
|
118
|
+
)
|
|
119
|
+
except jobs.ConflictError as exc:
|
|
120
|
+
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
|
121
|
+
|
|
122
|
+
return JobStartResponse(job_id=job_id)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@app.post("/evaluate", response_model=JobStartResponse)
|
|
126
|
+
def start_evaluate(request: EvaluateRequest) -> JobStartResponse:
|
|
127
|
+
"""Start an evaluate job; stream trials and scorings via SSE."""
|
|
128
|
+
try:
|
|
129
|
+
experiment, examples, metrics = build_experiment(
|
|
130
|
+
project=request.project,
|
|
131
|
+
program=request.program,
|
|
132
|
+
prompt=request.prompt,
|
|
133
|
+
model=request.model,
|
|
134
|
+
examples=request.examples,
|
|
135
|
+
metrics=request.metrics,
|
|
136
|
+
)
|
|
137
|
+
config = _job_config(request, kind="evaluate", metrics=request.metrics)
|
|
138
|
+
except ProjectValidationError as exc:
|
|
139
|
+
raise _validation_error(exc) from exc
|
|
140
|
+
|
|
141
|
+
assert metrics is not None
|
|
142
|
+
try:
|
|
143
|
+
job_id = jobs.start_evaluate_job(
|
|
144
|
+
config=config,
|
|
145
|
+
experiment=experiment,
|
|
146
|
+
examples=examples,
|
|
147
|
+
metrics=metrics,
|
|
148
|
+
workers=request.workers,
|
|
149
|
+
)
|
|
150
|
+
except jobs.ConflictError as exc:
|
|
151
|
+
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
|
152
|
+
|
|
153
|
+
return JobStartResponse(job_id=job_id)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@app.post("/optimize", response_model=JobStartResponse)
|
|
157
|
+
def start_optimize(request: OptimizeRequest) -> JobStartResponse:
|
|
158
|
+
"""Start an optimize job; stream trials, scorings, and steps via SSE."""
|
|
159
|
+
try:
|
|
160
|
+
experiment, examples, metrics = build_experiment(
|
|
161
|
+
project=request.project,
|
|
162
|
+
program=request.program,
|
|
163
|
+
prompt=request.prompt,
|
|
164
|
+
model=request.model,
|
|
165
|
+
examples=request.examples,
|
|
166
|
+
metrics=request.metrics,
|
|
167
|
+
)
|
|
168
|
+
config = _job_config(
|
|
169
|
+
request,
|
|
170
|
+
kind="optimize",
|
|
171
|
+
metrics=request.metrics,
|
|
172
|
+
steps=request.steps,
|
|
173
|
+
proposer_model=request.proposer_model,
|
|
174
|
+
)
|
|
175
|
+
except ProjectValidationError as exc:
|
|
176
|
+
raise _validation_error(exc) from exc
|
|
177
|
+
|
|
178
|
+
assert metrics is not None
|
|
179
|
+
try:
|
|
180
|
+
job_id = jobs.start_optimize_job(
|
|
181
|
+
config=config,
|
|
182
|
+
experiment=experiment,
|
|
183
|
+
examples=examples,
|
|
184
|
+
metrics=metrics,
|
|
185
|
+
workers=request.workers,
|
|
186
|
+
steps=request.steps,
|
|
187
|
+
proposer_model=request.proposer_model,
|
|
188
|
+
)
|
|
189
|
+
except jobs.ConflictError as exc:
|
|
190
|
+
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
|
191
|
+
|
|
192
|
+
return JobStartResponse(job_id=job_id)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@app.get("/jobs/{job_id}/events")
|
|
196
|
+
async def job_events(job_id: str) -> StreamingResponse:
|
|
197
|
+
"""Server-sent events for one job until it completes or errors."""
|
|
198
|
+
try:
|
|
199
|
+
jobs.get_job(job_id)
|
|
200
|
+
except jobs.JobNotFoundError as exc:
|
|
201
|
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="job not found") from exc
|
|
202
|
+
|
|
203
|
+
return StreamingResponse(
|
|
204
|
+
jobs.stream_job_events(job_id),
|
|
205
|
+
media_type="text/event-stream",
|
|
206
|
+
)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Pydantic request and response models for the HTTP API."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class JobStartResponse(BaseModel):
|
|
7
|
+
"""Response body returned when a job is accepted."""
|
|
8
|
+
|
|
9
|
+
job_id: str
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ProjectCatalogResponse(BaseModel):
|
|
13
|
+
"""Name lists for one on-disk project."""
|
|
14
|
+
|
|
15
|
+
name: str
|
|
16
|
+
programs: list[str]
|
|
17
|
+
metrics: list[str]
|
|
18
|
+
prompts: list[str]
|
|
19
|
+
datasets: list[str]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CatalogResponse(BaseModel):
|
|
23
|
+
"""Workspace catalog for building job request selectors."""
|
|
24
|
+
|
|
25
|
+
projects_root: str
|
|
26
|
+
projects: list[ProjectCatalogResponse]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RunRequest(BaseModel):
|
|
30
|
+
"""Start a run job over a project-local program and dataset."""
|
|
31
|
+
|
|
32
|
+
project: str
|
|
33
|
+
program: str
|
|
34
|
+
prompt: str
|
|
35
|
+
model: str
|
|
36
|
+
examples: str
|
|
37
|
+
workers: int = Field(default=1, ge=1)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class EvaluateRequest(RunRequest):
|
|
41
|
+
"""Start an evaluate job with one or more project-local metrics."""
|
|
42
|
+
|
|
43
|
+
metrics: list[str] = Field(min_length=1)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class OptimizeRequest(EvaluateRequest):
|
|
47
|
+
"""Start an optimize job with a proposer budget."""
|
|
48
|
+
|
|
49
|
+
steps: int = Field(ge=0)
|
|
50
|
+
proposer_model: str
|