patientzero 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. backend/README.md +71 -0
  2. backend/__init__.py +0 -0
  3. backend/api/__init__.py +0 -0
  4. backend/api/dependencies.py +8 -0
  5. backend/api/main.py +81 -0
  6. backend/api/routes/__init__.py +0 -0
  7. backend/api/routes/agents.py +24 -0
  8. backend/api/routes/analysis.py +138 -0
  9. backend/api/routes/chat.py +109 -0
  10. backend/api/routes/distributions.py +18 -0
  11. backend/api/routes/experiments.py +150 -0
  12. backend/api/routes/settings.py +12 -0
  13. backend/api/routes/simulate.py +229 -0
  14. core/__init__.py +18 -0
  15. core/agent.py +42 -0
  16. core/agents/__init__.py +3 -0
  17. core/agents/base.py +61 -0
  18. core/analysis/__init__.py +0 -0
  19. core/analysis/coverage.py +123 -0
  20. core/config/__init__.py +0 -0
  21. core/config/settings.py +41 -0
  22. core/db/__init__.py +0 -0
  23. core/db/database.py +42 -0
  24. core/db/queries/__init__.py +0 -0
  25. core/db/queries/sessions.py +69 -0
  26. core/db/schema.sql +69 -0
  27. core/distribution.py +391 -0
  28. core/examples/__init__.py +1 -0
  29. core/examples/medical/__init__.py +1 -0
  30. core/examples/medical/config.py +32 -0
  31. core/examples/medical/distributions.py +157 -0
  32. core/examples/medical/prompts.py +25 -0
  33. core/examples/medical/run.py +54 -0
  34. core/experiment.py +171 -0
  35. core/feedback/__init__.py +0 -0
  36. core/feedback/feedback.py +141 -0
  37. core/judge.py +107 -0
  38. core/llm/__init__.py +0 -0
  39. core/llm/base.py +8 -0
  40. core/llm/claude_cli_provider.py +73 -0
  41. core/llm/factory.py +52 -0
  42. core/llm/mock.py +65 -0
  43. core/llm/openai_provider.py +35 -0
  44. core/logger.py +123 -0
  45. core/repositories/__init__.py +44 -0
  46. core/repositories/base.py +33 -0
  47. core/repositories/evaluations.py +131 -0
  48. core/repositories/experiments.py +182 -0
  49. core/repositories/optimization_targets.py +83 -0
  50. core/repositories/simulations.py +117 -0
  51. core/sampling.py +10 -0
  52. core/services/__init__.py +0 -0
  53. core/services/feedback.py +83 -0
  54. core/simulation.py +366 -0
  55. core/types/__init__.py +51 -0
  56. core/types/analysis.py +28 -0
  57. core/types/enums.py +6 -0
  58. core/types/events.py +13 -0
  59. core/types/feedback.py +54 -0
  60. core/types/judge_result.py +44 -0
  61. core/types/message.py +7 -0
  62. core/types/records.py +257 -0
  63. core/types/settings.py +9 -0
  64. core/types/simulation.py +9 -0
  65. core/types/trace.py +63 -0
  66. core/types/transcript.py +33 -0
  67. evaluations/__init__.py +0 -0
  68. evaluations/feedback/RECOMMENDATIONS.md +60 -0
  69. evaluations/feedback/RUNBOOK.md +104 -0
  70. evaluations/feedback/artifacts/baseline_low_lit_v1/analysis.csv +17 -0
  71. evaluations/feedback/artifacts/baseline_low_lit_v1/analysis.json +110 -0
  72. evaluations/feedback/artifacts/baseline_low_lit_v1/analysis_experiment.csv +4 -0
  73. evaluations/feedback/artifacts/baseline_low_lit_v1/run_summary.json +13 -0
  74. evaluations/feedback/artifacts/baseline_low_lit_v2_n10_restart/analysis.csv +55 -0
  75. evaluations/feedback/artifacts/baseline_low_lit_v2_n10_restart/analysis.json +110 -0
  76. evaluations/feedback/artifacts/baseline_low_lit_v2_n10_restart/analysis_experiment.csv +11 -0
  77. evaluations/feedback/artifacts/baseline_low_lit_v2_n10_restart/run_summary.json +13 -0
  78. evaluations/feedback/artifacts/baseline_low_lit_v2_n10_stage1/analysis.csv +45 -0
  79. evaluations/feedback/artifacts/baseline_low_lit_v2_n10_stage1/analysis.json +110 -0
  80. evaluations/feedback/artifacts/baseline_low_lit_v2_n10_stage1/analysis_experiment.csv +9 -0
  81. evaluations/feedback/artifacts/baseline_low_lit_v2_n10_stage1/run_summary.json +13 -0
  82. evaluations/feedback/artifacts/baseline_low_lit_v2_n20/analysis.csv +23 -0
  83. evaluations/feedback/artifacts/baseline_low_lit_v2_n20/analysis.json +110 -0
  84. evaluations/feedback/artifacts/baseline_low_lit_v2_n20/run_summary.json +13 -0
  85. evaluations/feedback/artifacts/baseline_low_lit_v2_n20_rerun1/analysis.csv +41 -0
  86. evaluations/feedback/artifacts/baseline_low_lit_v2_n20_rerun1/analysis.json +110 -0
  87. evaluations/feedback/artifacts/baseline_low_lit_v2_n20_rerun1/analysis_experiment.csv +15 -0
  88. evaluations/feedback/artifacts/baseline_low_lit_v2_n20_rerun1/run_summary.json +13 -0
  89. evaluations/feedback/artifacts/baseline_smoke/analysis.csv +1 -0
  90. evaluations/feedback/artifacts/baseline_smoke/analysis.json +8 -0
  91. evaluations/feedback/artifacts/baseline_smoke/run_summary.json +13 -0
  92. evaluations/feedback/artifacts/baseline_smoke2/analysis.csv +2 -0
  93. evaluations/feedback/artifacts/baseline_smoke2/analysis.json +38 -0
  94. evaluations/feedback/artifacts/baseline_smoke2/analysis_experiment.csv +2 -0
  95. evaluations/feedback/artifacts/baseline_smoke2/run_summary.json +13 -0
  96. evaluations/feedback/artifacts/baseline_v1/analysis.csv +14 -0
  97. evaluations/feedback/artifacts/baseline_v1/analysis.json +110 -0
  98. evaluations/feedback/artifacts/baseline_v1/analysis_experiment.csv +10 -0
  99. evaluations/feedback/artifacts/baseline_v1/hypotheses.md +39 -0
  100. evaluations/feedback/artifacts/baseline_v1/run_summary.json +13 -0
  101. evaluations/feedback/artifacts/intervention_low_lit_v2/analysis.csv +23 -0
  102. evaluations/feedback/artifacts/intervention_low_lit_v2/analysis.json +110 -0
  103. evaluations/feedback/artifacts/intervention_low_lit_v2/analysis_experiment.csv +7 -0
  104. evaluations/feedback/artifacts/intervention_low_lit_v2/compare_vs_baseline.json +39 -0
  105. evaluations/feedback/artifacts/intervention_low_lit_v2/compare_vs_baseline.md +34 -0
  106. evaluations/feedback/artifacts/intervention_low_lit_v2/run_summary.json +13 -0
  107. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_restart/analysis.csv +61 -0
  108. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_restart/analysis.json +110 -0
  109. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_restart/analysis_experiment.csv +7 -0
  110. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_restart/run_summary.json +13 -0
  111. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/analysis.csv +71 -0
  112. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/analysis.json +110 -0
  113. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/analysis_experiment.csv +11 -0
  114. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/compare_vs_baseline.json +55 -0
  115. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/compare_vs_baseline.md +34 -0
  116. evaluations/feedback/artifacts/intervention_low_lit_v3_n10_retry2/run_summary.json +13 -0
  117. evaluations/feedback/artifacts/intervention_low_lit_v3_n20/analysis.csv +23 -0
  118. evaluations/feedback/artifacts/intervention_low_lit_v3_n20/analysis.json +110 -0
  119. evaluations/feedback/artifacts/intervention_low_lit_v3_n20/run_summary.json +13 -0
  120. evaluations/feedback/artifacts/intervention_v2/analysis.csv +10 -0
  121. evaluations/feedback/artifacts/intervention_v2/analysis.json +102 -0
  122. evaluations/feedback/artifacts/intervention_v2/analysis_experiment.csv +4 -0
  123. evaluations/feedback/artifacts/intervention_v2/compare_vs_baseline.json +61 -0
  124. evaluations/feedback/artifacts/intervention_v2/compare_vs_baseline.md +20 -0
  125. evaluations/feedback/artifacts/intervention_v2/run_summary.json +13 -0
  126. evaluations/judge/__init__.py +0 -0
  127. evaluations/judge/cases/__init__.py +0 -0
  128. evaluations/judge/cases/bad_explanation.py +43 -0
  129. evaluations/judge/cases/cbc_good.py +36 -0
  130. evaluations/judge/cases/cbc_poor.py +32 -0
  131. evaluations/judge/cases/confidence_gap.py +34 -0
  132. evaluations/judge/cases/hba1c_good.py +43 -0
  133. evaluations/judge/cases/liver_passive.py +40 -0
  134. evaluations/judge/cases/metabolic_high_literacy.py +51 -0
  135. evaluations/judge/cases/metformin_mixed.py +39 -0
  136. evaluations/judge/cases/patient_contradicts.py +49 -0
  137. evaluations/judge/cases/short_exchange.py +29 -0
  138. evaluations/judge/output/bad_explanation.json +13 -0
  139. evaluations/judge/output/cbc_good.json +13 -0
  140. evaluations/judge/output/cbc_poor.json +13 -0
  141. evaluations/judge/output/confidence_gap.json +13 -0
  142. evaluations/judge/output/hba1c_good.json +13 -0
  143. evaluations/judge/output/liver_passive.json +13 -0
  144. evaluations/judge/output/metabolic_high_literacy.json +13 -0
  145. evaluations/judge/output/metformin_mixed.json +13 -0
  146. evaluations/judge/output/patient_contradicts.json +13 -0
  147. evaluations/judge/output/short_exchange.json +13 -0
  148. patientzero-0.1.0.dist-info/METADATA +96 -0
  149. patientzero-0.1.0.dist-info/RECORD +151 -0
  150. patientzero-0.1.0.dist-info/WHEEL +4 -0
  151. patientzero-0.1.0.dist-info/licenses/LICENSE +21 -0
