themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. themis/__init__.py +12 -1
  2. themis/_version.py +2 -2
  3. themis/api.py +343 -0
  4. themis/backends/__init__.py +17 -0
  5. themis/backends/execution.py +197 -0
  6. themis/backends/storage.py +260 -0
  7. themis/cli/__init__.py +5 -0
  8. themis/cli/__main__.py +6 -0
  9. themis/cli/commands/__init__.py +19 -0
  10. themis/cli/commands/benchmarks.py +221 -0
  11. themis/cli/commands/comparison.py +394 -0
  12. themis/cli/commands/config_commands.py +244 -0
  13. themis/cli/commands/cost.py +214 -0
  14. themis/cli/commands/demo.py +68 -0
  15. themis/cli/commands/info.py +90 -0
  16. themis/cli/commands/leaderboard.py +362 -0
  17. themis/cli/commands/math_benchmarks.py +318 -0
  18. themis/cli/commands/mcq_benchmarks.py +207 -0
  19. themis/cli/commands/results.py +252 -0
  20. themis/cli/commands/sample_run.py +244 -0
  21. themis/cli/commands/visualize.py +299 -0
  22. themis/cli/main.py +463 -0
  23. themis/cli/new_project.py +33 -0
  24. themis/cli/utils.py +51 -0
  25. themis/comparison/__init__.py +25 -0
  26. themis/comparison/engine.py +348 -0
  27. themis/comparison/reports.py +283 -0
  28. themis/comparison/statistics.py +402 -0
  29. themis/config/__init__.py +19 -0
  30. themis/config/loader.py +27 -0
  31. themis/config/registry.py +34 -0
  32. themis/config/runtime.py +214 -0
  33. themis/config/schema.py +112 -0
  34. themis/core/__init__.py +5 -0
  35. themis/core/conversation.py +354 -0
  36. themis/core/entities.py +184 -0
  37. themis/core/serialization.py +231 -0
  38. themis/core/tools.py +393 -0
  39. themis/core/types.py +141 -0
  40. themis/datasets/__init__.py +273 -0
  41. themis/datasets/base.py +264 -0
  42. themis/datasets/commonsense_qa.py +174 -0
  43. themis/datasets/competition_math.py +265 -0
  44. themis/datasets/coqa.py +133 -0
  45. themis/datasets/gpqa.py +190 -0
  46. themis/datasets/gsm8k.py +123 -0
  47. themis/datasets/gsm_symbolic.py +124 -0
  48. themis/datasets/math500.py +122 -0
  49. themis/datasets/med_qa.py +179 -0
  50. themis/datasets/medmcqa.py +169 -0
  51. themis/datasets/mmlu_pro.py +262 -0
  52. themis/datasets/piqa.py +146 -0
  53. themis/datasets/registry.py +201 -0
  54. themis/datasets/schema.py +245 -0
  55. themis/datasets/sciq.py +150 -0
  56. themis/datasets/social_i_qa.py +151 -0
  57. themis/datasets/super_gpqa.py +263 -0
  58. themis/evaluation/__init__.py +1 -0
  59. themis/evaluation/conditional.py +410 -0
  60. themis/evaluation/extractors/__init__.py +19 -0
  61. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  62. themis/evaluation/extractors/exceptions.py +7 -0
  63. themis/evaluation/extractors/identity_extractor.py +29 -0
  64. themis/evaluation/extractors/json_field_extractor.py +45 -0
  65. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  66. themis/evaluation/extractors/regex_extractor.py +43 -0
  67. themis/evaluation/math_verify_utils.py +87 -0
  68. themis/evaluation/metrics/__init__.py +21 -0
  69. themis/evaluation/metrics/code/__init__.py +19 -0
  70. themis/evaluation/metrics/code/codebleu.py +144 -0
  71. themis/evaluation/metrics/code/execution.py +280 -0
  72. themis/evaluation/metrics/code/pass_at_k.py +181 -0
  73. themis/evaluation/metrics/composite_metric.py +47 -0
  74. themis/evaluation/metrics/consistency_metric.py +80 -0
  75. themis/evaluation/metrics/exact_match.py +51 -0
  76. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  77. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  78. themis/evaluation/metrics/nlp/__init__.py +21 -0
  79. themis/evaluation/metrics/nlp/bertscore.py +138 -0
  80. themis/evaluation/metrics/nlp/bleu.py +129 -0
  81. themis/evaluation/metrics/nlp/meteor.py +153 -0
  82. themis/evaluation/metrics/nlp/rouge.py +136 -0
  83. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  84. themis/evaluation/metrics/response_length.py +33 -0
  85. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  86. themis/evaluation/pipeline.py +49 -0
  87. themis/evaluation/pipelines/__init__.py +15 -0
  88. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  89. themis/evaluation/pipelines/standard_pipeline.py +348 -0
  90. themis/evaluation/reports.py +293 -0
  91. themis/evaluation/statistics/__init__.py +53 -0
  92. themis/evaluation/statistics/bootstrap.py +79 -0
  93. themis/evaluation/statistics/confidence_intervals.py +121 -0
  94. themis/evaluation/statistics/distributions.py +207 -0
  95. themis/evaluation/statistics/effect_sizes.py +124 -0
  96. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  97. themis/evaluation/statistics/types.py +139 -0
  98. themis/evaluation/strategies/__init__.py +13 -0
  99. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  100. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  101. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  102. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  103. themis/experiment/__init__.py +5 -0
  104. themis/experiment/builder.py +151 -0
  105. themis/experiment/cache_manager.py +134 -0
  106. themis/experiment/comparison.py +631 -0
  107. themis/experiment/cost.py +310 -0
  108. themis/experiment/definitions.py +62 -0
  109. themis/experiment/export.py +798 -0
  110. themis/experiment/export_csv.py +159 -0
  111. themis/experiment/integration_manager.py +104 -0
  112. themis/experiment/math.py +192 -0
  113. themis/experiment/mcq.py +169 -0
  114. themis/experiment/orchestrator.py +415 -0
  115. themis/experiment/pricing.py +317 -0
  116. themis/experiment/storage.py +1458 -0
  117. themis/experiment/visualization.py +588 -0
  118. themis/generation/__init__.py +1 -0
  119. themis/generation/agentic_runner.py +420 -0
  120. themis/generation/batching.py +254 -0
  121. themis/generation/clients.py +143 -0
  122. themis/generation/conversation_runner.py +236 -0
  123. themis/generation/plan.py +456 -0
  124. themis/generation/providers/litellm_provider.py +221 -0
  125. themis/generation/providers/vllm_provider.py +135 -0
  126. themis/generation/router.py +34 -0
  127. themis/generation/runner.py +207 -0
  128. themis/generation/strategies.py +98 -0
  129. themis/generation/templates.py +71 -0
  130. themis/generation/turn_strategies.py +393 -0
  131. themis/generation/types.py +9 -0
  132. themis/integrations/__init__.py +0 -0
  133. themis/integrations/huggingface.py +72 -0
  134. themis/integrations/wandb.py +77 -0
  135. themis/interfaces/__init__.py +169 -0
  136. themis/presets/__init__.py +10 -0
  137. themis/presets/benchmarks.py +354 -0
  138. themis/presets/models.py +190 -0
  139. themis/project/__init__.py +20 -0
  140. themis/project/definitions.py +98 -0
  141. themis/project/patterns.py +230 -0
  142. themis/providers/__init__.py +5 -0
  143. themis/providers/registry.py +39 -0
  144. themis/server/__init__.py +28 -0
  145. themis/server/app.py +337 -0
  146. themis/utils/api_generator.py +379 -0
  147. themis/utils/cost_tracking.py +376 -0
  148. themis/utils/dashboard.py +452 -0
  149. themis/utils/logging_utils.py +41 -0
  150. themis/utils/progress.py +58 -0
  151. themis/utils/tracing.py +320 -0
  152. themis_eval-0.2.0.dist-info/METADATA +596 -0
  153. themis_eval-0.2.0.dist-info/RECORD +157 -0
  154. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
  155. themis_eval-0.1.0.dist-info/METADATA +0 -758
  156. themis_eval-0.1.0.dist-info/RECORD +0 -8
  157. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
  158. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,402 @@
