promptuna-server 1.24.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.3
2
+ Name: promptuna-server
3
+ Version: 1.24.0
4
+ Summary: FastAPI server for streaming promptuna run / evaluate / optimize jobs
5
+ Requires-Dist: fastapi>=0.115.0
6
+ Requires-Dist: promptuna==1.24.0
7
+ Requires-Dist: uvicorn[standard]>=0.32.0
8
+ Requires-Python: >=3.13
@@ -0,0 +1,17 @@
1
+ [project]
2
+ name = "promptuna-server"
3
+ version = "1.24.0"
4
+ description = "FastAPI server for streaming promptuna run / evaluate / optimize jobs"
5
+ requires-python = ">=3.13"
6
+ dependencies = [
7
+ "fastapi>=0.115.0",
8
+ "promptuna==1.24.0",
9
+ "uvicorn[standard]>=0.32.0",
10
+ ]
11
+
12
+ [tool.uv.sources]
13
+ promptuna = { workspace = true }
14
+
15
+ [build-system]
16
+ requires = ["uv_build>=0.11.1,<0.12.0"]
17
+ build-backend = "uv_build"
@@ -0,0 +1 @@
1
+ """FastAPI transport layer for streaming promptuna jobs."""
@@ -0,0 +1,182 @@
1
+ """Job lifecycle: background threads, queues, SSE event bridging, and on-disk jobs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import threading
8
+ import uuid
9
+ from collections.abc import Callable, Iterator
10
+ from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+ from promptuna.evaluate import Metric, Scoring, stream_evaluate
14
+ from promptuna.jobs import JobArchive, JobConfig, JobKind, JobStatus, get_jobs_root, stream_job
15
+ from promptuna.optimize import Step, stream_optimize
16
+ from promptuna.program import Example, Experiment
17
+ from promptuna.run import Trial, stream_run
18
+
19
+ _WAIT_TIMEOUT_SECONDS = 1.0
20
+
21
+ StreamSource = Callable[[], Iterator[Trial | Scoring | Step]]
22
+
23
+
24
+ @dataclass
25
+ class JobState:
26
+ """In-memory job record for one streaming job."""
27
+
28
+ job_id: str
29
+ kind: JobKind
30
+ archive: JobArchive
31
+ status: JobStatus = "running"
32
+ events: list[dict[str, Any]] = field(default_factory=list)
33
+ error: str | None = None
34
+ _cond: threading.Condition = field(default_factory=threading.Condition)
35
+
36
+
37
+ _jobs: dict[str, JobState] = {}
38
+ _jobs_lock = threading.Lock()
39
+
40
+
41
+ def _has_running_job() -> bool:
42
+ return any(job.status == "running" for job in _jobs.values())
43
+
44
+
45
+ def _append_event(job: JobState, envelope: dict[str, Any]) -> None:
46
+ with job._cond:
47
+ job.events.append(envelope)
48
+ job._cond.notify_all()
49
+
50
+
51
+ def _finish_job(job: JobState) -> None:
52
+ with job._cond:
53
+ job._cond.notify_all()
54
+
55
+
56
+ def start_run_job(
57
+ *,
58
+ config: JobConfig,
59
+ experiment: Experiment,
60
+ examples: list[Example],
61
+ workers: int,
62
+ ) -> str:
63
+ """Start a run job and return its ``job_id``."""
64
+ return _start_job(
65
+ "run",
66
+ lambda: stream_run(experiment, examples, workers=workers),
67
+ config=config,
68
+ )
69
+
70
+
71
+ def start_evaluate_job(
72
+ *,
73
+ config: JobConfig,
74
+ experiment: Experiment,
75
+ examples: list[Example],
76
+ metrics: list[Metric],
77
+ workers: int,
78
+ ) -> str:
79
+ """Start an evaluate job and return its ``job_id``."""
80
+ return _start_job(
81
+ "evaluate",
82
+ lambda: stream_evaluate(experiment, examples, metrics, workers=workers),
83
+ config=config,
84
+ )
85
+
86
+
87
+ def start_optimize_job(
88
+ *,
89
+ config: JobConfig,
90
+ experiment: Experiment,
91
+ examples: list[Example],
92
+ metrics: list[Metric],
93
+ workers: int,
94
+ steps: int,
95
+ proposer_model: str,
96
+ ) -> str:
97
+ """Start an optimize job and return its ``job_id``."""
98
+ return _start_job(
99
+ "optimize",
100
+ lambda: stream_optimize(
101
+ experiment,
102
+ examples,
103
+ metrics,
104
+ proposer_model=proposer_model,
105
+ steps=steps,
106
+ workers=workers,
107
+ ),
108
+ config=config,
109
+ )
110
+
111
+
112
+ def _start_job(kind: JobKind, build_source: StreamSource, *, config: JobConfig) -> str:
113
+ with _jobs_lock:
114
+ if _has_running_job():
115
+ raise ConflictError("another job is already running")
116
+ job_id = str(uuid.uuid4())
117
+ archive = JobArchive.open(get_jobs_root(), job_id, config)
118
+ job = JobState(job_id=job_id, kind=kind, archive=archive)
119
+ _jobs[job_id] = job
120
+
121
+ thread = threading.Thread(
122
+ target=_job_thread,
123
+ args=(job, build_source),
124
+ daemon=True,
125
+ name=f"promptuna-{kind}-{job_id[:8]}",
126
+ )
127
+ thread.start()
128
+ return job_id
129
+
130
+
131
+ class ConflictError(Exception):
132
+ """Raised when a second job is started while one is running."""
133
+
134
+
135
+ class JobNotFoundError(Exception):
136
+ """Raised when ``job_id`` is unknown."""
137
+
138
+
139
+ def get_job(job_id: str) -> JobState:
140
+ """Return job state or raise :class:`JobNotFoundError`."""
141
+ try:
142
+ return _jobs[job_id]
143
+ except KeyError as exc:
144
+ raise JobNotFoundError(job_id) from exc
145
+
146
+
147
+ def _job_thread(job: JobState, build_source: StreamSource) -> None:
148
+ try:
149
+ for envelope in stream_job(job.archive, build_source()):
150
+ _append_event(job, envelope)
151
+ job.status = "done"
152
+ except Exception as exc:
153
+ job.status = "error"
154
+ job.error = str(exc)
155
+ finally:
156
+ _finish_job(job)
157
+
158
+
159
+ async def stream_job_events(job_id: str):
160
+ """Async generator yielding SSE ``data:`` lines for ``job_id``."""
161
+ job = get_job(job_id)
162
+ offset = 0
163
+ while True:
164
+ batch, done = await asyncio.to_thread(_wait_for_events, job, offset)
165
+ for envelope in batch:
166
+ yield f"data: {json.dumps(envelope)}\n\n"
167
+ offset += 1
168
+ if done and offset >= len(job.events):
169
+ break
170
+
171
+
172
+ def _wait_for_events(job: JobState, offset: int) -> tuple[list[dict[str, Any]], bool]:
173
+ with job._cond:
174
+ while offset >= len(job.events) and job.status == "running":
175
+ job._cond.wait(timeout=_WAIT_TIMEOUT_SECONDS)
176
+ return job.events[offset:], job.status != "running"
177
+
178
+
179
+ def reset_jobs() -> None:
180
+ """Clear all jobs (tests only)."""
181
+ with _jobs_lock:
182
+ _jobs.clear()
@@ -0,0 +1,206 @@
1
+ """FastAPI application: routes, CORS, and job orchestration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from fastapi import FastAPI, HTTPException, status
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import StreamingResponse
8
+
9
+ from promptuna.jobs import JobConfig, JobKind
10
+ from promptuna.projects import (
11
+ ProjectValidationError,
12
+ build_catalog,
13
+ build_experiment,
14
+ get_projects_root,
15
+ resolve_dataset_path,
16
+ resolve_project_dir,
17
+ )
18
+ from promptuna_server import jobs
19
+ from promptuna_server.schemas import (
20
+ CatalogResponse,
21
+ EvaluateRequest,
22
+ JobStartResponse,
23
+ OptimizeRequest,
24
+ ProjectCatalogResponse,
25
+ RunRequest,
26
+ )
27
+
28
+ app = FastAPI(title="promptuna-server")
29
+
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=[
33
+ "http://localhost:5173",
34
+ "http://127.0.0.1:5173",
35
+ ],
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+
42
+ def _validation_error(exc: ProjectValidationError) -> HTTPException:
43
+ return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc))
44
+
45
+
46
+ def _job_config(
47
+ request: RunRequest,
48
+ *,
49
+ kind: JobKind,
50
+ metrics: list[str] | None = None,
51
+ steps: int | None = None,
52
+ proposer_model: str | None = None,
53
+ ) -> JobConfig:
54
+ project_dir = resolve_project_dir(request.project)
55
+ dataset_path = resolve_dataset_path(project_dir, request.examples)
56
+ return JobConfig(
57
+ kind=kind,
58
+ projects_root=get_projects_root(),
59
+ project=request.project,
60
+ program=request.program,
61
+ prompt=request.prompt,
62
+ examples=request.examples,
63
+ dataset_path=dataset_path,
64
+ model=request.model,
65
+ workers=request.workers,
66
+ metrics=tuple(metrics) if metrics is not None else None,
67
+ steps=steps,
68
+ proposer_model=proposer_model,
69
+ )
70
+
71
+
72
+ @app.get("/health")
73
+ def health() -> dict[str, str]:
74
+ """Liveness check."""
75
+ return {"status": "ok"}
76
+
77
+
78
+ @app.get("/catalog", response_model=CatalogResponse)
79
+ def catalog() -> CatalogResponse:
80
+ """List project and artifact names under the active projects root."""
81
+ workspace = build_catalog()
82
+ return CatalogResponse(
83
+ projects_root=str(workspace.projects_root),
84
+ projects=[
85
+ ProjectCatalogResponse(
86
+ name=entry.name,
87
+ programs=entry.programs,
88
+ metrics=entry.metrics,
89
+ prompts=entry.prompts,
90
+ datasets=entry.datasets,
91
+ )
92
+ for entry in workspace.projects
93
+ ],
94
+ )
95
+
96
+
97
+ @app.post("/run", response_model=JobStartResponse)
98
+ def start_run(request: RunRequest) -> JobStartResponse:
99
+ """Start a run job; stream trials via ``GET /jobs/{job_id}/events``."""
100
+ try:
101
+ experiment, examples, _ = build_experiment(
102
+ project=request.project,
103
+ program=request.program,
104
+ prompt=request.prompt,
105
+ model=request.model,
106
+ examples=request.examples,
107
+ )
108
+ config = _job_config(request, kind="run")
109
+ except ProjectValidationError as exc:
110
+ raise _validation_error(exc) from exc
111
+
112
+ try:
113
+ job_id = jobs.start_run_job(
114
+ config=config,
115
+ experiment=experiment,
116
+ examples=examples,
117
+ workers=request.workers,
118
+ )
119
+ except jobs.ConflictError as exc:
120
+ raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
121
+
122
+ return JobStartResponse(job_id=job_id)
123
+
124
+
125
+ @app.post("/evaluate", response_model=JobStartResponse)
126
+ def start_evaluate(request: EvaluateRequest) -> JobStartResponse:
127
+ """Start an evaluate job; stream trials and scorings via SSE."""
128
+ try:
129
+ experiment, examples, metrics = build_experiment(
130
+ project=request.project,
131
+ program=request.program,
132
+ prompt=request.prompt,
133
+ model=request.model,
134
+ examples=request.examples,
135
+ metrics=request.metrics,
136
+ )
137
+ config = _job_config(request, kind="evaluate", metrics=request.metrics)
138
+ except ProjectValidationError as exc:
139
+ raise _validation_error(exc) from exc
140
+
141
+ assert metrics is not None
142
+ try:
143
+ job_id = jobs.start_evaluate_job(
144
+ config=config,
145
+ experiment=experiment,
146
+ examples=examples,
147
+ metrics=metrics,
148
+ workers=request.workers,
149
+ )
150
+ except jobs.ConflictError as exc:
151
+ raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
152
+
153
+ return JobStartResponse(job_id=job_id)
154
+
155
+
156
+ @app.post("/optimize", response_model=JobStartResponse)
157
+ def start_optimize(request: OptimizeRequest) -> JobStartResponse:
158
+ """Start an optimize job; stream trials, scorings, and steps via SSE."""
159
+ try:
160
+ experiment, examples, metrics = build_experiment(
161
+ project=request.project,
162
+ program=request.program,
163
+ prompt=request.prompt,
164
+ model=request.model,
165
+ examples=request.examples,
166
+ metrics=request.metrics,
167
+ )
168
+ config = _job_config(
169
+ request,
170
+ kind="optimize",
171
+ metrics=request.metrics,
172
+ steps=request.steps,
173
+ proposer_model=request.proposer_model,
174
+ )
175
+ except ProjectValidationError as exc:
176
+ raise _validation_error(exc) from exc
177
+
178
+ assert metrics is not None
179
+ try:
180
+ job_id = jobs.start_optimize_job(
181
+ config=config,
182
+ experiment=experiment,
183
+ examples=examples,
184
+ metrics=metrics,
185
+ workers=request.workers,
186
+ steps=request.steps,
187
+ proposer_model=request.proposer_model,
188
+ )
189
+ except jobs.ConflictError as exc:
190
+ raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
191
+
192
+ return JobStartResponse(job_id=job_id)
193
+
194
+
195
+ @app.get("/jobs/{job_id}/events")
196
+ async def job_events(job_id: str) -> StreamingResponse:
197
+ """Server-sent events for one job until it completes or errors."""
198
+ try:
199
+ jobs.get_job(job_id)
200
+ except jobs.JobNotFoundError as exc:
201
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="job not found") from exc
202
+
203
+ return StreamingResponse(
204
+ jobs.stream_job_events(job_id),
205
+ media_type="text/event-stream",
206
+ )
@@ -0,0 +1,50 @@
1
+ """Pydantic request and response models for the HTTP API."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class JobStartResponse(BaseModel):
7
+ """Response body returned when a job is accepted."""
8
+
9
+ job_id: str
10
+
11
+
12
+ class ProjectCatalogResponse(BaseModel):
13
+ """Name lists for one on-disk project."""
14
+
15
+ name: str
16
+ programs: list[str]
17
+ metrics: list[str]
18
+ prompts: list[str]
19
+ datasets: list[str]
20
+
21
+
22
+ class CatalogResponse(BaseModel):
23
+ """Workspace catalog for building job request selectors."""
24
+
25
+ projects_root: str
26
+ projects: list[ProjectCatalogResponse]
27
+
28
+
29
+ class RunRequest(BaseModel):
30
+ """Start a run job over a project-local program and dataset."""
31
+
32
+ project: str
33
+ program: str
34
+ prompt: str
35
+ model: str
36
+ examples: str
37
+ workers: int = Field(default=1, ge=1)
38
+
39
+
40
+ class EvaluateRequest(RunRequest):
41
+ """Start an evaluate job with one or more project-local metrics."""
42
+
43
+ metrics: list[str] = Field(min_length=1)
44
+
45
+
46
+ class OptimizeRequest(EvaluateRequest):
47
+ """Start an optimize job with a proposer budget."""
48
+
49
+ steps: int = Field(ge=0)
50
+ proposer_model: str