themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
themis/server/app.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""FastAPI application for Themis server.
|
|
2
|
+
|
|
3
|
+
This module defines the main FastAPI app with REST endpoints and WebSocket support.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List
|
|
11
|
+
|
|
12
|
+
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
13
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
14
|
+
from fastapi.responses import JSONResponse
|
|
15
|
+
from fastapi.staticfiles import StaticFiles
|
|
16
|
+
from pydantic import BaseModel, Field
|
|
17
|
+
|
|
18
|
+
from themis.comparison import compare_runs
|
|
19
|
+
from themis.comparison.statistics import StatisticalTest
|
|
20
|
+
from themis.experiment.storage import ExperimentStorage
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RunSummary(BaseModel):
|
|
24
|
+
"""Summary of an experiment run."""
|
|
25
|
+
|
|
26
|
+
run_id: str
|
|
27
|
+
experiment_id: str = "default"
|
|
28
|
+
status: str
|
|
29
|
+
num_samples: int = 0
|
|
30
|
+
metrics: Dict[str, float] = Field(default_factory=dict)
|
|
31
|
+
created_at: str | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RunDetail(BaseModel):
|
|
35
|
+
"""Detailed information about a run."""
|
|
36
|
+
|
|
37
|
+
run_id: str
|
|
38
|
+
experiment_id: str = "default"
|
|
39
|
+
status: str
|
|
40
|
+
num_samples: int
|
|
41
|
+
metrics: Dict[str, float]
|
|
42
|
+
samples: List[Dict[str, Any]] = Field(default_factory=list)
|
|
43
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ComparisonRequest(BaseModel):
|
|
47
|
+
"""Request to compare multiple runs."""
|
|
48
|
+
|
|
49
|
+
run_ids: List[str]
|
|
50
|
+
metrics: List[str] | None = None
|
|
51
|
+
statistical_test: str = "bootstrap"
|
|
52
|
+
alpha: float = 0.05
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ErrorResponse(BaseModel):
|
|
56
|
+
"""Error response model."""
|
|
57
|
+
|
|
58
|
+
error: str
|
|
59
|
+
detail: str | None = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def create_app(storage_path: str | Path = ".cache/experiments") -> FastAPI:
|
|
63
|
+
"""Create FastAPI application.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
storage_path: Path to experiment storage
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Configured FastAPI application
|
|
70
|
+
"""
|
|
71
|
+
app = FastAPI(
|
|
72
|
+
title="Themis API",
|
|
73
|
+
description="REST API for Themis experiment management",
|
|
74
|
+
version="2.0.0",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Enable CORS for web dashboard
|
|
78
|
+
app.add_middleware(
|
|
79
|
+
CORSMiddleware,
|
|
80
|
+
allow_origins=["*"], # Configure appropriately for production
|
|
81
|
+
allow_credentials=True,
|
|
82
|
+
allow_methods=["*"],
|
|
83
|
+
allow_headers=["*"],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Initialize storage
|
|
87
|
+
storage = ExperimentStorage(storage_path)
|
|
88
|
+
|
|
89
|
+
# Mount static files (dashboard)
|
|
90
|
+
static_dir = Path(__file__).parent / "static"
|
|
91
|
+
if static_dir.exists():
|
|
92
|
+
app.mount("/dashboard", StaticFiles(directory=str(static_dir), html=True), name="static")
|
|
93
|
+
|
|
94
|
+
# WebSocket connection manager
|
|
95
|
+
class ConnectionManager:
|
|
96
|
+
def __init__(self):
|
|
97
|
+
self.active_connections: List[WebSocket] = []
|
|
98
|
+
|
|
99
|
+
async def connect(self, websocket: WebSocket):
|
|
100
|
+
await websocket.accept()
|
|
101
|
+
self.active_connections.append(websocket)
|
|
102
|
+
|
|
103
|
+
def disconnect(self, websocket: WebSocket):
|
|
104
|
+
self.active_connections.remove(websocket)
|
|
105
|
+
|
|
106
|
+
async def broadcast(self, message: dict):
|
|
107
|
+
for connection in self.active_connections:
|
|
108
|
+
await connection.send_json(message)
|
|
109
|
+
|
|
110
|
+
manager = ConnectionManager()
|
|
111
|
+
|
|
112
|
+
# ===== REST ENDPOINTS =====
|
|
113
|
+
|
|
114
|
+
@app.get("/", tags=["health"])
|
|
115
|
+
async def root():
|
|
116
|
+
"""Health check endpoint."""
|
|
117
|
+
return {
|
|
118
|
+
"status": "ok",
|
|
119
|
+
"service": "themis-api",
|
|
120
|
+
"version": "2.0.0",
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
@app.get("/api/runs", response_model=List[RunSummary], tags=["runs"])
|
|
124
|
+
async def list_runs():
|
|
125
|
+
"""List all experiment runs."""
|
|
126
|
+
run_ids = storage.list_runs()
|
|
127
|
+
|
|
128
|
+
summaries = []
|
|
129
|
+
for run_id in run_ids:
|
|
130
|
+
# Load basic info
|
|
131
|
+
eval_records = storage.load_cached_evaluations(run_id)
|
|
132
|
+
|
|
133
|
+
# Calculate average metrics
|
|
134
|
+
metrics_dict: Dict[str, List[float]] = {}
|
|
135
|
+
for record in eval_records.values():
|
|
136
|
+
for metric_name, score_obj in record.scores.items():
|
|
137
|
+
if metric_name not in metrics_dict:
|
|
138
|
+
metrics_dict[metric_name] = []
|
|
139
|
+
|
|
140
|
+
# Extract numeric score
|
|
141
|
+
if hasattr(score_obj, 'value'):
|
|
142
|
+
metrics_dict[metric_name].append(score_obj.value)
|
|
143
|
+
elif isinstance(score_obj, (int, float)):
|
|
144
|
+
metrics_dict[metric_name].append(float(score_obj))
|
|
145
|
+
|
|
146
|
+
# Average metrics
|
|
147
|
+
avg_metrics = {
|
|
148
|
+
name: sum(scores) / len(scores) if scores else 0.0
|
|
149
|
+
for name, scores in metrics_dict.items()
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
summaries.append(RunSummary(
|
|
153
|
+
run_id=run_id,
|
|
154
|
+
experiment_id="default",
|
|
155
|
+
status="completed",
|
|
156
|
+
num_samples=len(eval_records),
|
|
157
|
+
metrics=avg_metrics,
|
|
158
|
+
))
|
|
159
|
+
|
|
160
|
+
return summaries
|
|
161
|
+
|
|
162
|
+
@app.get("/api/runs/{run_id}", response_model=RunDetail, tags=["runs"])
|
|
163
|
+
async def get_run(run_id: str):
|
|
164
|
+
"""Get detailed information about a run."""
|
|
165
|
+
if run_id not in storage.list_runs():
|
|
166
|
+
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
|
|
167
|
+
|
|
168
|
+
# Load records
|
|
169
|
+
eval_records = storage.load_cached_evaluations(run_id)
|
|
170
|
+
gen_records_dict = storage.load_cached_records(run_id)
|
|
171
|
+
|
|
172
|
+
# Calculate metrics
|
|
173
|
+
metrics_dict: Dict[str, List[float]] = {}
|
|
174
|
+
samples = []
|
|
175
|
+
|
|
176
|
+
for cache_key, eval_record in eval_records.items():
|
|
177
|
+
# Get generation record
|
|
178
|
+
gen_record = gen_records_dict.get(cache_key)
|
|
179
|
+
|
|
180
|
+
# Extract scores
|
|
181
|
+
scores = {}
|
|
182
|
+
for metric_name, score_obj in eval_record.scores.items():
|
|
183
|
+
if hasattr(score_obj, 'value'):
|
|
184
|
+
value = score_obj.value
|
|
185
|
+
elif isinstance(score_obj, (int, float)):
|
|
186
|
+
value = float(score_obj)
|
|
187
|
+
else:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
scores[metric_name] = value
|
|
191
|
+
|
|
192
|
+
if metric_name not in metrics_dict:
|
|
193
|
+
metrics_dict[metric_name] = []
|
|
194
|
+
metrics_dict[metric_name].append(value)
|
|
195
|
+
|
|
196
|
+
# Build sample
|
|
197
|
+
sample = {
|
|
198
|
+
"id": gen_record.id if gen_record else cache_key,
|
|
199
|
+
"prompt": gen_record.prompt if gen_record else "",
|
|
200
|
+
"response": gen_record.response if gen_record else "",
|
|
201
|
+
"scores": scores,
|
|
202
|
+
}
|
|
203
|
+
samples.append(sample)
|
|
204
|
+
|
|
205
|
+
# Average metrics
|
|
206
|
+
avg_metrics = {
|
|
207
|
+
name: sum(scores) / len(scores) if scores else 0.0
|
|
208
|
+
for name, scores in metrics_dict.items()
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
return RunDetail(
|
|
212
|
+
run_id=run_id,
|
|
213
|
+
experiment_id="default",
|
|
214
|
+
status="completed",
|
|
215
|
+
num_samples=len(eval_records),
|
|
216
|
+
metrics=avg_metrics,
|
|
217
|
+
samples=samples,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
@app.delete("/api/runs/{run_id}", tags=["runs"])
|
|
221
|
+
async def delete_run(run_id: str):
|
|
222
|
+
"""Delete a run."""
|
|
223
|
+
if run_id not in storage.list_runs():
|
|
224
|
+
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
|
|
225
|
+
|
|
226
|
+
# Note: Current storage doesn't implement delete
|
|
227
|
+
# This is a placeholder for future implementation
|
|
228
|
+
raise HTTPException(
|
|
229
|
+
status_code=501,
|
|
230
|
+
detail="Delete not implemented in current storage"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
@app.post("/api/compare", tags=["comparison"])
|
|
234
|
+
async def compare_runs_api(request: ComparisonRequest):
|
|
235
|
+
"""Compare multiple runs."""
|
|
236
|
+
# Validate runs exist
|
|
237
|
+
existing_runs = set(storage.list_runs())
|
|
238
|
+
for run_id in request.run_ids:
|
|
239
|
+
if run_id not in existing_runs:
|
|
240
|
+
raise HTTPException(
|
|
241
|
+
status_code=404,
|
|
242
|
+
detail=f"Run not found: {run_id}"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
if len(request.run_ids) < 2:
|
|
246
|
+
raise HTTPException(
|
|
247
|
+
status_code=400,
|
|
248
|
+
detail="Need at least 2 runs to compare"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Parse statistical test
|
|
252
|
+
try:
|
|
253
|
+
test_enum = StatisticalTest(request.statistical_test)
|
|
254
|
+
except ValueError:
|
|
255
|
+
raise HTTPException(
|
|
256
|
+
status_code=400,
|
|
257
|
+
detail=f"Invalid statistical test: {request.statistical_test}"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Run comparison
|
|
261
|
+
report = compare_runs(
|
|
262
|
+
run_ids=request.run_ids,
|
|
263
|
+
storage_path=storage._base_dir,
|
|
264
|
+
metrics=request.metrics,
|
|
265
|
+
statistical_test=test_enum,
|
|
266
|
+
alpha=request.alpha,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
return report.to_dict()
|
|
270
|
+
|
|
271
|
+
@app.get("/api/benchmarks", tags=["presets"])
|
|
272
|
+
async def list_benchmarks():
|
|
273
|
+
"""List available benchmark presets."""
|
|
274
|
+
from themis.presets import list_benchmarks
|
|
275
|
+
|
|
276
|
+
benchmarks = list_benchmarks()
|
|
277
|
+
return {"benchmarks": benchmarks}
|
|
278
|
+
|
|
279
|
+
# ===== WEBSOCKET ENDPOINTS =====
|
|
280
|
+
|
|
281
|
+
@app.websocket("/ws")
|
|
282
|
+
async def websocket_endpoint(websocket: WebSocket):
|
|
283
|
+
"""WebSocket endpoint for real-time updates.
|
|
284
|
+
|
|
285
|
+
Messages sent from server:
|
|
286
|
+
- {"type": "run_started", "run_id": "...", "data": {...}}
|
|
287
|
+
- {"type": "run_progress", "run_id": "...", "progress": 0.5}
|
|
288
|
+
- {"type": "run_completed", "run_id": "...", "data": {...}}
|
|
289
|
+
- {"type": "error", "message": "..."}
|
|
290
|
+
|
|
291
|
+
Messages expected from client:
|
|
292
|
+
- {"type": "subscribe", "run_id": "..."}
|
|
293
|
+
- {"type": "unsubscribe", "run_id": "..."}
|
|
294
|
+
- {"type": "ping"}
|
|
295
|
+
"""
|
|
296
|
+
await manager.connect(websocket)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
while True:
|
|
300
|
+
# Receive message from client
|
|
301
|
+
data = await websocket.receive_text()
|
|
302
|
+
message = json.loads(data)
|
|
303
|
+
|
|
304
|
+
msg_type = message.get("type")
|
|
305
|
+
|
|
306
|
+
if msg_type == "ping":
|
|
307
|
+
await websocket.send_json({"type": "pong"})
|
|
308
|
+
|
|
309
|
+
elif msg_type == "subscribe":
|
|
310
|
+
run_id = message.get("run_id")
|
|
311
|
+
# TODO: Implement run subscription logic
|
|
312
|
+
await websocket.send_json({
|
|
313
|
+
"type": "subscribed",
|
|
314
|
+
"run_id": run_id
|
|
315
|
+
})
|
|
316
|
+
|
|
317
|
+
elif msg_type == "unsubscribe":
|
|
318
|
+
run_id = message.get("run_id")
|
|
319
|
+
# TODO: Implement unsubscribe logic
|
|
320
|
+
await websocket.send_json({
|
|
321
|
+
"type": "unsubscribed",
|
|
322
|
+
"run_id": run_id
|
|
323
|
+
})
|
|
324
|
+
|
|
325
|
+
else:
|
|
326
|
+
await websocket.send_json({
|
|
327
|
+
"type": "error",
|
|
328
|
+
"message": f"Unknown message type: {msg_type}"
|
|
329
|
+
})
|
|
330
|
+
|
|
331
|
+
except WebSocketDisconnect:
|
|
332
|
+
manager.disconnect(websocket)
|
|
333
|
+
|
|
334
|
+
return app
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
__all__ = ["create_app"]
|
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""API generation utilities for creating REST APIs from Python functions.
|
|
2
|
+
|
|
3
|
+
This module provides tools to automatically generate REST API endpoints from
|
|
4
|
+
Python functions using their docstrings and type hints. It leverages FastAPI
|
|
5
|
+
for automatic OpenAPI schema generation.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
```python
|
|
9
|
+
from themis.utils.api_generator import create_api_from_module
|
|
10
|
+
|
|
11
|
+
# Generate API from a module
|
|
12
|
+
app = create_api_from_module(
|
|
13
|
+
module=themis.evaluation.statistics,
|
|
14
|
+
prefix="/api/v1/statistics"
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Run the API server
|
|
18
|
+
# uvicorn main:app --reload
|
|
19
|
+
```
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import inspect
|
|
25
|
+
from typing import Any, Callable, Dict, List, get_type_hints
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from fastapi import FastAPI, HTTPException
|
|
29
|
+
from pydantic import BaseModel, create_model
|
|
30
|
+
|
|
31
|
+
FASTAPI_AVAILABLE = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
FASTAPI_AVAILABLE = False
|
|
34
|
+
FastAPI = None
|
|
35
|
+
HTTPException = None
|
|
36
|
+
BaseModel = None
|
|
37
|
+
create_model = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class APIGenerationError(Exception):
|
|
41
|
+
"""Exception raised when API generation fails."""
|
|
42
|
+
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def create_api_from_functions(
|
|
47
|
+
functions: List[Callable],
|
|
48
|
+
title: str = "Auto-Generated API",
|
|
49
|
+
description: str = "API generated from Python functions",
|
|
50
|
+
version: str = "1.0.0",
|
|
51
|
+
prefix: str = "",
|
|
52
|
+
) -> Any:
|
|
53
|
+
"""Create a FastAPI application from a list of functions.
|
|
54
|
+
|
|
55
|
+
This function inspects each function's signature, type hints, and docstring
|
|
56
|
+
to automatically generate REST API endpoints with proper request/response
|
|
57
|
+
validation and OpenAPI documentation.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
functions: List of functions to expose as API endpoints
|
|
61
|
+
title: API title
|
|
62
|
+
description: API description
|
|
63
|
+
version: API version
|
|
64
|
+
prefix: URL prefix for all endpoints (e.g., "/api/v1")
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
FastAPI application instance
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
APIGenerationError: If FastAPI is not installed or function inspection fails
|
|
71
|
+
|
|
72
|
+
Example:
|
|
73
|
+
```python
|
|
74
|
+
from themis.evaluation.statistics import compute_confidence_interval
|
|
75
|
+
|
|
76
|
+
app = create_api_from_functions(
|
|
77
|
+
functions=[compute_confidence_interval],
|
|
78
|
+
title="Statistics API",
|
|
79
|
+
prefix="/api/stats"
|
|
80
|
+
)
|
|
81
|
+
```
|
|
82
|
+
"""
|
|
83
|
+
if not FASTAPI_AVAILABLE:
|
|
84
|
+
raise APIGenerationError(
|
|
85
|
+
"FastAPI is not installed. Install it with: pip install fastapi uvicorn"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
app = FastAPI(title=title, description=description, version=version)
|
|
89
|
+
|
|
90
|
+
for func in functions:
|
|
91
|
+
_register_function_as_endpoint(app, func, prefix)
|
|
92
|
+
|
|
93
|
+
return app
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def create_api_from_module(
|
|
97
|
+
module: Any,
|
|
98
|
+
title: str | None = None,
|
|
99
|
+
description: str | None = None,
|
|
100
|
+
version: str = "1.0.0",
|
|
101
|
+
prefix: str = "",
|
|
102
|
+
include_private: bool = False,
|
|
103
|
+
) -> Any:
|
|
104
|
+
"""Create a FastAPI application from all functions in a module.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
module: Python module containing functions to expose
|
|
108
|
+
title: API title (defaults to module name)
|
|
109
|
+
description: API description (defaults to module docstring)
|
|
110
|
+
version: API version
|
|
111
|
+
prefix: URL prefix for all endpoints
|
|
112
|
+
include_private: Whether to include private functions (starting with _)
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
FastAPI application instance
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
APIGenerationError: If FastAPI is not installed
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
```python
|
|
122
|
+
from themis.evaluation import statistics
|
|
123
|
+
|
|
124
|
+
app = create_api_from_module(
|
|
125
|
+
module=statistics,
|
|
126
|
+
prefix="/api/stats"
|
|
127
|
+
)
|
|
128
|
+
```
|
|
129
|
+
"""
|
|
130
|
+
if not FASTAPI_AVAILABLE:
|
|
131
|
+
raise APIGenerationError(
|
|
132
|
+
"FastAPI is not installed. Install it with: pip install fastapi uvicorn"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Extract module metadata
|
|
136
|
+
if title is None:
|
|
137
|
+
title = f"{module.__name__} API"
|
|
138
|
+
|
|
139
|
+
if description is None:
|
|
140
|
+
description = inspect.getdoc(module) or f"API for {module.__name__}"
|
|
141
|
+
|
|
142
|
+
# Find all functions in the module
|
|
143
|
+
functions = []
|
|
144
|
+
for name, obj in inspect.getmembers(module, inspect.isfunction):
|
|
145
|
+
# Skip private functions unless explicitly included
|
|
146
|
+
if not include_private and name.startswith("_"):
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
# Only include functions defined in this module
|
|
150
|
+
if obj.__module__ == module.__name__:
|
|
151
|
+
functions.append(obj)
|
|
152
|
+
|
|
153
|
+
return create_api_from_functions(
|
|
154
|
+
functions=functions,
|
|
155
|
+
title=title,
|
|
156
|
+
description=description,
|
|
157
|
+
version=version,
|
|
158
|
+
prefix=prefix,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _register_function_as_endpoint(
|
|
163
|
+
app: Any,
|
|
164
|
+
func: Callable,
|
|
165
|
+
prefix: str = "",
|
|
166
|
+
) -> None:
|
|
167
|
+
"""Register a single function as a POST endpoint in the FastAPI app.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
app: FastAPI application instance
|
|
171
|
+
func: Function to register
|
|
172
|
+
prefix: URL prefix for the endpoint
|
|
173
|
+
"""
|
|
174
|
+
func_name = func.__name__
|
|
175
|
+
endpoint_path = f"{prefix}/{func_name}".replace("//", "/")
|
|
176
|
+
|
|
177
|
+
# Get function signature and type hints
|
|
178
|
+
sig = inspect.signature(func)
|
|
179
|
+
type_hints = get_type_hints(func)
|
|
180
|
+
|
|
181
|
+
# Extract docstring
|
|
182
|
+
docstring = inspect.getdoc(func) or f"Execute {func_name}"
|
|
183
|
+
|
|
184
|
+
# Parse docstring to extract parameter descriptions
|
|
185
|
+
param_docs = _parse_docstring_params(docstring)
|
|
186
|
+
|
|
187
|
+
# Build Pydantic model for request body
|
|
188
|
+
request_model = _create_request_model(func_name, sig, type_hints, param_docs)
|
|
189
|
+
|
|
190
|
+
# Create endpoint function
|
|
191
|
+
async def endpoint(request: request_model): # type: ignore
|
|
192
|
+
try:
|
|
193
|
+
# Convert request model to dict
|
|
194
|
+
params = request.dict()
|
|
195
|
+
|
|
196
|
+
# Call the original function
|
|
197
|
+
result = func(**params)
|
|
198
|
+
|
|
199
|
+
return {"result": result}
|
|
200
|
+
except Exception as e:
|
|
201
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
202
|
+
|
|
203
|
+
# Set endpoint metadata
|
|
204
|
+
endpoint.__name__ = f"endpoint_{func_name}"
|
|
205
|
+
endpoint.__doc__ = docstring
|
|
206
|
+
|
|
207
|
+
# Register the endpoint
|
|
208
|
+
app.post(
|
|
209
|
+
endpoint_path,
|
|
210
|
+
response_model=Dict[str, Any],
|
|
211
|
+
summary=f"Execute {func_name}",
|
|
212
|
+
description=docstring,
|
|
213
|
+
)(endpoint)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _create_request_model(
|
|
217
|
+
func_name: str,
|
|
218
|
+
sig: inspect.Signature,
|
|
219
|
+
type_hints: Dict[str, type],
|
|
220
|
+
param_docs: Dict[str, str],
|
|
221
|
+
) -> type:
|
|
222
|
+
"""Create a Pydantic model for function parameters.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
func_name: Function name (used for model name)
|
|
226
|
+
sig: Function signature
|
|
227
|
+
type_hints: Type hints dictionary
|
|
228
|
+
param_docs: Parameter documentation from docstring
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Pydantic model class
|
|
232
|
+
"""
|
|
233
|
+
fields = {}
|
|
234
|
+
|
|
235
|
+
for param_name, param in sig.parameters.items():
|
|
236
|
+
# Skip self/cls parameters
|
|
237
|
+
if param_name in ("self", "cls"):
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
# Get type hint or default to Any
|
|
241
|
+
param_type = type_hints.get(param_name, Any)
|
|
242
|
+
|
|
243
|
+
# Get default value
|
|
244
|
+
if param.default is inspect.Parameter.empty:
|
|
245
|
+
default = ... # Required field
|
|
246
|
+
else:
|
|
247
|
+
default = param.default
|
|
248
|
+
|
|
249
|
+
# Get description from docstring
|
|
250
|
+
description = param_docs.get(param_name, "")
|
|
251
|
+
|
|
252
|
+
# Create field with description
|
|
253
|
+
fields[param_name] = (param_type, default)
|
|
254
|
+
|
|
255
|
+
# Create model name
|
|
256
|
+
model_name = f"{func_name.title().replace('_', '')}Request"
|
|
257
|
+
|
|
258
|
+
# Create and return Pydantic model
|
|
259
|
+
return create_model(model_name, **fields)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _parse_docstring_params(docstring: str) -> Dict[str, str]:
|
|
263
|
+
"""Parse parameter descriptions from Google-style docstring.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
docstring: Function docstring
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Dictionary mapping parameter names to descriptions
|
|
270
|
+
"""
|
|
271
|
+
param_docs = {}
|
|
272
|
+
|
|
273
|
+
if not docstring:
|
|
274
|
+
return param_docs
|
|
275
|
+
|
|
276
|
+
# Look for Args section
|
|
277
|
+
lines = docstring.split("\n")
|
|
278
|
+
in_args_section = False
|
|
279
|
+
current_param = None
|
|
280
|
+
current_desc = []
|
|
281
|
+
|
|
282
|
+
for line in lines:
|
|
283
|
+
stripped = line.strip()
|
|
284
|
+
|
|
285
|
+
# Check if we're entering Args section
|
|
286
|
+
if stripped.lower().startswith("args:"):
|
|
287
|
+
in_args_section = True
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
# Check if we're leaving Args section
|
|
291
|
+
if in_args_section and stripped and not line.startswith(" "):
|
|
292
|
+
break
|
|
293
|
+
|
|
294
|
+
if in_args_section and stripped:
|
|
295
|
+
# Check if this is a parameter line (has a colon)
|
|
296
|
+
if ":" in stripped and not stripped.startswith(":"):
|
|
297
|
+
# Save previous parameter
|
|
298
|
+
if current_param:
|
|
299
|
+
param_docs[current_param] = " ".join(current_desc).strip()
|
|
300
|
+
|
|
301
|
+
# Parse new parameter
|
|
302
|
+
parts = stripped.split(":", 1)
|
|
303
|
+
current_param = parts[0].strip()
|
|
304
|
+
if len(parts) > 1:
|
|
305
|
+
current_desc = [parts[1].strip()]
|
|
306
|
+
else:
|
|
307
|
+
current_desc = []
|
|
308
|
+
elif current_param:
|
|
309
|
+
# Continue description from previous line
|
|
310
|
+
current_desc.append(stripped)
|
|
311
|
+
|
|
312
|
+
# Save last parameter
|
|
313
|
+
if current_param:
|
|
314
|
+
param_docs[current_param] = " ".join(current_desc).strip()
|
|
315
|
+
|
|
316
|
+
return param_docs
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def generate_api_documentation(
|
|
320
|
+
app: Any,
|
|
321
|
+
output_path: str = "api_docs.md",
|
|
322
|
+
) -> None:
|
|
323
|
+
"""Generate markdown documentation for a FastAPI application.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
app: FastAPI application instance
|
|
327
|
+
output_path: Path to output markdown file
|
|
328
|
+
"""
|
|
329
|
+
if not FASTAPI_AVAILABLE:
|
|
330
|
+
raise APIGenerationError("FastAPI is not installed")
|
|
331
|
+
|
|
332
|
+
lines = [
|
|
333
|
+
f"# {app.title}",
|
|
334
|
+
"",
|
|
335
|
+
app.description,
|
|
336
|
+
"",
|
|
337
|
+
f"**Version:** {app.version}",
|
|
338
|
+
"",
|
|
339
|
+
"## Endpoints",
|
|
340
|
+
"",
|
|
341
|
+
]
|
|
342
|
+
|
|
343
|
+
for route in app.routes:
|
|
344
|
+
if hasattr(route, "methods") and "POST" in route.methods:
|
|
345
|
+
lines.append(f"### `POST {route.path}`")
|
|
346
|
+
lines.append("")
|
|
347
|
+
if route.description:
|
|
348
|
+
lines.append(route.description)
|
|
349
|
+
lines.append("")
|
|
350
|
+
|
|
351
|
+
with open(output_path, "w") as f:
|
|
352
|
+
f.write("\n".join(lines))
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
__all__ = [
|
|
356
|
+
"create_api_from_functions",
|
|
357
|
+
"create_api_from_module",
|
|
358
|
+
"generate_api_documentation",
|
|
359
|
+
"APIGenerationError",
|
|
360
|
+
]
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
# Example usage
|
|
364
|
+
if __name__ == "__main__":
|
|
365
|
+
# Check if FastAPI is available
|
|
366
|
+
if not FASTAPI_AVAILABLE:
|
|
367
|
+
print("FastAPI is not installed. Install with: pip install fastapi uvicorn")
|
|
368
|
+
exit(1)
|
|
369
|
+
|
|
370
|
+
# Example: Create API from evaluation.statistics module
|
|
371
|
+
print("Example: Creating API from functions...")
|
|
372
|
+
print("To use this utility:")
|
|
373
|
+
print("1. Install FastAPI: pip install fastapi uvicorn")
|
|
374
|
+
print("2. Create an API:")
|
|
375
|
+
print(" from themis.utils.api_generator import create_api_from_module")
|
|
376
|
+
print(" from themis.evaluation import statistics")
|
|
377
|
+
print(" app = create_api_from_module(statistics, prefix='/api/stats')")
|
|
378
|
+
print("3. Run the server:")
|
|
379
|
+
print(" uvicorn your_module:app --reload")
|