1
+ """Statistical tests for comparing experiment results.
2
+
3
+ This module provides various statistical tests to determine if differences
4
+ between runs are statistically significant.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ import random
11
+ from dataclasses import dataclass
12
+ from enum import Enum
13
+ from typing import Sequence
14
+
15
+
16
+ class StatisticalTest(str, Enum):
17
+ """Available statistical tests."""
18
+
19
+ T_TEST = "t_test"
20
+ BOOTSTRAP = "bootstrap"
21
+ PERMUTATION = "permutation"
22
+ NONE = "none"
23
+
24
+
25
+ @dataclass
26
+ class StatisticalTestResult:
27
+ """Result of a statistical test.
28
+
29
+ Attributes:
30
+ test_name: Name of the test performed
31
+ statistic: Test statistic value
32
+ p_value: P-value (probability of observing this difference by chance)
33
+ significant: Whether the difference is statistically significant
34
+ confidence_level: Confidence level used (e.g., 0.95 for 95%)
35
+ effect_size: Effect size (e.g., Cohen's d)
36
+ confidence_interval: Confidence interval for the difference
37
+ """
38
+
39
+ test_name: str
40
+ statistic: float
41
+ p_value: float
42
+ significant: bool
43
+ confidence_level: float = 0.95
44
+ effect_size: float | None = None
45
+ confidence_interval: tuple[float, float] | None = None
46
+
47
+ def __str__(self) -> str:
48
+ """Human-readable summary."""
49
+ sig_str = "significant" if self.significant else "not significant"
50
+ result = f"{self.test_name}: p={self.p_value:.4f} ({sig_str})"
51
+
52
+ if self.effect_size is not None:
53
+ result += f", effect_size={self.effect_size:.3f}"
54
+
55
+ if self.confidence_interval is not None:
56
+ low, high = self.confidence_interval
57
+ result += f", CI=[{low:.3f}, {high:.3f}]"
58
+
59
+ return result
60
+
61
+
62
+ def t_test(
63
+ samples_a: Sequence[float],
64
+ samples_b: Sequence[float],
65
+ *,
66
+ alpha: float = 0.05,
67
+ paired: bool = True,
68
+ ) -> StatisticalTestResult:
69
+ """Perform a t-test to compare two sets of samples.
70
+
71
+ Args:
72
+ samples_a: First set of samples
73
+ samples_b: Second set of samples
74
+ alpha: Significance level (default: 0.05 for 95% confidence)
75
+ paired: Whether to use paired t-test (default: True)
76
+
77
+ Returns:
78
+ StatisticalTestResult with test statistics and significance
79
+
80
+ Raises:
81
+ ValueError: If samples are empty or have mismatched lengths (for paired test)
82
+ """
83
+ if not samples_a or not samples_b:
84
+ raise ValueError("Cannot perform t-test on empty samples")
85
+
86
+ if paired and len(samples_a) != len(samples_b):
87
+ raise ValueError(
88
+ f"Paired t-test requires equal sample sizes. "
89
+ f"Got {len(samples_a)} and {len(samples_b)}"
90
+ )
91
+
92
+ n_a = len(samples_a)
93
+ n_b = len(samples_b)
94
+
95
+ # Calculate means
96
+ mean_a = sum(samples_a) / n_a
97
+ mean_b = sum(samples_b) / n_b
98
+
99
+ if paired:
100
+ # Paired t-test: test on differences
101
+ diffs = [a - b for a, b in zip(samples_a, samples_b)]
102
+ mean_diff = sum(diffs) / len(diffs)
103
+
104
+ # Standard deviation of differences
105
+ var_diff = sum((d - mean_diff) ** 2 for d in diffs) / (len(diffs) - 1) if len(diffs) > 1 else 0
106
+ se_diff = math.sqrt(var_diff / len(diffs))
107
+
108
+ # T-statistic
109
+ if se_diff > 1e-10: # Non-zero standard error
110
+ t_stat = mean_diff / se_diff
111
+ elif abs(mean_diff) > 1e-10: # Perfect consistency with non-zero difference
112
+ t_stat = float('inf') if mean_diff > 0 else float('-inf')
113
+ else: # No difference at all
114
+ t_stat = 0.0
115
+
116
+ df = len(diffs) - 1
117
+
118
+ # Effect size (Cohen's d for paired samples)
119
+ sd_diff = math.sqrt(var_diff)
120
+ effect_size = mean_diff / sd_diff if sd_diff > 1e-10 else (1.0 if abs(mean_diff) > 1e-10 else 0.0)
121
+
122
+ else:
123
+ # Independent samples t-test
124
+ # Calculate pooled standard deviation
125
+ var_a = sum((x - mean_a) ** 2 for x in samples_a) / (n_a - 1) if n_a > 1 else 0
126
+ var_b = sum((x - mean_b) ** 2 for x in samples_b) / (n_b - 1) if n_b > 1 else 0
127
+
128
+ pooled_sd = math.sqrt(((n_a - 1) * var_a + (n_b - 1) * var_b) / (n_a + n_b - 2))
129
+ se = pooled_sd * math.sqrt(1/n_a + 1/n_b)
130
+
131
+ # T-statistic
132
+ t_stat = (mean_a - mean_b) / se if se > 0 else 0.0
133
+ df = n_a + n_b - 2
134
+
135
+ # Effect size (Cohen's d)
136
+ effect_size = (mean_a - mean_b) / pooled_sd if pooled_sd > 0 else 0.0
137
+
138
+ # Approximate p-value using t-distribution
139
+ # For simplicity, we use a conservative approximation
140
+ # In practice, you'd use scipy.stats.t.sf for accurate p-values
141
+ p_value = _approximate_t_test_p_value(abs(t_stat), df)
142
+
143
+ # Confidence interval (approximate)
144
+ # t_critical ≈ 2.0 for 95% CI and reasonable df
145
+ t_critical = 2.0 # Conservative estimate
146
+ margin = t_critical * (se_diff if paired else se)
147
+ ci = (mean_a - mean_b - margin, mean_a - mean_b + margin)
148
+
149
+ return StatisticalTestResult(
150
+ test_name="t-test (paired)" if paired else "t-test (independent)",
151
+ statistic=t_stat,
152
+ p_value=p_value,
153
+ significant=p_value < alpha,
154
+ confidence_level=1 - alpha,
155
+ effect_size=effect_size,
156
+ confidence_interval=ci,
157
+ )
158
+
159
+
160
+ def _approximate_t_test_p_value(t_stat: float, df: int) -> float:
161
+ """Approximate p-value for t-test.
162
+
163
+ This is a rough approximation. For accurate p-values, use scipy.stats.
164
+ """
165
+ # Very rough approximation based on standard normal
166
+ # This gets less accurate for small df
167
+ if df < 1:
168
+ return 1.0
169
+
170
+ # Convert to z-score approximation for large df
171
+ if df > 30:
172
+ z = t_stat
173
+ # Approximate p-value for two-tailed test
174
+ # P(|Z| > z) ≈ 2 * (1 - Φ(z))
175
+ if z > 6:
176
+ return 0.0
177
+ elif z < 0.5:
178
+ return 1.0
179
+ else:
180
+ # Rough approximation
181
+ return min(1.0, 2 * math.exp(-0.5 * z * z) / math.sqrt(2 * math.pi))
182
+
183
+ # For small df, be conservative
184
+ return min(1.0, 0.5 if t_stat < 2 else 0.1 if t_stat < 3 else 0.01)
185
+
186
+
187
+ def bootstrap_confidence_interval(
188
+ samples_a: Sequence[float],
189
+ samples_b: Sequence[float],
190
+ *,
191
+ n_bootstrap: int = 10000,
192
+ confidence_level: float = 0.95,
193
+ statistic_fn: callable = None,
194
+ seed: int | None = None,
195
+ ) -> StatisticalTestResult:
196
+ """Compute bootstrap confidence interval for difference between two samples.
197
+
198
+ Uses bootstrap resampling to estimate the confidence interval for the
199
+ difference in means (or other statistic) between two samples.
200
+
201
+ Args:
202
+ samples_a: First set of samples
203
+ samples_b: Second set of samples
204
+ n_bootstrap: Number of bootstrap iterations (default: 10000)
205
+ confidence_level: Confidence level (default: 0.95)
206
+ statistic_fn: Function to compute statistic (default: mean difference)
207
+ seed: Random seed for reproducibility
208
+
209
+ Returns:
210
+ StatisticalTestResult with bootstrap confidence interval
211
+ """
212
+ if not samples_a or not samples_b:
213
+ raise ValueError("Cannot perform bootstrap on empty samples")
214
+
215
+ if seed is not None:
216
+ random.seed(seed)
217
+
218
+ # Default statistic: difference in means
219
+ if statistic_fn is None:
220
+ def statistic_fn(a, b):
221
+ return sum(a) / len(a) - sum(b) / len(b)
222
+
223
+ # Observed difference
224
+ observed_diff = statistic_fn(samples_a, samples_b)
225
+
226
+ # Bootstrap resampling
227
+ bootstrap_diffs = []
228
+ for _ in range(n_bootstrap):
229
+ # Resample with replacement
230
+ resampled_a = [random.choice(samples_a) for _ in range(len(samples_a))]
231
+ resampled_b = [random.choice(samples_b) for _ in range(len(samples_b))]
232
+
233
+ diff = statistic_fn(resampled_a, resampled_b)
234
+ bootstrap_diffs.append(diff)
235
+
236
+ # Sort for percentile method
237
+ bootstrap_diffs.sort()
238
+
239
+ # Compute confidence interval
240
+ alpha = 1 - confidence_level
241
+ lower_idx = int(n_bootstrap * (alpha / 2))
242
+ upper_idx = int(n_bootstrap * (1 - alpha / 2))
243
+
244
+ ci = (bootstrap_diffs[lower_idx], bootstrap_diffs[upper_idx])
245
+
246
+ # Check if 0 is in the confidence interval
247
+ significant = not (ci[0] <= 0 <= ci[1])
248
+
249
+ # Pseudo p-value: proportion of bootstrap samples with opposite sign
250
+ p_value = sum(1 for d in bootstrap_diffs if (d * observed_diff) < 0) / n_bootstrap
251
+ p_value = max(p_value, 1 / n_bootstrap) # Lower bound
252
+
253
+ return StatisticalTestResult(
254
+ test_name=f"bootstrap (n={n_bootstrap})",
255
+ statistic=observed_diff,
256
+ p_value=p_value,
257
+ significant=significant,
258
+ confidence_level=confidence_level,
259
+ confidence_interval=ci,
260
+ )
261
+
262
+
263
+ def permutation_test(
264
+ samples_a: Sequence[float],
265
+ samples_b: Sequence[float],
266
+ *,
267
+ n_permutations: int = 10000,
268
+ alpha: float = 0.05,
269
+ statistic_fn: callable = None,
270
+ seed: int | None = None,
271
+ ) -> StatisticalTestResult:
272
+ """Perform permutation test to compare two samples.
273
+
274
+ Tests the null hypothesis that the two samples come from the same
275
+ distribution by randomly permuting the labels and computing the test
276
+ statistic.
277
+
278
+ Args:
279
+ samples_a: First set of samples
280
+ samples_b: Second set of samples
281
+ n_permutations: Number of permutations (default: 10000)
282
+ alpha: Significance level (default: 0.05)
283
+ statistic_fn: Function to compute statistic (default: difference in means)
284
+ seed: Random seed for reproducibility
285
+
286
+ Returns:
287
+ StatisticalTestResult with permutation test results
288
+ """
289
+ if not samples_a or not samples_b:
290
+ raise ValueError("Cannot perform permutation test on empty samples")
291
+
292
+ if seed is not None:
293
+ random.seed(seed)
294
+
295
+ # Default statistic: absolute difference in means
296
+ if statistic_fn is None:
297
+ def statistic_fn(a, b):
298
+ return abs(sum(a) / len(a) - sum(b) / len(b))
299
+
300
+ # Observed statistic
301
+ observed_stat = statistic_fn(samples_a, samples_b)
302
+
303
+ # Combine all samples
304
+ combined = list(samples_a) + list(samples_b)
305
+ n_a = len(samples_a)
306
+ n_total = len(combined)
307
+
308
+ # Permutation testing
309
+ more_extreme = 0
310
+ for _ in range(n_permutations):
311
+ # Shuffle and split
312
+ shuffled = combined.copy()
313
+ random.shuffle(shuffled)
314
+
315
+ perm_a = shuffled[:n_a]
316
+ perm_b = shuffled[n_a:]
317
+
318
+ perm_stat = statistic_fn(perm_a, perm_b)
319
+
320
+ if perm_stat >= observed_stat:
321
+ more_extreme += 1
322
+
323
+ # P-value: proportion of permutations as extreme as observed
324
+ p_value = more_extreme / n_permutations
325
+
326
+ return StatisticalTestResult(
327
+ test_name=f"permutation (n={n_permutations})",
328
+ statistic=observed_stat,
329
+ p_value=p_value,
330
+ significant=p_value < alpha,
331
+ confidence_level=1 - alpha,
332
+ )
333
+
334
+
335
+ def mcnemar_test(
336
+ contingency_table: tuple[int, int, int, int],
337
+ *,
338
+ alpha: float = 0.05,
339
+ ) -> StatisticalTestResult:
340
+ """Perform McNemar's test for paired nominal data.
341
+
342
+ Useful for comparing two models on the same test set, where you want to
343
+ know if one model consistently outperforms the other.
344
+
345
+ Args:
346
+ contingency_table: 2x2 contingency table as (n_00, n_01, n_10, n_11)
347
+ where n_ij = number of samples where model A predicts i and model B predicts j
348
+ (0 = incorrect, 1 = correct)
349
+ alpha: Significance level
350
+
351
+ Returns:
352
+ StatisticalTestResult with McNemar's test results
353
+ """
354
+ n_00, n_01, n_10, n_11 = contingency_table
355
+
356
+ # Only discordant pairs matter
357
+ b = n_01 # A wrong, B correct
358
+ c = n_10 # A correct, B wrong
359
+
360
+ if b + c == 0:
361
+ # No discordant pairs
362
+ return StatisticalTestResult(
363
+ test_name="McNemar's test",
364
+ statistic=0.0,
365
+ p_value=1.0,
366
+ significant=False,
367
+ confidence_level=1 - alpha,
368
+ )
369
+
370
+ # McNemar's statistic with continuity correction
371
+ chi_square = ((abs(b - c) - 1) ** 2) / (b + c)
372
+
373
+ # Approximate p-value (chi-square with 1 df)
374
+ # For chi-square > 3.84, p < 0.05
375
+ # For chi-square > 6.63, p < 0.01
376
+ if chi_square > 10.83:
377
+ p_value = 0.001
378
+ elif chi_square > 6.63:
379
+ p_value = 0.01
380
+ elif chi_square > 3.84:
381
+ p_value = 0.05
382
+ else:
383
+ # Rough linear approximation
384
+ p_value = 1.0 - (chi_square / 3.84) * 0.95
385
+
386
+ return StatisticalTestResult(
387
+ test_name="McNemar's test",
388
+ statistic=chi_square,
389
+ p_value=p_value,
390
+ significant=p_value < alpha,
391
+ confidence_level=1 - alpha,
392
+ )
393
+
394
+
395
+ __all__ = [
396
+ "StatisticalTest",
397
+ "StatisticalTestResult",
398
+ "t_test",
399
+ "bootstrap_confidence_interval",
400
+ "permutation_test",
401
+ "mcnemar_test",
402
+ ]
@@ -0,0 +1,19 @@
1
+ """Hydra-backed configuration helpers for assembling experiments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .loader import load_experiment_config
6
+ from .runtime import (
7
+ load_dataset_from_config,
8
+ run_experiment_from_config,
9
+ summarize_report_for_config,
10
+ )
11
+ from .schema import ExperimentConfig
12
+
13
+ __all__ = [
14
+ "ExperimentConfig",
15
+ "load_dataset_from_config",
16
+ "load_experiment_config",
17
+ "run_experiment_from_config",
18
+ "summarize_report_for_config",
19
+ ]
@@ -0,0 +1,27 @@
1
+ """Utilities for loading experiment configs via Hydra/OmegaConf."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Iterable
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+ from . import schema
11
+
12
+
13
+ def load_experiment_config(
14
+ config_path: Path,
15
+ overrides: Iterable[str] | None = None,
16
+ ) -> schema.ExperimentConfig:
17
+ """Load and validate an experiment config file with optional overrides."""
18
+
19
+ base = OmegaConf.structured(schema.ExperimentConfig)
20
+ file_conf = OmegaConf.load(config_path)
21
+ merged = OmegaConf.merge(base, file_conf)
22
+
23
+ if overrides:
24
+ override_conf = OmegaConf.from_dotlist(list(overrides))
25
+ merged = OmegaConf.merge(merged, override_conf)
26
+
27
+ return OmegaConf.to_object(merged)
@@ -0,0 +1,34 @@
1
+ """Registry for experiment builders."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Callable
6
+
7
+ from themis.config import schema
8
+ from themis.experiment import orchestrator
9
+
10
+ ExperimentBuilder = Callable[
11
+ [schema.ExperimentConfig], orchestrator.ExperimentOrchestrator
12
+ ]
13
+
14
+ _EXPERIMENT_BUILDERS: dict[str, ExperimentBuilder] = {}
15
+
16
+
17
+ def register_experiment_builder(task: str) -> Callable[[ExperimentBuilder], ExperimentBuilder]:
18
+ """Decorator to register an experiment builder for a specific task."""
19
+
20
+ def decorator(builder: ExperimentBuilder) -> ExperimentBuilder:
21
+ _EXPERIMENT_BUILDERS[task] = builder
22
+ return builder
23
+
24
+ return decorator
25
+
26
+
27
+ def get_experiment_builder(task: str) -> ExperimentBuilder:
28
+ """Get the experiment builder for a specific task."""
29
+ if task not in _EXPERIMENT_BUILDERS:
30
+ raise ValueError(
31
+ f"No experiment builder registered for task '{task}'. "
32
+ f"Available tasks: {', '.join(sorted(_EXPERIMENT_BUILDERS.keys()))}"
33
+ )
34
+ return _EXPERIMENT_BUILDERS[task]
@@ -0,0 +1,214 @@
1
+ """Runtime helpers for executing experiments from Hydra configs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import asdict
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ from themis.core import entities as core_entities
10
+ from themis.datasets import create_dataset
11
+ from themis.experiment import math as math_experiment
12
+ from themis.experiment import mcq as mcq_experiment
13
+ from themis.experiment import orchestrator as experiment_orchestrator
14
+ from themis.experiment import storage as experiment_storage
15
+ from themis.providers import registry as provider_registry
16
+
17
+ from . import registry, schema
18
+
19
+
20
+
21
+
22
+ def run_experiment_from_config(
23
+ config: schema.ExperimentConfig,
24
+ *,
25
+ dataset: list[dict[str, object]] | None = None,
26
+ on_result=None,
27
+ ) -> experiment_orchestrator.ExperimentReport:
28
+ dataset_to_use = (
29
+ dataset
30
+ if dataset is not None
31
+ else _load_dataset(config.dataset, experiment_name=config.name)
32
+ )
33
+ experiment = _build_experiment(config)
34
+ return experiment.run(
35
+ dataset_to_use,
36
+ max_samples=config.max_samples,
37
+ run_id=config.run_id,
38
+ resume=config.resume,
39
+ on_result=on_result,
40
+ )
41
+
42
+
43
+ def summarize_report_for_config(
44
+ config: schema.ExperimentConfig,
45
+ report: experiment_orchestrator.ExperimentReport,
46
+ ) -> str:
47
+ if config.task in {
48
+ "math500",
49
+ "aime24",
50
+ "aime25",
51
+ "amc23",
52
+ "olympiadbench",
53
+ "beyondaime",
54
+ }:
55
+ return math_experiment.summarize_report(report)
56
+ if config.task in {"supergpqa", "mmlu_pro"}:
57
+ return mcq_experiment.summarize_report(report)
58
+ raise ValueError(f"Unsupported task '{config.task}' for summarization.")
59
+
60
+
61
+ def load_dataset_from_config(
62
+ config: schema.ExperimentConfig,
63
+ ) -> list[dict[str, object]]:
64
+ return _load_dataset(config.dataset, experiment_name=config.name)
65
+
66
+
67
+ def _build_experiment(
68
+ config: schema.ExperimentConfig,
69
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
70
+ if config.task:
71
+ builder = registry.get_experiment_builder(config.task)
72
+ return builder(config)
73
+
74
+ raise ValueError(
75
+ "Experiment configuration must specify a 'task'. "
76
+ f"Available tasks: {', '.join(sorted(registry._EXPERIMENT_BUILDERS.keys()))}"
77
+ )
78
+
79
+
80
+ @registry.register_experiment_builder("math500")
81
+ @registry.register_experiment_builder("aime24")
82
+ @registry.register_experiment_builder("aime25")
83
+ @registry.register_experiment_builder("amc23")
84
+ @registry.register_experiment_builder("olympiadbench")
85
+ @registry.register_experiment_builder("beyondaime")
86
+ def _build_math_experiment(
87
+ config: schema.ExperimentConfig,
88
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
89
+ # Use the specific path if provided, otherwise use the default path
90
+ storage_path = config.storage.path or config.storage.default_path
91
+ storage = (
92
+ experiment_storage.ExperimentStorage(Path(storage_path))
93
+ if storage_path
94
+ else None
95
+ )
96
+ sampling_cfg = core_entities.SamplingConfig(
97
+ temperature=config.generation.sampling.temperature,
98
+ top_p=config.generation.sampling.top_p,
99
+ max_tokens=config.generation.sampling.max_tokens,
100
+ )
101
+ provider = provider_registry.create_provider(
102
+ config.generation.provider.name, **config.generation.provider.options
103
+ )
104
+ runner_options = asdict(config.generation.runner)
105
+
106
+ # Use the task name from config as the default task name
107
+ task_name = config.task or "math500"
108
+ # Override task name if provided in task_options
109
+ if config.task_options and "task_name" in config.task_options:
110
+ task_name = config.task_options["task_name"]
111
+
112
+ return math_experiment.build_math500_zero_shot_experiment(
113
+ model_client=provider,
114
+ model_name=config.generation.model_identifier,
115
+ storage=storage,
116
+ sampling=sampling_cfg,
117
+ provider_name=config.generation.provider.name,
118
+ runner_options=runner_options,
119
+ task_name=task_name,
120
+ )
121
+
122
+
123
+ @registry.register_experiment_builder("supergpqa")
124
+ def _build_supergpqa_experiment(
125
+ config: schema.ExperimentConfig,
126
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
127
+ return _build_mcq_experiment(config, "supergpqa", "supergpqa")
128
+
129
+
130
+ @registry.register_experiment_builder("mmlu_pro")
131
+ def _build_mmlu_pro_experiment(
132
+ config: schema.ExperimentConfig,
133
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
134
+ return _build_mcq_experiment(config, "mmlu-pro", "mmlu_pro")
135
+
136
+
137
+ def _build_mcq_experiment(
138
+ config: schema.ExperimentConfig, dataset_name: str, task_id: str
139
+ ) -> experiment_orchestrator.ExperimentOrchestrator:
140
+ # Use the specific path if provided, otherwise use the default path
141
+ storage_path = config.storage.path or config.storage.default_path
142
+ storage = (
143
+ experiment_storage.ExperimentStorage(Path(storage_path))
144
+ if storage_path
145
+ else None
146
+ )
147
+ sampling_cfg = core_entities.SamplingConfig(
148
+ temperature=config.generation.sampling.temperature,
149
+ top_p=config.generation.sampling.top_p,
150
+ max_tokens=config.generation.sampling.max_tokens,
151
+ )
152
+ provider = provider_registry.create_provider(
153
+ config.generation.provider.name, **config.generation.provider.options
154
+ )
155
+ runner_options = asdict(config.generation.runner)
156
+
157
+ return mcq_experiment.build_multiple_choice_json_experiment(
158
+ dataset_name=dataset_name,
159
+ task_id=task_id,
160
+ model_client=provider,
161
+ model_name=config.generation.model_identifier,
162
+ storage=storage,
163
+ sampling=sampling_cfg,
164
+ provider_name=config.generation.provider.name,
165
+ runner_options=runner_options,
166
+ )
167
+
168
+
169
+ def _load_dataset(
170
+ config: schema.DatasetConfig, *, experiment_name: str
171
+ ) -> List[dict[str, object]]:
172
+ """Load dataset samples using the dataset registry.
173
+
174
+ Args:
175
+ config: Dataset configuration
176
+ experiment_name: Name of the experiment (used to map to dataset)
177
+
178
+ Returns:
179
+ List of sample dictionaries ready for generation
180
+ """
181
+ # Handle inline datasets (not in registry)
182
+ if config.source == "inline":
183
+ if not config.inline_samples:
184
+ raise ValueError(
185
+ "dataset.inline_samples must contain at least one row when"
186
+ " dataset.source='inline'."
187
+ )
188
+ return list(config.inline_samples)
189
+
190
+ # Use explicit dataset_id if provided
191
+ dataset_name = config.dataset_id
192
+ if not dataset_name:
193
+ # Fallback to task name if dataset_id is not provided
194
+ # This allows simple configs where task name matches dataset name
195
+ # But we should probably enforce dataset_id for clarity in the future
196
+ # For now, let's try to infer from task if available in config object passed to this function?
197
+ # Wait, _load_dataset only gets DatasetConfig and experiment_name.
198
+ # We should probably pass the full config or at least the task.
199
+ # But for now, let's rely on dataset_id being present or raise error.
200
+ raise ValueError(
201
+ "dataset.dataset_id must be provided when source is not 'inline'."
202
+ )
203
+
204
+ # Prepare options for dataset factory
205
+ options = {
206
+ "source": config.source,
207
+ "data_dir": config.data_dir,
208
+ "split": config.split,
209
+ "limit": config.limit,
210
+ "subjects": list(config.subjects) if config.subjects else None,
211
+ }
212
+
213
+ # Load samples via registry
214
+ return create_dataset(dataset_name, **options)