themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
"""Standard evaluation pipeline implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
import warnings
|
|
8
|
+
from typing import Callable, Sequence
|
|
9
|
+
|
|
10
|
+
from themis.core import entities as core_entities
|
|
11
|
+
from themis.evaluation import extractors
|
|
12
|
+
from themis.evaluation import strategies as evaluation_strategies
|
|
13
|
+
from themis.evaluation.reports import (
|
|
14
|
+
EvaluationFailure,
|
|
15
|
+
EvaluationReport,
|
|
16
|
+
MetricAggregate,
|
|
17
|
+
)
|
|
18
|
+
from themis.interfaces import Metric as MetricInterface
|
|
19
|
+
from themis.utils import tracing
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _default_reference_selector(record: core_entities.GenerationRecord):
|
|
25
|
+
"""Default reference selector from generation record.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
record: Generation record
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Reference value or None
|
|
32
|
+
"""
|
|
33
|
+
reference = record.task.reference
|
|
34
|
+
if reference is None:
|
|
35
|
+
return None
|
|
36
|
+
return reference.value
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _normalize_references(reference) -> list:
|
|
40
|
+
"""Normalize reference to list format for metric consumption.
|
|
41
|
+
|
|
42
|
+
This function converts various reference formats into a standardized list
|
|
43
|
+
that metrics can reliably consume. The normalized format is always a list
|
|
44
|
+
where each element represents one reference value.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
reference: Reference value in various formats:
|
|
48
|
+
- Reference object: Extracts .value field
|
|
49
|
+
- dict: Kept as-is in a list (for multi-value references)
|
|
50
|
+
- list/tuple: Returned as list
|
|
51
|
+
- scalar: Wrapped in a list
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
List of reference values. Each element can be:
|
|
55
|
+
- A scalar value (str, int, float, bool)
|
|
56
|
+
- A dict (for multi-value references like {"target": 122, "numbers": [...]})
|
|
57
|
+
- Any other type from the original reference
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
>>> _normalize_references(Reference(kind="answer", value="42"))
|
|
61
|
+
["42"]
|
|
62
|
+
|
|
63
|
+
>>> _normalize_references(Reference(kind="task", value={"target": 122, "numbers": [25, 50]}))
|
|
64
|
+
[{"target": 122, "numbers": [25, 50]}]
|
|
65
|
+
|
|
66
|
+
>>> _normalize_references(["yes", "no", "maybe"])
|
|
67
|
+
["yes", "no", "maybe"]
|
|
68
|
+
|
|
69
|
+
>>> _normalize_references("42")
|
|
70
|
+
["42"]
|
|
71
|
+
|
|
72
|
+
Note:
|
|
73
|
+
Metrics receive references in this normalized format and should handle
|
|
74
|
+
both simple values and dict values appropriately.
|
|
75
|
+
"""
|
|
76
|
+
if isinstance(reference, core_entities.Reference):
|
|
77
|
+
reference = reference.value
|
|
78
|
+
if isinstance(reference, list):
|
|
79
|
+
return reference
|
|
80
|
+
if isinstance(reference, tuple):
|
|
81
|
+
return list(reference)
|
|
82
|
+
return [reference]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class EvaluationPipeline:
|
|
86
|
+
"""Traditional batch evaluation pipeline.
|
|
87
|
+
|
|
88
|
+
This pipeline evaluates generation records using extractors, metrics,
|
|
89
|
+
and evaluation strategies. It supports slicing for subset analysis.
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
>>> pipeline = EvaluationPipeline(
|
|
93
|
+
... extractor=JsonFieldExtractor("answer"),
|
|
94
|
+
... metrics=[ExactMatch()]
|
|
95
|
+
... )
|
|
96
|
+
>>> report = pipeline.evaluate(records)
|
|
97
|
+
|
|
98
|
+
Attributes:
|
|
99
|
+
_extractor: Extractor for parsing model output
|
|
100
|
+
_metrics: List of metrics to compute
|
|
101
|
+
_reference_selector: Function to extract reference from record
|
|
102
|
+
_strategy_resolver: Function to resolve evaluation strategy
|
|
103
|
+
_slices: List of (name, predicate) tuples for slicing
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
*,
|
|
109
|
+
extractor,
|
|
110
|
+
metrics: Sequence[MetricInterface],
|
|
111
|
+
reference_selector: Callable[[core_entities.GenerationRecord], object]
|
|
112
|
+
| None = None,
|
|
113
|
+
strategy_resolver: Callable[
|
|
114
|
+
[core_entities.GenerationRecord], evaluation_strategies.EvaluationStrategy
|
|
115
|
+
]
|
|
116
|
+
| None = None,
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Initialize evaluation pipeline.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
extractor: Extractor for parsing model output
|
|
122
|
+
metrics: List of metrics to compute
|
|
123
|
+
reference_selector: Optional function to extract reference from record.
|
|
124
|
+
If provided, this takes precedence over item.reference from strategies.
|
|
125
|
+
strategy_resolver: Optional function to resolve evaluation strategy.
|
|
126
|
+
If using a custom reference_selector with DefaultEvaluationStrategy,
|
|
127
|
+
the selector will take precedence.
|
|
128
|
+
|
|
129
|
+
Note:
|
|
130
|
+
When using DefaultEvaluationStrategy with a custom reference_selector,
|
|
131
|
+
the reference_selector will override the default behavior. Consider
|
|
132
|
+
using a custom strategy if you need more control over reference selection.
|
|
133
|
+
"""
|
|
134
|
+
self._extractor = extractor
|
|
135
|
+
self._metrics = list(metrics)
|
|
136
|
+
self._reference_selector = reference_selector
|
|
137
|
+
self._has_custom_reference_selector = reference_selector is not None
|
|
138
|
+
self._strategy_resolver = strategy_resolver or (
|
|
139
|
+
lambda record: evaluation_strategies.DefaultEvaluationStrategy()
|
|
140
|
+
)
|
|
141
|
+
self._slices: list[
|
|
142
|
+
tuple[str, Callable[[core_entities.GenerationRecord], bool]]
|
|
143
|
+
] = []
|
|
144
|
+
|
|
145
|
+
# Validation: warn if custom reference_selector is used with default strategy
|
|
146
|
+
if self._has_custom_reference_selector and strategy_resolver is None:
|
|
147
|
+
warnings.warn(
|
|
148
|
+
"Custom reference_selector provided without custom strategy_resolver. "
|
|
149
|
+
"The reference_selector will take precedence over DefaultEvaluationStrategy's "
|
|
150
|
+
"reference handling. If you need more control, consider providing a custom "
|
|
151
|
+
"strategy_resolver that sets reference=None in EvaluationItem.",
|
|
152
|
+
UserWarning,
|
|
153
|
+
stacklevel=2,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def evaluate(
|
|
157
|
+
self, records: Sequence[core_entities.GenerationRecord]
|
|
158
|
+
) -> EvaluationReport:
|
|
159
|
+
"""Evaluate generation records.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
records: Generation records to evaluate
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Evaluation report with metrics and failures
|
|
166
|
+
"""
|
|
167
|
+
with tracing.span("evaluate_pipeline", total_records=len(records)):
|
|
168
|
+
per_metric: dict[str, list[core_entities.MetricScore]] = {
|
|
169
|
+
metric.name: [] for metric in self._metrics
|
|
170
|
+
}
|
|
171
|
+
failures: list[EvaluationFailure] = []
|
|
172
|
+
per_record: list[core_entities.EvaluationRecord] = []
|
|
173
|
+
slice_members: dict[str, set[str]] = {
|
|
174
|
+
name: set() for name, _ in self._slices
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
for record in records:
|
|
178
|
+
with tracing.span("evaluate_record"):
|
|
179
|
+
logger.debug(
|
|
180
|
+
"Evaluating sample %s with %s metric(s)",
|
|
181
|
+
record.task.metadata.get("dataset_id")
|
|
182
|
+
or record.task.metadata.get("sample_id"),
|
|
183
|
+
len(self._metrics),
|
|
184
|
+
)
|
|
185
|
+
strategy = self._strategy_resolver(record)
|
|
186
|
+
task_metadata = record.task.metadata
|
|
187
|
+
sample_id = task_metadata.get("dataset_id") or task_metadata.get(
|
|
188
|
+
"sample_id"
|
|
189
|
+
)
|
|
190
|
+
for name, fn in self._slices:
|
|
191
|
+
try:
|
|
192
|
+
if fn(record) and sample_id is not None:
|
|
193
|
+
slice_members[name].add(sample_id)
|
|
194
|
+
except Exception:
|
|
195
|
+
pass
|
|
196
|
+
eval_items = list(strategy.prepare(record))
|
|
197
|
+
item_scores: list[core_entities.MetricScore] = []
|
|
198
|
+
record_failures: list[str] = []
|
|
199
|
+
|
|
200
|
+
for item in eval_items:
|
|
201
|
+
if item.record.output is None:
|
|
202
|
+
message = "Missing model output"
|
|
203
|
+
failures.append(
|
|
204
|
+
EvaluationFailure(sample_id=sample_id, message=message)
|
|
205
|
+
)
|
|
206
|
+
record_failures.append(message)
|
|
207
|
+
continue
|
|
208
|
+
try:
|
|
209
|
+
with tracing.span("extract"):
|
|
210
|
+
prediction = self._extractor.extract(
|
|
211
|
+
item.record.output.text
|
|
212
|
+
)
|
|
213
|
+
except extractors.FieldExtractionError as exc:
|
|
214
|
+
message = str(exc)
|
|
215
|
+
failures.append(
|
|
216
|
+
EvaluationFailure(sample_id=sample_id, message=message)
|
|
217
|
+
)
|
|
218
|
+
record_failures.append(message)
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
# CRITICAL: Always call reference_selector if provided (takes precedence)
|
|
222
|
+
# This fixes the issue where DefaultEvaluationStrategy's reference
|
|
223
|
+
# would prevent custom reference_selector from being called
|
|
224
|
+
if self._has_custom_reference_selector:
|
|
225
|
+
reference = self._reference_selector(record)
|
|
226
|
+
elif item.reference is not None:
|
|
227
|
+
reference = item.reference
|
|
228
|
+
else:
|
|
229
|
+
reference = _default_reference_selector(record)
|
|
230
|
+
|
|
231
|
+
references = (
|
|
232
|
+
_normalize_references(reference)
|
|
233
|
+
if reference is not None
|
|
234
|
+
else []
|
|
235
|
+
)
|
|
236
|
+
metadata = {"sample_id": sample_id}
|
|
237
|
+
extract_start = time.perf_counter()
|
|
238
|
+
item_scores_for_item: list[core_entities.MetricScore] = []
|
|
239
|
+
for metric in self._metrics:
|
|
240
|
+
requires_reference = getattr(
|
|
241
|
+
metric, "requires_reference", True
|
|
242
|
+
)
|
|
243
|
+
if requires_reference and not references:
|
|
244
|
+
message = (
|
|
245
|
+
f"Missing reference for metric '{metric.name}'"
|
|
246
|
+
)
|
|
247
|
+
failures.append(
|
|
248
|
+
EvaluationFailure(
|
|
249
|
+
sample_id=sample_id, message=message
|
|
250
|
+
)
|
|
251
|
+
)
|
|
252
|
+
record_failures.append(message)
|
|
253
|
+
continue
|
|
254
|
+
metric_start = time.perf_counter()
|
|
255
|
+
try:
|
|
256
|
+
with tracing.span(
|
|
257
|
+
"compute_metric", metric_name=metric.name
|
|
258
|
+
):
|
|
259
|
+
score = metric.compute(
|
|
260
|
+
prediction=prediction,
|
|
261
|
+
references=references,
|
|
262
|
+
metadata=metadata,
|
|
263
|
+
)
|
|
264
|
+
score.metadata["evaluation_time_ms"] = (
|
|
265
|
+
time.perf_counter() - metric_start
|
|
266
|
+
) * 1000
|
|
267
|
+
item_scores_for_item.append(score)
|
|
268
|
+
except Exception as exc: # pragma: no cover - guarded
|
|
269
|
+
message = (
|
|
270
|
+
f"Metric '{metric.name}' failed for sample {sample_id}: {exc}"
|
|
271
|
+
)
|
|
272
|
+
logger.warning(message)
|
|
273
|
+
failures.append(
|
|
274
|
+
EvaluationFailure(
|
|
275
|
+
sample_id=sample_id, message=message
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
record_failures.append(message)
|
|
279
|
+
extraction_duration = (
|
|
280
|
+
time.perf_counter() - extract_start
|
|
281
|
+
) * 1000
|
|
282
|
+
for score in item_scores_for_item:
|
|
283
|
+
score.metadata.setdefault(
|
|
284
|
+
"extraction_time_ms", extraction_duration
|
|
285
|
+
)
|
|
286
|
+
item_scores.extend(item_scores_for_item)
|
|
287
|
+
|
|
288
|
+
aggregated_scores = strategy.aggregate(record, item_scores)
|
|
289
|
+
for score in aggregated_scores:
|
|
290
|
+
per_metric[score.metric_name].append(score)
|
|
291
|
+
per_record.append(
|
|
292
|
+
core_entities.EvaluationRecord(
|
|
293
|
+
sample_id=sample_id,
|
|
294
|
+
scores=aggregated_scores,
|
|
295
|
+
failures=record_failures,
|
|
296
|
+
)
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
aggregates = {
|
|
300
|
+
name: MetricAggregate.from_scores(name, scores)
|
|
301
|
+
for name, scores in per_metric.items()
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
return EvaluationReport(
|
|
305
|
+
metrics=aggregates,
|
|
306
|
+
failures=failures,
|
|
307
|
+
records=per_record,
|
|
308
|
+
slices=self._compute_slice_aggregates(per_metric, slice_members),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
def register_slice(
|
|
312
|
+
self, name: str, fn: Callable[[core_entities.GenerationRecord], bool]
|
|
313
|
+
) -> None:
|
|
314
|
+
"""Register a slice for subset analysis.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
name: Slice name
|
|
318
|
+
fn: Predicate function to determine slice membership
|
|
319
|
+
"""
|
|
320
|
+
self._slices.append((name, fn))
|
|
321
|
+
|
|
322
|
+
def _compute_slice_aggregates(
|
|
323
|
+
self,
|
|
324
|
+
per_metric: dict[str, list[core_entities.MetricScore]],
|
|
325
|
+
slice_members: dict[str, set[str]],
|
|
326
|
+
) -> dict[str, dict[str, MetricAggregate]]:
|
|
327
|
+
"""Compute metric aggregates for each slice.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
per_metric: Scores by metric name
|
|
331
|
+
slice_members: Sample IDs by slice name
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Nested dict of slice -> metric -> aggregate
|
|
335
|
+
"""
|
|
336
|
+
if not slice_members:
|
|
337
|
+
return {}
|
|
338
|
+
slice_aggregates: dict[str, dict[str, MetricAggregate]] = {}
|
|
339
|
+
for name, members in slice_members.items():
|
|
340
|
+
slice_scores_by_metric: dict[str, list[core_entities.MetricScore]] = {}
|
|
341
|
+
for metric_name, scores in per_metric.items():
|
|
342
|
+
filtered = [s for s in scores if s.metadata.get("sample_id") in members]
|
|
343
|
+
slice_scores_by_metric[metric_name] = filtered
|
|
344
|
+
slice_aggregates[name] = {
|
|
345
|
+
metric_name: MetricAggregate.from_scores(metric_name, scores)
|
|
346
|
+
for metric_name, scores in slice_scores_by_metric.items()
|
|
347
|
+
}
|
|
348
|
+
return slice_aggregates
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""Evaluation report data structures."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from statistics import mean
|
|
7
|
+
from typing import Dict, List, Literal, Sequence
|
|
8
|
+
|
|
9
|
+
from themis.core import entities as core_entities
|
|
10
|
+
from themis.evaluation.statistics import (
|
|
11
|
+
bootstrap_ci,
|
|
12
|
+
cohens_d,
|
|
13
|
+
cohens_h,
|
|
14
|
+
holm_bonferroni,
|
|
15
|
+
paired_permutation_test,
|
|
16
|
+
paired_t_test,
|
|
17
|
+
permutation_test,
|
|
18
|
+
)
|
|
19
|
+
from themis.evaluation.statistics.types import (
|
|
20
|
+
BootstrapResult,
|
|
21
|
+
ComparisonResult,
|
|
22
|
+
EffectSize,
|
|
23
|
+
PermutationTestResult,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class EvaluationFailure:
|
|
29
|
+
sample_id: str | None
|
|
30
|
+
message: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class MetricAggregate:
|
|
35
|
+
name: str
|
|
36
|
+
count: int
|
|
37
|
+
mean: float
|
|
38
|
+
per_sample: List[core_entities.MetricScore]
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_scores(
|
|
42
|
+
cls, name: str, scores: List[core_entities.MetricScore]
|
|
43
|
+
) -> "MetricAggregate":
|
|
44
|
+
if not scores:
|
|
45
|
+
return cls(name=name, count=0, mean=0.0, per_sample=[])
|
|
46
|
+
return cls(
|
|
47
|
+
name=name,
|
|
48
|
+
count=len(scores),
|
|
49
|
+
mean=mean(score.value for score in scores),
|
|
50
|
+
per_sample=scores,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class EvaluationReport:
|
|
56
|
+
metrics: dict[str, MetricAggregate]
|
|
57
|
+
failures: List[EvaluationFailure]
|
|
58
|
+
records: List[core_entities.EvaluationRecord]
|
|
59
|
+
slices: dict[str, dict[str, MetricAggregate]] = field(default_factory=dict)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _metric_values(report: EvaluationReport, metric_name: str) -> list[float]:
|
|
63
|
+
agg = report.metrics.get(metric_name)
|
|
64
|
+
if not agg:
|
|
65
|
+
return []
|
|
66
|
+
return [s.value for s in agg.per_sample]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _metric_values_by_sample(
|
|
70
|
+
report: EvaluationReport, metric_name: str
|
|
71
|
+
) -> dict[str, float]:
|
|
72
|
+
values: dict[str, float] = {}
|
|
73
|
+
for record in report.records:
|
|
74
|
+
if not record.sample_id:
|
|
75
|
+
continue
|
|
76
|
+
for score in record.scores:
|
|
77
|
+
if score.metric_name == metric_name:
|
|
78
|
+
values[record.sample_id] = score.value
|
|
79
|
+
break
|
|
80
|
+
return values
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def aligned_metric_values(
|
|
84
|
+
report_a: EvaluationReport, report_b: EvaluationReport, metric_name: str
|
|
85
|
+
) -> tuple[list[float], list[float]]:
|
|
86
|
+
values_a = _metric_values_by_sample(report_a, metric_name)
|
|
87
|
+
values_b = _metric_values_by_sample(report_b, metric_name)
|
|
88
|
+
common_ids = sorted(set(values_a) & set(values_b))
|
|
89
|
+
if not common_ids:
|
|
90
|
+
raise ValueError(f"No overlapping sample_ids for metric '{metric_name}'")
|
|
91
|
+
aligned_a = [values_a[sample_id] for sample_id in common_ids]
|
|
92
|
+
aligned_b = [values_b[sample_id] for sample_id in common_ids]
|
|
93
|
+
return aligned_a, aligned_b
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def ci_for_metric(
|
|
97
|
+
report: EvaluationReport,
|
|
98
|
+
metric_name: str,
|
|
99
|
+
confidence_level: float = 0.95,
|
|
100
|
+
n_bootstrap: int = 10000,
|
|
101
|
+
) -> BootstrapResult:
|
|
102
|
+
values = _metric_values(report, metric_name)
|
|
103
|
+
if not values:
|
|
104
|
+
raise ValueError(f"No scores for metric '{metric_name}'")
|
|
105
|
+
return bootstrap_ci(
|
|
106
|
+
values, n_bootstrap=n_bootstrap, confidence_level=confidence_level
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def permutation_test_for_metric(
|
|
111
|
+
report_a: EvaluationReport,
|
|
112
|
+
report_b: EvaluationReport,
|
|
113
|
+
metric_name: str,
|
|
114
|
+
statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
|
|
115
|
+
n_permutations: int = 10000,
|
|
116
|
+
seed: int | None = None,
|
|
117
|
+
align_by_sample_id: bool = True,
|
|
118
|
+
) -> PermutationTestResult:
|
|
119
|
+
if align_by_sample_id:
|
|
120
|
+
values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
|
|
121
|
+
else:
|
|
122
|
+
values_a = _metric_values(report_a, metric_name)
|
|
123
|
+
values_b = _metric_values(report_b, metric_name)
|
|
124
|
+
if not values_a or not values_b:
|
|
125
|
+
raise ValueError(f"Both reports must have scores for metric '{metric_name}'")
|
|
126
|
+
return permutation_test(
|
|
127
|
+
values_a,
|
|
128
|
+
values_b,
|
|
129
|
+
statistic=statistic,
|
|
130
|
+
n_permutations=n_permutations,
|
|
131
|
+
seed=seed,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def paired_permutation_test_for_metric(
|
|
136
|
+
report_a: EvaluationReport,
|
|
137
|
+
report_b: EvaluationReport,
|
|
138
|
+
metric_name: str,
|
|
139
|
+
statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
|
|
140
|
+
n_permutations: int = 10000,
|
|
141
|
+
seed: int | None = None,
|
|
142
|
+
) -> PermutationTestResult:
|
|
143
|
+
values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
|
|
144
|
+
return paired_permutation_test(
|
|
145
|
+
values_a,
|
|
146
|
+
values_b,
|
|
147
|
+
statistic=statistic,
|
|
148
|
+
n_permutations=n_permutations,
|
|
149
|
+
seed=seed,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def cohens_h_for_metric(
|
|
154
|
+
report_a: EvaluationReport,
|
|
155
|
+
report_b: EvaluationReport,
|
|
156
|
+
metric_name: str,
|
|
157
|
+
) -> EffectSize:
|
|
158
|
+
agg_a = report_a.metrics.get(metric_name)
|
|
159
|
+
agg_b = report_b.metrics.get(metric_name)
|
|
160
|
+
if not agg_a or not agg_b:
|
|
161
|
+
raise ValueError(f"Both reports must have aggregate for metric '{metric_name}'")
|
|
162
|
+
return cohens_h(agg_a.mean, agg_b.mean)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def cohens_d_for_metric(
|
|
166
|
+
report_a: EvaluationReport,
|
|
167
|
+
report_b: EvaluationReport,
|
|
168
|
+
metric_name: str,
|
|
169
|
+
) -> EffectSize:
|
|
170
|
+
values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
|
|
171
|
+
if len(values_a) < 2 or len(values_b) < 2:
|
|
172
|
+
raise ValueError("Each group must have at least 2 values for Cohen's d")
|
|
173
|
+
return cohens_d(values_a, values_b)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def paired_t_test_for_metric(
|
|
177
|
+
report_a: EvaluationReport,
|
|
178
|
+
report_b: EvaluationReport,
|
|
179
|
+
metric_name: str,
|
|
180
|
+
significance_level: float = 0.05,
|
|
181
|
+
) -> ComparisonResult:
|
|
182
|
+
values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
|
|
183
|
+
result = paired_t_test(values_a, values_b, significance_level=significance_level)
|
|
184
|
+
return ComparisonResult(
|
|
185
|
+
metric_name=metric_name,
|
|
186
|
+
baseline_mean=result.baseline_mean,
|
|
187
|
+
treatment_mean=result.treatment_mean,
|
|
188
|
+
difference=result.difference,
|
|
189
|
+
relative_change=result.relative_change,
|
|
190
|
+
t_statistic=result.t_statistic,
|
|
191
|
+
p_value=result.p_value,
|
|
192
|
+
is_significant=result.is_significant,
|
|
193
|
+
baseline_ci=result.baseline_ci,
|
|
194
|
+
treatment_ci=result.treatment_ci,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _slice_metric_values(
|
|
199
|
+
report: EvaluationReport, slice_name: str, metric_name: str
|
|
200
|
+
) -> list[float]:
|
|
201
|
+
slice_map = report.slices.get(slice_name)
|
|
202
|
+
if not slice_map:
|
|
203
|
+
return []
|
|
204
|
+
agg = slice_map.get(metric_name)
|
|
205
|
+
if not agg:
|
|
206
|
+
return []
|
|
207
|
+
return [s.value for s in agg.per_sample]
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def ci_for_slice_metric(
|
|
211
|
+
report: EvaluationReport,
|
|
212
|
+
slice_name: str,
|
|
213
|
+
metric_name: str,
|
|
214
|
+
confidence_level: float = 0.95,
|
|
215
|
+
n_bootstrap: int = 10000,
|
|
216
|
+
) -> BootstrapResult:
|
|
217
|
+
values = _slice_metric_values(report, slice_name, metric_name)
|
|
218
|
+
if not values:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"No scores for metric '{metric_name}' in slice '{slice_name}'"
|
|
221
|
+
)
|
|
222
|
+
return bootstrap_ci(
|
|
223
|
+
values, n_bootstrap=n_bootstrap, confidence_level=confidence_level
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def compare_reports_with_holm(
|
|
228
|
+
report_a: EvaluationReport,
|
|
229
|
+
report_b: EvaluationReport,
|
|
230
|
+
metric_names: Sequence[str],
|
|
231
|
+
statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
|
|
232
|
+
n_permutations: int = 10000,
|
|
233
|
+
seed: int | None = None,
|
|
234
|
+
paired: bool = True,
|
|
235
|
+
) -> Dict[str, object]:
|
|
236
|
+
p_values: list[float] = []
|
|
237
|
+
pt_results: Dict[str, PermutationTestResult] = {}
|
|
238
|
+
for name in metric_names:
|
|
239
|
+
if paired:
|
|
240
|
+
pt = paired_permutation_test_for_metric(
|
|
241
|
+
report_a,
|
|
242
|
+
report_b,
|
|
243
|
+
name,
|
|
244
|
+
statistic=statistic,
|
|
245
|
+
n_permutations=n_permutations,
|
|
246
|
+
seed=seed,
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
pt = permutation_test_for_metric(
|
|
250
|
+
report_a,
|
|
251
|
+
report_b,
|
|
252
|
+
name,
|
|
253
|
+
statistic=statistic,
|
|
254
|
+
n_permutations=n_permutations,
|
|
255
|
+
seed=seed,
|
|
256
|
+
align_by_sample_id=True,
|
|
257
|
+
)
|
|
258
|
+
pt_results[name] = pt
|
|
259
|
+
p_values.append(pt.p_value)
|
|
260
|
+
corrected = holm_bonferroni(p_values)
|
|
261
|
+
return {
|
|
262
|
+
"per_metric": pt_results,
|
|
263
|
+
"holm_significant": corrected,
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def confusion_matrix(
|
|
268
|
+
labels_true: Sequence[str], labels_pred: Sequence[str]
|
|
269
|
+
) -> Dict[str, Dict[str, int]]:
|
|
270
|
+
if len(labels_true) != len(labels_pred):
|
|
271
|
+
raise ValueError("labels_true and labels_pred must have same length")
|
|
272
|
+
cm: Dict[str, Dict[str, int]] = {}
|
|
273
|
+
for t, p in zip(labels_true, labels_pred):
|
|
274
|
+
cm.setdefault(t, {})
|
|
275
|
+
cm[t][p] = cm[t].get(p, 0) + 1
|
|
276
|
+
return cm
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
__all__ = [
|
|
280
|
+
"EvaluationFailure",
|
|
281
|
+
"MetricAggregate",
|
|
282
|
+
"EvaluationReport",
|
|
283
|
+
"aligned_metric_values",
|
|
284
|
+
"ci_for_metric",
|
|
285
|
+
"ci_for_slice_metric",
|
|
286
|
+
"permutation_test_for_metric",
|
|
287
|
+
"paired_permutation_test_for_metric",
|
|
288
|
+
"cohens_h_for_metric",
|
|
289
|
+
"cohens_d_for_metric",
|
|
290
|
+
"paired_t_test_for_metric",
|
|
291
|
+
"confusion_matrix",
|
|
292
|
+
"compare_reports_with_holm",
|
|
293
|
+
]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Statistical analysis utilities for experiment evaluation results.
|
|
2
|
+
|
|
3
|
+
This module provides statistical analysis tools for computing confidence intervals,
|
|
4
|
+
significance tests, and statistical comparisons across experiment runs.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from .bootstrap import bootstrap_ci
|
|
10
|
+
from .confidence_intervals import (
|
|
11
|
+
compute_confidence_interval,
|
|
12
|
+
compute_statistical_summary,
|
|
13
|
+
)
|
|
14
|
+
from .effect_sizes import cohens_d, cohens_h
|
|
15
|
+
from .hypothesis_tests import (
|
|
16
|
+
compare_metrics,
|
|
17
|
+
holm_bonferroni,
|
|
18
|
+
paired_permutation_test,
|
|
19
|
+
paired_t_test,
|
|
20
|
+
permutation_test,
|
|
21
|
+
)
|
|
22
|
+
from .types import (
|
|
23
|
+
BootstrapResult,
|
|
24
|
+
ComparisonResult,
|
|
25
|
+
ConfidenceInterval,
|
|
26
|
+
EffectSize,
|
|
27
|
+
PermutationTestResult,
|
|
28
|
+
StatisticalSummary,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
# Types
|
|
33
|
+
"ConfidenceInterval",
|
|
34
|
+
"StatisticalSummary",
|
|
35
|
+
"ComparisonResult",
|
|
36
|
+
"PermutationTestResult",
|
|
37
|
+
"BootstrapResult",
|
|
38
|
+
"EffectSize",
|
|
39
|
+
# Confidence intervals
|
|
40
|
+
"compute_confidence_interval",
|
|
41
|
+
"compute_statistical_summary",
|
|
42
|
+
# Hypothesis tests
|
|
43
|
+
"compare_metrics",
|
|
44
|
+
"permutation_test",
|
|
45
|
+
"paired_permutation_test",
|
|
46
|
+
"paired_t_test",
|
|
47
|
+
"holm_bonferroni",
|
|
48
|
+
# Bootstrap
|
|
49
|
+
"bootstrap_ci",
|
|
50
|
+
# Effect sizes
|
|
51
|
+
"cohens_h",
|
|
52
|
+
"cohens_d",
|
|
53
|
+
]
|