themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. themis/__init__.py +12 -1
  2. themis/_version.py +2 -2
  3. themis/api.py +343 -0
  4. themis/backends/__init__.py +17 -0
  5. themis/backends/execution.py +197 -0
  6. themis/backends/storage.py +260 -0
  7. themis/cli/__init__.py +5 -0
  8. themis/cli/__main__.py +6 -0
  9. themis/cli/commands/__init__.py +19 -0
  10. themis/cli/commands/benchmarks.py +221 -0
  11. themis/cli/commands/comparison.py +394 -0
  12. themis/cli/commands/config_commands.py +244 -0
  13. themis/cli/commands/cost.py +214 -0
  14. themis/cli/commands/demo.py +68 -0
  15. themis/cli/commands/info.py +90 -0
  16. themis/cli/commands/leaderboard.py +362 -0
  17. themis/cli/commands/math_benchmarks.py +318 -0
  18. themis/cli/commands/mcq_benchmarks.py +207 -0
  19. themis/cli/commands/results.py +252 -0
  20. themis/cli/commands/sample_run.py +244 -0
  21. themis/cli/commands/visualize.py +299 -0
  22. themis/cli/main.py +463 -0
  23. themis/cli/new_project.py +33 -0
  24. themis/cli/utils.py +51 -0
  25. themis/comparison/__init__.py +25 -0
  26. themis/comparison/engine.py +348 -0
  27. themis/comparison/reports.py +283 -0
  28. themis/comparison/statistics.py +402 -0
  29. themis/config/__init__.py +19 -0
  30. themis/config/loader.py +27 -0
  31. themis/config/registry.py +34 -0
  32. themis/config/runtime.py +214 -0
  33. themis/config/schema.py +112 -0
  34. themis/core/__init__.py +5 -0
  35. themis/core/conversation.py +354 -0
  36. themis/core/entities.py +184 -0
  37. themis/core/serialization.py +231 -0
  38. themis/core/tools.py +393 -0
  39. themis/core/types.py +141 -0
  40. themis/datasets/__init__.py +273 -0
  41. themis/datasets/base.py +264 -0
  42. themis/datasets/commonsense_qa.py +174 -0
  43. themis/datasets/competition_math.py +265 -0
  44. themis/datasets/coqa.py +133 -0
  45. themis/datasets/gpqa.py +190 -0
  46. themis/datasets/gsm8k.py +123 -0
  47. themis/datasets/gsm_symbolic.py +124 -0
  48. themis/datasets/math500.py +122 -0
  49. themis/datasets/med_qa.py +179 -0
  50. themis/datasets/medmcqa.py +169 -0
  51. themis/datasets/mmlu_pro.py +262 -0
  52. themis/datasets/piqa.py +146 -0
  53. themis/datasets/registry.py +201 -0
  54. themis/datasets/schema.py +245 -0
  55. themis/datasets/sciq.py +150 -0
  56. themis/datasets/social_i_qa.py +151 -0
  57. themis/datasets/super_gpqa.py +263 -0
  58. themis/evaluation/__init__.py +1 -0
  59. themis/evaluation/conditional.py +410 -0
  60. themis/evaluation/extractors/__init__.py +19 -0
  61. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  62. themis/evaluation/extractors/exceptions.py +7 -0
  63. themis/evaluation/extractors/identity_extractor.py +29 -0
  64. themis/evaluation/extractors/json_field_extractor.py +45 -0
  65. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  66. themis/evaluation/extractors/regex_extractor.py +43 -0
  67. themis/evaluation/math_verify_utils.py +87 -0
  68. themis/evaluation/metrics/__init__.py +21 -0
  69. themis/evaluation/metrics/code/__init__.py +19 -0
  70. themis/evaluation/metrics/code/codebleu.py +144 -0
  71. themis/evaluation/metrics/code/execution.py +280 -0
  72. themis/evaluation/metrics/code/pass_at_k.py +181 -0
  73. themis/evaluation/metrics/composite_metric.py +47 -0
  74. themis/evaluation/metrics/consistency_metric.py +80 -0
  75. themis/evaluation/metrics/exact_match.py +51 -0
  76. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  77. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  78. themis/evaluation/metrics/nlp/__init__.py +21 -0
  79. themis/evaluation/metrics/nlp/bertscore.py +138 -0
  80. themis/evaluation/metrics/nlp/bleu.py +129 -0
  81. themis/evaluation/metrics/nlp/meteor.py +153 -0
  82. themis/evaluation/metrics/nlp/rouge.py +136 -0
  83. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  84. themis/evaluation/metrics/response_length.py +33 -0
  85. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  86. themis/evaluation/pipeline.py +49 -0
  87. themis/evaluation/pipelines/__init__.py +15 -0
  88. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  89. themis/evaluation/pipelines/standard_pipeline.py +348 -0
  90. themis/evaluation/reports.py +293 -0
  91. themis/evaluation/statistics/__init__.py +53 -0
  92. themis/evaluation/statistics/bootstrap.py +79 -0
  93. themis/evaluation/statistics/confidence_intervals.py +121 -0
  94. themis/evaluation/statistics/distributions.py +207 -0
  95. themis/evaluation/statistics/effect_sizes.py +124 -0
  96. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  97. themis/evaluation/statistics/types.py +139 -0
  98. themis/evaluation/strategies/__init__.py +13 -0
  99. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  100. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  101. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  102. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  103. themis/experiment/__init__.py +5 -0
  104. themis/experiment/builder.py +151 -0
  105. themis/experiment/cache_manager.py +134 -0
  106. themis/experiment/comparison.py +631 -0
  107. themis/experiment/cost.py +310 -0
  108. themis/experiment/definitions.py +62 -0
  109. themis/experiment/export.py +798 -0
  110. themis/experiment/export_csv.py +159 -0
  111. themis/experiment/integration_manager.py +104 -0
  112. themis/experiment/math.py +192 -0
  113. themis/experiment/mcq.py +169 -0
  114. themis/experiment/orchestrator.py +415 -0
  115. themis/experiment/pricing.py +317 -0
  116. themis/experiment/storage.py +1458 -0
  117. themis/experiment/visualization.py +588 -0
  118. themis/generation/__init__.py +1 -0
  119. themis/generation/agentic_runner.py +420 -0
  120. themis/generation/batching.py +254 -0
  121. themis/generation/clients.py +143 -0
  122. themis/generation/conversation_runner.py +236 -0
  123. themis/generation/plan.py +456 -0
  124. themis/generation/providers/litellm_provider.py +221 -0
  125. themis/generation/providers/vllm_provider.py +135 -0
  126. themis/generation/router.py +34 -0
  127. themis/generation/runner.py +207 -0
  128. themis/generation/strategies.py +98 -0
  129. themis/generation/templates.py +71 -0
  130. themis/generation/turn_strategies.py +393 -0
  131. themis/generation/types.py +9 -0
  132. themis/integrations/__init__.py +0 -0
  133. themis/integrations/huggingface.py +72 -0
  134. themis/integrations/wandb.py +77 -0
  135. themis/interfaces/__init__.py +169 -0
  136. themis/presets/__init__.py +10 -0
  137. themis/presets/benchmarks.py +354 -0
  138. themis/presets/models.py +190 -0
  139. themis/project/__init__.py +20 -0
  140. themis/project/definitions.py +98 -0
  141. themis/project/patterns.py +230 -0
  142. themis/providers/__init__.py +5 -0
  143. themis/providers/registry.py +39 -0
  144. themis/server/__init__.py +28 -0
  145. themis/server/app.py +337 -0
  146. themis/utils/api_generator.py +379 -0
  147. themis/utils/cost_tracking.py +376 -0
  148. themis/utils/dashboard.py +452 -0
  149. themis/utils/logging_utils.py +41 -0
  150. themis/utils/progress.py +58 -0
  151. themis/utils/tracing.py +320 -0
  152. themis_eval-0.2.0.dist-info/METADATA +596 -0
  153. themis_eval-0.2.0.dist-info/RECORD +157 -0
  154. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
  155. themis_eval-0.1.0.dist-info/METADATA +0 -758
  156. themis_eval-0.1.0.dist-info/RECORD +0 -8
  157. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
  158. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ """Standard evaluation pipeline implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ import warnings
8
+ from typing import Callable, Sequence
9
+
10
+ from themis.core import entities as core_entities
11
+ from themis.evaluation import extractors
12
+ from themis.evaluation import strategies as evaluation_strategies
13
+ from themis.evaluation.reports import (
14
+ EvaluationFailure,
15
+ EvaluationReport,
16
+ MetricAggregate,
17
+ )
18
+ from themis.interfaces import Metric as MetricInterface
19
+ from themis.utils import tracing
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def _default_reference_selector(record: core_entities.GenerationRecord):
25
+ """Default reference selector from generation record.
26
+
27
+ Args:
28
+ record: Generation record
29
+
30
+ Returns:
31
+ Reference value or None
32
+ """
33
+ reference = record.task.reference
34
+ if reference is None:
35
+ return None
36
+ return reference.value
37
+
38
+
39
+ def _normalize_references(reference) -> list:
40
+ """Normalize reference to list format for metric consumption.
41
+
42
+ This function converts various reference formats into a standardized list
43
+ that metrics can reliably consume. The normalized format is always a list
44
+ where each element represents one reference value.
45
+
46
+ Args:
47
+ reference: Reference value in various formats:
48
+ - Reference object: Extracts .value field
49
+ - dict: Kept as-is in a list (for multi-value references)
50
+ - list/tuple: Returned as list
51
+ - scalar: Wrapped in a list
52
+
53
+ Returns:
54
+ List of reference values. Each element can be:
55
+ - A scalar value (str, int, float, bool)
56
+ - A dict (for multi-value references like {"target": 122, "numbers": [...]})
57
+ - Any other type from the original reference
58
+
59
+ Examples:
60
+ >>> _normalize_references(Reference(kind="answer", value="42"))
61
+ ["42"]
62
+
63
+ >>> _normalize_references(Reference(kind="task", value={"target": 122, "numbers": [25, 50]}))
64
+ [{"target": 122, "numbers": [25, 50]}]
65
+
66
+ >>> _normalize_references(["yes", "no", "maybe"])
67
+ ["yes", "no", "maybe"]
68
+
69
+ >>> _normalize_references("42")
70
+ ["42"]
71
+
72
+ Note:
73
+ Metrics receive references in this normalized format and should handle
74
+ both simple values and dict values appropriately.
75
+ """
76
+ if isinstance(reference, core_entities.Reference):
77
+ reference = reference.value
78
+ if isinstance(reference, list):
79
+ return reference
80
+ if isinstance(reference, tuple):
81
+ return list(reference)
82
+ return [reference]
83
+
84
+
85
+ class EvaluationPipeline:
86
+ """Traditional batch evaluation pipeline.
87
+
88
+ This pipeline evaluates generation records using extractors, metrics,
89
+ and evaluation strategies. It supports slicing for subset analysis.
90
+
91
+ Example:
92
+ >>> pipeline = EvaluationPipeline(
93
+ ... extractor=JsonFieldExtractor("answer"),
94
+ ... metrics=[ExactMatch()]
95
+ ... )
96
+ >>> report = pipeline.evaluate(records)
97
+
98
+ Attributes:
99
+ _extractor: Extractor for parsing model output
100
+ _metrics: List of metrics to compute
101
+ _reference_selector: Function to extract reference from record
102
+ _strategy_resolver: Function to resolve evaluation strategy
103
+ _slices: List of (name, predicate) tuples for slicing
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ *,
109
+ extractor,
110
+ metrics: Sequence[MetricInterface],
111
+ reference_selector: Callable[[core_entities.GenerationRecord], object]
112
+ | None = None,
113
+ strategy_resolver: Callable[
114
+ [core_entities.GenerationRecord], evaluation_strategies.EvaluationStrategy
115
+ ]
116
+ | None = None,
117
+ ) -> None:
118
+ """Initialize evaluation pipeline.
119
+
120
+ Args:
121
+ extractor: Extractor for parsing model output
122
+ metrics: List of metrics to compute
123
+ reference_selector: Optional function to extract reference from record.
124
+ If provided, this takes precedence over item.reference from strategies.
125
+ strategy_resolver: Optional function to resolve evaluation strategy.
126
+ If using a custom reference_selector with DefaultEvaluationStrategy,
127
+ the selector will take precedence.
128
+
129
+ Note:
130
+ When using DefaultEvaluationStrategy with a custom reference_selector,
131
+ the reference_selector will override the default behavior. Consider
132
+ using a custom strategy if you need more control over reference selection.
133
+ """
134
+ self._extractor = extractor
135
+ self._metrics = list(metrics)
136
+ self._reference_selector = reference_selector
137
+ self._has_custom_reference_selector = reference_selector is not None
138
+ self._strategy_resolver = strategy_resolver or (
139
+ lambda record: evaluation_strategies.DefaultEvaluationStrategy()
140
+ )
141
+ self._slices: list[
142
+ tuple[str, Callable[[core_entities.GenerationRecord], bool]]
143
+ ] = []
144
+
145
+ # Validation: warn if custom reference_selector is used with default strategy
146
+ if self._has_custom_reference_selector and strategy_resolver is None:
147
+ warnings.warn(
148
+ "Custom reference_selector provided without custom strategy_resolver. "
149
+ "The reference_selector will take precedence over DefaultEvaluationStrategy's "
150
+ "reference handling. If you need more control, consider providing a custom "
151
+ "strategy_resolver that sets reference=None in EvaluationItem.",
152
+ UserWarning,
153
+ stacklevel=2,
154
+ )
155
+
156
+ def evaluate(
157
+ self, records: Sequence[core_entities.GenerationRecord]
158
+ ) -> EvaluationReport:
159
+ """Evaluate generation records.
160
+
161
+ Args:
162
+ records: Generation records to evaluate
163
+
164
+ Returns:
165
+ Evaluation report with metrics and failures
166
+ """
167
+ with tracing.span("evaluate_pipeline", total_records=len(records)):
168
+ per_metric: dict[str, list[core_entities.MetricScore]] = {
169
+ metric.name: [] for metric in self._metrics
170
+ }
171
+ failures: list[EvaluationFailure] = []
172
+ per_record: list[core_entities.EvaluationRecord] = []
173
+ slice_members: dict[str, set[str]] = {
174
+ name: set() for name, _ in self._slices
175
+ }
176
+
177
+ for record in records:
178
+ with tracing.span("evaluate_record"):
179
+ logger.debug(
180
+ "Evaluating sample %s with %s metric(s)",
181
+ record.task.metadata.get("dataset_id")
182
+ or record.task.metadata.get("sample_id"),
183
+ len(self._metrics),
184
+ )
185
+ strategy = self._strategy_resolver(record)
186
+ task_metadata = record.task.metadata
187
+ sample_id = task_metadata.get("dataset_id") or task_metadata.get(
188
+ "sample_id"
189
+ )
190
+ for name, fn in self._slices:
191
+ try:
192
+ if fn(record) and sample_id is not None:
193
+ slice_members[name].add(sample_id)
194
+ except Exception:
195
+ pass
196
+ eval_items = list(strategy.prepare(record))
197
+ item_scores: list[core_entities.MetricScore] = []
198
+ record_failures: list[str] = []
199
+
200
+ for item in eval_items:
201
+ if item.record.output is None:
202
+ message = "Missing model output"
203
+ failures.append(
204
+ EvaluationFailure(sample_id=sample_id, message=message)
205
+ )
206
+ record_failures.append(message)
207
+ continue
208
+ try:
209
+ with tracing.span("extract"):
210
+ prediction = self._extractor.extract(
211
+ item.record.output.text
212
+ )
213
+ except extractors.FieldExtractionError as exc:
214
+ message = str(exc)
215
+ failures.append(
216
+ EvaluationFailure(sample_id=sample_id, message=message)
217
+ )
218
+ record_failures.append(message)
219
+ continue
220
+
221
+ # CRITICAL: Always call reference_selector if provided (takes precedence)
222
+ # This fixes the issue where DefaultEvaluationStrategy's reference
223
+ # would prevent custom reference_selector from being called
224
+ if self._has_custom_reference_selector:
225
+ reference = self._reference_selector(record)
226
+ elif item.reference is not None:
227
+ reference = item.reference
228
+ else:
229
+ reference = _default_reference_selector(record)
230
+
231
+ references = (
232
+ _normalize_references(reference)
233
+ if reference is not None
234
+ else []
235
+ )
236
+ metadata = {"sample_id": sample_id}
237
+ extract_start = time.perf_counter()
238
+ item_scores_for_item: list[core_entities.MetricScore] = []
239
+ for metric in self._metrics:
240
+ requires_reference = getattr(
241
+ metric, "requires_reference", True
242
+ )
243
+ if requires_reference and not references:
244
+ message = (
245
+ f"Missing reference for metric '{metric.name}'"
246
+ )
247
+ failures.append(
248
+ EvaluationFailure(
249
+ sample_id=sample_id, message=message
250
+ )
251
+ )
252
+ record_failures.append(message)
253
+ continue
254
+ metric_start = time.perf_counter()
255
+ try:
256
+ with tracing.span(
257
+ "compute_metric", metric_name=metric.name
258
+ ):
259
+ score = metric.compute(
260
+ prediction=prediction,
261
+ references=references,
262
+ metadata=metadata,
263
+ )
264
+ score.metadata["evaluation_time_ms"] = (
265
+ time.perf_counter() - metric_start
266
+ ) * 1000
267
+ item_scores_for_item.append(score)
268
+ except Exception as exc: # pragma: no cover - guarded
269
+ message = (
270
+ f"Metric '{metric.name}' failed for sample {sample_id}: {exc}"
271
+ )
272
+ logger.warning(message)
273
+ failures.append(
274
+ EvaluationFailure(
275
+ sample_id=sample_id, message=message
276
+ )
277
+ )
278
+ record_failures.append(message)
279
+ extraction_duration = (
280
+ time.perf_counter() - extract_start
281
+ ) * 1000
282
+ for score in item_scores_for_item:
283
+ score.metadata.setdefault(
284
+ "extraction_time_ms", extraction_duration
285
+ )
286
+ item_scores.extend(item_scores_for_item)
287
+
288
+ aggregated_scores = strategy.aggregate(record, item_scores)
289
+ for score in aggregated_scores:
290
+ per_metric[score.metric_name].append(score)
291
+ per_record.append(
292
+ core_entities.EvaluationRecord(
293
+ sample_id=sample_id,
294
+ scores=aggregated_scores,
295
+ failures=record_failures,
296
+ )
297
+ )
298
+
299
+ aggregates = {
300
+ name: MetricAggregate.from_scores(name, scores)
301
+ for name, scores in per_metric.items()
302
+ }
303
+
304
+ return EvaluationReport(
305
+ metrics=aggregates,
306
+ failures=failures,
307
+ records=per_record,
308
+ slices=self._compute_slice_aggregates(per_metric, slice_members),
309
+ )
310
+
311
+ def register_slice(
312
+ self, name: str, fn: Callable[[core_entities.GenerationRecord], bool]
313
+ ) -> None:
314
+ """Register a slice for subset analysis.
315
+
316
+ Args:
317
+ name: Slice name
318
+ fn: Predicate function to determine slice membership
319
+ """
320
+ self._slices.append((name, fn))
321
+
322
+ def _compute_slice_aggregates(
323
+ self,
324
+ per_metric: dict[str, list[core_entities.MetricScore]],
325
+ slice_members: dict[str, set[str]],
326
+ ) -> dict[str, dict[str, MetricAggregate]]:
327
+ """Compute metric aggregates for each slice.
328
+
329
+ Args:
330
+ per_metric: Scores by metric name
331
+ slice_members: Sample IDs by slice name
332
+
333
+ Returns:
334
+ Nested dict of slice -> metric -> aggregate
335
+ """
336
+ if not slice_members:
337
+ return {}
338
+ slice_aggregates: dict[str, dict[str, MetricAggregate]] = {}
339
+ for name, members in slice_members.items():
340
+ slice_scores_by_metric: dict[str, list[core_entities.MetricScore]] = {}
341
+ for metric_name, scores in per_metric.items():
342
+ filtered = [s for s in scores if s.metadata.get("sample_id") in members]
343
+ slice_scores_by_metric[metric_name] = filtered
344
+ slice_aggregates[name] = {
345
+ metric_name: MetricAggregate.from_scores(metric_name, scores)
346
+ for metric_name, scores in slice_scores_by_metric.items()
347
+ }
348
+ return slice_aggregates
@@ -0,0 +1,293 @@
1
+ """Evaluation report data structures."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from statistics import mean
7
+ from typing import Dict, List, Literal, Sequence
8
+
9
+ from themis.core import entities as core_entities
10
+ from themis.evaluation.statistics import (
11
+ bootstrap_ci,
12
+ cohens_d,
13
+ cohens_h,
14
+ holm_bonferroni,
15
+ paired_permutation_test,
16
+ paired_t_test,
17
+ permutation_test,
18
+ )
19
+ from themis.evaluation.statistics.types import (
20
+ BootstrapResult,
21
+ ComparisonResult,
22
+ EffectSize,
23
+ PermutationTestResult,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class EvaluationFailure:
29
+ sample_id: str | None
30
+ message: str
31
+
32
+
33
+ @dataclass
34
+ class MetricAggregate:
35
+ name: str
36
+ count: int
37
+ mean: float
38
+ per_sample: List[core_entities.MetricScore]
39
+
40
+ @classmethod
41
+ def from_scores(
42
+ cls, name: str, scores: List[core_entities.MetricScore]
43
+ ) -> "MetricAggregate":
44
+ if not scores:
45
+ return cls(name=name, count=0, mean=0.0, per_sample=[])
46
+ return cls(
47
+ name=name,
48
+ count=len(scores),
49
+ mean=mean(score.value for score in scores),
50
+ per_sample=scores,
51
+ )
52
+
53
+
54
+ @dataclass
55
+ class EvaluationReport:
56
+ metrics: dict[str, MetricAggregate]
57
+ failures: List[EvaluationFailure]
58
+ records: List[core_entities.EvaluationRecord]
59
+ slices: dict[str, dict[str, MetricAggregate]] = field(default_factory=dict)
60
+
61
+
62
+ def _metric_values(report: EvaluationReport, metric_name: str) -> list[float]:
63
+ agg = report.metrics.get(metric_name)
64
+ if not agg:
65
+ return []
66
+ return [s.value for s in agg.per_sample]
67
+
68
+
69
+ def _metric_values_by_sample(
70
+ report: EvaluationReport, metric_name: str
71
+ ) -> dict[str, float]:
72
+ values: dict[str, float] = {}
73
+ for record in report.records:
74
+ if not record.sample_id:
75
+ continue
76
+ for score in record.scores:
77
+ if score.metric_name == metric_name:
78
+ values[record.sample_id] = score.value
79
+ break
80
+ return values
81
+
82
+
83
+ def aligned_metric_values(
84
+ report_a: EvaluationReport, report_b: EvaluationReport, metric_name: str
85
+ ) -> tuple[list[float], list[float]]:
86
+ values_a = _metric_values_by_sample(report_a, metric_name)
87
+ values_b = _metric_values_by_sample(report_b, metric_name)
88
+ common_ids = sorted(set(values_a) & set(values_b))
89
+ if not common_ids:
90
+ raise ValueError(f"No overlapping sample_ids for metric '{metric_name}'")
91
+ aligned_a = [values_a[sample_id] for sample_id in common_ids]
92
+ aligned_b = [values_b[sample_id] for sample_id in common_ids]
93
+ return aligned_a, aligned_b
94
+
95
+
96
+ def ci_for_metric(
97
+ report: EvaluationReport,
98
+ metric_name: str,
99
+ confidence_level: float = 0.95,
100
+ n_bootstrap: int = 10000,
101
+ ) -> BootstrapResult:
102
+ values = _metric_values(report, metric_name)
103
+ if not values:
104
+ raise ValueError(f"No scores for metric '{metric_name}'")
105
+ return bootstrap_ci(
106
+ values, n_bootstrap=n_bootstrap, confidence_level=confidence_level
107
+ )
108
+
109
+
110
+ def permutation_test_for_metric(
111
+ report_a: EvaluationReport,
112
+ report_b: EvaluationReport,
113
+ metric_name: str,
114
+ statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
115
+ n_permutations: int = 10000,
116
+ seed: int | None = None,
117
+ align_by_sample_id: bool = True,
118
+ ) -> PermutationTestResult:
119
+ if align_by_sample_id:
120
+ values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
121
+ else:
122
+ values_a = _metric_values(report_a, metric_name)
123
+ values_b = _metric_values(report_b, metric_name)
124
+ if not values_a or not values_b:
125
+ raise ValueError(f"Both reports must have scores for metric '{metric_name}'")
126
+ return permutation_test(
127
+ values_a,
128
+ values_b,
129
+ statistic=statistic,
130
+ n_permutations=n_permutations,
131
+ seed=seed,
132
+ )
133
+
134
+
135
+ def paired_permutation_test_for_metric(
136
+ report_a: EvaluationReport,
137
+ report_b: EvaluationReport,
138
+ metric_name: str,
139
+ statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
140
+ n_permutations: int = 10000,
141
+ seed: int | None = None,
142
+ ) -> PermutationTestResult:
143
+ values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
144
+ return paired_permutation_test(
145
+ values_a,
146
+ values_b,
147
+ statistic=statistic,
148
+ n_permutations=n_permutations,
149
+ seed=seed,
150
+ )
151
+
152
+
153
+ def cohens_h_for_metric(
154
+ report_a: EvaluationReport,
155
+ report_b: EvaluationReport,
156
+ metric_name: str,
157
+ ) -> EffectSize:
158
+ agg_a = report_a.metrics.get(metric_name)
159
+ agg_b = report_b.metrics.get(metric_name)
160
+ if not agg_a or not agg_b:
161
+ raise ValueError(f"Both reports must have aggregate for metric '{metric_name}'")
162
+ return cohens_h(agg_a.mean, agg_b.mean)
163
+
164
+
165
+ def cohens_d_for_metric(
166
+ report_a: EvaluationReport,
167
+ report_b: EvaluationReport,
168
+ metric_name: str,
169
+ ) -> EffectSize:
170
+ values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
171
+ if len(values_a) < 2 or len(values_b) < 2:
172
+ raise ValueError("Each group must have at least 2 values for Cohen's d")
173
+ return cohens_d(values_a, values_b)
174
+
175
+
176
+ def paired_t_test_for_metric(
177
+ report_a: EvaluationReport,
178
+ report_b: EvaluationReport,
179
+ metric_name: str,
180
+ significance_level: float = 0.05,
181
+ ) -> ComparisonResult:
182
+ values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
183
+ result = paired_t_test(values_a, values_b, significance_level=significance_level)
184
+ return ComparisonResult(
185
+ metric_name=metric_name,
186
+ baseline_mean=result.baseline_mean,
187
+ treatment_mean=result.treatment_mean,
188
+ difference=result.difference,
189
+ relative_change=result.relative_change,
190
+ t_statistic=result.t_statistic,
191
+ p_value=result.p_value,
192
+ is_significant=result.is_significant,
193
+ baseline_ci=result.baseline_ci,
194
+ treatment_ci=result.treatment_ci,
195
+ )
196
+
197
+
198
+ def _slice_metric_values(
199
+ report: EvaluationReport, slice_name: str, metric_name: str
200
+ ) -> list[float]:
201
+ slice_map = report.slices.get(slice_name)
202
+ if not slice_map:
203
+ return []
204
+ agg = slice_map.get(metric_name)
205
+ if not agg:
206
+ return []
207
+ return [s.value for s in agg.per_sample]
208
+
209
+
210
+ def ci_for_slice_metric(
211
+ report: EvaluationReport,
212
+ slice_name: str,
213
+ metric_name: str,
214
+ confidence_level: float = 0.95,
215
+ n_bootstrap: int = 10000,
216
+ ) -> BootstrapResult:
217
+ values = _slice_metric_values(report, slice_name, metric_name)
218
+ if not values:
219
+ raise ValueError(
220
+ f"No scores for metric '{metric_name}' in slice '{slice_name}'"
221
+ )
222
+ return bootstrap_ci(
223
+ values, n_bootstrap=n_bootstrap, confidence_level=confidence_level
224
+ )
225
+
226
+
227
+ def compare_reports_with_holm(
228
+ report_a: EvaluationReport,
229
+ report_b: EvaluationReport,
230
+ metric_names: Sequence[str],
231
+ statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
232
+ n_permutations: int = 10000,
233
+ seed: int | None = None,
234
+ paired: bool = True,
235
+ ) -> Dict[str, object]:
236
+ p_values: list[float] = []
237
+ pt_results: Dict[str, PermutationTestResult] = {}
238
+ for name in metric_names:
239
+ if paired:
240
+ pt = paired_permutation_test_for_metric(
241
+ report_a,
242
+ report_b,
243
+ name,
244
+ statistic=statistic,
245
+ n_permutations=n_permutations,
246
+ seed=seed,
247
+ )
248
+ else:
249
+ pt = permutation_test_for_metric(
250
+ report_a,
251
+ report_b,
252
+ name,
253
+ statistic=statistic,
254
+ n_permutations=n_permutations,
255
+ seed=seed,
256
+ align_by_sample_id=True,
257
+ )
258
+ pt_results[name] = pt
259
+ p_values.append(pt.p_value)
260
+ corrected = holm_bonferroni(p_values)
261
+ return {
262
+ "per_metric": pt_results,
263
+ "holm_significant": corrected,
264
+ }
265
+
266
+
267
+ def confusion_matrix(
268
+ labels_true: Sequence[str], labels_pred: Sequence[str]
269
+ ) -> Dict[str, Dict[str, int]]:
270
+ if len(labels_true) != len(labels_pred):
271
+ raise ValueError("labels_true and labels_pred must have same length")
272
+ cm: Dict[str, Dict[str, int]] = {}
273
+ for t, p in zip(labels_true, labels_pred):
274
+ cm.setdefault(t, {})
275
+ cm[t][p] = cm[t].get(p, 0) + 1
276
+ return cm
277
+
278
+
279
+ __all__ = [
280
+ "EvaluationFailure",
281
+ "MetricAggregate",
282
+ "EvaluationReport",
283
+ "aligned_metric_values",
284
+ "ci_for_metric",
285
+ "ci_for_slice_metric",
286
+ "permutation_test_for_metric",
287
+ "paired_permutation_test_for_metric",
288
+ "cohens_h_for_metric",
289
+ "cohens_d_for_metric",
290
+ "paired_t_test_for_metric",
291
+ "confusion_matrix",
292
+ "compare_reports_with_holm",
293
+ ]
@@ -0,0 +1,53 @@
1
+ """Statistical analysis utilities for experiment evaluation results.
2
+
3
+ This module provides statistical analysis tools for computing confidence intervals,
4
+ significance tests, and statistical comparisons across experiment runs.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .bootstrap import bootstrap_ci
10
+ from .confidence_intervals import (
11
+ compute_confidence_interval,
12
+ compute_statistical_summary,
13
+ )
14
+ from .effect_sizes import cohens_d, cohens_h
15
+ from .hypothesis_tests import (
16
+ compare_metrics,
17
+ holm_bonferroni,
18
+ paired_permutation_test,
19
+ paired_t_test,
20
+ permutation_test,
21
+ )
22
+ from .types import (
23
+ BootstrapResult,
24
+ ComparisonResult,
25
+ ConfidenceInterval,
26
+ EffectSize,
27
+ PermutationTestResult,
28
+ StatisticalSummary,
29
+ )
30
+
31
+ __all__ = [
32
+ # Types
33
+ "ConfidenceInterval",
34
+ "StatisticalSummary",
35
+ "ComparisonResult",
36
+ "PermutationTestResult",
37
+ "BootstrapResult",
38
+ "EffectSize",
39
+ # Confidence intervals
40
+ "compute_confidence_interval",
41
+ "compute_statistical_summary",
42
+ # Hypothesis tests
43
+ "compare_metrics",
44
+ "permutation_test",
45
+ "paired_permutation_test",
46
+ "paired_t_test",
47
+ "holm_bonferroni",
48
+ # Bootstrap
49
+ "bootstrap_ci",
50
+ # Effect sizes
51
+ "cohens_h",
52
+ "cohens_d",
53
+ ]