themis-eval 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +429 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/main.py +427 -57
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/core/entities.py +23 -3
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/pipelines/standard_pipeline.py +68 -8
- themis/experiment/cache_manager.py +8 -3
- themis/experiment/export.py +110 -2
- themis/experiment/orchestrator.py +109 -11
- themis/experiment/storage.py +1457 -110
- themis/generation/providers/litellm_provider.py +46 -0
- themis/generation/runner.py +22 -6
- themis/integrations/huggingface.py +12 -1
- themis/integrations/wandb.py +13 -1
- themis/interfaces/__init__.py +86 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis_eval-0.2.1.dist-info/METADATA +596 -0
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.1.dist-info}/RECORD +42 -19
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.1.dist-info}/WHEEL +1 -1
- themis_eval-0.1.1.dist-info/METADATA +0 -758
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.1.dist-info}/top_level.txt +0 -0
themis/presets/models.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Model name parsing and provider detection.
|
|
2
|
+
|
|
3
|
+
This module automatically detects the appropriate provider based on
|
|
4
|
+
model names, eliminating the need for users to specify providers manually.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def parse_model_name(model: str, **kwargs: Any) -> tuple[str, str, dict[str, Any]]:
|
|
14
|
+
"""Parse model name and detect provider.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
model: Model identifier (e.g., "gpt-4", "claude-3-opus", "llama-2-70b")
|
|
18
|
+
**kwargs: Additional provider-specific options
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Tuple of (provider_name, model_id, provider_options)
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
>>> parse_model_name("gpt-4")
|
|
25
|
+
("litellm", "gpt-4", {})
|
|
26
|
+
|
|
27
|
+
>>> parse_model_name("claude-3-opus-20240229")
|
|
28
|
+
("litellm", "claude-3-opus-20240229", {})
|
|
29
|
+
|
|
30
|
+
>>> parse_model_name("local-llm", base_url="http://localhost:1234/v1")
|
|
31
|
+
("litellm", "local-llm", {"base_url": "http://localhost:1234/v1"})
|
|
32
|
+
"""
|
|
33
|
+
model_lower = model.lower()
|
|
34
|
+
|
|
35
|
+
# OpenAI models
|
|
36
|
+
if any(pattern in model_lower for pattern in ["gpt-", "o1-", "text-davinci"]):
|
|
37
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
38
|
+
|
|
39
|
+
# Anthropic models
|
|
40
|
+
if "claude" in model_lower:
|
|
41
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
42
|
+
|
|
43
|
+
# Google models
|
|
44
|
+
if any(pattern in model_lower for pattern in ["gemini", "palm"]):
|
|
45
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
46
|
+
|
|
47
|
+
# Meta models
|
|
48
|
+
if "llama" in model_lower:
|
|
49
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
50
|
+
|
|
51
|
+
# Mistral models
|
|
52
|
+
if "mistral" in model_lower or "mixtral" in model_lower:
|
|
53
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
54
|
+
|
|
55
|
+
# Cohere models
|
|
56
|
+
if "command" in model_lower and "xl" in model_lower:
|
|
57
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
58
|
+
|
|
59
|
+
# AI21 models
|
|
60
|
+
if "j2-" in model_lower:
|
|
61
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
62
|
+
|
|
63
|
+
# Fake model for testing
|
|
64
|
+
if "fake" in model_lower:
|
|
65
|
+
return "fake", model, {}
|
|
66
|
+
|
|
67
|
+
# Default: assume it's a litellm-compatible model
|
|
68
|
+
# User can provide base_url for custom endpoints
|
|
69
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _extract_provider_options(kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
73
|
+
"""Extract provider-specific options from kwargs.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
kwargs: Dictionary of options
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Dictionary of provider options
|
|
80
|
+
"""
|
|
81
|
+
provider_options = {}
|
|
82
|
+
|
|
83
|
+
# Known provider options
|
|
84
|
+
option_keys = [
|
|
85
|
+
"api_key",
|
|
86
|
+
"base_url",
|
|
87
|
+
"api_base",
|
|
88
|
+
"api_version",
|
|
89
|
+
"timeout",
|
|
90
|
+
"max_retries",
|
|
91
|
+
"n_parallel",
|
|
92
|
+
"organization",
|
|
93
|
+
"api_type",
|
|
94
|
+
"region_name",
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
for key in option_keys:
|
|
98
|
+
if key in kwargs:
|
|
99
|
+
provider_options[key] = kwargs[key]
|
|
100
|
+
|
|
101
|
+
return provider_options
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_provider_for_model(model: str) -> str:
|
|
105
|
+
"""Get provider name for a model (without parsing full options).
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
model: Model identifier
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Provider name
|
|
112
|
+
|
|
113
|
+
Examples:
|
|
114
|
+
>>> get_provider_for_model("gpt-4")
|
|
115
|
+
"litellm"
|
|
116
|
+
|
|
117
|
+
>>> get_provider_for_model("claude-3-opus")
|
|
118
|
+
"litellm"
|
|
119
|
+
"""
|
|
120
|
+
provider, _, _ = parse_model_name(model)
|
|
121
|
+
return provider
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# Model family detection for preset selection
|
|
125
|
+
def get_model_family(model: str) -> str:
|
|
126
|
+
"""Get the model family for capability detection.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
model: Model identifier
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Model family name
|
|
133
|
+
|
|
134
|
+
Examples:
|
|
135
|
+
>>> get_model_family("gpt-4-turbo")
|
|
136
|
+
"gpt-4"
|
|
137
|
+
|
|
138
|
+
>>> get_model_family("claude-3-opus-20240229")
|
|
139
|
+
"claude-3"
|
|
140
|
+
"""
|
|
141
|
+
model_lower = model.lower()
|
|
142
|
+
|
|
143
|
+
# OpenAI families
|
|
144
|
+
if "gpt-4" in model_lower:
|
|
145
|
+
return "gpt-4"
|
|
146
|
+
if "gpt-3.5" in model_lower:
|
|
147
|
+
return "gpt-3.5"
|
|
148
|
+
if "o1" in model_lower:
|
|
149
|
+
return "o1"
|
|
150
|
+
|
|
151
|
+
# Anthropic families
|
|
152
|
+
if "claude-3" in model_lower:
|
|
153
|
+
if "opus" in model_lower:
|
|
154
|
+
return "claude-3-opus"
|
|
155
|
+
elif "sonnet" in model_lower:
|
|
156
|
+
return "claude-3-sonnet"
|
|
157
|
+
elif "haiku" in model_lower:
|
|
158
|
+
return "claude-3-haiku"
|
|
159
|
+
return "claude-3"
|
|
160
|
+
if "claude-2" in model_lower:
|
|
161
|
+
return "claude-2"
|
|
162
|
+
|
|
163
|
+
# Google families
|
|
164
|
+
if "gemini-pro" in model_lower:
|
|
165
|
+
return "gemini-pro"
|
|
166
|
+
if "gemini-ultra" in model_lower:
|
|
167
|
+
return "gemini-ultra"
|
|
168
|
+
|
|
169
|
+
# Meta families
|
|
170
|
+
if "llama-2" in model_lower:
|
|
171
|
+
if "70b" in model_lower:
|
|
172
|
+
return "llama-2-70b"
|
|
173
|
+
elif "13b" in model_lower:
|
|
174
|
+
return "llama-2-13b"
|
|
175
|
+
elif "7b" in model_lower:
|
|
176
|
+
return "llama-2-7b"
|
|
177
|
+
return "llama-2"
|
|
178
|
+
if "llama-3" in model_lower:
|
|
179
|
+
return "llama-3"
|
|
180
|
+
|
|
181
|
+
# Mistral families
|
|
182
|
+
if "mixtral" in model_lower:
|
|
183
|
+
return "mixtral"
|
|
184
|
+
if "mistral" in model_lower:
|
|
185
|
+
return "mistral"
|
|
186
|
+
|
|
187
|
+
return "unknown"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
__all__ = ["parse_model_name", "get_provider_for_model", "get_model_family"]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""FastAPI server for Themis web dashboard.
|
|
2
|
+
|
|
3
|
+
This module provides a REST API and WebSocket interface for:
|
|
4
|
+
- Listing and viewing experiment runs
|
|
5
|
+
- Comparing multiple runs
|
|
6
|
+
- Real-time monitoring of running experiments
|
|
7
|
+
- Exporting results in various formats
|
|
8
|
+
|
|
9
|
+
The server is optional and requires the 'server' extra:
|
|
10
|
+
pip install themis[server]
|
|
11
|
+
# or
|
|
12
|
+
uv pip install themis[server]
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
# Start the server
|
|
16
|
+
themis serve --port 8080
|
|
17
|
+
|
|
18
|
+
# Or programmatically
|
|
19
|
+
from themis.server import create_app
|
|
20
|
+
app = create_app(storage_path=".cache/experiments")
|
|
21
|
+
|
|
22
|
+
# Run with uvicorn
|
|
23
|
+
uvicorn themis.server:app --host 0.0.0.0 --port 8080
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from themis.server.app import create_app
|
|
27
|
+
|
|
28
|
+
__all__ = ["create_app"]
|
themis/server/app.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""FastAPI application for Themis server.
|
|
2
|
+
|
|
3
|
+
This module defines the main FastAPI app with REST endpoints and WebSocket support.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List
|
|
11
|
+
|
|
12
|
+
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
13
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
14
|
+
from fastapi.responses import JSONResponse
|
|
15
|
+
from fastapi.staticfiles import StaticFiles
|
|
16
|
+
from pydantic import BaseModel, Field
|
|
17
|
+
|
|
18
|
+
from themis.comparison import compare_runs
|
|
19
|
+
from themis.comparison.statistics import StatisticalTest
|
|
20
|
+
from themis.experiment.storage import ExperimentStorage
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RunSummary(BaseModel):
|
|
24
|
+
"""Summary of an experiment run."""
|
|
25
|
+
|
|
26
|
+
run_id: str
|
|
27
|
+
experiment_id: str = "default"
|
|
28
|
+
status: str
|
|
29
|
+
num_samples: int = 0
|
|
30
|
+
metrics: Dict[str, float] = Field(default_factory=dict)
|
|
31
|
+
created_at: str | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RunDetail(BaseModel):
|
|
35
|
+
"""Detailed information about a run."""
|
|
36
|
+
|
|
37
|
+
run_id: str
|
|
38
|
+
experiment_id: str = "default"
|
|
39
|
+
status: str
|
|
40
|
+
num_samples: int
|
|
41
|
+
metrics: Dict[str, float]
|
|
42
|
+
samples: List[Dict[str, Any]] = Field(default_factory=list)
|
|
43
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ComparisonRequest(BaseModel):
|
|
47
|
+
"""Request to compare multiple runs."""
|
|
48
|
+
|
|
49
|
+
run_ids: List[str]
|
|
50
|
+
metrics: List[str] | None = None
|
|
51
|
+
statistical_test: str = "bootstrap"
|
|
52
|
+
alpha: float = 0.05
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ErrorResponse(BaseModel):
|
|
56
|
+
"""Error response model."""
|
|
57
|
+
|
|
58
|
+
error: str
|
|
59
|
+
detail: str | None = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def create_app(storage_path: str | Path = ".cache/experiments") -> FastAPI:
|
|
63
|
+
"""Create FastAPI application.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
storage_path: Path to experiment storage
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Configured FastAPI application
|
|
70
|
+
"""
|
|
71
|
+
app = FastAPI(
|
|
72
|
+
title="Themis API",
|
|
73
|
+
description="REST API for Themis experiment management",
|
|
74
|
+
version="2.0.0",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Enable CORS for web dashboard
|
|
78
|
+
app.add_middleware(
|
|
79
|
+
CORSMiddleware,
|
|
80
|
+
allow_origins=["*"], # Configure appropriately for production
|
|
81
|
+
allow_credentials=True,
|
|
82
|
+
allow_methods=["*"],
|
|
83
|
+
allow_headers=["*"],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Initialize storage
|
|
87
|
+
storage = ExperimentStorage(storage_path)
|
|
88
|
+
|
|
89
|
+
# Mount static files (dashboard)
|
|
90
|
+
static_dir = Path(__file__).parent / "static"
|
|
91
|
+
if static_dir.exists():
|
|
92
|
+
app.mount("/dashboard", StaticFiles(directory=str(static_dir), html=True), name="static")
|
|
93
|
+
|
|
94
|
+
# WebSocket connection manager
|
|
95
|
+
class ConnectionManager:
|
|
96
|
+
def __init__(self):
|
|
97
|
+
self.active_connections: List[WebSocket] = []
|
|
98
|
+
|
|
99
|
+
async def connect(self, websocket: WebSocket):
|
|
100
|
+
await websocket.accept()
|
|
101
|
+
self.active_connections.append(websocket)
|
|
102
|
+
|
|
103
|
+
def disconnect(self, websocket: WebSocket):
|
|
104
|
+
self.active_connections.remove(websocket)
|
|
105
|
+
|
|
106
|
+
async def broadcast(self, message: dict):
|
|
107
|
+
for connection in self.active_connections:
|
|
108
|
+
await connection.send_json(message)
|
|
109
|
+
|
|
110
|
+
manager = ConnectionManager()
|
|
111
|
+
|
|
112
|
+
# ===== REST ENDPOINTS =====
|
|
113
|
+
|
|
114
|
+
@app.get("/", tags=["health"])
|
|
115
|
+
async def root():
|
|
116
|
+
"""Health check endpoint."""
|
|
117
|
+
return {
|
|
118
|
+
"status": "ok",
|
|
119
|
+
"service": "themis-api",
|
|
120
|
+
"version": "2.0.0",
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
@app.get("/api/runs", response_model=List[RunSummary], tags=["runs"])
|
|
124
|
+
async def list_runs():
|
|
125
|
+
"""List all experiment runs."""
|
|
126
|
+
run_ids = storage.list_runs()
|
|
127
|
+
|
|
128
|
+
summaries = []
|
|
129
|
+
for run_id in run_ids:
|
|
130
|
+
# Load basic info
|
|
131
|
+
eval_records = storage.load_cached_evaluations(run_id)
|
|
132
|
+
|
|
133
|
+
# Calculate average metrics
|
|
134
|
+
metrics_dict: Dict[str, List[float]] = {}
|
|
135
|
+
for record in eval_records.values():
|
|
136
|
+
for metric_name, score_obj in record.scores.items():
|
|
137
|
+
if metric_name not in metrics_dict:
|
|
138
|
+
metrics_dict[metric_name] = []
|
|
139
|
+
|
|
140
|
+
# Extract numeric score
|
|
141
|
+
if hasattr(score_obj, 'value'):
|
|
142
|
+
metrics_dict[metric_name].append(score_obj.value)
|
|
143
|
+
elif isinstance(score_obj, (int, float)):
|
|
144
|
+
metrics_dict[metric_name].append(float(score_obj))
|
|
145
|
+
|
|
146
|
+
# Average metrics
|
|
147
|
+
avg_metrics = {
|
|
148
|
+
name: sum(scores) / len(scores) if scores else 0.0
|
|
149
|
+
for name, scores in metrics_dict.items()
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
summaries.append(RunSummary(
|
|
153
|
+
run_id=run_id,
|
|
154
|
+
experiment_id="default",
|
|
155
|
+
status="completed",
|
|
156
|
+
num_samples=len(eval_records),
|
|
157
|
+
metrics=avg_metrics,
|
|
158
|
+
))
|
|
159
|
+
|
|
160
|
+
return summaries
|
|
161
|
+
|
|
162
|
+
@app.get("/api/runs/{run_id}", response_model=RunDetail, tags=["runs"])
|
|
163
|
+
async def get_run(run_id: str):
|
|
164
|
+
"""Get detailed information about a run."""
|
|
165
|
+
if run_id not in storage.list_runs():
|
|
166
|
+
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
|
|
167
|
+
|
|
168
|
+
# Load records
|
|
169
|
+
eval_records = storage.load_cached_evaluations(run_id)
|
|
170
|
+
gen_records_dict = storage.load_cached_records(run_id)
|
|
171
|
+
|
|
172
|
+
# Calculate metrics
|
|
173
|
+
metrics_dict: Dict[str, List[float]] = {}
|
|
174
|
+
samples = []
|
|
175
|
+
|
|
176
|
+
for cache_key, eval_record in eval_records.items():
|
|
177
|
+
# Get generation record
|
|
178
|
+
gen_record = gen_records_dict.get(cache_key)
|
|
179
|
+
|
|
180
|
+
# Extract scores
|
|
181
|
+
scores = {}
|
|
182
|
+
for metric_name, score_obj in eval_record.scores.items():
|
|
183
|
+
if hasattr(score_obj, 'value'):
|
|
184
|
+
value = score_obj.value
|
|
185
|
+
elif isinstance(score_obj, (int, float)):
|
|
186
|
+
value = float(score_obj)
|
|
187
|
+
else:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
scores[metric_name] = value
|
|
191
|
+
|
|
192
|
+
if metric_name not in metrics_dict:
|
|
193
|
+
metrics_dict[metric_name] = []
|
|
194
|
+
metrics_dict[metric_name].append(value)
|
|
195
|
+
|
|
196
|
+
# Build sample
|
|
197
|
+
sample = {
|
|
198
|
+
"id": gen_record.id if gen_record else cache_key,
|
|
199
|
+
"prompt": gen_record.prompt if gen_record else "",
|
|
200
|
+
"response": gen_record.response if gen_record else "",
|
|
201
|
+
"scores": scores,
|
|
202
|
+
}
|
|
203
|
+
samples.append(sample)
|
|
204
|
+
|
|
205
|
+
# Average metrics
|
|
206
|
+
avg_metrics = {
|
|
207
|
+
name: sum(scores) / len(scores) if scores else 0.0
|
|
208
|
+
for name, scores in metrics_dict.items()
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
return RunDetail(
|
|
212
|
+
run_id=run_id,
|
|
213
|
+
experiment_id="default",
|
|
214
|
+
status="completed",
|
|
215
|
+
num_samples=len(eval_records),
|
|
216
|
+
metrics=avg_metrics,
|
|
217
|
+
samples=samples,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
@app.delete("/api/runs/{run_id}", tags=["runs"])
|
|
221
|
+
async def delete_run(run_id: str):
|
|
222
|
+
"""Delete a run."""
|
|
223
|
+
if run_id not in storage.list_runs():
|
|
224
|
+
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
|
|
225
|
+
|
|
226
|
+
# Note: Current storage doesn't implement delete
|
|
227
|
+
# This is a placeholder for future implementation
|
|
228
|
+
raise HTTPException(
|
|
229
|
+
status_code=501,
|
|
230
|
+
detail="Delete not implemented in current storage"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
@app.post("/api/compare", tags=["comparison"])
|
|
234
|
+
async def compare_runs_api(request: ComparisonRequest):
|
|
235
|
+
"""Compare multiple runs."""
|
|
236
|
+
# Validate runs exist
|
|
237
|
+
existing_runs = set(storage.list_runs())
|
|
238
|
+
for run_id in request.run_ids:
|
|
239
|
+
if run_id not in existing_runs:
|
|
240
|
+
raise HTTPException(
|
|
241
|
+
status_code=404,
|
|
242
|
+
detail=f"Run not found: {run_id}"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
if len(request.run_ids) < 2:
|
|
246
|
+
raise HTTPException(
|
|
247
|
+
status_code=400,
|
|
248
|
+
detail="Need at least 2 runs to compare"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Parse statistical test
|
|
252
|
+
try:
|
|
253
|
+
test_enum = StatisticalTest(request.statistical_test)
|
|
254
|
+
except ValueError:
|
|
255
|
+
raise HTTPException(
|
|
256
|
+
status_code=400,
|
|
257
|
+
detail=f"Invalid statistical test: {request.statistical_test}"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Run comparison
|
|
261
|
+
report = compare_runs(
|
|
262
|
+
run_ids=request.run_ids,
|
|
263
|
+
storage_path=storage._base_dir,
|
|
264
|
+
metrics=request.metrics,
|
|
265
|
+
statistical_test=test_enum,
|
|
266
|
+
alpha=request.alpha,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
return report.to_dict()
|
|
270
|
+
|
|
271
|
+
@app.get("/api/benchmarks", tags=["presets"])
|
|
272
|
+
async def list_benchmarks():
|
|
273
|
+
"""List available benchmark presets."""
|
|
274
|
+
from themis.presets import list_benchmarks
|
|
275
|
+
|
|
276
|
+
benchmarks = list_benchmarks()
|
|
277
|
+
return {"benchmarks": benchmarks}
|
|
278
|
+
|
|
279
|
+
# ===== WEBSOCKET ENDPOINTS =====
|
|
280
|
+
|
|
281
|
+
@app.websocket("/ws")
|
|
282
|
+
async def websocket_endpoint(websocket: WebSocket):
|
|
283
|
+
"""WebSocket endpoint for real-time updates.
|
|
284
|
+
|
|
285
|
+
Messages sent from server:
|
|
286
|
+
- {"type": "run_started", "run_id": "...", "data": {...}}
|
|
287
|
+
- {"type": "run_progress", "run_id": "...", "progress": 0.5}
|
|
288
|
+
- {"type": "run_completed", "run_id": "...", "data": {...}}
|
|
289
|
+
- {"type": "error", "message": "..."}
|
|
290
|
+
|
|
291
|
+
Messages expected from client:
|
|
292
|
+
- {"type": "subscribe", "run_id": "..."}
|
|
293
|
+
- {"type": "unsubscribe", "run_id": "..."}
|
|
294
|
+
- {"type": "ping"}
|
|
295
|
+
"""
|
|
296
|
+
await manager.connect(websocket)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
while True:
|
|
300
|
+
# Receive message from client
|
|
301
|
+
data = await websocket.receive_text()
|
|
302
|
+
message = json.loads(data)
|
|
303
|
+
|
|
304
|
+
msg_type = message.get("type")
|
|
305
|
+
|
|
306
|
+
if msg_type == "ping":
|
|
307
|
+
await websocket.send_json({"type": "pong"})
|
|
308
|
+
|
|
309
|
+
elif msg_type == "subscribe":
|
|
310
|
+
run_id = message.get("run_id")
|
|
311
|
+
# TODO: Implement run subscription logic
|
|
312
|
+
await websocket.send_json({
|
|
313
|
+
"type": "subscribed",
|
|
314
|
+
"run_id": run_id
|
|
315
|
+
})
|
|
316
|
+
|
|
317
|
+
elif msg_type == "unsubscribe":
|
|
318
|
+
run_id = message.get("run_id")
|
|
319
|
+
# TODO: Implement unsubscribe logic
|
|
320
|
+
await websocket.send_json({
|
|
321
|
+
"type": "unsubscribed",
|
|
322
|
+
"run_id": run_id
|
|
323
|
+
})
|
|
324
|
+
|
|
325
|
+
else:
|
|
326
|
+
await websocket.send_json({
|
|
327
|
+
"type": "error",
|
|
328
|
+
"message": f"Unknown message type: {msg_type}"
|
|
329
|
+
})
|
|
330
|
+
|
|
331
|
+
except WebSocketDisconnect:
|
|
332
|
+
manager.disconnect(websocket)
|
|
333
|
+
|
|
334
|
+
return app
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
__all__ = ["create_app"]
|