themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,288 @@
1
+ """Standard evaluation pipeline implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ from typing import Callable, Sequence
8
+
9
+ from themis.core import entities as core_entities
10
+ from themis.evaluation import extractors
11
+ from themis.evaluation import strategies as evaluation_strategies
12
+ from themis.evaluation.reports import (
13
+ EvaluationFailure,
14
+ EvaluationReport,
15
+ MetricAggregate,
16
+ )
17
+ from themis.interfaces import Metric as MetricInterface
18
+ from themis.utils import tracing
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def _default_reference_selector(record: core_entities.GenerationRecord):
24
+ """Default reference selector from generation record.
25
+
26
+ Args:
27
+ record: Generation record
28
+
29
+ Returns:
30
+ Reference value or None
31
+ """
32
+ reference = record.task.reference
33
+ if reference is None:
34
+ return None
35
+ return reference.value
36
+
37
+
38
+ def _normalize_references(reference):
39
+ """Normalize reference to list format.
40
+
41
+ Args:
42
+ reference: Reference value
43
+
44
+ Returns:
45
+ List of references
46
+ """
47
+ if isinstance(reference, core_entities.Reference):
48
+ reference = reference.value
49
+ if isinstance(reference, list):
50
+ return reference
51
+ return [reference]
52
+
53
+
54
+ class EvaluationPipeline:
55
+ """Traditional batch evaluation pipeline.
56
+
57
+ This pipeline evaluates generation records using extractors, metrics,
58
+ and evaluation strategies. It supports slicing for subset analysis.
59
+
60
+ Example:
61
+ >>> pipeline = EvaluationPipeline(
62
+ ... extractor=JsonFieldExtractor("answer"),
63
+ ... metrics=[ExactMatch()]
64
+ ... )
65
+ >>> report = pipeline.evaluate(records)
66
+
67
+ Attributes:
68
+ _extractor: Extractor for parsing model output
69
+ _metrics: List of metrics to compute
70
+ _reference_selector: Function to extract reference from record
71
+ _strategy_resolver: Function to resolve evaluation strategy
72
+ _slices: List of (name, predicate) tuples for slicing
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ *,
78
+ extractor,
79
+ metrics: Sequence[MetricInterface],
80
+ reference_selector: Callable[[core_entities.GenerationRecord], object]
81
+ | None = None,
82
+ strategy_resolver: Callable[
83
+ [core_entities.GenerationRecord], evaluation_strategies.EvaluationStrategy
84
+ ]
85
+ | None = None,
86
+ ) -> None:
87
+ """Initialize evaluation pipeline.
88
+
89
+ Args:
90
+ extractor: Extractor for parsing model output
91
+ metrics: List of metrics to compute
92
+ reference_selector: Optional function to extract reference
93
+ strategy_resolver: Optional function to resolve strategy
94
+ """
95
+ self._extractor = extractor
96
+ self._metrics = list(metrics)
97
+ self._reference_selector = reference_selector or _default_reference_selector
98
+ self._strategy_resolver = strategy_resolver or (
99
+ lambda record: evaluation_strategies.DefaultEvaluationStrategy()
100
+ )
101
+ self._slices: list[
102
+ tuple[str, Callable[[core_entities.GenerationRecord], bool]]
103
+ ] = []
104
+
105
+ def evaluate(
106
+ self, records: Sequence[core_entities.GenerationRecord]
107
+ ) -> EvaluationReport:
108
+ """Evaluate generation records.
109
+
110
+ Args:
111
+ records: Generation records to evaluate
112
+
113
+ Returns:
114
+ Evaluation report with metrics and failures
115
+ """
116
+ with tracing.span("evaluate_pipeline", total_records=len(records)):
117
+ per_metric: dict[str, list[core_entities.MetricScore]] = {
118
+ metric.name: [] for metric in self._metrics
119
+ }
120
+ failures: list[EvaluationFailure] = []
121
+ per_record: list[core_entities.EvaluationRecord] = []
122
+ slice_members: dict[str, set[str]] = {
123
+ name: set() for name, _ in self._slices
124
+ }
125
+
126
+ for record in records:
127
+ with tracing.span("evaluate_record"):
128
+ logger.debug(
129
+ "Evaluating sample %s with %s metric(s)",
130
+ record.task.metadata.get("dataset_id")
131
+ or record.task.metadata.get("sample_id"),
132
+ len(self._metrics),
133
+ )
134
+ strategy = self._strategy_resolver(record)
135
+ task_metadata = record.task.metadata
136
+ sample_id = task_metadata.get("dataset_id") or task_metadata.get(
137
+ "sample_id"
138
+ )
139
+ for name, fn in self._slices:
140
+ try:
141
+ if fn(record) and sample_id is not None:
142
+ slice_members[name].add(sample_id)
143
+ except Exception:
144
+ pass
145
+ eval_items = list(strategy.prepare(record))
146
+ item_scores: list[core_entities.MetricScore] = []
147
+ record_failures: list[str] = []
148
+
149
+ for item in eval_items:
150
+ if item.record.output is None:
151
+ message = "Missing model output"
152
+ failures.append(
153
+ EvaluationFailure(sample_id=sample_id, message=message)
154
+ )
155
+ record_failures.append(message)
156
+ continue
157
+ try:
158
+ with tracing.span("extract"):
159
+ prediction = self._extractor.extract(
160
+ item.record.output.text
161
+ )
162
+ except extractors.FieldExtractionError as exc:
163
+ message = str(exc)
164
+ failures.append(
165
+ EvaluationFailure(sample_id=sample_id, message=message)
166
+ )
167
+ record_failures.append(message)
168
+ continue
169
+
170
+ reference = item.reference or self._reference_selector(record)
171
+ references = (
172
+ _normalize_references(reference)
173
+ if reference is not None
174
+ else []
175
+ )
176
+ metadata = {"sample_id": sample_id}
177
+ extract_start = time.perf_counter()
178
+ item_scores_for_item: list[core_entities.MetricScore] = []
179
+ for metric in self._metrics:
180
+ requires_reference = getattr(
181
+ metric, "requires_reference", True
182
+ )
183
+ if requires_reference and not references:
184
+ message = (
185
+ f"Missing reference for metric '{metric.name}'"
186
+ )
187
+ failures.append(
188
+ EvaluationFailure(
189
+ sample_id=sample_id, message=message
190
+ )
191
+ )
192
+ record_failures.append(message)
193
+ continue
194
+ metric_start = time.perf_counter()
195
+ try:
196
+ with tracing.span(
197
+ "compute_metric", metric_name=metric.name
198
+ ):
199
+ score = metric.compute(
200
+ prediction=prediction,
201
+ references=references,
202
+ metadata=metadata,
203
+ )
204
+ score.metadata["evaluation_time_ms"] = (
205
+ time.perf_counter() - metric_start
206
+ ) * 1000
207
+ item_scores_for_item.append(score)
208
+ except Exception as exc: # pragma: no cover - guarded
209
+ message = (
210
+ f"Metric '{metric.name}' failed for sample {sample_id}: {exc}"
211
+ )
212
+ logger.warning(message)
213
+ failures.append(
214
+ EvaluationFailure(
215
+ sample_id=sample_id, message=message
216
+ )
217
+ )
218
+ record_failures.append(message)
219
+ extraction_duration = (
220
+ time.perf_counter() - extract_start
221
+ ) * 1000
222
+ for score in item_scores_for_item:
223
+ score.metadata.setdefault(
224
+ "extraction_time_ms", extraction_duration
225
+ )
226
+ item_scores.extend(item_scores_for_item)
227
+
228
+ aggregated_scores = strategy.aggregate(record, item_scores)
229
+ for score in aggregated_scores:
230
+ per_metric[score.metric_name].append(score)
231
+ per_record.append(
232
+ core_entities.EvaluationRecord(
233
+ sample_id=sample_id,
234
+ scores=aggregated_scores,
235
+ failures=record_failures,
236
+ )
237
+ )
238
+
239
+ aggregates = {
240
+ name: MetricAggregate.from_scores(name, scores)
241
+ for name, scores in per_metric.items()
242
+ }
243
+
244
+ return EvaluationReport(
245
+ metrics=aggregates,
246
+ failures=failures,
247
+ records=per_record,
248
+ slices=self._compute_slice_aggregates(per_metric, slice_members),
249
+ )
250
+
251
+ def register_slice(
252
+ self, name: str, fn: Callable[[core_entities.GenerationRecord], bool]
253
+ ) -> None:
254
+ """Register a slice for subset analysis.
255
+
256
+ Args:
257
+ name: Slice name
258
+ fn: Predicate function to determine slice membership
259
+ """
260
+ self._slices.append((name, fn))
261
+
262
+ def _compute_slice_aggregates(
263
+ self,
264
+ per_metric: dict[str, list[core_entities.MetricScore]],
265
+ slice_members: dict[str, set[str]],
266
+ ) -> dict[str, dict[str, MetricAggregate]]:
267
+ """Compute metric aggregates for each slice.
268
+
269
+ Args:
270
+ per_metric: Scores by metric name
271
+ slice_members: Sample IDs by slice name
272
+
273
+ Returns:
274
+ Nested dict of slice -> metric -> aggregate
275
+ """
276
+ if not slice_members:
277
+ return {}
278
+ slice_aggregates: dict[str, dict[str, MetricAggregate]] = {}
279
+ for name, members in slice_members.items():
280
+ slice_scores_by_metric: dict[str, list[core_entities.MetricScore]] = {}
281
+ for metric_name, scores in per_metric.items():
282
+ filtered = [s for s in scores if s.metadata.get("sample_id") in members]
283
+ slice_scores_by_metric[metric_name] = filtered
284
+ slice_aggregates[name] = {
285
+ metric_name: MetricAggregate.from_scores(metric_name, scores)
286
+ for metric_name, scores in slice_scores_by_metric.items()
287
+ }
288
+ return slice_aggregates
@@ -0,0 +1,293 @@
1
+ """Evaluation report data structures."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from statistics import mean
7
+ from typing import Dict, List, Literal, Sequence
8
+
9
+ from themis.core import entities as core_entities
10
+ from themis.evaluation.statistics import (
11
+ bootstrap_ci,
12
+ cohens_d,
13
+ cohens_h,
14
+ holm_bonferroni,
15
+ paired_permutation_test,
16
+ paired_t_test,
17
+ permutation_test,
18
+ )
19
+ from themis.evaluation.statistics.types import (
20
+ BootstrapResult,
21
+ ComparisonResult,
22
+ EffectSize,
23
+ PermutationTestResult,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class EvaluationFailure:
29
+ sample_id: str | None
30
+ message: str
31
+
32
+
33
+ @dataclass
34
+ class MetricAggregate:
35
+ name: str
36
+ count: int
37
+ mean: float
38
+ per_sample: List[core_entities.MetricScore]
39
+
40
+ @classmethod
41
+ def from_scores(
42
+ cls, name: str, scores: List[core_entities.MetricScore]
43
+ ) -> "MetricAggregate":
44
+ if not scores:
45
+ return cls(name=name, count=0, mean=0.0, per_sample=[])
46
+ return cls(
47
+ name=name,
48
+ count=len(scores),
49
+ mean=mean(score.value for score in scores),
50
+ per_sample=scores,
51
+ )
52
+
53
+
54
+ @dataclass
55
+ class EvaluationReport:
56
+ metrics: dict[str, MetricAggregate]
57
+ failures: List[EvaluationFailure]
58
+ records: List[core_entities.EvaluationRecord]
59
+ slices: dict[str, dict[str, MetricAggregate]] = field(default_factory=dict)
60
+
61
+
62
+ def _metric_values(report: EvaluationReport, metric_name: str) -> list[float]:
63
+ agg = report.metrics.get(metric_name)
64
+ if not agg:
65
+ return []
66
+ return [s.value for s in agg.per_sample]
67
+
68
+
69
+ def _metric_values_by_sample(
70
+ report: EvaluationReport, metric_name: str
71
+ ) -> dict[str, float]:
72
+ values: dict[str, float] = {}
73
+ for record in report.records:
74
+ if not record.sample_id:
75
+ continue
76
+ for score in record.scores:
77
+ if score.metric_name == metric_name:
78
+ values[record.sample_id] = score.value
79
+ break
80
+ return values
81
+
82
+
83
+ def aligned_metric_values(
84
+ report_a: EvaluationReport, report_b: EvaluationReport, metric_name: str
85
+ ) -> tuple[list[float], list[float]]:
86
+ values_a = _metric_values_by_sample(report_a, metric_name)
87
+ values_b = _metric_values_by_sample(report_b, metric_name)
88
+ common_ids = sorted(set(values_a) & set(values_b))
89
+ if not common_ids:
90
+ raise ValueError(f"No overlapping sample_ids for metric '{metric_name}'")
91
+ aligned_a = [values_a[sample_id] for sample_id in common_ids]
92
+ aligned_b = [values_b[sample_id] for sample_id in common_ids]
93
+ return aligned_a, aligned_b
94
+
95
+
96
+ def ci_for_metric(
97
+ report: EvaluationReport,
98
+ metric_name: str,
99
+ confidence_level: float = 0.95,
100
+ n_bootstrap: int = 10000,
101
+ ) -> BootstrapResult:
102
+ values = _metric_values(report, metric_name)
103
+ if not values:
104
+ raise ValueError(f"No scores for metric '{metric_name}'")
105
+ return bootstrap_ci(
106
+ values, n_bootstrap=n_bootstrap, confidence_level=confidence_level
107
+ )
108
+
109
+
110
+ def permutation_test_for_metric(
111
+ report_a: EvaluationReport,
112
+ report_b: EvaluationReport,
113
+ metric_name: str,
114
+ statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
115
+ n_permutations: int = 10000,
116
+ seed: int | None = None,
117
+ align_by_sample_id: bool = True,
118
+ ) -> PermutationTestResult:
119
+ if align_by_sample_id:
120
+ values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
121
+ else:
122
+ values_a = _metric_values(report_a, metric_name)
123
+ values_b = _metric_values(report_b, metric_name)
124
+ if not values_a or not values_b:
125
+ raise ValueError(f"Both reports must have scores for metric '{metric_name}'")
126
+ return permutation_test(
127
+ values_a,
128
+ values_b,
129
+ statistic=statistic,
130
+ n_permutations=n_permutations,
131
+ seed=seed,
132
+ )
133
+
134
+
135
+ def paired_permutation_test_for_metric(
136
+ report_a: EvaluationReport,
137
+ report_b: EvaluationReport,
138
+ metric_name: str,
139
+ statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
140
+ n_permutations: int = 10000,
141
+ seed: int | None = None,
142
+ ) -> PermutationTestResult:
143
+ values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
144
+ return paired_permutation_test(
145
+ values_a,
146
+ values_b,
147
+ statistic=statistic,
148
+ n_permutations=n_permutations,
149
+ seed=seed,
150
+ )
151
+
152
+
153
+ def cohens_h_for_metric(
154
+ report_a: EvaluationReport,
155
+ report_b: EvaluationReport,
156
+ metric_name: str,
157
+ ) -> EffectSize:
158
+ agg_a = report_a.metrics.get(metric_name)
159
+ agg_b = report_b.metrics.get(metric_name)
160
+ if not agg_a or not agg_b:
161
+ raise ValueError(f"Both reports must have aggregate for metric '{metric_name}'")
162
+ return cohens_h(agg_a.mean, agg_b.mean)
163
+
164
+
165
+ def cohens_d_for_metric(
166
+ report_a: EvaluationReport,
167
+ report_b: EvaluationReport,
168
+ metric_name: str,
169
+ ) -> EffectSize:
170
+ values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
171
+ if len(values_a) < 2 or len(values_b) < 2:
172
+ raise ValueError("Each group must have at least 2 values for Cohen's d")
173
+ return cohens_d(values_a, values_b)
174
+
175
+
176
+ def paired_t_test_for_metric(
177
+ report_a: EvaluationReport,
178
+ report_b: EvaluationReport,
179
+ metric_name: str,
180
+ significance_level: float = 0.05,
181
+ ) -> ComparisonResult:
182
+ values_a, values_b = aligned_metric_values(report_a, report_b, metric_name)
183
+ result = paired_t_test(values_a, values_b, significance_level=significance_level)
184
+ return ComparisonResult(
185
+ metric_name=metric_name,
186
+ baseline_mean=result.baseline_mean,
187
+ treatment_mean=result.treatment_mean,
188
+ difference=result.difference,
189
+ relative_change=result.relative_change,
190
+ t_statistic=result.t_statistic,
191
+ p_value=result.p_value,
192
+ is_significant=result.is_significant,
193
+ baseline_ci=result.baseline_ci,
194
+ treatment_ci=result.treatment_ci,
195
+ )
196
+
197
+
198
+ def _slice_metric_values(
199
+ report: EvaluationReport, slice_name: str, metric_name: str
200
+ ) -> list[float]:
201
+ slice_map = report.slices.get(slice_name)
202
+ if not slice_map:
203
+ return []
204
+ agg = slice_map.get(metric_name)
205
+ if not agg:
206
+ return []
207
+ return [s.value for s in agg.per_sample]
208
+
209
+
210
+ def ci_for_slice_metric(
211
+ report: EvaluationReport,
212
+ slice_name: str,
213
+ metric_name: str,
214
+ confidence_level: float = 0.95,
215
+ n_bootstrap: int = 10000,
216
+ ) -> BootstrapResult:
217
+ values = _slice_metric_values(report, slice_name, metric_name)
218
+ if not values:
219
+ raise ValueError(
220
+ f"No scores for metric '{metric_name}' in slice '{slice_name}'"
221
+ )
222
+ return bootstrap_ci(
223
+ values, n_bootstrap=n_bootstrap, confidence_level=confidence_level
224
+ )
225
+
226
+
227
+ def compare_reports_with_holm(
228
+ report_a: EvaluationReport,
229
+ report_b: EvaluationReport,
230
+ metric_names: Sequence[str],
231
+ statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
232
+ n_permutations: int = 10000,
233
+ seed: int | None = None,
234
+ paired: bool = True,
235
+ ) -> Dict[str, object]:
236
+ p_values: list[float] = []
237
+ pt_results: Dict[str, PermutationTestResult] = {}
238
+ for name in metric_names:
239
+ if paired:
240
+ pt = paired_permutation_test_for_metric(
241
+ report_a,
242
+ report_b,
243
+ name,
244
+ statistic=statistic,
245
+ n_permutations=n_permutations,
246
+ seed=seed,
247
+ )
248
+ else:
249
+ pt = permutation_test_for_metric(
250
+ report_a,
251
+ report_b,
252
+ name,
253
+ statistic=statistic,
254
+ n_permutations=n_permutations,
255
+ seed=seed,
256
+ align_by_sample_id=True,
257
+ )
258
+ pt_results[name] = pt
259
+ p_values.append(pt.p_value)
260
+ corrected = holm_bonferroni(p_values)
261
+ return {
262
+ "per_metric": pt_results,
263
+ "holm_significant": corrected,
264
+ }
265
+
266
+
267
+ def confusion_matrix(
268
+ labels_true: Sequence[str], labels_pred: Sequence[str]
269
+ ) -> Dict[str, Dict[str, int]]:
270
+ if len(labels_true) != len(labels_pred):
271
+ raise ValueError("labels_true and labels_pred must have same length")
272
+ cm: Dict[str, Dict[str, int]] = {}
273
+ for t, p in zip(labels_true, labels_pred):
274
+ cm.setdefault(t, {})
275
+ cm[t][p] = cm[t].get(p, 0) + 1
276
+ return cm
277
+
278
+
279
+ __all__ = [
280
+ "EvaluationFailure",
281
+ "MetricAggregate",
282
+ "EvaluationReport",
283
+ "aligned_metric_values",
284
+ "ci_for_metric",
285
+ "ci_for_slice_metric",
286
+ "permutation_test_for_metric",
287
+ "paired_permutation_test_for_metric",
288
+ "cohens_h_for_metric",
289
+ "cohens_d_for_metric",
290
+ "paired_t_test_for_metric",
291
+ "confusion_matrix",
292
+ "compare_reports_with_holm",
293
+ ]
@@ -0,0 +1,53 @@
1
+ """Statistical analysis utilities for experiment evaluation results.
2
+
3
+ This module provides statistical analysis tools for computing confidence intervals,
4
+ significance tests, and statistical comparisons across experiment runs.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .bootstrap import bootstrap_ci
10
+ from .confidence_intervals import (
11
+ compute_confidence_interval,
12
+ compute_statistical_summary,
13
+ )
14
+ from .effect_sizes import cohens_d, cohens_h
15
+ from .hypothesis_tests import (
16
+ compare_metrics,
17
+ holm_bonferroni,
18
+ paired_permutation_test,
19
+ paired_t_test,
20
+ permutation_test,
21
+ )
22
+ from .types import (
23
+ BootstrapResult,
24
+ ComparisonResult,
25
+ ConfidenceInterval,
26
+ EffectSize,
27
+ PermutationTestResult,
28
+ StatisticalSummary,
29
+ )
30
+
31
+ __all__ = [
32
+ # Types
33
+ "ConfidenceInterval",
34
+ "StatisticalSummary",
35
+ "ComparisonResult",
36
+ "PermutationTestResult",
37
+ "BootstrapResult",
38
+ "EffectSize",
39
+ # Confidence intervals
40
+ "compute_confidence_interval",
41
+ "compute_statistical_summary",
42
+ # Hypothesis tests
43
+ "compare_metrics",
44
+ "permutation_test",
45
+ "paired_permutation_test",
46
+ "paired_t_test",
47
+ "holm_bonferroni",
48
+ # Bootstrap
49
+ "bootstrap_ci",
50
+ # Effect sizes
51
+ "cohens_h",
52
+ "cohens_d",
53
+ ]