themis-eval 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. themis/__init__.py +12 -1
  2. themis/_version.py +2 -2
  3. themis/api.py +343 -0
  4. themis/backends/__init__.py +17 -0
  5. themis/backends/execution.py +197 -0
  6. themis/backends/storage.py +260 -0
  7. themis/cli/commands/results.py +252 -0
  8. themis/cli/main.py +427 -57
  9. themis/comparison/__init__.py +25 -0
  10. themis/comparison/engine.py +348 -0
  11. themis/comparison/reports.py +283 -0
  12. themis/comparison/statistics.py +402 -0
  13. themis/core/entities.py +23 -3
  14. themis/evaluation/metrics/code/__init__.py +19 -0
  15. themis/evaluation/metrics/code/codebleu.py +144 -0
  16. themis/evaluation/metrics/code/execution.py +280 -0
  17. themis/evaluation/metrics/code/pass_at_k.py +181 -0
  18. themis/evaluation/metrics/nlp/__init__.py +21 -0
  19. themis/evaluation/metrics/nlp/bertscore.py +138 -0
  20. themis/evaluation/metrics/nlp/bleu.py +129 -0
  21. themis/evaluation/metrics/nlp/meteor.py +153 -0
  22. themis/evaluation/metrics/nlp/rouge.py +136 -0
  23. themis/evaluation/pipelines/standard_pipeline.py +68 -8
  24. themis/experiment/cache_manager.py +8 -3
  25. themis/experiment/export.py +110 -2
  26. themis/experiment/orchestrator.py +48 -6
  27. themis/experiment/storage.py +1313 -110
  28. themis/integrations/huggingface.py +12 -1
  29. themis/integrations/wandb.py +13 -1
  30. themis/interfaces/__init__.py +86 -0
  31. themis/presets/__init__.py +10 -0
  32. themis/presets/benchmarks.py +354 -0
  33. themis/presets/models.py +190 -0
  34. themis/server/__init__.py +28 -0
  35. themis/server/app.py +337 -0
  36. themis_eval-0.2.0.dist-info/METADATA +596 -0
  37. {themis_eval-0.1.1.dist-info → themis_eval-0.2.0.dist-info}/RECORD +40 -17
  38. {themis_eval-0.1.1.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
  39. themis_eval-0.1.1.dist-info/METADATA +0 -758
  40. {themis_eval-0.1.1.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
  41. {themis_eval-0.1.1.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,190 @@
1
+ """Model name parsing and provider detection.
2
+
3
+ This module automatically detects the appropriate provider based on
4
+ model names, eliminating the need for users to specify providers manually.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import re
10
+ from typing import Any
11
+
12
+
13
+ def parse_model_name(model: str, **kwargs: Any) -> tuple[str, str, dict[str, Any]]:
14
+ """Parse model name and detect provider.
15
+
16
+ Args:
17
+ model: Model identifier (e.g., "gpt-4", "claude-3-opus", "llama-2-70b")
18
+ **kwargs: Additional provider-specific options
19
+
20
+ Returns:
21
+ Tuple of (provider_name, model_id, provider_options)
22
+
23
+ Examples:
24
+ >>> parse_model_name("gpt-4")
25
+ ("litellm", "gpt-4", {})
26
+
27
+ >>> parse_model_name("claude-3-opus-20240229")
28
+ ("litellm", "claude-3-opus-20240229", {})
29
+
30
+ >>> parse_model_name("local-llm", base_url="http://localhost:1234/v1")
31
+ ("litellm", "local-llm", {"base_url": "http://localhost:1234/v1"})
32
+ """
33
+ model_lower = model.lower()
34
+
35
+ # OpenAI models
36
+ if any(pattern in model_lower for pattern in ["gpt-", "o1-", "text-davinci"]):
37
+ return "litellm", model, _extract_provider_options(kwargs)
38
+
39
+ # Anthropic models
40
+ if "claude" in model_lower:
41
+ return "litellm", model, _extract_provider_options(kwargs)
42
+
43
+ # Google models
44
+ if any(pattern in model_lower for pattern in ["gemini", "palm"]):
45
+ return "litellm", model, _extract_provider_options(kwargs)
46
+
47
+ # Meta models
48
+ if "llama" in model_lower:
49
+ return "litellm", model, _extract_provider_options(kwargs)
50
+
51
+ # Mistral models
52
+ if "mistral" in model_lower or "mixtral" in model_lower:
53
+ return "litellm", model, _extract_provider_options(kwargs)
54
+
55
+ # Cohere models
56
+ if "command" in model_lower and "xl" in model_lower:
57
+ return "litellm", model, _extract_provider_options(kwargs)
58
+
59
+ # AI21 models
60
+ if "j2-" in model_lower:
61
+ return "litellm", model, _extract_provider_options(kwargs)
62
+
63
+ # Fake model for testing
64
+ if "fake" in model_lower:
65
+ return "fake", model, {}
66
+
67
+ # Default: assume it's a litellm-compatible model
68
+ # User can provide base_url for custom endpoints
69
+ return "litellm", model, _extract_provider_options(kwargs)
70
+
71
+
72
+ def _extract_provider_options(kwargs: dict[str, Any]) -> dict[str, Any]:
73
+ """Extract provider-specific options from kwargs.
74
+
75
+ Args:
76
+ kwargs: Dictionary of options
77
+
78
+ Returns:
79
+ Dictionary of provider options
80
+ """
81
+ provider_options = {}
82
+
83
+ # Known provider options
84
+ option_keys = [
85
+ "api_key",
86
+ "base_url",
87
+ "api_base",
88
+ "api_version",
89
+ "timeout",
90
+ "max_retries",
91
+ "n_parallel",
92
+ "organization",
93
+ "api_type",
94
+ "region_name",
95
+ ]
96
+
97
+ for key in option_keys:
98
+ if key in kwargs:
99
+ provider_options[key] = kwargs[key]
100
+
101
+ return provider_options
102
+
103
+
104
+ def get_provider_for_model(model: str) -> str:
105
+ """Get provider name for a model (without parsing full options).
106
+
107
+ Args:
108
+ model: Model identifier
109
+
110
+ Returns:
111
+ Provider name
112
+
113
+ Examples:
114
+ >>> get_provider_for_model("gpt-4")
115
+ "litellm"
116
+
117
+ >>> get_provider_for_model("claude-3-opus")
118
+ "litellm"
119
+ """
120
+ provider, _, _ = parse_model_name(model)
121
+ return provider
122
+
123
+
124
+ # Model family detection for preset selection
125
+ def get_model_family(model: str) -> str:
126
+ """Get the model family for capability detection.
127
+
128
+ Args:
129
+ model: Model identifier
130
+
131
+ Returns:
132
+ Model family name
133
+
134
+ Examples:
135
+ >>> get_model_family("gpt-4-turbo")
136
+ "gpt-4"
137
+
138
+ >>> get_model_family("claude-3-opus-20240229")
139
+ "claude-3"
140
+ """
141
+ model_lower = model.lower()
142
+
143
+ # OpenAI families
144
+ if "gpt-4" in model_lower:
145
+ return "gpt-4"
146
+ if "gpt-3.5" in model_lower:
147
+ return "gpt-3.5"
148
+ if "o1" in model_lower:
149
+ return "o1"
150
+
151
+ # Anthropic families
152
+ if "claude-3" in model_lower:
153
+ if "opus" in model_lower:
154
+ return "claude-3-opus"
155
+ elif "sonnet" in model_lower:
156
+ return "claude-3-sonnet"
157
+ elif "haiku" in model_lower:
158
+ return "claude-3-haiku"
159
+ return "claude-3"
160
+ if "claude-2" in model_lower:
161
+ return "claude-2"
162
+
163
+ # Google families
164
+ if "gemini-pro" in model_lower:
165
+ return "gemini-pro"
166
+ if "gemini-ultra" in model_lower:
167
+ return "gemini-ultra"
168
+
169
+ # Meta families
170
+ if "llama-2" in model_lower:
171
+ if "70b" in model_lower:
172
+ return "llama-2-70b"
173
+ elif "13b" in model_lower:
174
+ return "llama-2-13b"
175
+ elif "7b" in model_lower:
176
+ return "llama-2-7b"
177
+ return "llama-2"
178
+ if "llama-3" in model_lower:
179
+ return "llama-3"
180
+
181
+ # Mistral families
182
+ if "mixtral" in model_lower:
183
+ return "mixtral"
184
+ if "mistral" in model_lower:
185
+ return "mistral"
186
+
187
+ return "unknown"
188
+
189
+
190
+ __all__ = ["parse_model_name", "get_provider_for_model", "get_model_family"]
@@ -0,0 +1,28 @@
1
+ """FastAPI server for Themis web dashboard.
2
+
3
+ This module provides a REST API and WebSocket interface for:
4
+ - Listing and viewing experiment runs
5
+ - Comparing multiple runs
6
+ - Real-time monitoring of running experiments
7
+ - Exporting results in various formats
8
+
9
+ The server is optional and requires the 'server' extra:
10
+ pip install themis[server]
11
+ # or
12
+ uv pip install themis[server]
13
+
14
+ Usage:
15
+ # Start the server
16
+ themis serve --port 8080
17
+
18
+ # Or programmatically
19
+ from themis.server import create_app
20
+ app = create_app(storage_path=".cache/experiments")
21
+
22
+ # Run with uvicorn
23
+ uvicorn themis.server:app --host 0.0.0.0 --port 8080
24
+ """
25
+
26
+ from themis.server.app import create_app
27
+
28
+ __all__ = ["create_app"]
themis/server/app.py ADDED
@@ -0,0 +1,337 @@
1
+ """FastAPI application for Themis server.
2
+
3
+ This module defines the main FastAPI app with REST endpoints and WebSocket support.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List
11
+
12
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import JSONResponse
15
+ from fastapi.staticfiles import StaticFiles
16
+ from pydantic import BaseModel, Field
17
+
18
+ from themis.comparison import compare_runs
19
+ from themis.comparison.statistics import StatisticalTest
20
+ from themis.experiment.storage import ExperimentStorage
21
+
22
+
23
+ class RunSummary(BaseModel):
24
+ """Summary of an experiment run."""
25
+
26
+ run_id: str
27
+ experiment_id: str = "default"
28
+ status: str
29
+ num_samples: int = 0
30
+ metrics: Dict[str, float] = Field(default_factory=dict)
31
+ created_at: str | None = None
32
+
33
+
34
+ class RunDetail(BaseModel):
35
+ """Detailed information about a run."""
36
+
37
+ run_id: str
38
+ experiment_id: str = "default"
39
+ status: str
40
+ num_samples: int
41
+ metrics: Dict[str, float]
42
+ samples: List[Dict[str, Any]] = Field(default_factory=list)
43
+ metadata: Dict[str, Any] = Field(default_factory=dict)
44
+
45
+
46
+ class ComparisonRequest(BaseModel):
47
+ """Request to compare multiple runs."""
48
+
49
+ run_ids: List[str]
50
+ metrics: List[str] | None = None
51
+ statistical_test: str = "bootstrap"
52
+ alpha: float = 0.05
53
+
54
+
55
+ class ErrorResponse(BaseModel):
56
+ """Error response model."""
57
+
58
+ error: str
59
+ detail: str | None = None
60
+
61
+
62
+ def create_app(storage_path: str | Path = ".cache/experiments") -> FastAPI:
63
+ """Create FastAPI application.
64
+
65
+ Args:
66
+ storage_path: Path to experiment storage
67
+
68
+ Returns:
69
+ Configured FastAPI application
70
+ """
71
+ app = FastAPI(
72
+ title="Themis API",
73
+ description="REST API for Themis experiment management",
74
+ version="2.0.0",
75
+ )
76
+
77
+ # Enable CORS for web dashboard
78
+ app.add_middleware(
79
+ CORSMiddleware,
80
+ allow_origins=["*"], # Configure appropriately for production
81
+ allow_credentials=True,
82
+ allow_methods=["*"],
83
+ allow_headers=["*"],
84
+ )
85
+
86
+ # Initialize storage
87
+ storage = ExperimentStorage(storage_path)
88
+
89
+ # Mount static files (dashboard)
90
+ static_dir = Path(__file__).parent / "static"
91
+ if static_dir.exists():
92
+ app.mount("/dashboard", StaticFiles(directory=str(static_dir), html=True), name="static")
93
+
94
+ # WebSocket connection manager
95
+ class ConnectionManager:
96
+ def __init__(self):
97
+ self.active_connections: List[WebSocket] = []
98
+
99
+ async def connect(self, websocket: WebSocket):
100
+ await websocket.accept()
101
+ self.active_connections.append(websocket)
102
+
103
+ def disconnect(self, websocket: WebSocket):
104
+ self.active_connections.remove(websocket)
105
+
106
+ async def broadcast(self, message: dict):
107
+ for connection in self.active_connections:
108
+ await connection.send_json(message)
109
+
110
+ manager = ConnectionManager()
111
+
112
+ # ===== REST ENDPOINTS =====
113
+
114
+ @app.get("/", tags=["health"])
115
+ async def root():
116
+ """Health check endpoint."""
117
+ return {
118
+ "status": "ok",
119
+ "service": "themis-api",
120
+ "version": "2.0.0",
121
+ }
122
+
123
+ @app.get("/api/runs", response_model=List[RunSummary], tags=["runs"])
124
+ async def list_runs():
125
+ """List all experiment runs."""
126
+ run_ids = storage.list_runs()
127
+
128
+ summaries = []
129
+ for run_id in run_ids:
130
+ # Load basic info
131
+ eval_records = storage.load_cached_evaluations(run_id)
132
+
133
+ # Calculate average metrics
134
+ metrics_dict: Dict[str, List[float]] = {}
135
+ for record in eval_records.values():
136
+ for metric_name, score_obj in record.scores.items():
137
+ if metric_name not in metrics_dict:
138
+ metrics_dict[metric_name] = []
139
+
140
+ # Extract numeric score
141
+ if hasattr(score_obj, 'value'):
142
+ metrics_dict[metric_name].append(score_obj.value)
143
+ elif isinstance(score_obj, (int, float)):
144
+ metrics_dict[metric_name].append(float(score_obj))
145
+
146
+ # Average metrics
147
+ avg_metrics = {
148
+ name: sum(scores) / len(scores) if scores else 0.0
149
+ for name, scores in metrics_dict.items()
150
+ }
151
+
152
+ summaries.append(RunSummary(
153
+ run_id=run_id,
154
+ experiment_id="default",
155
+ status="completed",
156
+ num_samples=len(eval_records),
157
+ metrics=avg_metrics,
158
+ ))
159
+
160
+ return summaries
161
+
162
+ @app.get("/api/runs/{run_id}", response_model=RunDetail, tags=["runs"])
163
+ async def get_run(run_id: str):
164
+ """Get detailed information about a run."""
165
+ if run_id not in storage.list_runs():
166
+ raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
167
+
168
+ # Load records
169
+ eval_records = storage.load_cached_evaluations(run_id)
170
+ gen_records_dict = storage.load_cached_records(run_id)
171
+
172
+ # Calculate metrics
173
+ metrics_dict: Dict[str, List[float]] = {}
174
+ samples = []
175
+
176
+ for cache_key, eval_record in eval_records.items():
177
+ # Get generation record
178
+ gen_record = gen_records_dict.get(cache_key)
179
+
180
+ # Extract scores
181
+ scores = {}
182
+ for metric_name, score_obj in eval_record.scores.items():
183
+ if hasattr(score_obj, 'value'):
184
+ value = score_obj.value
185
+ elif isinstance(score_obj, (int, float)):
186
+ value = float(score_obj)
187
+ else:
188
+ continue
189
+
190
+ scores[metric_name] = value
191
+
192
+ if metric_name not in metrics_dict:
193
+ metrics_dict[metric_name] = []
194
+ metrics_dict[metric_name].append(value)
195
+
196
+ # Build sample
197
+ sample = {
198
+ "id": gen_record.id if gen_record else cache_key,
199
+ "prompt": gen_record.prompt if gen_record else "",
200
+ "response": gen_record.response if gen_record else "",
201
+ "scores": scores,
202
+ }
203
+ samples.append(sample)
204
+
205
+ # Average metrics
206
+ avg_metrics = {
207
+ name: sum(scores) / len(scores) if scores else 0.0
208
+ for name, scores in metrics_dict.items()
209
+ }
210
+
211
+ return RunDetail(
212
+ run_id=run_id,
213
+ experiment_id="default",
214
+ status="completed",
215
+ num_samples=len(eval_records),
216
+ metrics=avg_metrics,
217
+ samples=samples,
218
+ )
219
+
220
+ @app.delete("/api/runs/{run_id}", tags=["runs"])
221
+ async def delete_run(run_id: str):
222
+ """Delete a run."""
223
+ if run_id not in storage.list_runs():
224
+ raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
225
+
226
+ # Note: Current storage doesn't implement delete
227
+ # This is a placeholder for future implementation
228
+ raise HTTPException(
229
+ status_code=501,
230
+ detail="Delete not implemented in current storage"
231
+ )
232
+
233
+ @app.post("/api/compare", tags=["comparison"])
234
+ async def compare_runs_api(request: ComparisonRequest):
235
+ """Compare multiple runs."""
236
+ # Validate runs exist
237
+ existing_runs = set(storage.list_runs())
238
+ for run_id in request.run_ids:
239
+ if run_id not in existing_runs:
240
+ raise HTTPException(
241
+ status_code=404,
242
+ detail=f"Run not found: {run_id}"
243
+ )
244
+
245
+ if len(request.run_ids) < 2:
246
+ raise HTTPException(
247
+ status_code=400,
248
+ detail="Need at least 2 runs to compare"
249
+ )
250
+
251
+ # Parse statistical test
252
+ try:
253
+ test_enum = StatisticalTest(request.statistical_test)
254
+ except ValueError:
255
+ raise HTTPException(
256
+ status_code=400,
257
+ detail=f"Invalid statistical test: {request.statistical_test}"
258
+ )
259
+
260
+ # Run comparison
261
+ report = compare_runs(
262
+ run_ids=request.run_ids,
263
+ storage_path=storage._base_dir,
264
+ metrics=request.metrics,
265
+ statistical_test=test_enum,
266
+ alpha=request.alpha,
267
+ )
268
+
269
+ return report.to_dict()
270
+
271
+ @app.get("/api/benchmarks", tags=["presets"])
272
+ async def list_benchmarks():
273
+ """List available benchmark presets."""
274
+ from themis.presets import list_benchmarks
275
+
276
+ benchmarks = list_benchmarks()
277
+ return {"benchmarks": benchmarks}
278
+
279
+ # ===== WEBSOCKET ENDPOINTS =====
280
+
281
+ @app.websocket("/ws")
282
+ async def websocket_endpoint(websocket: WebSocket):
283
+ """WebSocket endpoint for real-time updates.
284
+
285
+ Messages sent from server:
286
+ - {"type": "run_started", "run_id": "...", "data": {...}}
287
+ - {"type": "run_progress", "run_id": "...", "progress": 0.5}
288
+ - {"type": "run_completed", "run_id": "...", "data": {...}}
289
+ - {"type": "error", "message": "..."}
290
+
291
+ Messages expected from client:
292
+ - {"type": "subscribe", "run_id": "..."}
293
+ - {"type": "unsubscribe", "run_id": "..."}
294
+ - {"type": "ping"}
295
+ """
296
+ await manager.connect(websocket)
297
+
298
+ try:
299
+ while True:
300
+ # Receive message from client
301
+ data = await websocket.receive_text()
302
+ message = json.loads(data)
303
+
304
+ msg_type = message.get("type")
305
+
306
+ if msg_type == "ping":
307
+ await websocket.send_json({"type": "pong"})
308
+
309
+ elif msg_type == "subscribe":
310
+ run_id = message.get("run_id")
311
+ # TODO: Implement run subscription logic
312
+ await websocket.send_json({
313
+ "type": "subscribed",
314
+ "run_id": run_id
315
+ })
316
+
317
+ elif msg_type == "unsubscribe":
318
+ run_id = message.get("run_id")
319
+ # TODO: Implement unsubscribe logic
320
+ await websocket.send_json({
321
+ "type": "unsubscribed",
322
+ "run_id": run_id
323
+ })
324
+
325
+ else:
326
+ await websocket.send_json({
327
+ "type": "error",
328
+ "message": f"Unknown message type: {msg_type}"
329
+ })
330
+
331
+ except WebSocketDisconnect:
332
+ manager.disconnect(websocket)
333
+
334
+ return app
335
+
336
+
337
+ __all__ = ["create_app"]