patientzero 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- backend/README.md +71 -0
- backend/__init__.py +0 -0
- backend/api/__init__.py +0 -0
- backend/api/dependencies.py +8 -0
- backend/api/main.py +81 -0
- backend/api/routes/__init__.py +0 -0
- backend/api/routes/agents.py +24 -0
- backend/api/routes/analysis.py +138 -0
- backend/api/routes/chat.py +109 -0
- backend/api/routes/distributions.py +18 -0
- backend/api/routes/experiments.py +150 -0
- backend/api/routes/settings.py +12 -0
- backend/api/routes/simulate.py +229 -0
- core/__init__.py +18 -0
- core/agent.py +42 -0
- core/agents/__init__.py +3 -0
- core/agents/base.py +61 -0
- core/analysis/__init__.py +0 -0
- core/analysis/coverage.py +123 -0
- core/config/__init__.py +0 -0
- core/config/settings.py +41 -0
- core/db/__init__.py +0 -0
- core/db/database.py +42 -0
- core/db/queries/__init__.py +0 -0
- core/db/queries/sessions.py +69 -0
- core/db/schema.sql +69 -0
- core/distribution.py +391 -0
- core/examples/__init__.py +1 -0
- core/examples/medical/__init__.py +1 -0
- core/examples/medical/config.py +32 -0
- core/examples/medical/distributions.py +157 -0
- core/examples/medical/prompts.py +25 -0
- core/examples/medical/run.py +54 -0
- core/experiment.py +171 -0
- core/feedback/__init__.py +0 -0
- core/feedback/feedback.py +141 -0
- core/judge.py +107 -0
- core/llm/__init__.py +0 -0
- core/llm/base.py +8 -0
- core/llm/claude_cli_provider.py +73 -0
- core/llm/factory.py +52 -0
- core/llm/mock.py +65 -0
- core/llm/openai_provider.py +35 -0
- core/logger.py +123 -0
- core/repositories/__init__.py +44 -0
- core/repositories/base.py +33 -0
- core/repositories/evaluations.py +131 -0
- core/repositories/experiments.py +182 -0
- core/repositories/optimization_targets.py +83 -0
- core/repositories/simulations.py +117 -0
- core/sampling.py +10 -0
- core/services/__init__.py +0 -0
- core/services/feedback.py +83 -0
- core/simulation.py +366 -0
- core/types/__init__.py +51 -0
- core/types/analysis.py +28 -0
- core/types/enums.py +6 -0
- core/types/events.py +13 -0
- core/types/feedback.py +54 -0
- core/types/judge_result.py +44 -0
- core/types/message.py +7 -0
- core/types/records.py +257 -0
- core/types/settings.py +9 -0
- core/types/simulation.py +9 -0
- core/types/trace.py +63 -0
- core/types/transcript.py +33 -0
- evaluations/__init__.py +0 -0
- evaluations/feedback/RECOMMENDATIONS.md +60 -0
- evaluations/feedback/RUNBOOK.md +104 -0
- evaluations/feedback/artifacts/baseline_low_lit_v1/analysis.csv +17 -0
- evaluations/feedback/artifacts/baseline_low_lit_v1/analysis.json +110 -0
- evaluations/feedback/artifacts/baseline_low_lit_v1/analysis_experiment.csv +4 -0
- evaluations/feedback/artifacts/baseline_low_lit_v1/run_summary.json +13 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n10_restart/analysis.csv +55 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n10_restart/analysis.json +110 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n10_restart/analysis_experiment.csv +11 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n10_restart/run_summary.json +13 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n10_stage1/analysis.csv +45 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n10_stage1/analysis.json +110 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n10_stage1/analysis_experiment.csv +9 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n10_stage1/run_summary.json +13 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n20/analysis.csv +23 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n20/analysis.json +110 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n20/run_summary.json +13 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n20_rerun1/analysis.csv +41 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n20_rerun1/analysis.json +110 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n20_rerun1/analysis_experiment.csv +15 -0
- evaluations/feedback/artifacts/baseline_low_lit_v2_n20_rerun1/run_summary.json +13 -0
- evaluations/feedback/artifacts/baseline_smoke/analysis.csv +1 -0
- evaluations/feedback/artifacts/baseline_smoke/analysis.json +8 -0
- evaluations/feedback/artifacts/baseline_smoke/run_summary.json +13 -0
- evaluations/feedback/artifacts/baseline_smoke2/analysis.csv +2 -0
- evaluations/feedback/artifacts/baseline_smoke2/analysis.json +38 -0
- evaluations/feedback/artifacts/baseline_smoke2/analysis_experiment.csv +2 -0
- evaluations/feedback/artifacts/baseline_smoke2/run_summary.json +13 -0
- evaluations/feedback/artifacts/baseline_v1/analysis.csv +14 -0
- evaluations/feedback/artifacts/baseline_v1/analysis.json +110 -0
- evaluations/feedback/artifacts/baseline_v1/analysis_experiment.csv +10 -0
- evaluations/feedback/artifacts/baseline_v1/hypotheses.md +39 -0
- evaluations/feedback/artifacts/baseline_v1/run_summary.json +13 -0
- evaluations/feedback/artifacts/intervention_low_lit_v2/analysis.csv +23 -0
- evaluations/feedback/artifacts/intervention_low_lit_v2/analysis.json +110 -0
- evaluations/feedback/artifacts/intervention_low_lit_v2/analysis_experiment.csv +7 -0
- evaluations/feedback/artifacts/intervention_low_lit_v2/compare_vs_baseline.json +39 -0
- evaluations/feedback/artifacts/intervention_low_lit_v2/compare_vs_baseline.md +34 -0
- evaluations/feedback/artifacts/intervention_low_lit_v2/run_summary.json +13 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_restart/analysis.csv +61 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_restart/analysis.json +110 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_restart/analysis_experiment.csv +7 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_restart/run_summary.json +13 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/analysis.csv +71 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/analysis.json +110 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/analysis_experiment.csv +11 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/compare_vs_baseline.json +55 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/compare_vs_baseline.md +34 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/run_summary.json +13 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n20/analysis.csv +23 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n20/analysis.json +110 -0
- evaluations/feedback/artifacts/intervention_low_lit_v3_n20/run_summary.json +13 -0
- evaluations/feedback/artifacts/intervention_v2/analysis.csv +10 -0
- evaluations/feedback/artifacts/intervention_v2/analysis.json +102 -0
- evaluations/feedback/artifacts/intervention_v2/analysis_experiment.csv +4 -0
- evaluations/feedback/artifacts/intervention_v2/compare_vs_baseline.json +61 -0
- evaluations/feedback/artifacts/intervention_v2/compare_vs_baseline.md +20 -0
- evaluations/feedback/artifacts/intervention_v2/run_summary.json +13 -0
- evaluations/judge/__init__.py +0 -0
- evaluations/judge/cases/__init__.py +0 -0
- evaluations/judge/cases/bad_explanation.py +43 -0
- evaluations/judge/cases/cbc_good.py +36 -0
- evaluations/judge/cases/cbc_poor.py +32 -0
- evaluations/judge/cases/confidence_gap.py +34 -0
- evaluations/judge/cases/hba1c_good.py +43 -0
- evaluations/judge/cases/liver_passive.py +40 -0
- evaluations/judge/cases/metabolic_high_literacy.py +51 -0
- evaluations/judge/cases/metformin_mixed.py +39 -0
- evaluations/judge/cases/patient_contradicts.py +49 -0
- evaluations/judge/cases/short_exchange.py +29 -0
- evaluations/judge/output/bad_explanation.json +13 -0
- evaluations/judge/output/cbc_good.json +13 -0
- evaluations/judge/output/cbc_poor.json +13 -0
- evaluations/judge/output/confidence_gap.json +13 -0
- evaluations/judge/output/hba1c_good.json +13 -0
- evaluations/judge/output/liver_passive.json +13 -0
- evaluations/judge/output/metabolic_high_literacy.json +13 -0
- evaluations/judge/output/metformin_mixed.json +13 -0
- evaluations/judge/output/patient_contradicts.json +13 -0
- evaluations/judge/output/short_exchange.json +13 -0
- patientzero-0.1.0.dist-info/METADATA +96 -0
- patientzero-0.1.0.dist-info/RECORD +151 -0
- patientzero-0.1.0.dist-info/WHEEL +4 -0
- patientzero-0.1.0.dist-info/licenses/LICENSE +21 -0
backend/README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# PatientZero Backend
|
|
2
|
+
|
|
3
|
+
FastAPI server with SQLite database and abstracted LLM provider layer.
|
|
4
|
+
|
|
5
|
+
## Setup
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
uv sync
|
|
9
|
+
cp ../.env.example ../.env # if not done already
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
## Running
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
uv run uvicorn api.main:app --reload
|
|
16
|
+
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
Server runs at http://localhost:8000
|
|
20
|
+
|
|
21
|
+
## Project Structure
|
|
22
|
+
|
|
23
|
+
```
|
|
24
|
+
backend/
|
|
25
|
+
├── api/
|
|
26
|
+
│ ├── main.py # FastAPI app, CORS, lifespan
|
|
27
|
+
│ ├── dependencies.py # Shared db + provider instances
|
|
28
|
+
│ └── routes/
|
|
29
|
+
│ └── chat.py # Chat + session endpoints
|
|
30
|
+
├── config/
|
|
31
|
+
│ └── settings.py # Environment config
|
|
32
|
+
├── db/
|
|
33
|
+
│ ├── database.py # Database class (SQLite, raw queries)
|
|
34
|
+
│ ├── schema.sql # Table definitions
|
|
35
|
+
│ └── queries/
|
|
36
|
+
│ └── sessions.py # Session + turn CRUD
|
|
37
|
+
├── llm/
|
|
38
|
+
│ ├── base.py # Abstract LLMProvider
|
|
39
|
+
│ ├── mock.py # Mock provider (testing)
|
|
40
|
+
│ └── factory.py # Provider factory
|
|
41
|
+
└── pyproject.toml
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## API Endpoints
|
|
45
|
+
|
|
46
|
+
| Method | Path | Description |
|
|
47
|
+
|--------|------|-------------|
|
|
48
|
+
| `POST` | `/api/sessions` | Create a new chat session |
|
|
49
|
+
| `GET` | `/api/sessions` | List all sessions |
|
|
50
|
+
| `GET` | `/api/sessions/{id}` | Get session with turns |
|
|
51
|
+
| `POST` | `/api/chat` | Send message, receive SSE stream |
|
|
52
|
+
|
|
53
|
+
## LLM Providers
|
|
54
|
+
|
|
55
|
+
Set `LLM_PROVIDER` in `.env`:
|
|
56
|
+
|
|
57
|
+
| Provider | Value | Status |
|
|
58
|
+
|----------|-------|--------|
|
|
59
|
+
| Mock | `mock` | Available |
|
|
60
|
+
| OpenAI | `openai` | Planned |
|
|
61
|
+
| Claude | `claude` | Planned |
|
|
62
|
+
| Local | `local` | Planned |
|
|
63
|
+
|
|
64
|
+
## Database
|
|
65
|
+
|
|
66
|
+
SQLite with WAL mode. Tables:
|
|
67
|
+
|
|
68
|
+
- `sessions` — chat sessions (id, title, created_at)
|
|
69
|
+
- `turns` — individual messages (session_id, role, content, turn_number)
|
|
70
|
+
|
|
71
|
+
Database file is created automatically on first run.
|
backend/__init__.py
ADDED
|
File without changes
|
backend/api/__init__.py
ADDED
|
File without changes
|
backend/api/main.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import traceback
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
|
|
5
|
+
from fastapi import FastAPI, Request
|
|
6
|
+
from fastapi.exceptions import HTTPException, RequestValidationError
|
|
7
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
8
|
+
from fastapi.responses import JSONResponse
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger("patientzero.api")
|
|
11
|
+
|
|
12
|
+
from core import Experiment
|
|
13
|
+
from core.config.settings import FRONTEND_URL
|
|
14
|
+
from core.examples.medical.config import MEDICAL_EXAMPLE_CONFIG
|
|
15
|
+
from backend.api.dependencies import db, repos
|
|
16
|
+
from backend.api.routes.agents import router as agents_router
|
|
17
|
+
from backend.api.routes.analysis import router as analysis_router
|
|
18
|
+
from backend.api.routes.chat import router as chat_router
|
|
19
|
+
from backend.api.routes.distributions import router as distributions_router
|
|
20
|
+
from backend.api.routes.experiments import router as experiments_router
|
|
21
|
+
from backend.api.routes.settings import router as settings_router
|
|
22
|
+
from backend.api.routes.simulate import router as simulate_router
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@asynccontextmanager
|
|
26
|
+
async def lifespan(app: FastAPI):
|
|
27
|
+
db.init()
|
|
28
|
+
if not repos.experiments.list_all():
|
|
29
|
+
Experiment(MEDICAL_EXAMPLE_CONFIG, repos)
|
|
30
|
+
yield
|
|
31
|
+
db.close()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
app = FastAPI(title="PatientZero", lifespan=lifespan)
|
|
35
|
+
|
|
36
|
+
app.add_middleware(
|
|
37
|
+
CORSMiddleware,
|
|
38
|
+
allow_origins=[FRONTEND_URL],
|
|
39
|
+
allow_methods=["*"],
|
|
40
|
+
allow_headers=["*"],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@app.exception_handler(HTTPException)
|
|
45
|
+
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
46
|
+
if exc.status_code >= 400:
|
|
47
|
+
print(
|
|
48
|
+
f"\033[33m[{exc.status_code}]\033[0m {request.method} {request.url.path} "
|
|
49
|
+
f"→ {exc.detail}",
|
|
50
|
+
flush=True,
|
|
51
|
+
)
|
|
52
|
+
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@app.exception_handler(RequestValidationError)
|
|
56
|
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
57
|
+
print(
|
|
58
|
+
f"\033[33m[422]\033[0m {request.method} {request.url.path} → validation error:\n"
|
|
59
|
+
f" {exc.errors()}",
|
|
60
|
+
flush=True,
|
|
61
|
+
)
|
|
62
|
+
return JSONResponse(status_code=422, content={"detail": exc.errors()})
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@app.exception_handler(Exception)
|
|
66
|
+
async def unhandled_exception_handler(request: Request, exc: Exception):
|
|
67
|
+
print(
|
|
68
|
+
f"\033[31m[500]\033[0m {request.method} {request.url.path} → {type(exc).__name__}: {exc}",
|
|
69
|
+
flush=True,
|
|
70
|
+
)
|
|
71
|
+
traceback.print_exc()
|
|
72
|
+
return JSONResponse(status_code=500, content={"detail": f"{type(exc).__name__}: {exc}"})
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
app.include_router(chat_router, prefix="/api")
|
|
76
|
+
app.include_router(simulate_router, prefix="/api")
|
|
77
|
+
app.include_router(analysis_router, prefix="/api")
|
|
78
|
+
app.include_router(settings_router, prefix="/api")
|
|
79
|
+
app.include_router(experiments_router, prefix="/api")
|
|
80
|
+
app.include_router(distributions_router, prefix="/api")
|
|
81
|
+
app.include_router(agents_router, prefix="/api")
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from fastapi import APIRouter, HTTPException
|
|
2
|
+
|
|
3
|
+
from backend.api.dependencies import repos
|
|
4
|
+
|
|
5
|
+
router = APIRouter()
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@router.get("/experiments/{exp_id}/agents")
|
|
9
|
+
def get_experiment_agents(exp_id: str):
|
|
10
|
+
experiment = repos.experiments.get(exp_id)
|
|
11
|
+
if experiment is None:
|
|
12
|
+
raise HTTPException(status_code=404, detail="Experiment not found")
|
|
13
|
+
config = experiment.config
|
|
14
|
+
return {
|
|
15
|
+
"agents": [
|
|
16
|
+
{"name": a.name, "prompt": a.prompt, "model": a.model}
|
|
17
|
+
for a in config.agents
|
|
18
|
+
],
|
|
19
|
+
"judge": {
|
|
20
|
+
"rubric": dict(config.judge.rubric),
|
|
21
|
+
"instructions": config.judge.instructions,
|
|
22
|
+
"model": config.judge.model,
|
|
23
|
+
},
|
|
24
|
+
}
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Experiment analysis — aggregates judge scores across each agent's sampled traits.
|
|
3
|
+
|
|
4
|
+
The shape is derived from the experiment's actual distributions: group by
|
|
5
|
+
(agent_name, trait, value) and report per-metric mean/std/n per bucket.
|
|
6
|
+
No hardcoded dimension names — whatever traits the distributions declare
|
|
7
|
+
are what you get.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
|
|
12
|
+
from fastapi import APIRouter, HTTPException
|
|
13
|
+
|
|
14
|
+
from backend.api.dependencies import repos
|
|
15
|
+
|
|
16
|
+
router = APIRouter()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# ── Stats helpers ─────────────────────────────────────────────────────────────
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _mean(vals: list[float]) -> float | None:
|
|
23
|
+
return sum(vals) / len(vals) if vals else None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _std(vals: list[float]) -> float | None:
|
|
27
|
+
if len(vals) < 2:
|
|
28
|
+
return None
|
|
29
|
+
m = _mean(vals)
|
|
30
|
+
assert m is not None
|
|
31
|
+
return math.sqrt(sum((v - m) ** 2 for v in vals) / (len(vals) - 1))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _score_stats(per_dim_values: dict[str, list[float]]) -> dict:
|
|
35
|
+
out: dict[str, dict] = {}
|
|
36
|
+
for dim, vals in per_dim_values.items():
|
|
37
|
+
m = _mean(vals)
|
|
38
|
+
s = _std(vals)
|
|
39
|
+
out[dim] = {
|
|
40
|
+
"mean": round(m, 2) if m is not None else None,
|
|
41
|
+
"std": round(s, 2) if s is not None else None,
|
|
42
|
+
"n": len(vals),
|
|
43
|
+
}
|
|
44
|
+
return out
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _row_dims(row: dict, dims: set[str]) -> dict[str, list[float]]:
|
|
48
|
+
return {dim: [row[dim]] if row.get(dim) is not None else [] for dim in dims}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# ── Row building ──────────────────────────────────────────────────────────────
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _build_rows(experiment_id: str) -> tuple[list[dict], set[str]]:
|
|
55
|
+
"""Join completed simulations with their evaluations and flatten per-dim scores.
|
|
56
|
+
|
|
57
|
+
Each returned row has: ``profiles`` (dict[agent → trait dict]) plus every
|
|
58
|
+
judge rubric dimension at the top level as a single float (mean across
|
|
59
|
+
judge_results for that dimension on that evaluation).
|
|
60
|
+
Returns ``(rows, judge_dimensions)``.
|
|
61
|
+
"""
|
|
62
|
+
pairs = repos.evaluations.list_completed_with_evaluations_for_experiment(experiment_id)
|
|
63
|
+
rows: list[dict] = []
|
|
64
|
+
dims: set[str] = set()
|
|
65
|
+
for sim, ev in pairs:
|
|
66
|
+
row: dict = {"profiles": sim.config.profiles}
|
|
67
|
+
for judge in ev.judge_results:
|
|
68
|
+
for name, value in judge.scores.items():
|
|
69
|
+
dims.add(name)
|
|
70
|
+
for dim in dims:
|
|
71
|
+
vals = [
|
|
72
|
+
j.scores.get(dim)
|
|
73
|
+
for j in ev.judge_results
|
|
74
|
+
if j.scores.get(dim) is not None
|
|
75
|
+
]
|
|
76
|
+
row[dim] = (sum(vals) / len(vals)) if vals else None
|
|
77
|
+
rows.append(row)
|
|
78
|
+
return rows, dims
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# ── Grouping ──────────────────────────────────────────────────────────────────
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _aggregate_per_dim(rows: list[dict], dims: set[str]) -> dict[str, list[float]]:
|
|
85
|
+
out: dict[str, list[float]] = {dim: [] for dim in dims}
|
|
86
|
+
for row in rows:
|
|
87
|
+
for dim in dims:
|
|
88
|
+
v = row.get(dim)
|
|
89
|
+
if v is not None:
|
|
90
|
+
out[dim].append(v)
|
|
91
|
+
return out
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _group_by_trait(
|
|
95
|
+
rows: list[dict], dims: set[str], agent_name: str, trait: str
|
|
96
|
+
) -> dict[str, dict]:
|
|
97
|
+
buckets: dict[str, list[dict]] = {}
|
|
98
|
+
for row in rows:
|
|
99
|
+
value = row.get("profiles", {}).get(agent_name, {}).get(trait)
|
|
100
|
+
if value is None:
|
|
101
|
+
continue
|
|
102
|
+
buckets.setdefault(value, []).append(row)
|
|
103
|
+
return {
|
|
104
|
+
value: _score_stats(_aggregate_per_dim(bucket_rows, dims))
|
|
105
|
+
for value, bucket_rows in buckets.items()
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# ── Endpoints ─────────────────────────────────────────────────────────────────
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@router.get("/experiments/{exp_id}/analysis")
|
|
113
|
+
def get_experiment_analysis(exp_id: str):
|
|
114
|
+
experiment = repos.experiments.get(exp_id)
|
|
115
|
+
if experiment is None:
|
|
116
|
+
raise HTTPException(status_code=404, detail="Experiment not found")
|
|
117
|
+
|
|
118
|
+
rows, dims = _build_rows(exp_id)
|
|
119
|
+
if not rows:
|
|
120
|
+
return {
|
|
121
|
+
"total_evaluations": 0,
|
|
122
|
+
"overall": {},
|
|
123
|
+
"by_trait": {},
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
overall = _score_stats(_aggregate_per_dim(rows, dims))
|
|
127
|
+
|
|
128
|
+
by_trait: dict[str, dict[str, dict]] = {}
|
|
129
|
+
for agent in experiment.config.agents:
|
|
130
|
+
for trait in agent.distribution.topo_order:
|
|
131
|
+
key = f"{agent.name}.{trait}"
|
|
132
|
+
by_trait[key] = _group_by_trait(rows, dims, agent.name, trait)
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
"total_evaluations": len(rows),
|
|
136
|
+
"overall": overall,
|
|
137
|
+
"by_trait": by_trait,
|
|
138
|
+
}
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter, HTTPException
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
from sse_starlette.sse import EventSourceResponse
|
|
6
|
+
|
|
7
|
+
from backend.api.dependencies import db
|
|
8
|
+
from core.config.settings import AVAILABLE_MODELS
|
|
9
|
+
from core.db.queries.sessions import (
|
|
10
|
+
create_session,
|
|
11
|
+
create_turn,
|
|
12
|
+
delete_session,
|
|
13
|
+
get_session,
|
|
14
|
+
get_turn_count,
|
|
15
|
+
get_turns,
|
|
16
|
+
list_sessions,
|
|
17
|
+
update_session_model,
|
|
18
|
+
update_session_title,
|
|
19
|
+
)
|
|
20
|
+
from core.llm.factory import parse_provider_model
|
|
21
|
+
|
|
22
|
+
router = APIRouter()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CreateSessionRequest(BaseModel):
|
|
26
|
+
model: str = "mock:default"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ChatRequest(BaseModel):
|
|
30
|
+
session_id: str
|
|
31
|
+
message: str = Field(min_length=1, max_length=10000)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class UpdateSessionRequest(BaseModel):
|
|
35
|
+
model: str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@router.get("/models")
|
|
39
|
+
def get_available_models():
|
|
40
|
+
return AVAILABLE_MODELS
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@router.post("/sessions")
|
|
44
|
+
def create_new_session(request: CreateSessionRequest):
|
|
45
|
+
return create_session(db, request.model).to_dict()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@router.get("/sessions")
|
|
49
|
+
def get_all_sessions():
|
|
50
|
+
return [s.to_dict() for s in list_sessions(db)]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@router.get("/sessions/{session_id}")
|
|
54
|
+
def get_session_detail(session_id: str):
|
|
55
|
+
session = get_session(db, session_id)
|
|
56
|
+
if not session:
|
|
57
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
58
|
+
turns = get_turns(db, session_id)
|
|
59
|
+
return {**session.to_dict(), "turns": [t.to_dict() for t in turns]}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@router.patch("/sessions/{session_id}")
|
|
63
|
+
def update_session(session_id: str, request: UpdateSessionRequest):
|
|
64
|
+
session = get_session(db, session_id)
|
|
65
|
+
if not session:
|
|
66
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
67
|
+
update_session_model(db, session_id, request.model)
|
|
68
|
+
return get_session(db, session_id).to_dict()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@router.delete("/sessions/{session_id}")
|
|
72
|
+
def delete_session_endpoint(session_id: str):
|
|
73
|
+
session = get_session(db, session_id)
|
|
74
|
+
if not session:
|
|
75
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
76
|
+
delete_session(db, session_id)
|
|
77
|
+
return {"ok": True}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@router.post("/chat")
|
|
81
|
+
async def chat(request: ChatRequest):
|
|
82
|
+
session = get_session(db, request.session_id)
|
|
83
|
+
if not session:
|
|
84
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
85
|
+
|
|
86
|
+
turn_number = get_turn_count(db, request.session_id)
|
|
87
|
+
create_turn(db, request.session_id, "user", request.message, turn_number)
|
|
88
|
+
|
|
89
|
+
if turn_number == 0:
|
|
90
|
+
title = request.message[:50] + ("..." if len(request.message) > 50 else "")
|
|
91
|
+
update_session_title(db, request.session_id, title)
|
|
92
|
+
|
|
93
|
+
turns = get_turns(db, request.session_id)
|
|
94
|
+
messages = [{"role": t.role, "content": t.content} for t in turns]
|
|
95
|
+
|
|
96
|
+
provider, model = parse_provider_model(session.model)
|
|
97
|
+
|
|
98
|
+
async def generate():
|
|
99
|
+
full_response = ""
|
|
100
|
+
try:
|
|
101
|
+
async for chunk in provider.stream(messages, model):
|
|
102
|
+
full_response += chunk
|
|
103
|
+
yield {"data": json.dumps({"token": chunk})}
|
|
104
|
+
create_turn(db, request.session_id, "assistant", full_response, turn_number + 1)
|
|
105
|
+
yield {"event": "done", "data": ""}
|
|
106
|
+
except Exception as e:
|
|
107
|
+
yield {"event": "error", "data": json.dumps({"error": str(e)})}
|
|
108
|
+
|
|
109
|
+
return EventSourceResponse(generate())
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from fastapi import APIRouter, HTTPException
|
|
2
|
+
|
|
3
|
+
from backend.api.dependencies import repos
|
|
4
|
+
from core.distribution import distribution_to_dict
|
|
5
|
+
|
|
6
|
+
router = APIRouter()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@router.get("/experiments/{exp_id}/distributions/{agent_name}")
|
|
10
|
+
def get_agent_distribution(exp_id: str, agent_name: str):
|
|
11
|
+
experiment = repos.experiments.get(exp_id)
|
|
12
|
+
if experiment is None:
|
|
13
|
+
raise HTTPException(status_code=404, detail="Experiment not found")
|
|
14
|
+
try:
|
|
15
|
+
agent = experiment.config.agent(agent_name)
|
|
16
|
+
except KeyError:
|
|
17
|
+
raise HTTPException(status_code=404, detail=f"Agent {agent_name!r} not found in experiment")
|
|
18
|
+
return {"distribution": distribution_to_dict(agent.distribution)}
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter, HTTPException, Query, Response
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
from backend.api.dependencies import repos
|
|
7
|
+
from core import Experiment
|
|
8
|
+
from core.analysis.coverage import compute_coverage
|
|
9
|
+
from core.config.settings import APP_SETTINGS
|
|
10
|
+
from core.examples.medical.config import MEDICAL_EXAMPLE_CONFIG
|
|
11
|
+
from core.services.feedback import FeedbackService
|
|
12
|
+
|
|
13
|
+
router = APIRouter()
|
|
14
|
+
|
|
15
|
+
_optimize_semaphore = threading.Semaphore(APP_SETTINGS.max_concurrent_optimizations)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CreateExperimentRequest(BaseModel):
|
|
19
|
+
name: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SetCurrentOptimizationTargetBody(BaseModel):
|
|
23
|
+
optimization_target_id: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _experiment_or_404(exp_id: str):
|
|
27
|
+
exp = repos.experiments.get(exp_id)
|
|
28
|
+
if not exp:
|
|
29
|
+
raise HTTPException(status_code=404, detail="Experiment not found")
|
|
30
|
+
return exp
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@router.post("/experiments")
|
|
34
|
+
def post_experiment(request: CreateExperimentRequest):
|
|
35
|
+
config = MEDICAL_EXAMPLE_CONFIG
|
|
36
|
+
# Shallow-override the name for this instance.
|
|
37
|
+
from dataclasses import replace
|
|
38
|
+
named_config = replace(config, name=request.name)
|
|
39
|
+
try:
|
|
40
|
+
exp = Experiment(named_config, repos)
|
|
41
|
+
except ValueError as e:
|
|
42
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
43
|
+
record = exp.record
|
|
44
|
+
return record.to_dict(counts=repos.experiments.counts_for(record.id))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@router.get("/experiments")
|
|
48
|
+
def get_experiments():
|
|
49
|
+
exps = repos.experiments.list_all()
|
|
50
|
+
counts_by_id = repos.experiments.counts_all()
|
|
51
|
+
return [e.to_dict(counts=counts_by_id.get(e.id)) for e in exps]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@router.get("/experiments/{exp_id}")
|
|
55
|
+
def get_experiment_by_id(exp_id: str):
|
|
56
|
+
exp = _experiment_or_404(exp_id)
|
|
57
|
+
return exp.to_dict(counts=repos.experiments.counts_for(exp_id))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@router.delete("/experiments/{exp_id}")
|
|
61
|
+
def delete_experiment_by_id(exp_id: str):
|
|
62
|
+
_experiment_or_404(exp_id)
|
|
63
|
+
repos.experiments.delete(exp_id)
|
|
64
|
+
return Response(status_code=204)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@router.patch("/experiments/{exp_id}")
|
|
68
|
+
def patch_experiment(exp_id: str):
|
|
69
|
+
"""Currently only supports resetting the sample draw index."""
|
|
70
|
+
_experiment_or_404(exp_id)
|
|
71
|
+
repos.experiments.reset_sample_draw_index(exp_id)
|
|
72
|
+
updated = repos.experiments.get(exp_id)
|
|
73
|
+
assert updated is not None
|
|
74
|
+
return updated.to_dict(counts=repos.experiments.counts_for(exp_id))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# ── Experiment-scoped lists ─────────────────────────────────────────────────
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@router.get("/experiments/{exp_id}/simulations")
|
|
81
|
+
def list_experiment_simulations(exp_id: str):
|
|
82
|
+
_experiment_or_404(exp_id)
|
|
83
|
+
return [s.to_dict() for s in repos.simulations.list_for_experiment(exp_id)]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@router.get("/experiments/{exp_id}/evaluations")
|
|
87
|
+
def list_experiment_evaluations(exp_id: str):
|
|
88
|
+
_experiment_or_404(exp_id)
|
|
89
|
+
return [e.to_dict() for e in repos.evaluations.list_for_experiment(exp_id)]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# ── Optimization targets ────────────────────────────────────────────────────
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@router.get("/experiments/{exp_id}/optimization-targets")
|
|
96
|
+
def list_experiment_optimization_targets(exp_id: str):
|
|
97
|
+
_experiment_or_404(exp_id)
|
|
98
|
+
return [t.to_dict() for t in repos.optimization_targets.list_for_experiment(exp_id)]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@router.post("/experiments/{exp_id}/optimization-target/current")
|
|
102
|
+
def set_experiment_current_optimization_target(
|
|
103
|
+
exp_id: str, body: SetCurrentOptimizationTargetBody
|
|
104
|
+
):
|
|
105
|
+
_experiment_or_404(exp_id)
|
|
106
|
+
target = repos.optimization_targets.get(body.optimization_target_id)
|
|
107
|
+
if target is None or target.experiment_id != exp_id:
|
|
108
|
+
raise HTTPException(status_code=404, detail="Optimization target not found for this experiment")
|
|
109
|
+
repos.experiments.set_current_optimization_target(exp_id, body.optimization_target_id)
|
|
110
|
+
updated = repos.experiments.get(exp_id)
|
|
111
|
+
assert updated is not None
|
|
112
|
+
return updated.to_dict(counts=repos.experiments.counts_for(exp_id))
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# ── Coverage ────────────────────────────────────────────────────────────────
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@router.get("/experiments/{exp_id}/coverage")
|
|
119
|
+
def get_experiment_coverage(
|
|
120
|
+
exp_id: str,
|
|
121
|
+
mc_samples: int = Query(100_000, ge=5_000, le=500_000),
|
|
122
|
+
):
|
|
123
|
+
exp = _experiment_or_404(exp_id)
|
|
124
|
+
sims = repos.simulations.list_for_experiment(exp_id)
|
|
125
|
+
return compute_coverage(sims, exp.config.agents, samples=mc_samples).to_dict()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# ── Optimize ────────────────────────────────────────────────────────────────
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@router.post("/experiments/{exp_id}/optimize")
|
|
132
|
+
async def optimize_experiment(exp_id: str):
|
|
133
|
+
_experiment_or_404(exp_id)
|
|
134
|
+
|
|
135
|
+
if not _optimize_semaphore.acquire(blocking=False):
|
|
136
|
+
raise HTTPException(
|
|
137
|
+
status_code=409,
|
|
138
|
+
detail=(
|
|
139
|
+
"Another optimization run is in progress "
|
|
140
|
+
f"(max_concurrent_optimizations={APP_SETTINGS.max_concurrent_optimizations})"
|
|
141
|
+
),
|
|
142
|
+
)
|
|
143
|
+
try:
|
|
144
|
+
try:
|
|
145
|
+
result = await FeedbackService(repos).optimize(exp_id)
|
|
146
|
+
except ValueError as e:
|
|
147
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
148
|
+
return result.to_dict()
|
|
149
|
+
finally:
|
|
150
|
+
_optimize_semaphore.release()
|