backend/README.md ADDED
@@ -0,0 +1,71 @@
1
+ # PatientZero Backend
2
+
3
+ FastAPI server with SQLite database and abstracted LLM provider layer.
4
+
5
+ ## Setup
6
+
7
+ ```bash
8
+ uv sync
9
+ cp ../.env.example ../.env # if not done already
10
+ ```
11
+
12
+ ## Running
13
+
14
+ ```bash
15
+ uv run uvicorn api.main:app --reload
16
+
17
+ ```
18
+
19
+ Server runs at http://localhost:8000
20
+
21
+ ## Project Structure
22
+
23
+ ```
24
+ backend/
25
+ ├── api/
26
+ │ ├── main.py # FastAPI app, CORS, lifespan
27
+ │ ├── dependencies.py # Shared db + provider instances
28
+ │ └── routes/
29
+ │ └── chat.py # Chat + session endpoints
30
+ ├── config/
31
+ │ └── settings.py # Environment config
32
+ ├── db/
33
+ │ ├── database.py # Database class (SQLite, raw queries)
34
+ │ ├── schema.sql # Table definitions
35
+ │ └── queries/
36
+ │ └── sessions.py # Session + turn CRUD
37
+ ├── llm/
38
+ │ ├── base.py # Abstract LLMProvider
39
+ │ ├── mock.py # Mock provider (testing)
40
+ │ └── factory.py # Provider factory
41
+ └── pyproject.toml
42
+ ```
43
+
44
+ ## API Endpoints
45
+
46
+ | Method | Path | Description |
47
+ |--------|------|-------------|
48
+ | `POST` | `/api/sessions` | Create a new chat session |
49
+ | `GET` | `/api/sessions` | List all sessions |
50
+ | `GET` | `/api/sessions/{id}` | Get session with turns |
51
+ | `POST` | `/api/chat` | Send message, receive SSE stream |
52
+
53
+ ## LLM Providers
54
+
55
+ Set `LLM_PROVIDER` in `.env`:
56
+
57
+ | Provider | Value | Status |
58
+ |----------|-------|--------|
59
+ | Mock | `mock` | Available |
60
+ | OpenAI | `openai` | Planned |
61
+ | Claude | `claude` | Planned |
62
+ | Local | `local` | Planned |
63
+
64
+ ## Database
65
+
66
+ SQLite with WAL mode. Tables:
67
+
68
+ - `sessions` — chat sessions (id, title, created_at)
69
+ - `turns` — individual messages (session_id, role, content, turn_number)
70
+
71
+ Database file is created automatically on first run.
backend/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,8 @@
1
+ from core.config.settings import DB_PATH
2
+ from core.db.database import Database
3
+ from core.repositories import RepoSet
4
+ from core.logger import SimulationLogger
5
+
6
+ db = Database(DB_PATH)
7
+ repos = RepoSet.for_db(db)
8
+ logger = SimulationLogger()
backend/api/main.py ADDED
@@ -0,0 +1,81 @@
1
+ import logging
2
+ import traceback
3
+ from contextlib import asynccontextmanager
4
+
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.exceptions import HTTPException, RequestValidationError
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import JSONResponse
9
+
10
+ logger = logging.getLogger("patientzero.api")
11
+
12
+ from core import Experiment
13
+ from core.config.settings import FRONTEND_URL
14
+ from core.examples.medical.config import MEDICAL_EXAMPLE_CONFIG
15
+ from backend.api.dependencies import db, repos
16
+ from backend.api.routes.agents import router as agents_router
17
+ from backend.api.routes.analysis import router as analysis_router
18
+ from backend.api.routes.chat import router as chat_router
19
+ from backend.api.routes.distributions import router as distributions_router
20
+ from backend.api.routes.experiments import router as experiments_router
21
+ from backend.api.routes.settings import router as settings_router
22
+ from backend.api.routes.simulate import router as simulate_router
23
+
24
+
25
+ @asynccontextmanager
26
+ async def lifespan(app: FastAPI):
27
+ db.init()
28
+ if not repos.experiments.list_all():
29
+ Experiment(MEDICAL_EXAMPLE_CONFIG, repos)
30
+ yield
31
+ db.close()
32
+
33
+
34
+ app = FastAPI(title="PatientZero", lifespan=lifespan)
35
+
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=[FRONTEND_URL],
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+
44
+ @app.exception_handler(HTTPException)
45
+ async def http_exception_handler(request: Request, exc: HTTPException):
46
+ if exc.status_code >= 400:
47
+ print(
48
+ f"\033[33m[{exc.status_code}]\033[0m {request.method} {request.url.path} "
49
+ f"→ {exc.detail}",
50
+ flush=True,
51
+ )
52
+ return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
53
+
54
+
55
+ @app.exception_handler(RequestValidationError)
56
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
57
+ print(
58
+ f"\033[33m[422]\033[0m {request.method} {request.url.path} → validation error:\n"
59
+ f" {exc.errors()}",
60
+ flush=True,
61
+ )
62
+ return JSONResponse(status_code=422, content={"detail": exc.errors()})
63
+
64
+
65
+ @app.exception_handler(Exception)
66
+ async def unhandled_exception_handler(request: Request, exc: Exception):
67
+ print(
68
+ f"\033[31m[500]\033[0m {request.method} {request.url.path} → {type(exc).__name__}: {exc}",
69
+ flush=True,
70
+ )
71
+ traceback.print_exc()
72
+ return JSONResponse(status_code=500, content={"detail": f"{type(exc).__name__}: {exc}"})
73
+
74
+
75
+ app.include_router(chat_router, prefix="/api")
76
+ app.include_router(simulate_router, prefix="/api")
77
+ app.include_router(analysis_router, prefix="/api")
78
+ app.include_router(settings_router, prefix="/api")
79
+ app.include_router(experiments_router, prefix="/api")
80
+ app.include_router(distributions_router, prefix="/api")
81
+ app.include_router(agents_router, prefix="/api")
File without changes
@@ -0,0 +1,24 @@
1
+ from fastapi import APIRouter, HTTPException
2
+
3
+ from backend.api.dependencies import repos
4
+
5
+ router = APIRouter()
6
+
7
+
8
+ @router.get("/experiments/{exp_id}/agents")
9
+ def get_experiment_agents(exp_id: str):
10
+ experiment = repos.experiments.get(exp_id)
11
+ if experiment is None:
12
+ raise HTTPException(status_code=404, detail="Experiment not found")
13
+ config = experiment.config
14
+ return {
15
+ "agents": [
16
+ {"name": a.name, "prompt": a.prompt, "model": a.model}
17
+ for a in config.agents
18
+ ],
19
+ "judge": {
20
+ "rubric": dict(config.judge.rubric),
21
+ "instructions": config.judge.instructions,
22
+ "model": config.judge.model,
23
+ },
24
+ }
@@ -0,0 +1,138 @@
1
+ """
2
+ Experiment analysis — aggregates judge scores across each agent's sampled traits.
3
+
4
+ The shape is derived from the experiment's actual distributions: group by
5
+ (agent_name, trait, value) and report per-metric mean/std/n per bucket.
6
+ No hardcoded dimension names — whatever traits the distributions declare
7
+ are what you get.
8
+ """
9
+
10
+ import math
11
+
12
+ from fastapi import APIRouter, HTTPException
13
+
14
+ from backend.api.dependencies import repos
15
+
16
+ router = APIRouter()
17
+
18
+
19
+ # ── Stats helpers ─────────────────────────────────────────────────────────────
20
+
21
+
22
+ def _mean(vals: list[float]) -> float | None:
23
+ return sum(vals) / len(vals) if vals else None
24
+
25
+
26
+ def _std(vals: list[float]) -> float | None:
27
+ if len(vals) < 2:
28
+ return None
29
+ m = _mean(vals)
30
+ assert m is not None
31
+ return math.sqrt(sum((v - m) ** 2 for v in vals) / (len(vals) - 1))
32
+
33
+
34
+ def _score_stats(per_dim_values: dict[str, list[float]]) -> dict:
35
+ out: dict[str, dict] = {}
36
+ for dim, vals in per_dim_values.items():
37
+ m = _mean(vals)
38
+ s = _std(vals)
39
+ out[dim] = {
40
+ "mean": round(m, 2) if m is not None else None,
41
+ "std": round(s, 2) if s is not None else None,
42
+ "n": len(vals),
43
+ }
44
+ return out
45
+
46
+
47
+ def _row_dims(row: dict, dims: set[str]) -> dict[str, list[float]]:
48
+ return {dim: [row[dim]] if row.get(dim) is not None else [] for dim in dims}
49
+
50
+
51
+ # ── Row building ──────────────────────────────────────────────────────────────
52
+
53
+
54
+ def _build_rows(experiment_id: str) -> tuple[list[dict], set[str]]:
55
+ """Join completed simulations with their evaluations and flatten per-dim scores.
56
+
57
+ Each returned row has: ``profiles`` (dict[agent → trait dict]) plus every
58
+ judge rubric dimension at the top level as a single float (mean across
59
+ judge_results for that dimension on that evaluation).
60
+ Returns ``(rows, judge_dimensions)``.
61
+ """
62
+ pairs = repos.evaluations.list_completed_with_evaluations_for_experiment(experiment_id)
63
+ rows: list[dict] = []
64
+ dims: set[str] = set()
65
+ for sim, ev in pairs:
66
+ row: dict = {"profiles": sim.config.profiles}
67
+ for judge in ev.judge_results:
68
+ for name, value in judge.scores.items():
69
+ dims.add(name)
70
+ for dim in dims:
71
+ vals = [
72
+ j.scores.get(dim)
73
+ for j in ev.judge_results
74
+ if j.scores.get(dim) is not None
75
+ ]
76
+ row[dim] = (sum(vals) / len(vals)) if vals else None
77
+ rows.append(row)
78
+ return rows, dims
79
+
80
+
81
+ # ── Grouping ──────────────────────────────────────────────────────────────────
82
+
83
+
84
+ def _aggregate_per_dim(rows: list[dict], dims: set[str]) -> dict[str, list[float]]:
85
+ out: dict[str, list[float]] = {dim: [] for dim in dims}
86
+ for row in rows:
87
+ for dim in dims:
88
+ v = row.get(dim)
89
+ if v is not None:
90
+ out[dim].append(v)
91
+ return out
92
+
93
+
94
+ def _group_by_trait(
95
+ rows: list[dict], dims: set[str], agent_name: str, trait: str
96
+ ) -> dict[str, dict]:
97
+ buckets: dict[str, list[dict]] = {}
98
+ for row in rows:
99
+ value = row.get("profiles", {}).get(agent_name, {}).get(trait)
100
+ if value is None:
101
+ continue
102
+ buckets.setdefault(value, []).append(row)
103
+ return {
104
+ value: _score_stats(_aggregate_per_dim(bucket_rows, dims))
105
+ for value, bucket_rows in buckets.items()
106
+ }
107
+
108
+
109
+ # ── Endpoints ─────────────────────────────────────────────────────────────────
110
+
111
+
112
+ @router.get("/experiments/{exp_id}/analysis")
113
+ def get_experiment_analysis(exp_id: str):
114
+ experiment = repos.experiments.get(exp_id)
115
+ if experiment is None:
116
+ raise HTTPException(status_code=404, detail="Experiment not found")
117
+
118
+ rows, dims = _build_rows(exp_id)
119
+ if not rows:
120
+ return {
121
+ "total_evaluations": 0,
122
+ "overall": {},
123
+ "by_trait": {},
124
+ }
125
+
126
+ overall = _score_stats(_aggregate_per_dim(rows, dims))
127
+
128
+ by_trait: dict[str, dict[str, dict]] = {}
129
+ for agent in experiment.config.agents:
130
+ for trait in agent.distribution.topo_order:
131
+ key = f"{agent.name}.{trait}"
132
+ by_trait[key] = _group_by_trait(rows, dims, agent.name, trait)
133
+
134
+ return {
135
+ "total_evaluations": len(rows),
136
+ "overall": overall,
137
+ "by_trait": by_trait,
138
+ }
@@ -0,0 +1,109 @@
1
+ import json
2
+
3
+ from fastapi import APIRouter, HTTPException
4
+ from pydantic import BaseModel, Field
5
+ from sse_starlette.sse import EventSourceResponse
6
+
7
+ from backend.api.dependencies import db
8
+ from core.config.settings import AVAILABLE_MODELS
9
+ from core.db.queries.sessions import (
10
+ create_session,
11
+ create_turn,
12
+ delete_session,
13
+ get_session,
14
+ get_turn_count,
15
+ get_turns,
16
+ list_sessions,
17
+ update_session_model,
18
+ update_session_title,
19
+ )
20
+ from core.llm.factory import parse_provider_model
21
+
22
+ router = APIRouter()
23
+
24
+
25
+ class CreateSessionRequest(BaseModel):
26
+ model: str = "mock:default"
27
+
28
+
29
+ class ChatRequest(BaseModel):
30
+ session_id: str
31
+ message: str = Field(min_length=1, max_length=10000)
32
+
33
+
34
+ class UpdateSessionRequest(BaseModel):
35
+ model: str
36
+
37
+
38
+ @router.get("/models")
39
+ def get_available_models():
40
+ return AVAILABLE_MODELS
41
+
42
+
43
+ @router.post("/sessions")
44
+ def create_new_session(request: CreateSessionRequest):
45
+ return create_session(db, request.model).to_dict()
46
+
47
+
48
+ @router.get("/sessions")
49
+ def get_all_sessions():
50
+ return [s.to_dict() for s in list_sessions(db)]
51
+
52
+
53
+ @router.get("/sessions/{session_id}")
54
+ def get_session_detail(session_id: str):
55
+ session = get_session(db, session_id)
56
+ if not session:
57
+ raise HTTPException(status_code=404, detail="Session not found")
58
+ turns = get_turns(db, session_id)
59
+ return {**session.to_dict(), "turns": [t.to_dict() for t in turns]}
60
+
61
+
62
+ @router.patch("/sessions/{session_id}")
63
+ def update_session(session_id: str, request: UpdateSessionRequest):
64
+ session = get_session(db, session_id)
65
+ if not session:
66
+ raise HTTPException(status_code=404, detail="Session not found")
67
+ update_session_model(db, session_id, request.model)
68
+ return get_session(db, session_id).to_dict()
69
+
70
+
71
+ @router.delete("/sessions/{session_id}")
72
+ def delete_session_endpoint(session_id: str):
73
+ session = get_session(db, session_id)
74
+ if not session:
75
+ raise HTTPException(status_code=404, detail="Session not found")
76
+ delete_session(db, session_id)
77
+ return {"ok": True}
78
+
79
+
80
+ @router.post("/chat")
81
+ async def chat(request: ChatRequest):
82
+ session = get_session(db, request.session_id)
83
+ if not session:
84
+ raise HTTPException(status_code=404, detail="Session not found")
85
+
86
+ turn_number = get_turn_count(db, request.session_id)
87
+ create_turn(db, request.session_id, "user", request.message, turn_number)
88
+
89
+ if turn_number == 0:
90
+ title = request.message[:50] + ("..." if len(request.message) > 50 else "")
91
+ update_session_title(db, request.session_id, title)
92
+
93
+ turns = get_turns(db, request.session_id)
94
+ messages = [{"role": t.role, "content": t.content} for t in turns]
95
+
96
+ provider, model = parse_provider_model(session.model)
97
+
98
+ async def generate():
99
+ full_response = ""
100
+ try:
101
+ async for chunk in provider.stream(messages, model):
102
+ full_response += chunk
103
+ yield {"data": json.dumps({"token": chunk})}
104
+ create_turn(db, request.session_id, "assistant", full_response, turn_number + 1)
105
+ yield {"event": "done", "data": ""}
106
+ except Exception as e:
107
+ yield {"event": "error", "data": json.dumps({"error": str(e)})}
108
+
109
+ return EventSourceResponse(generate())
@@ -0,0 +1,18 @@
1
+ from fastapi import APIRouter, HTTPException
2
+
3
+ from backend.api.dependencies import repos
4
+ from core.distribution import distribution_to_dict
5
+
6
+ router = APIRouter()
7
+
8
+
9
+ @router.get("/experiments/{exp_id}/distributions/{agent_name}")
10
+ def get_agent_distribution(exp_id: str, agent_name: str):
11
+ experiment = repos.experiments.get(exp_id)
12
+ if experiment is None:
13
+ raise HTTPException(status_code=404, detail="Experiment not found")
14
+ try:
15
+ agent = experiment.config.agent(agent_name)
16
+ except KeyError:
17
+ raise HTTPException(status_code=404, detail=f"Agent {agent_name!r} not found in experiment")
18
+ return {"distribution": distribution_to_dict(agent.distribution)}
@@ -0,0 +1,150 @@
1
+ import threading
2
+
3
+ from fastapi import APIRouter, HTTPException, Query, Response
4
+ from pydantic import BaseModel, Field
5
+
6
+ from backend.api.dependencies import repos
7
+ from core import Experiment
8
+ from core.analysis.coverage import compute_coverage
9
+ from core.config.settings import APP_SETTINGS
10
+ from core.examples.medical.config import MEDICAL_EXAMPLE_CONFIG
11
+ from core.services.feedback import FeedbackService
12
+
13
+ router = APIRouter()
14
+
15
+ _optimize_semaphore = threading.Semaphore(APP_SETTINGS.max_concurrent_optimizations)
16
+
17
+
18
+ class CreateExperimentRequest(BaseModel):
19
+ name: str
20
+
21
+
22
+ class SetCurrentOptimizationTargetBody(BaseModel):
23
+ optimization_target_id: str
24
+
25
+
26
+ def _experiment_or_404(exp_id: str):
27
+ exp = repos.experiments.get(exp_id)
28
+ if not exp:
29
+ raise HTTPException(status_code=404, detail="Experiment not found")
30
+ return exp
31
+
32
+
33
+ @router.post("/experiments")
34
+ def post_experiment(request: CreateExperimentRequest):
35
+ config = MEDICAL_EXAMPLE_CONFIG
36
+ # Shallow-override the name for this instance.
37
+ from dataclasses import replace
38
+ named_config = replace(config, name=request.name)
39
+ try:
40
+ exp = Experiment(named_config, repos)
41
+ except ValueError as e:
42
+ raise HTTPException(status_code=400, detail=str(e))
43
+ record = exp.record
44
+ return record.to_dict(counts=repos.experiments.counts_for(record.id))
45
+
46
+
47
+ @router.get("/experiments")
48
+ def get_experiments():
49
+ exps = repos.experiments.list_all()
50
+ counts_by_id = repos.experiments.counts_all()
51
+ return [e.to_dict(counts=counts_by_id.get(e.id)) for e in exps]
52
+
53
+
54
+ @router.get("/experiments/{exp_id}")
55
+ def get_experiment_by_id(exp_id: str):
56
+ exp = _experiment_or_404(exp_id)
57
+ return exp.to_dict(counts=repos.experiments.counts_for(exp_id))
58
+
59
+
60
+ @router.delete("/experiments/{exp_id}")
61
+ def delete_experiment_by_id(exp_id: str):
62
+ _experiment_or_404(exp_id)
63
+ repos.experiments.delete(exp_id)
64
+ return Response(status_code=204)
65
+
66
+
67
+ @router.patch("/experiments/{exp_id}")
68
+ def patch_experiment(exp_id: str):
69
+ """Currently only supports resetting the sample draw index."""
70
+ _experiment_or_404(exp_id)
71
+ repos.experiments.reset_sample_draw_index(exp_id)
72
+ updated = repos.experiments.get(exp_id)
73
+ assert updated is not None
74
+ return updated.to_dict(counts=repos.experiments.counts_for(exp_id))
75
+
76
+
77
+ # ── Experiment-scoped lists ─────────────────────────────────────────────────
78
+
79
+
80
+ @router.get("/experiments/{exp_id}/simulations")
81
+ def list_experiment_simulations(exp_id: str):
82
+ _experiment_or_404(exp_id)
83
+ return [s.to_dict() for s in repos.simulations.list_for_experiment(exp_id)]
84
+
85
+
86
+ @router.get("/experiments/{exp_id}/evaluations")
87
+ def list_experiment_evaluations(exp_id: str):
88
+ _experiment_or_404(exp_id)
89
+ return [e.to_dict() for e in repos.evaluations.list_for_experiment(exp_id)]
90
+
91
+
92
+ # ── Optimization targets ────────────────────────────────────────────────────
93
+
94
+
95
+ @router.get("/experiments/{exp_id}/optimization-targets")
96
+ def list_experiment_optimization_targets(exp_id: str):
97
+ _experiment_or_404(exp_id)
98
+ return [t.to_dict() for t in repos.optimization_targets.list_for_experiment(exp_id)]
99
+
100
+
101
+ @router.post("/experiments/{exp_id}/optimization-target/current")
102
+ def set_experiment_current_optimization_target(
103
+ exp_id: str, body: SetCurrentOptimizationTargetBody
104
+ ):
105
+ _experiment_or_404(exp_id)
106
+ target = repos.optimization_targets.get(body.optimization_target_id)
107
+ if target is None or target.experiment_id != exp_id:
108
+ raise HTTPException(status_code=404, detail="Optimization target not found for this experiment")
109
+ repos.experiments.set_current_optimization_target(exp_id, body.optimization_target_id)
110
+ updated = repos.experiments.get(exp_id)
111
+ assert updated is not None
112
+ return updated.to_dict(counts=repos.experiments.counts_for(exp_id))
113
+
114
+
115
+ # ── Coverage ────────────────────────────────────────────────────────────────
116
+
117
+
118
+ @router.get("/experiments/{exp_id}/coverage")
119
+ def get_experiment_coverage(
120
+ exp_id: str,
121
+ mc_samples: int = Query(100_000, ge=5_000, le=500_000),
122
+ ):
123
+ exp = _experiment_or_404(exp_id)
124
+ sims = repos.simulations.list_for_experiment(exp_id)
125
+ return compute_coverage(sims, exp.config.agents, samples=mc_samples).to_dict()
126
+
127
+
128
+ # ── Optimize ────────────────────────────────────────────────────────────────
129
+
130
+
131
+ @router.post("/experiments/{exp_id}/optimize")
132
+ async def optimize_experiment(exp_id: str):
133
+ _experiment_or_404(exp_id)
134
+
135
+ if not _optimize_semaphore.acquire(blocking=False):
136
+ raise HTTPException(
137
+ status_code=409,
138
+ detail=(
139
+ "Another optimization run is in progress "
140
+ f"(max_concurrent_optimizations={APP_SETTINGS.max_concurrent_optimizations})"
141
+ ),
142
+ )
143
+ try:
144
+ try:
145
+ result = await FeedbackService(repos).optimize(exp_id)
146
+ except ValueError as e:
147
+ raise HTTPException(status_code=400, detail=str(e))
148
+ return result.to_dict()
149
+ finally:
150
+ _optimize_semaphore.release()
@@ -0,0 +1,12 @@
1
+ from dataclasses import asdict
2
+
3
+ from fastapi import APIRouter
4
+
5
+ from core.config.settings import APP_SETTINGS
6
+
7
+ router = APIRouter()
8
+
9
+
10
+ @router.get("/settings")
11
+ def get_settings():
12
+ return asdict(APP_SETTINGS)