themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""Hypothesis testing functions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import random
|
|
7
|
+
from statistics import mean, stdev
|
|
8
|
+
from typing import List, Literal, Sequence
|
|
9
|
+
|
|
10
|
+
from themis.core import entities as core_entities
|
|
11
|
+
|
|
12
|
+
from .confidence_intervals import compute_confidence_interval
|
|
13
|
+
from .distributions import t_to_p_value
|
|
14
|
+
from .types import ComparisonResult, PermutationTestResult
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compare_metrics(
|
|
18
|
+
baseline_scores: List[core_entities.MetricScore],
|
|
19
|
+
treatment_scores: List[core_entities.MetricScore],
|
|
20
|
+
significance_level: float = 0.05,
|
|
21
|
+
) -> ComparisonResult:
|
|
22
|
+
"""Perform two-sample t-test to compare baseline vs treatment metrics.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
baseline_scores: Metric scores from baseline/control group
|
|
26
|
+
treatment_scores: Metric scores from treatment group
|
|
27
|
+
significance_level: Threshold for statistical significance (default: 0.05)
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
ComparisonResult with comparison statistics
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If either scores list is empty or metric names don't match
|
|
34
|
+
"""
|
|
35
|
+
if not baseline_scores or not treatment_scores:
|
|
36
|
+
raise ValueError("Both baseline and treatment scores must be non-empty")
|
|
37
|
+
|
|
38
|
+
baseline_name = baseline_scores[0].metric_name
|
|
39
|
+
treatment_name = treatment_scores[0].metric_name
|
|
40
|
+
if baseline_name != treatment_name:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Metric names must match: baseline='{baseline_name}', "
|
|
43
|
+
f"treatment='{treatment_name}'"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
baseline_values = [score.value for score in baseline_scores]
|
|
47
|
+
treatment_values = [score.value for score in treatment_scores]
|
|
48
|
+
|
|
49
|
+
n1 = len(baseline_values)
|
|
50
|
+
n2 = len(treatment_values)
|
|
51
|
+
mean1 = mean(baseline_values)
|
|
52
|
+
mean2 = mean(treatment_values)
|
|
53
|
+
|
|
54
|
+
# Compute standard deviations
|
|
55
|
+
std1 = stdev(baseline_values) if n1 >= 2 else 0.0
|
|
56
|
+
std2 = stdev(treatment_values) if n2 >= 2 else 0.0
|
|
57
|
+
|
|
58
|
+
# Two-sample t-test (Welch's t-test for unequal variances)
|
|
59
|
+
if std1 == 0.0 and std2 == 0.0:
|
|
60
|
+
# Both groups have no variance
|
|
61
|
+
t_stat = 0.0 if mean1 == mean2 else float("inf")
|
|
62
|
+
p_value = 1.0 if mean1 == mean2 else 0.0
|
|
63
|
+
else:
|
|
64
|
+
pooled_se = math.sqrt((std1**2) / n1 + (std2**2) / n2)
|
|
65
|
+
if pooled_se == 0.0:
|
|
66
|
+
t_stat = 0.0
|
|
67
|
+
p_value = 1.0
|
|
68
|
+
else:
|
|
69
|
+
t_stat = (mean2 - mean1) / pooled_se
|
|
70
|
+
# Degrees of freedom (Welch-Satterthwaite approximation)
|
|
71
|
+
if std1 > 0 and std2 > 0:
|
|
72
|
+
df = ((std1**2 / n1 + std2**2 / n2) ** 2) / (
|
|
73
|
+
(std1**2 / n1) ** 2 / (n1 - 1) + (std2**2 / n2) ** 2 / (n2 - 1)
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
df = max(n1, n2) - 1
|
|
77
|
+
# Approximate p-value using t-distribution
|
|
78
|
+
p_value = t_to_p_value(abs(t_stat), int(df))
|
|
79
|
+
|
|
80
|
+
difference = mean2 - mean1
|
|
81
|
+
relative_change = (difference / mean1 * 100.0) if mean1 != 0 else float("inf")
|
|
82
|
+
|
|
83
|
+
# Compute confidence intervals
|
|
84
|
+
baseline_ci = compute_confidence_interval(baseline_values)
|
|
85
|
+
treatment_ci = compute_confidence_interval(treatment_values)
|
|
86
|
+
|
|
87
|
+
return ComparisonResult(
|
|
88
|
+
metric_name=baseline_name,
|
|
89
|
+
baseline_mean=mean1,
|
|
90
|
+
treatment_mean=mean2,
|
|
91
|
+
difference=difference,
|
|
92
|
+
relative_change=relative_change,
|
|
93
|
+
t_statistic=t_stat,
|
|
94
|
+
p_value=p_value,
|
|
95
|
+
is_significant=p_value < significance_level,
|
|
96
|
+
baseline_ci=baseline_ci,
|
|
97
|
+
treatment_ci=treatment_ci,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def permutation_test(
|
|
102
|
+
group_a: Sequence[float],
|
|
103
|
+
group_b: Sequence[float],
|
|
104
|
+
statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
|
|
105
|
+
n_permutations: int = 10000,
|
|
106
|
+
seed: int | None = None,
|
|
107
|
+
) -> PermutationTestResult:
|
|
108
|
+
"""Perform permutation test to compare two groups.
|
|
109
|
+
|
|
110
|
+
This non-parametric test does not assume normality and is robust
|
|
111
|
+
to outliers and skewed distributions.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
group_a: Values from first group
|
|
115
|
+
group_b: Values from second group
|
|
116
|
+
statistic: Test statistic to use ("mean_diff" or "median_diff")
|
|
117
|
+
n_permutations: Number of permutation iterations (default: 10000)
|
|
118
|
+
seed: Random seed for reproducibility
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
PermutationTestResult with p-value and statistics
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: If either group is empty
|
|
125
|
+
"""
|
|
126
|
+
if not group_a or not group_b:
|
|
127
|
+
raise ValueError("Both groups must be non-empty")
|
|
128
|
+
|
|
129
|
+
rng = random.Random(seed)
|
|
130
|
+
|
|
131
|
+
# Compute observed statistic
|
|
132
|
+
def compute_stat(a: Sequence[float], b: Sequence[float]) -> float:
|
|
133
|
+
if statistic == "mean_diff":
|
|
134
|
+
return mean(b) - mean(a)
|
|
135
|
+
elif statistic == "median_diff":
|
|
136
|
+
import statistics
|
|
137
|
+
|
|
138
|
+
return statistics.median(b) - statistics.median(a)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(f"Unknown statistic: {statistic}")
|
|
141
|
+
|
|
142
|
+
observed = compute_stat(group_a, group_b)
|
|
143
|
+
|
|
144
|
+
# Combine all values for permutation
|
|
145
|
+
combined = list(group_a) + list(group_b)
|
|
146
|
+
n_a = len(group_a)
|
|
147
|
+
|
|
148
|
+
# Permutation iterations
|
|
149
|
+
count_extreme = 0
|
|
150
|
+
for _ in range(n_permutations):
|
|
151
|
+
# Shuffle and split into two groups
|
|
152
|
+
rng.shuffle(combined)
|
|
153
|
+
perm_a = combined[:n_a]
|
|
154
|
+
perm_b = combined[n_a:]
|
|
155
|
+
|
|
156
|
+
# Compute permuted statistic
|
|
157
|
+
perm_stat = compute_stat(perm_a, perm_b)
|
|
158
|
+
|
|
159
|
+
# Two-tailed test: count if |perm_stat| >= |observed|
|
|
160
|
+
if abs(perm_stat) >= abs(observed):
|
|
161
|
+
count_extreme += 1
|
|
162
|
+
|
|
163
|
+
# +1 correction avoids zero p-values with finite permutations
|
|
164
|
+
p_value = (count_extreme + 1) / (n_permutations + 1)
|
|
165
|
+
|
|
166
|
+
return PermutationTestResult(
|
|
167
|
+
observed_statistic=observed,
|
|
168
|
+
p_value=p_value,
|
|
169
|
+
n_permutations=n_permutations,
|
|
170
|
+
is_significant=p_value < 0.05,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def paired_permutation_test(
|
|
175
|
+
group_a: Sequence[float],
|
|
176
|
+
group_b: Sequence[float],
|
|
177
|
+
statistic: Literal["mean_diff", "median_diff"] = "mean_diff",
|
|
178
|
+
n_permutations: int = 10000,
|
|
179
|
+
seed: int | None = None,
|
|
180
|
+
) -> PermutationTestResult:
|
|
181
|
+
"""Perform paired permutation test using sign flips on paired differences."""
|
|
182
|
+
if len(group_a) != len(group_b):
|
|
183
|
+
raise ValueError("Paired test requires equal-length groups")
|
|
184
|
+
if not group_a:
|
|
185
|
+
raise ValueError("Paired test requires non-empty groups")
|
|
186
|
+
|
|
187
|
+
rng = random.Random(seed)
|
|
188
|
+
diffs = [b - a for a, b in zip(group_a, group_b)]
|
|
189
|
+
|
|
190
|
+
def compute_stat(values: Sequence[float]) -> float:
|
|
191
|
+
if statistic == "mean_diff":
|
|
192
|
+
return mean(values)
|
|
193
|
+
elif statistic == "median_diff":
|
|
194
|
+
import statistics
|
|
195
|
+
|
|
196
|
+
return statistics.median(values)
|
|
197
|
+
else:
|
|
198
|
+
raise ValueError(f"Unknown statistic: {statistic}")
|
|
199
|
+
|
|
200
|
+
observed = compute_stat(diffs)
|
|
201
|
+
count_extreme = 0
|
|
202
|
+
for _ in range(n_permutations):
|
|
203
|
+
flipped = [d if rng.random() < 0.5 else -d for d in diffs]
|
|
204
|
+
perm_stat = compute_stat(flipped)
|
|
205
|
+
if abs(perm_stat) >= abs(observed):
|
|
206
|
+
count_extreme += 1
|
|
207
|
+
|
|
208
|
+
p_value = (count_extreme + 1) / (n_permutations + 1)
|
|
209
|
+
|
|
210
|
+
return PermutationTestResult(
|
|
211
|
+
observed_statistic=observed,
|
|
212
|
+
p_value=p_value,
|
|
213
|
+
n_permutations=n_permutations,
|
|
214
|
+
is_significant=p_value < 0.05,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def paired_t_test(
|
|
219
|
+
group_a: Sequence[float],
|
|
220
|
+
group_b: Sequence[float],
|
|
221
|
+
significance_level: float = 0.05,
|
|
222
|
+
) -> ComparisonResult:
|
|
223
|
+
"""Perform paired t-test on matched samples."""
|
|
224
|
+
if len(group_a) != len(group_b):
|
|
225
|
+
raise ValueError("Paired t-test requires equal-length groups")
|
|
226
|
+
if not group_a:
|
|
227
|
+
raise ValueError("Paired t-test requires non-empty groups")
|
|
228
|
+
|
|
229
|
+
diffs = [b - a for a, b in zip(group_a, group_b)]
|
|
230
|
+
n = len(diffs)
|
|
231
|
+
mean_diff = mean(diffs)
|
|
232
|
+
std_diff = stdev(diffs) if n >= 2 else 0.0
|
|
233
|
+
|
|
234
|
+
if std_diff == 0.0:
|
|
235
|
+
t_stat = 0.0 if mean_diff == 0.0 else float("inf")
|
|
236
|
+
p_value = 1.0 if mean_diff == 0.0 else 0.0
|
|
237
|
+
else:
|
|
238
|
+
se = std_diff / math.sqrt(n)
|
|
239
|
+
t_stat = mean_diff / se if se > 0 else 0.0
|
|
240
|
+
p_value = t_to_p_value(abs(t_stat), n - 1)
|
|
241
|
+
|
|
242
|
+
baseline_mean = mean(group_a)
|
|
243
|
+
treatment_mean = mean(group_b)
|
|
244
|
+
difference = treatment_mean - baseline_mean
|
|
245
|
+
relative_change = (difference / baseline_mean * 100.0) if baseline_mean != 0 else float("inf")
|
|
246
|
+
|
|
247
|
+
baseline_ci = compute_confidence_interval(group_a)
|
|
248
|
+
treatment_ci = compute_confidence_interval(group_b)
|
|
249
|
+
|
|
250
|
+
return ComparisonResult(
|
|
251
|
+
metric_name="paired",
|
|
252
|
+
baseline_mean=baseline_mean,
|
|
253
|
+
treatment_mean=treatment_mean,
|
|
254
|
+
difference=difference,
|
|
255
|
+
relative_change=relative_change,
|
|
256
|
+
t_statistic=t_stat,
|
|
257
|
+
p_value=p_value,
|
|
258
|
+
is_significant=p_value < significance_level,
|
|
259
|
+
baseline_ci=baseline_ci,
|
|
260
|
+
treatment_ci=treatment_ci,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def holm_bonferroni(p_values: Sequence[float]) -> List[bool]:
|
|
265
|
+
"""Apply Holm-Bonferroni correction for multiple comparisons.
|
|
266
|
+
|
|
267
|
+
This method controls the family-wise error rate (FWER) while being
|
|
268
|
+
more powerful than the simple Bonferroni correction.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
p_values: List of p-values from multiple tests
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
List of boolean values indicating which tests remain significant
|
|
275
|
+
after correction (True = significant, False = not significant)
|
|
276
|
+
|
|
277
|
+
Example:
|
|
278
|
+
>>> p_vals = [0.01, 0.04, 0.03, 0.20]
|
|
279
|
+
>>> significant = holm_bonferroni(p_vals)
|
|
280
|
+
>>> # Returns which tests are significant after correction
|
|
281
|
+
"""
|
|
282
|
+
if not p_values:
|
|
283
|
+
return []
|
|
284
|
+
|
|
285
|
+
n = len(p_values)
|
|
286
|
+
|
|
287
|
+
# Create (p-value, original_index) pairs and sort by p-value
|
|
288
|
+
indexed_pvals = [(p, i) for i, p in enumerate(p_values)]
|
|
289
|
+
indexed_pvals.sort(key=lambda x: x[0])
|
|
290
|
+
|
|
291
|
+
# Apply Holm-Bonferroni sequential rejection
|
|
292
|
+
results = [False] * n
|
|
293
|
+
alpha = 0.05 # Standard significance level
|
|
294
|
+
|
|
295
|
+
for rank, (p_val, orig_idx) in enumerate(indexed_pvals):
|
|
296
|
+
# Adjusted threshold: alpha / (n - rank)
|
|
297
|
+
threshold = alpha / (n - rank)
|
|
298
|
+
|
|
299
|
+
if p_val < threshold:
|
|
300
|
+
results[orig_idx] = True
|
|
301
|
+
else:
|
|
302
|
+
# Once we fail to reject, all subsequent tests also fail
|
|
303
|
+
break
|
|
304
|
+
|
|
305
|
+
return results
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""Statistical result types and dataclasses."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ConfidenceInterval:
|
|
10
|
+
"""Confidence interval for a metric.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
mean: Sample mean of the metric
|
|
14
|
+
lower: Lower bound of the confidence interval
|
|
15
|
+
upper: Upper bound of the confidence interval
|
|
16
|
+
confidence_level: Confidence level (e.g., 0.95 for 95%)
|
|
17
|
+
sample_size: Number of samples used
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
mean: float
|
|
21
|
+
lower: float
|
|
22
|
+
upper: float
|
|
23
|
+
confidence_level: float
|
|
24
|
+
sample_size: int
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def margin_of_error(self) -> float:
|
|
28
|
+
"""Return the margin of error (half-width of the interval)."""
|
|
29
|
+
return (self.upper - self.lower) / 2.0
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def width(self) -> float:
|
|
33
|
+
"""Return the width of the confidence interval."""
|
|
34
|
+
return self.upper - self.lower
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class StatisticalSummary:
|
|
39
|
+
"""Statistical summary for a set of metric scores.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
metric_name: Name of the metric
|
|
43
|
+
count: Number of samples
|
|
44
|
+
mean: Sample mean
|
|
45
|
+
std: Sample standard deviation
|
|
46
|
+
min_value: Minimum value
|
|
47
|
+
max_value: Maximum value
|
|
48
|
+
median: Median value
|
|
49
|
+
confidence_interval_95: 95% confidence interval for the mean
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
metric_name: str
|
|
53
|
+
count: int
|
|
54
|
+
mean: float
|
|
55
|
+
std: float
|
|
56
|
+
min_value: float
|
|
57
|
+
max_value: float
|
|
58
|
+
median: float
|
|
59
|
+
confidence_interval_95: ConfidenceInterval | None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class ComparisonResult:
|
|
64
|
+
"""Result of a statistical comparison between two metric sets.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
metric_name: Name of the metric being compared
|
|
68
|
+
baseline_mean: Mean of the baseline (control) group
|
|
69
|
+
treatment_mean: Mean of the treatment group
|
|
70
|
+
difference: Difference between treatment and baseline means
|
|
71
|
+
relative_change: Relative change as a percentage
|
|
72
|
+
t_statistic: t-test statistic
|
|
73
|
+
p_value: p-value for the two-sample t-test
|
|
74
|
+
is_significant: Whether the difference is statistically significant (p < 0.05)
|
|
75
|
+
baseline_ci: 95% confidence interval for baseline mean
|
|
76
|
+
treatment_ci: 95% confidence interval for treatment mean
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
metric_name: str
|
|
80
|
+
baseline_mean: float
|
|
81
|
+
treatment_mean: float
|
|
82
|
+
difference: float
|
|
83
|
+
relative_change: float
|
|
84
|
+
t_statistic: float
|
|
85
|
+
p_value: float
|
|
86
|
+
is_significant: bool
|
|
87
|
+
baseline_ci: ConfidenceInterval
|
|
88
|
+
treatment_ci: ConfidenceInterval
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class PermutationTestResult:
|
|
93
|
+
"""Result of a permutation test.
|
|
94
|
+
|
|
95
|
+
Attributes:
|
|
96
|
+
observed_statistic: Observed test statistic
|
|
97
|
+
p_value: Permutation test p-value
|
|
98
|
+
n_permutations: Number of permutations performed
|
|
99
|
+
is_significant: Whether result is significant at alpha=0.05
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
observed_statistic: float
|
|
103
|
+
p_value: float
|
|
104
|
+
n_permutations: int
|
|
105
|
+
is_significant: bool
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class BootstrapResult:
|
|
110
|
+
"""Result of bootstrap resampling.
|
|
111
|
+
|
|
112
|
+
Attributes:
|
|
113
|
+
statistic: Point estimate of the statistic
|
|
114
|
+
ci_lower: Lower bound of bootstrap CI
|
|
115
|
+
ci_upper: Upper bound of bootstrap CI
|
|
116
|
+
confidence_level: Confidence level used
|
|
117
|
+
n_bootstrap: Number of bootstrap iterations
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
statistic: float
|
|
121
|
+
ci_lower: float
|
|
122
|
+
ci_upper: float
|
|
123
|
+
confidence_level: float
|
|
124
|
+
n_bootstrap: int
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass
|
|
128
|
+
class EffectSize:
|
|
129
|
+
"""Effect size measure.
|
|
130
|
+
|
|
131
|
+
Attributes:
|
|
132
|
+
name: Name of effect size measure (e.g., "cohen_h", "cohen_d")
|
|
133
|
+
value: Effect size value
|
|
134
|
+
interpretation: Text interpretation (e.g., "small", "medium", "large")
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
name: str
|
|
138
|
+
value: float
|
|
139
|
+
interpretation: str
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .attempt_aware_evaluation_strategy import AttemptAwareEvaluationStrategy
|
|
4
|
+
from .default_evaluation_strategy import DefaultEvaluationStrategy
|
|
5
|
+
from .evaluation_strategy import EvaluationStrategy
|
|
6
|
+
from .judge_evaluation_strategy import JudgeEvaluationStrategy
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"EvaluationStrategy",
|
|
10
|
+
"DefaultEvaluationStrategy",
|
|
11
|
+
"JudgeEvaluationStrategy",
|
|
12
|
+
"AttemptAwareEvaluationStrategy",
|
|
13
|
+
]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Iterable, List
|
|
5
|
+
|
|
6
|
+
from themis.core import entities as core_entities
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class AttemptAwareEvaluationStrategy:
|
|
11
|
+
"""Evaluates each generation attempt independently.
|
|
12
|
+
|
|
13
|
+
When average_attempts=True, returns a single averaged score per metric.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
average_attempts: bool = True
|
|
17
|
+
|
|
18
|
+
def prepare(
|
|
19
|
+
self, record: core_entities.GenerationRecord
|
|
20
|
+
) -> Iterable[core_entities.EvaluationItem]:
|
|
21
|
+
attempts = record.attempts or [record]
|
|
22
|
+
for attempt in attempts:
|
|
23
|
+
yield core_entities.EvaluationItem(
|
|
24
|
+
record=attempt, reference=attempt.task.reference
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def aggregate(
|
|
28
|
+
self,
|
|
29
|
+
record: core_entities.GenerationRecord,
|
|
30
|
+
scores: List[core_entities.MetricScore],
|
|
31
|
+
) -> List[core_entities.MetricScore]:
|
|
32
|
+
if not self.average_attempts or not scores:
|
|
33
|
+
return scores
|
|
34
|
+
aggregated: list[core_entities.MetricScore] = []
|
|
35
|
+
grouped: dict[str, list[core_entities.MetricScore]] = {}
|
|
36
|
+
for score in scores:
|
|
37
|
+
grouped.setdefault(score.metric_name, []).append(score)
|
|
38
|
+
for metric_name, group in grouped.items():
|
|
39
|
+
value = sum(item.value for item in group) / len(group)
|
|
40
|
+
aggregated.append(
|
|
41
|
+
core_entities.MetricScore(
|
|
42
|
+
metric_name=metric_name,
|
|
43
|
+
value=value,
|
|
44
|
+
metadata={
|
|
45
|
+
"attempts": len(group),
|
|
46
|
+
"sample_id": group[0].metadata.get("sample_id"),
|
|
47
|
+
},
|
|
48
|
+
details={},
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
return aggregated
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Iterable, List
|
|
5
|
+
|
|
6
|
+
from themis.core import entities as core_entities
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class DefaultEvaluationStrategy:
|
|
11
|
+
"""Single-item evaluation for exact-match style metrics."""
|
|
12
|
+
|
|
13
|
+
def prepare(
|
|
14
|
+
self, record: core_entities.GenerationRecord
|
|
15
|
+
) -> Iterable[core_entities.EvaluationItem]:
|
|
16
|
+
yield core_entities.EvaluationItem(
|
|
17
|
+
record=record, reference=record.task.reference
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def aggregate(
|
|
21
|
+
self,
|
|
22
|
+
record: core_entities.GenerationRecord,
|
|
23
|
+
scores: List[core_entities.MetricScore],
|
|
24
|
+
) -> List[core_entities.MetricScore]:
|
|
25
|
+
return scores
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Iterable, List, Protocol
|
|
4
|
+
|
|
5
|
+
from themis.core import entities as core_entities
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EvaluationStrategy(Protocol):
|
|
9
|
+
"""Strategy controlling how evaluation items are constructed and aggregated."""
|
|
10
|
+
|
|
11
|
+
def prepare(
|
|
12
|
+
self, record: core_entities.GenerationRecord
|
|
13
|
+
) -> Iterable[core_entities.EvaluationItem]: # pragma: no cover - interface
|
|
14
|
+
...
|
|
15
|
+
|
|
16
|
+
def aggregate(
|
|
17
|
+
self,
|
|
18
|
+
record: core_entities.GenerationRecord,
|
|
19
|
+
scores: List[core_entities.MetricScore],
|
|
20
|
+
) -> List[core_entities.MetricScore]: # pragma: no cover - interface
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = ["EvaluationStrategy"]
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Iterable, List
|
|
5
|
+
|
|
6
|
+
from themis.core import entities as core_entities
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class JudgeEvaluationStrategy:
|
|
11
|
+
"""Aggregate multiple judge metric scores and report agreement.
|
|
12
|
+
|
|
13
|
+
This strategy groups incoming MetricScore items by metric_name and returns
|
|
14
|
+
a single aggregated score per metric, including inter-judge agreement.
|
|
15
|
+
It is model-agnostic and works with RubricJudgeMetric and PairwiseJudgeMetric.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def prepare(
|
|
19
|
+
self, record: core_entities.GenerationRecord
|
|
20
|
+
) -> Iterable[core_entities.EvaluationItem]:
|
|
21
|
+
yield core_entities.EvaluationItem(
|
|
22
|
+
record=record, reference=record.task.reference
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def aggregate(
|
|
26
|
+
self,
|
|
27
|
+
record: core_entities.GenerationRecord,
|
|
28
|
+
scores: List[core_entities.MetricScore],
|
|
29
|
+
) -> List[core_entities.MetricScore]:
|
|
30
|
+
if not scores:
|
|
31
|
+
return []
|
|
32
|
+
grouped: dict[str, list[core_entities.MetricScore]] = {}
|
|
33
|
+
for score in scores:
|
|
34
|
+
grouped.setdefault(score.metric_name, []).append(score)
|
|
35
|
+
|
|
36
|
+
aggregated: list[core_entities.MetricScore] = []
|
|
37
|
+
for metric_name, group in grouped.items():
|
|
38
|
+
value = sum(item.value for item in group) / max(1, len(group))
|
|
39
|
+
labels: list[str] = []
|
|
40
|
+
for item in group:
|
|
41
|
+
details = item.details or {}
|
|
42
|
+
label = details.get("verdict") or details.get("preference")
|
|
43
|
+
if isinstance(label, str) and label:
|
|
44
|
+
labels.append(label.lower().strip())
|
|
45
|
+
agreement = 0.0
|
|
46
|
+
if labels:
|
|
47
|
+
from collections import Counter
|
|
48
|
+
|
|
49
|
+
counts = Counter(labels)
|
|
50
|
+
agreement = max(counts.values()) / max(1, len(labels))
|
|
51
|
+
|
|
52
|
+
aggregated.append(
|
|
53
|
+
core_entities.MetricScore(
|
|
54
|
+
metric_name=metric_name,
|
|
55
|
+
value=value,
|
|
56
|
+
details={
|
|
57
|
+
"judge_count": len(group),
|
|
58
|
+
"agreement": agreement,
|
|
59
|
+
"labels": labels,
|
|
60
|
+
},
|
|
61
|
+
metadata={"sample_id": group[0].metadata.get("sample_id")},
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
return aggregated
|