themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
themis/presets/models.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Model name parsing and provider detection.
|
|
2
|
+
|
|
3
|
+
This module automatically detects the appropriate provider based on
|
|
4
|
+
model names, eliminating the need for users to specify providers manually.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def parse_model_name(model: str, **kwargs: Any) -> tuple[str, str, dict[str, Any]]:
|
|
14
|
+
"""Parse model name and detect provider.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
model: Model identifier (e.g., "gpt-4", "claude-3-opus", "llama-2-70b")
|
|
18
|
+
**kwargs: Additional provider-specific options
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Tuple of (provider_name, model_id, provider_options)
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
>>> parse_model_name("gpt-4")
|
|
25
|
+
("litellm", "gpt-4", {})
|
|
26
|
+
|
|
27
|
+
>>> parse_model_name("claude-3-opus-20240229")
|
|
28
|
+
("litellm", "claude-3-opus-20240229", {})
|
|
29
|
+
|
|
30
|
+
>>> parse_model_name("local-llm", base_url="http://localhost:1234/v1")
|
|
31
|
+
("litellm", "local-llm", {"base_url": "http://localhost:1234/v1"})
|
|
32
|
+
"""
|
|
33
|
+
model_lower = model.lower()
|
|
34
|
+
|
|
35
|
+
# OpenAI models
|
|
36
|
+
if any(pattern in model_lower for pattern in ["gpt-", "o1-", "text-davinci"]):
|
|
37
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
38
|
+
|
|
39
|
+
# Anthropic models
|
|
40
|
+
if "claude" in model_lower:
|
|
41
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
42
|
+
|
|
43
|
+
# Google models
|
|
44
|
+
if any(pattern in model_lower for pattern in ["gemini", "palm"]):
|
|
45
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
46
|
+
|
|
47
|
+
# Meta models
|
|
48
|
+
if "llama" in model_lower:
|
|
49
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
50
|
+
|
|
51
|
+
# Mistral models
|
|
52
|
+
if "mistral" in model_lower or "mixtral" in model_lower:
|
|
53
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
54
|
+
|
|
55
|
+
# Cohere models
|
|
56
|
+
if "command" in model_lower and "xl" in model_lower:
|
|
57
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
58
|
+
|
|
59
|
+
# AI21 models
|
|
60
|
+
if "j2-" in model_lower:
|
|
61
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
62
|
+
|
|
63
|
+
# Fake model for testing
|
|
64
|
+
if "fake" in model_lower:
|
|
65
|
+
return "fake", model, {}
|
|
66
|
+
|
|
67
|
+
# Default: assume it's a litellm-compatible model
|
|
68
|
+
# User can provide base_url for custom endpoints
|
|
69
|
+
return "litellm", model, _extract_provider_options(kwargs)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _extract_provider_options(kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
73
|
+
"""Extract provider-specific options from kwargs.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
kwargs: Dictionary of options
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Dictionary of provider options
|
|
80
|
+
"""
|
|
81
|
+
provider_options = {}
|
|
82
|
+
|
|
83
|
+
# Known provider options
|
|
84
|
+
option_keys = [
|
|
85
|
+
"api_key",
|
|
86
|
+
"base_url",
|
|
87
|
+
"api_base",
|
|
88
|
+
"api_version",
|
|
89
|
+
"timeout",
|
|
90
|
+
"max_retries",
|
|
91
|
+
"n_parallel",
|
|
92
|
+
"organization",
|
|
93
|
+
"api_type",
|
|
94
|
+
"region_name",
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
for key in option_keys:
|
|
98
|
+
if key in kwargs:
|
|
99
|
+
provider_options[key] = kwargs[key]
|
|
100
|
+
|
|
101
|
+
return provider_options
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_provider_for_model(model: str) -> str:
|
|
105
|
+
"""Get provider name for a model (without parsing full options).
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
model: Model identifier
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Provider name
|
|
112
|
+
|
|
113
|
+
Examples:
|
|
114
|
+
>>> get_provider_for_model("gpt-4")
|
|
115
|
+
"litellm"
|
|
116
|
+
|
|
117
|
+
>>> get_provider_for_model("claude-3-opus")
|
|
118
|
+
"litellm"
|
|
119
|
+
"""
|
|
120
|
+
provider, _, _ = parse_model_name(model)
|
|
121
|
+
return provider
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# Model family detection for preset selection
|
|
125
|
+
def get_model_family(model: str) -> str:
|
|
126
|
+
"""Get the model family for capability detection.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
model: Model identifier
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Model family name
|
|
133
|
+
|
|
134
|
+
Examples:
|
|
135
|
+
>>> get_model_family("gpt-4-turbo")
|
|
136
|
+
"gpt-4"
|
|
137
|
+
|
|
138
|
+
>>> get_model_family("claude-3-opus-20240229")
|
|
139
|
+
"claude-3"
|
|
140
|
+
"""
|
|
141
|
+
model_lower = model.lower()
|
|
142
|
+
|
|
143
|
+
# OpenAI families
|
|
144
|
+
if "gpt-4" in model_lower:
|
|
145
|
+
return "gpt-4"
|
|
146
|
+
if "gpt-3.5" in model_lower:
|
|
147
|
+
return "gpt-3.5"
|
|
148
|
+
if "o1" in model_lower:
|
|
149
|
+
return "o1"
|
|
150
|
+
|
|
151
|
+
# Anthropic families
|
|
152
|
+
if "claude-3" in model_lower:
|
|
153
|
+
if "opus" in model_lower:
|
|
154
|
+
return "claude-3-opus"
|
|
155
|
+
elif "sonnet" in model_lower:
|
|
156
|
+
return "claude-3-sonnet"
|
|
157
|
+
elif "haiku" in model_lower:
|
|
158
|
+
return "claude-3-haiku"
|
|
159
|
+
return "claude-3"
|
|
160
|
+
if "claude-2" in model_lower:
|
|
161
|
+
return "claude-2"
|
|
162
|
+
|
|
163
|
+
# Google families
|
|
164
|
+
if "gemini-pro" in model_lower:
|
|
165
|
+
return "gemini-pro"
|
|
166
|
+
if "gemini-ultra" in model_lower:
|
|
167
|
+
return "gemini-ultra"
|
|
168
|
+
|
|
169
|
+
# Meta families
|
|
170
|
+
if "llama-2" in model_lower:
|
|
171
|
+
if "70b" in model_lower:
|
|
172
|
+
return "llama-2-70b"
|
|
173
|
+
elif "13b" in model_lower:
|
|
174
|
+
return "llama-2-13b"
|
|
175
|
+
elif "7b" in model_lower:
|
|
176
|
+
return "llama-2-7b"
|
|
177
|
+
return "llama-2"
|
|
178
|
+
if "llama-3" in model_lower:
|
|
179
|
+
return "llama-3"
|
|
180
|
+
|
|
181
|
+
# Mistral families
|
|
182
|
+
if "mixtral" in model_lower:
|
|
183
|
+
return "mixtral"
|
|
184
|
+
if "mistral" in model_lower:
|
|
185
|
+
return "mistral"
|
|
186
|
+
|
|
187
|
+
return "unknown"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
__all__ = ["parse_model_name", "get_provider_for_model", "get_model_family"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Project helpers for managing experiment collections."""
|
|
2
|
+
|
|
3
|
+
from themis.project.definitions import Project, ProjectExperiment
|
|
4
|
+
from themis.project.patterns import (
|
|
5
|
+
AblationChart,
|
|
6
|
+
AblationChartPoint,
|
|
7
|
+
AblationVariant,
|
|
8
|
+
XAbationPattern,
|
|
9
|
+
XAbationPatternApplication,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Project",
|
|
14
|
+
"ProjectExperiment",
|
|
15
|
+
"AblationChart",
|
|
16
|
+
"AblationChartPoint",
|
|
17
|
+
"AblationVariant",
|
|
18
|
+
"XAbationPattern",
|
|
19
|
+
"XAbationPatternApplication",
|
|
20
|
+
]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Project-level definitions for grouping experiments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Mapping, Sequence
|
|
7
|
+
|
|
8
|
+
from themis.experiment.definitions import ExperimentDefinition
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class ProjectExperiment:
|
|
13
|
+
"""Metadata wrapper that pairs a name with an experiment definition."""
|
|
14
|
+
|
|
15
|
+
name: str
|
|
16
|
+
definition: ExperimentDefinition
|
|
17
|
+
description: str | None = None
|
|
18
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
19
|
+
tags: tuple[str, ...] = field(default_factory=tuple)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Project:
|
|
24
|
+
"""Container that organizes multiple experiments under a shared project."""
|
|
25
|
+
|
|
26
|
+
project_id: str
|
|
27
|
+
name: str
|
|
28
|
+
description: str | None = None
|
|
29
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
30
|
+
tags: tuple[str, ...] = field(default_factory=tuple)
|
|
31
|
+
experiments: Sequence[ProjectExperiment] = field(default_factory=tuple)
|
|
32
|
+
|
|
33
|
+
def __post_init__(self) -> None:
|
|
34
|
+
self._experiment_index: dict[str, ProjectExperiment] = {}
|
|
35
|
+
normalized: list[ProjectExperiment] = []
|
|
36
|
+
for experiment in self.experiments:
|
|
37
|
+
self._register_experiment(experiment)
|
|
38
|
+
normalized.append(experiment)
|
|
39
|
+
self.experiments = tuple(normalized)
|
|
40
|
+
|
|
41
|
+
def add_experiment(self, experiment: ProjectExperiment) -> ProjectExperiment:
|
|
42
|
+
"""Attach an experiment to the project, enforcing unique names."""
|
|
43
|
+
|
|
44
|
+
self._register_experiment(experiment)
|
|
45
|
+
self.experiments = tuple(list(self.experiments) + [experiment])
|
|
46
|
+
return experiment
|
|
47
|
+
|
|
48
|
+
def create_experiment(
|
|
49
|
+
self,
|
|
50
|
+
*,
|
|
51
|
+
name: str,
|
|
52
|
+
definition: ExperimentDefinition,
|
|
53
|
+
description: str | None = None,
|
|
54
|
+
metadata: Mapping[str, Any] | None = None,
|
|
55
|
+
tags: Sequence[str] | None = None,
|
|
56
|
+
) -> ProjectExperiment:
|
|
57
|
+
"""Convenience helper to register an experiment from raw components."""
|
|
58
|
+
|
|
59
|
+
experiment = ProjectExperiment(
|
|
60
|
+
name=name,
|
|
61
|
+
description=description,
|
|
62
|
+
definition=definition,
|
|
63
|
+
metadata=dict(metadata or {}),
|
|
64
|
+
tags=tuple(tags or ()),
|
|
65
|
+
)
|
|
66
|
+
return self.add_experiment(experiment)
|
|
67
|
+
|
|
68
|
+
def get_experiment(self, name: str) -> ProjectExperiment:
|
|
69
|
+
try:
|
|
70
|
+
return self._experiment_index[name]
|
|
71
|
+
except KeyError as exc: # pragma: no cover - defensive guard
|
|
72
|
+
raise KeyError(
|
|
73
|
+
f"Experiment '{name}' not registered in project '{self.project_id}'"
|
|
74
|
+
) from exc
|
|
75
|
+
|
|
76
|
+
def metadata_for_experiment(self, name: str) -> dict[str, Any]:
|
|
77
|
+
"""Merge project-level metadata with experiment-specific overrides."""
|
|
78
|
+
|
|
79
|
+
combined = dict(self.metadata)
|
|
80
|
+
combined.update(self.get_experiment(name).metadata)
|
|
81
|
+
return combined
|
|
82
|
+
|
|
83
|
+
def list_experiment_names(self) -> tuple[str, ...]:
|
|
84
|
+
return tuple(self._experiment_index.keys())
|
|
85
|
+
|
|
86
|
+
def _register_experiment(self, experiment: ProjectExperiment) -> None:
|
|
87
|
+
if experiment.name in self._experiment_index:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Experiment '{experiment.name}' already registered "
|
|
90
|
+
f"in project '{self.project_id}'"
|
|
91
|
+
)
|
|
92
|
+
self._experiment_index[experiment.name] = experiment
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
__all__ = [
|
|
96
|
+
"Project",
|
|
97
|
+
"ProjectExperiment",
|
|
98
|
+
]
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""Reusable experiment patterns for organizing projects."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable, Mapping, Sequence
|
|
8
|
+
|
|
9
|
+
from themis.experiment.definitions import ExperimentDefinition
|
|
10
|
+
from themis.experiment.orchestrator import ExperimentReport
|
|
11
|
+
from themis.project.definitions import Project, ProjectExperiment
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _slugify(value: str) -> str:
|
|
15
|
+
text = value.strip().lower()
|
|
16
|
+
text = re.sub(r"[^a-z0-9]+", "-", text)
|
|
17
|
+
text = text.strip("-")
|
|
18
|
+
return text or "variant"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class AblationVariant:
|
|
23
|
+
value: Any
|
|
24
|
+
label: str | None = None
|
|
25
|
+
metadata: Mapping[str, Any] = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
def display_label(self) -> str:
|
|
28
|
+
return self.label or str(self.value)
|
|
29
|
+
|
|
30
|
+
def slug(self) -> str:
|
|
31
|
+
return _slugify(self.display_label())
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class AblationChartPoint:
|
|
36
|
+
x_value: Any
|
|
37
|
+
label: str
|
|
38
|
+
metric_value: float
|
|
39
|
+
metric_name: str
|
|
40
|
+
count: int
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class AblationChart:
|
|
45
|
+
title: str
|
|
46
|
+
x_label: str
|
|
47
|
+
y_label: str
|
|
48
|
+
metric_name: str
|
|
49
|
+
points: tuple[AblationChartPoint, ...]
|
|
50
|
+
|
|
51
|
+
def as_dict(self) -> dict[str, Any]:
|
|
52
|
+
return {
|
|
53
|
+
"title": self.title,
|
|
54
|
+
"x_label": self.x_label,
|
|
55
|
+
"y_label": self.y_label,
|
|
56
|
+
"metric": self.metric_name,
|
|
57
|
+
"points": [
|
|
58
|
+
{
|
|
59
|
+
"label": point.label,
|
|
60
|
+
"x": point.x_value,
|
|
61
|
+
"value": point.metric_value,
|
|
62
|
+
"count": point.count,
|
|
63
|
+
}
|
|
64
|
+
for point in self.points
|
|
65
|
+
],
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass(frozen=True)
|
|
70
|
+
class XAbationPatternApplication:
|
|
71
|
+
pattern_name: str
|
|
72
|
+
parameter_name: str
|
|
73
|
+
experiments: tuple[ProjectExperiment, ...]
|
|
74
|
+
variant_by_name: Mapping[str, AblationVariant]
|
|
75
|
+
_pattern: "XAblationPattern" = field(repr=False)
|
|
76
|
+
|
|
77
|
+
def build_chart(self, reports: Mapping[str, ExperimentReport]) -> AblationChart:
|
|
78
|
+
return self._pattern._build_chart(reports, self.variant_by_name)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class XAbationPattern:
|
|
82
|
+
"""Vary a single factor across values to compare performance."""
|
|
83
|
+
|
|
84
|
+
pattern_type = "x-ablation"
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
*,
|
|
89
|
+
name: str,
|
|
90
|
+
parameter_name: str,
|
|
91
|
+
values: Sequence[AblationVariant | Any],
|
|
92
|
+
definition_builder: Callable[[AblationVariant], ExperimentDefinition],
|
|
93
|
+
metric_name: str,
|
|
94
|
+
x_axis_label: str | None = None,
|
|
95
|
+
y_axis_label: str | None = None,
|
|
96
|
+
title: str | None = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
if not values:
|
|
99
|
+
raise ValueError("XAblationPattern requires at least one value")
|
|
100
|
+
self.name = name
|
|
101
|
+
self.parameter_name = parameter_name
|
|
102
|
+
self._variants = [self._normalize_variant(value) for value in values]
|
|
103
|
+
self._definition_builder = definition_builder
|
|
104
|
+
self.metric_name = metric_name
|
|
105
|
+
self.x_axis_label = x_axis_label or parameter_name
|
|
106
|
+
self.y_axis_label = y_axis_label or metric_name
|
|
107
|
+
self.title = title or f"{name} ({parameter_name} ablation)"
|
|
108
|
+
|
|
109
|
+
def materialize(
|
|
110
|
+
self,
|
|
111
|
+
project: Project,
|
|
112
|
+
*,
|
|
113
|
+
name_template: str | None = None,
|
|
114
|
+
description_template: str | None = None,
|
|
115
|
+
base_tags: Sequence[str] | None = None,
|
|
116
|
+
) -> XAbationPatternApplication:
|
|
117
|
+
template = name_template or "{pattern}-{value_slug}"
|
|
118
|
+
tags = tuple(base_tags or ()) + (self.pattern_type,)
|
|
119
|
+
experiments: list[ProjectExperiment] = []
|
|
120
|
+
variant_map: dict[str, AblationVariant] = {}
|
|
121
|
+
for index, variant in enumerate(self._variants):
|
|
122
|
+
experiment_name = template.format(
|
|
123
|
+
pattern=self.name,
|
|
124
|
+
parameter=self.parameter_name,
|
|
125
|
+
value=variant.value,
|
|
126
|
+
value_label=variant.display_label(),
|
|
127
|
+
value_slug=variant.slug(),
|
|
128
|
+
index=index,
|
|
129
|
+
)
|
|
130
|
+
description: str | None = None
|
|
131
|
+
if description_template is not None:
|
|
132
|
+
description = description_template.format(
|
|
133
|
+
pattern=self.name,
|
|
134
|
+
parameter=self.parameter_name,
|
|
135
|
+
value=variant.value,
|
|
136
|
+
value_label=variant.display_label(),
|
|
137
|
+
index=index,
|
|
138
|
+
)
|
|
139
|
+
metadata = {
|
|
140
|
+
"pattern": self.pattern_type,
|
|
141
|
+
"pattern_name": self.name,
|
|
142
|
+
"parameter_name": self.parameter_name,
|
|
143
|
+
"parameter_value": variant.value,
|
|
144
|
+
"parameter_label": variant.display_label(),
|
|
145
|
+
"pattern_index": index,
|
|
146
|
+
}
|
|
147
|
+
metadata.update(dict(variant.metadata))
|
|
148
|
+
definition = self._definition_builder(variant)
|
|
149
|
+
project_experiment = project.add_experiment(
|
|
150
|
+
ProjectExperiment(
|
|
151
|
+
name=experiment_name,
|
|
152
|
+
description=description,
|
|
153
|
+
definition=definition,
|
|
154
|
+
metadata=metadata,
|
|
155
|
+
tags=tuple(dict.fromkeys(tags)),
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
experiments.append(project_experiment)
|
|
159
|
+
variant_map[project_experiment.name] = variant
|
|
160
|
+
return XAbationPatternApplication(
|
|
161
|
+
pattern_name=self.name,
|
|
162
|
+
parameter_name=self.parameter_name,
|
|
163
|
+
experiments=tuple(experiments),
|
|
164
|
+
variant_by_name=variant_map,
|
|
165
|
+
_pattern=self,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _build_chart(
|
|
169
|
+
self,
|
|
170
|
+
reports: Mapping[str, ExperimentReport],
|
|
171
|
+
variant_by_name: Mapping[str, AblationVariant],
|
|
172
|
+
) -> AblationChart:
|
|
173
|
+
points: list[AblationChartPoint] = []
|
|
174
|
+
for experiment in variant_by_name:
|
|
175
|
+
variant = variant_by_name[experiment]
|
|
176
|
+
report = reports.get(experiment)
|
|
177
|
+
if report is None:
|
|
178
|
+
raise KeyError(
|
|
179
|
+
f"Missing report for experiment '{experiment}' in pattern '{self.name}'"
|
|
180
|
+
)
|
|
181
|
+
metric = report.evaluation_report.metrics.get(self.metric_name)
|
|
182
|
+
if metric is None:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"Metric '{self.metric_name}' not found for experiment '{experiment}'"
|
|
185
|
+
)
|
|
186
|
+
points.append(
|
|
187
|
+
AblationChartPoint(
|
|
188
|
+
x_value=variant.value,
|
|
189
|
+
label=variant.display_label(),
|
|
190
|
+
metric_value=metric.mean,
|
|
191
|
+
metric_name=metric.name,
|
|
192
|
+
count=metric.count,
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
ordered_points = self._order_points(points, variant_by_name)
|
|
196
|
+
return AblationChart(
|
|
197
|
+
title=self.title,
|
|
198
|
+
x_label=self.x_axis_label,
|
|
199
|
+
y_label=self.y_axis_label,
|
|
200
|
+
metric_name=self.metric_name,
|
|
201
|
+
points=tuple(ordered_points),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def _order_points(
|
|
205
|
+
self,
|
|
206
|
+
points: Sequence[AblationChartPoint],
|
|
207
|
+
variant_by_name: Mapping[str, AblationVariant],
|
|
208
|
+
) -> list[AblationChartPoint]:
|
|
209
|
+
order: dict[Any, int] = {
|
|
210
|
+
variant.value: index for index, variant in enumerate(self._variants)
|
|
211
|
+
}
|
|
212
|
+
return sorted(points, key=lambda point: order.get(point.x_value, 0))
|
|
213
|
+
|
|
214
|
+
def _normalize_variant(self, value: AblationVariant | Any) -> AblationVariant:
|
|
215
|
+
if isinstance(value, AblationVariant):
|
|
216
|
+
return AblationVariant(
|
|
217
|
+
value=value.value,
|
|
218
|
+
label=value.label,
|
|
219
|
+
metadata=dict(value.metadata),
|
|
220
|
+
)
|
|
221
|
+
return AblationVariant(value=value)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
__all__ = [
|
|
225
|
+
"AblationChart",
|
|
226
|
+
"AblationChartPoint",
|
|
227
|
+
"AblationVariant",
|
|
228
|
+
"XAblationPattern",
|
|
229
|
+
"XAblationPatternApplication",
|
|
230
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Simple registry for ModelProvider factories."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Callable, Dict
|
|
6
|
+
|
|
7
|
+
from themis.interfaces import ModelProvider
|
|
8
|
+
|
|
9
|
+
ProviderFactory = Callable[..., ModelProvider]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _ProviderRegistry:
|
|
13
|
+
def __init__(self) -> None:
|
|
14
|
+
self._factories: Dict[str, ProviderFactory] = {}
|
|
15
|
+
|
|
16
|
+
def register(self, name: str, factory: ProviderFactory) -> None:
|
|
17
|
+
key = name.lower()
|
|
18
|
+
self._factories[key] = factory
|
|
19
|
+
|
|
20
|
+
def create(self, name: str, **options) -> ModelProvider:
|
|
21
|
+
key = name.lower()
|
|
22
|
+
factory = self._factories.get(key)
|
|
23
|
+
if factory is None:
|
|
24
|
+
raise KeyError(f"No provider registered under name '{name}'")
|
|
25
|
+
return factory(**options)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_REGISTRY = _ProviderRegistry()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def register_provider(name: str, factory: ProviderFactory) -> None:
|
|
32
|
+
_REGISTRY.register(name, factory)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def create_provider(name: str, **options) -> ModelProvider:
|
|
36
|
+
return _REGISTRY.create(name, **options)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
__all__ = ["register_provider", "create_provider", "ProviderFactory"]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""FastAPI server for Themis web dashboard.
|
|
2
|
+
|
|
3
|
+
This module provides a REST API and WebSocket interface for:
|
|
4
|
+
- Listing and viewing experiment runs
|
|
5
|
+
- Comparing multiple runs
|
|
6
|
+
- Real-time monitoring of running experiments
|
|
7
|
+
- Exporting results in various formats
|
|
8
|
+
|
|
9
|
+
The server is optional and requires the 'server' extra:
|
|
10
|
+
pip install themis[server]
|
|
11
|
+
# or
|
|
12
|
+
uv pip install themis[server]
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
# Start the server
|
|
16
|
+
themis serve --port 8080
|
|
17
|
+
|
|
18
|
+
# Or programmatically
|
|
19
|
+
from themis.server import create_app
|
|
20
|
+
app = create_app(storage_path=".cache/experiments")
|
|
21
|
+
|
|
22
|
+
# Run with uvicorn
|
|
23
|
+
uvicorn themis.server:app --host 0.0.0.0 --port 8080
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from themis.server.app import create_app
|
|
27
|
+
|
|
28
|
+
__all__ = ["create_app"]
|