themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
1
+ """Multi-experiment comparison tools for analyzing multiple runs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import warnings
7
+ from dataclasses import dataclass, field
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ from themis.core.entities import ExperimentReport
13
+
14
+
15
+ @dataclass
16
+ class ComparisonRow:
17
+ """Single experiment in a multi-experiment comparison."""
18
+
19
+ run_id: str
20
+ metric_values: dict[str, float]
21
+ metadata: dict[str, Any] = field(default_factory=dict)
22
+ timestamp: str | None = None
23
+ sample_count: int = 0
24
+ failure_count: int = 0
25
+
26
+ def get_metric(self, metric_name: str) -> float | None:
27
+ """Get metric value by name.
28
+
29
+ Special metric names:
30
+ - 'cost' or 'total_cost': Checks metadata first, then metric_values
31
+ - Any other name: Returns from metric_values dict
32
+ """
33
+ # Handle special cost metrics - check metadata first
34
+ if metric_name in ("cost", "total_cost"):
35
+ cost_data = self.metadata.get("cost")
36
+ if cost_data and "total_cost" in cost_data:
37
+ return cost_data["total_cost"]
38
+ # Fall back to metric_values if not in metadata
39
+ # (for backward compatibility and tests)
40
+ if metric_name in self.metric_values:
41
+ return self.metric_values[metric_name]
42
+ return None
43
+
44
+ return self.metric_values.get(metric_name)
45
+
46
+ def get_cost(self) -> float | None:
47
+ """Get total cost if available.
48
+
49
+ Returns:
50
+ Total cost in USD, or None if not tracked
51
+ """
52
+ return self.get_metric("cost")
53
+
54
+
55
+ @dataclass
56
+ class ConfigDiff:
57
+ """Differences between two experiment configurations."""
58
+
59
+ run_id_a: str
60
+ run_id_b: str
61
+ changed_fields: dict[str, tuple[Any, Any]] = field(default_factory=dict)
62
+ added_fields: dict[str, Any] = field(default_factory=dict)
63
+ removed_fields: dict[str, Any] = field(default_factory=dict)
64
+
65
+ def has_differences(self) -> bool:
66
+ """Check if there are any differences."""
67
+ return bool(self.changed_fields or self.added_fields or self.removed_fields)
68
+
69
+
70
+ @dataclass
71
+ class MultiExperimentComparison:
72
+ """Comparison across multiple experiments."""
73
+
74
+ experiments: list[ComparisonRow]
75
+ metrics: list[str]
76
+
77
+ def __post_init__(self):
78
+ """Validate comparison data."""
79
+ if not self.experiments:
80
+ raise ValueError("Must have at least one experiment to compare")
81
+ if not self.metrics:
82
+ # Infer metrics from first experiment
83
+ if self.experiments:
84
+ self.metrics = list(self.experiments[0].metric_values.keys())
85
+
86
+ def rank_by_metric(
87
+ self, metric: str, ascending: bool = False
88
+ ) -> list[ComparisonRow]:
89
+ """Rank experiments by metric value.
90
+
91
+ Args:
92
+ metric: Metric name to rank by (can be 'cost' or 'total_cost'
93
+ for cost ranking)
94
+ ascending: If True, rank from lowest to highest (default: False)
95
+
96
+ Returns:
97
+ List of experiments sorted by metric value
98
+ """
99
+ # Special handling for cost metrics
100
+ if metric not in self.metrics and metric not in ("cost", "total_cost"):
101
+ raise ValueError(f"Metric '{metric}' not found. Available: {self.metrics}")
102
+
103
+ # Sort experiments, handling None values
104
+ def key_func(row: ComparisonRow) -> tuple[bool, float]:
105
+ value = row.get_metric(metric)
106
+ # Put None values at the end
107
+ if value is None:
108
+ return (True, float("inf"))
109
+ return (False, value)
110
+
111
+ return sorted(self.experiments, key=key_func, reverse=not ascending)
112
+
113
+ def highlight_best(
114
+ self, metric: str, higher_is_better: bool = True
115
+ ) -> ComparisonRow | None:
116
+ """Find experiment with best value for metric.
117
+
118
+ Args:
119
+ metric: Metric name
120
+ higher_is_better: If True, higher values are better (default: True)
121
+
122
+ Returns:
123
+ Experiment with best metric value, or None if no valid values
124
+ """
125
+ ranked = self.rank_by_metric(metric, ascending=not higher_is_better)
126
+ # Return first experiment with valid metric value
127
+ for exp in ranked:
128
+ if exp.get_metric(metric) is not None:
129
+ return exp
130
+ return None
131
+
132
+ def pareto_frontier(
133
+ self, objectives: list[str], maximize: list[bool] | None = None
134
+ ) -> list[str]:
135
+ """Find Pareto-optimal experiments.
136
+
137
+ Args:
138
+ objectives: List of metric names to optimize
139
+ maximize: For each objective, whether to maximize (True) or
140
+ minimize (False). Default: maximize all objectives.
141
+
142
+ Returns:
143
+ List of run_ids on the Pareto frontier
144
+ """
145
+ if not objectives:
146
+ raise ValueError("Must specify at least one objective")
147
+
148
+ if maximize is None:
149
+ maximize = [True] * len(objectives)
150
+
151
+ if len(maximize) != len(objectives):
152
+ raise ValueError(
153
+ f"maximize list length ({len(maximize)}) must match "
154
+ f"objectives length ({len(objectives)})"
155
+ )
156
+
157
+ # Filter out experiments with missing values
158
+ valid_experiments = [
159
+ exp
160
+ for exp in self.experiments
161
+ if all(exp.get_metric(obj) is not None for obj in objectives)
162
+ ]
163
+
164
+ if not valid_experiments:
165
+ return []
166
+
167
+ pareto_optimal: list[ComparisonRow] = []
168
+
169
+ for candidate in valid_experiments:
170
+ is_dominated = False
171
+
172
+ # Check if candidate is dominated by any other experiment
173
+ for other in valid_experiments:
174
+ if candidate.run_id == other.run_id:
175
+ continue
176
+
177
+ # Check if 'other' dominates 'candidate'
178
+ dominates = True
179
+ strictly_better_in_one = False
180
+
181
+ for obj, should_maximize in zip(objectives, maximize, strict=True):
182
+ candidate_val = candidate.get_metric(obj)
183
+ other_val = other.get_metric(obj)
184
+
185
+ # Should never be None due to filtering, but handle defensively
186
+ if candidate_val is None or other_val is None:
187
+ dominates = False
188
+ break
189
+
190
+ if should_maximize:
191
+ if other_val < candidate_val:
192
+ dominates = False
193
+ break
194
+ if other_val > candidate_val:
195
+ strictly_better_in_one = True
196
+ else:
197
+ if other_val > candidate_val:
198
+ dominates = False
199
+ break
200
+ if other_val < candidate_val:
201
+ strictly_better_in_one = True
202
+
203
+ if dominates and strictly_better_in_one:
204
+ is_dominated = True
205
+ break
206
+
207
+ if not is_dominated:
208
+ pareto_optimal.append(candidate)
209
+
210
+ return [exp.run_id for exp in pareto_optimal]
211
+
212
+ def to_dict(self) -> dict[str, Any]:
213
+ """Export as dictionary."""
214
+ return {
215
+ "experiments": [
216
+ {
217
+ "run_id": exp.run_id,
218
+ "metric_values": exp.metric_values,
219
+ "metadata": exp.metadata,
220
+ "timestamp": exp.timestamp,
221
+ "sample_count": exp.sample_count,
222
+ "failure_count": exp.failure_count,
223
+ }
224
+ for exp in self.experiments
225
+ ],
226
+ "metrics": self.metrics,
227
+ }
228
+
229
+ def to_csv(self, output_path: Path | str, include_metadata: bool = True) -> None:
230
+ """Export comparison to CSV.
231
+
232
+ Args:
233
+ output_path: Where to save CSV file
234
+ include_metadata: Whether to include metadata columns
235
+ """
236
+ import csv
237
+
238
+ output_path = Path(output_path)
239
+
240
+ with output_path.open("w", newline="", encoding="utf-8") as f:
241
+ # Build column names
242
+ columns = ["run_id"] + self.metrics
243
+
244
+ if include_metadata:
245
+ # Collect all metadata keys
246
+ all_metadata_keys: set[str] = set()
247
+ for exp in self.experiments:
248
+ all_metadata_keys.update(exp.metadata.keys())
249
+ metadata_columns = sorted(all_metadata_keys)
250
+ columns.extend(metadata_columns)
251
+ columns.extend(["timestamp", "sample_count", "failure_count"])
252
+
253
+ writer = csv.DictWriter(f, fieldnames=columns)
254
+ writer.writeheader()
255
+
256
+ for exp in self.experiments:
257
+ row: dict[str, Any] = {"run_id": exp.run_id}
258
+ row.update(exp.metric_values)
259
+
260
+ if include_metadata:
261
+ for key in metadata_columns:
262
+ row[key] = exp.metadata.get(key, "")
263
+ row["timestamp"] = exp.timestamp or ""
264
+ row["sample_count"] = exp.sample_count
265
+ row["failure_count"] = exp.failure_count
266
+
267
+ writer.writerow(row)
268
+
269
+ def to_markdown(self, output_path: Path | str | None = None) -> str:
270
+ """Export comparison as markdown table.
271
+
272
+ Args:
273
+ output_path: Optional path to save markdown file
274
+
275
+ Returns:
276
+ Markdown table string
277
+ """
278
+ lines = ["# Experiment Comparison\n"]
279
+
280
+ # Check if any experiment has cost data
281
+ has_cost = any(
282
+ exp.metadata.get("cost") and exp.metadata["cost"].get("total_cost")
283
+ for exp in self.experiments
284
+ )
285
+
286
+ # Build table header
287
+ headers = ["Run ID"] + self.metrics + ["Samples", "Failures"]
288
+ if has_cost:
289
+ headers.append("Cost ($)")
290
+ lines.append("| " + " | ".join(headers) + " |")
291
+ lines.append("| " + " | ".join(["---"] * len(headers)) + " |")
292
+
293
+ # Build table rows
294
+ for exp in self.experiments:
295
+ values = [exp.run_id]
296
+ for metric in self.metrics:
297
+ val = exp.get_metric(metric)
298
+ values.append(f"{val:.4f}" if val is not None else "N/A")
299
+ values.append(str(exp.sample_count))
300
+ values.append(str(exp.failure_count))
301
+
302
+ # Add cost if available
303
+ if has_cost:
304
+ cost = exp.metadata.get("cost", {}).get("total_cost")
305
+ if cost is not None:
306
+ values.append(f"{cost:.4f}")
307
+ else:
308
+ values.append("N/A")
309
+
310
+ lines.append("| " + " | ".join(values) + " |")
311
+
312
+ markdown = "\n".join(lines)
313
+
314
+ if output_path:
315
+ Path(output_path).write_text(markdown, encoding="utf-8")
316
+
317
+ return markdown
318
+
319
+ def to_latex(
320
+ self,
321
+ output_path: Path | str | None = None,
322
+ style: str = "booktabs",
323
+ caption: str | None = None,
324
+ label: str | None = None,
325
+ ) -> str:
326
+ """Export comparison as LaTeX table.
327
+
328
+ Args:
329
+ output_path: Optional path to save LaTeX file
330
+ style: Table style - "booktabs" or "basic"
331
+ caption: Table caption
332
+ label: LaTeX label for referencing
333
+
334
+ Returns:
335
+ LaTeX table string
336
+
337
+ Example:
338
+ >>> latex = comparison.to_latex(
339
+ ... caption="Experiment comparison results",
340
+ ... label="tab:results"
341
+ ... )
342
+ """
343
+ lines = []
344
+
345
+ # Check if any experiment has cost data
346
+ has_cost = any(
347
+ exp.metadata.get("cost") and exp.metadata["cost"].get("total_cost")
348
+ for exp in self.experiments
349
+ )
350
+
351
+ # Determine number of columns
352
+ n_metrics = len(self.metrics)
353
+ n_cols = 1 + n_metrics + 2 # run_id + metrics + samples + failures
354
+ if has_cost:
355
+ n_cols += 1
356
+
357
+ # Table preamble
358
+ if style == "booktabs":
359
+ lines.append("\\begin{table}[htbp]")
360
+ lines.append("\\centering")
361
+ if caption:
362
+ lines.append(f"\\caption{{{caption}}}")
363
+ if label:
364
+ lines.append(f"\\label{{{label}}}")
365
+
366
+ # Column specification
367
+ col_spec = "l" + "r" * (n_cols - 1) # Left for run_id, right for numbers
368
+ lines.append(f"\\begin{{tabular}}{{{col_spec}}}")
369
+ lines.append("\\toprule")
370
+
371
+ # Header
372
+ headers = ["Run ID"] + self.metrics + ["Samples", "Failures"]
373
+ if has_cost:
374
+ headers.append("Cost (\\$)")
375
+ lines.append(" & ".join(headers) + " \\\\")
376
+ lines.append("\\midrule")
377
+
378
+ # Data rows
379
+ for exp in self.experiments:
380
+ values = [exp.run_id.replace("_", "\\_")] # Escape underscores
381
+ for metric in self.metrics:
382
+ val = exp.get_metric(metric)
383
+ values.append(f"{val:.4f}" if val is not None else "---")
384
+ values.append(str(exp.sample_count))
385
+ values.append(str(exp.failure_count))
386
+
387
+ # Add cost if available
388
+ if has_cost:
389
+ cost = exp.metadata.get("cost", {}).get("total_cost")
390
+ if cost is not None:
391
+ values.append(f"{cost:.4f}")
392
+ else:
393
+ values.append("---")
394
+
395
+ lines.append(" & ".join(values) + " \\\\")
396
+
397
+ lines.append("\\bottomrule")
398
+ lines.append("\\end{tabular}")
399
+ lines.append("\\end{table}")
400
+
401
+ else: # basic style
402
+ lines.append("\\begin{table}[htbp]")
403
+ lines.append("\\centering")
404
+ if caption:
405
+ lines.append(f"\\caption{{{caption}}}")
406
+ if label:
407
+ lines.append(f"\\label{{{label}}}")
408
+
409
+ col_spec = "|l|" + "r|" * (n_cols - 1)
410
+ lines.append(f"\\begin{{tabular}}{{{col_spec}}}")
411
+ lines.append("\\hline")
412
+
413
+ # Header
414
+ headers = ["Run ID"] + self.metrics + ["Samples", "Failures"]
415
+ if has_cost:
416
+ headers.append("Cost (\\$)")
417
+ lines.append(" & ".join(headers) + " \\\\")
418
+ lines.append("\\hline")
419
+
420
+ # Data rows
421
+ for exp in self.experiments:
422
+ values = [exp.run_id.replace("_", "\\_")]
423
+ for metric in self.metrics:
424
+ val = exp.get_metric(metric)
425
+ values.append(f"{val:.4f}" if val is not None else "---")
426
+ values.append(str(exp.sample_count))
427
+ values.append(str(exp.failure_count))
428
+
429
+ if has_cost:
430
+ cost = exp.metadata.get("cost", {}).get("total_cost")
431
+ if cost is not None:
432
+ values.append(f"{cost:.4f}")
433
+ else:
434
+ values.append("---")
435
+
436
+ lines.append(" & ".join(values) + " \\\\")
437
+ lines.append("\\hline")
438
+
439
+ lines.append("\\end{tabular}")
440
+ lines.append("\\end{table}")
441
+
442
+ latex = "\n".join(lines)
443
+
444
+ if output_path:
445
+ output_path = Path(output_path)
446
+ output_path.write_text(latex, encoding="utf-8")
447
+
448
+ return latex
449
+
450
+
451
+ def load_experiment_report(storage_dir: Path, run_id: str) -> ExperimentReport | None:
452
+ """Load experiment report from storage.
453
+
454
+ Args:
455
+ storage_dir: Storage directory
456
+ run_id: Run identifier
457
+
458
+ Returns:
459
+ ExperimentReport if found, None otherwise
460
+ """
461
+ report_path = storage_dir / run_id / "report.json"
462
+
463
+ if not report_path.exists():
464
+ return None
465
+
466
+ with report_path.open("r", encoding="utf-8") as f:
467
+ data = json.load(f)
468
+
469
+ # Reconstruct ExperimentReport from JSON
470
+ # Note: This is a simplified loader. For production,
471
+ # you'd want proper deserialization
472
+ return data
473
+
474
+
475
+ def compare_experiments(
476
+ run_ids: list[str],
477
+ storage_dir: Path | str,
478
+ metrics: list[str] | None = None,
479
+ include_metadata: bool = True,
480
+ ) -> MultiExperimentComparison:
481
+ """Compare multiple experiments.
482
+
483
+ Args:
484
+ run_ids: List of experiment run IDs to compare
485
+ storage_dir: Directory containing experiment results
486
+ metrics: Metrics to compare (None = all available)
487
+ include_metadata: Include config metadata in comparison
488
+
489
+ Returns:
490
+ Comparison object with all experiment data
491
+
492
+ Raises:
493
+ FileNotFoundError: If experiment data not found
494
+ ValueError: If no valid experiments found
495
+ """
496
+ storage_dir = Path(storage_dir)
497
+
498
+ comparison_rows: list[ComparisonRow] = []
499
+ all_metrics: set[str] = set()
500
+
501
+ for run_id in run_ids:
502
+ # Load evaluation records
503
+ try:
504
+ # Try loading from a report.json if available
505
+ report_path = storage_dir / run_id / "report.json"
506
+ if report_path.exists():
507
+ with report_path.open("r", encoding="utf-8") as f:
508
+ report_data = json.load(f)
509
+
510
+ metric_values: dict[str, float] = {}
511
+ # The JSON structure has a "metrics" array with {name, count, mean}
512
+ if "metrics" in report_data:
513
+ for metric_data in report_data["metrics"]:
514
+ if isinstance(metric_data, dict):
515
+ metric_name = metric_data.get("name")
516
+ metric_mean = metric_data.get("mean")
517
+ if metric_name and metric_mean is not None:
518
+ metric_values[metric_name] = metric_mean
519
+ all_metrics.add(metric_name)
520
+
521
+ metadata_dict: dict[str, Any] = {}
522
+ if include_metadata and "summary" in report_data:
523
+ # The summary section contains metadata
524
+ metadata_dict = report_data.get("summary", {})
525
+
526
+ # Count samples and failures
527
+ sample_count = report_data.get("total_samples", 0)
528
+ failure_count = report_data.get("summary", {}).get(
529
+ "run_failures", 0
530
+ ) + report_data.get("summary", {}).get("evaluation_failures", 0)
531
+
532
+ # Get timestamp from metadata or file modification time
533
+ timestamp = metadata_dict.get("timestamp")
534
+ if not timestamp and report_path.exists():
535
+ timestamp = datetime.fromtimestamp(
536
+ report_path.stat().st_mtime
537
+ ).isoformat()
538
+
539
+ row = ComparisonRow(
540
+ run_id=run_id,
541
+ metric_values=metric_values,
542
+ metadata=metadata_dict,
543
+ timestamp=timestamp,
544
+ sample_count=sample_count,
545
+ failure_count=failure_count,
546
+ )
547
+ comparison_rows.append(row)
548
+ else:
549
+ warnings.warn(
550
+ f"No report.json found for run '{run_id}', skipping",
551
+ stacklevel=2,
552
+ )
553
+
554
+ except Exception as e:
555
+ warnings.warn(f"Failed to load run '{run_id}': {e}", stacklevel=2)
556
+ continue
557
+
558
+ if not comparison_rows:
559
+ raise ValueError(
560
+ f"No valid experiments found for run_ids: {run_ids}. "
561
+ "Make sure experiments have been run and saved with report.json files."
562
+ )
563
+
564
+ # Filter metrics if specified
565
+ if metrics:
566
+ all_metrics = set(metrics)
567
+
568
+ return MultiExperimentComparison(
569
+ experiments=comparison_rows, metrics=sorted(all_metrics)
570
+ )
571
+
572
+
573
+ def diff_configs(run_id_a: str, run_id_b: str, storage_dir: Path | str) -> ConfigDiff:
574
+ """Show configuration differences between two experiments.
575
+
576
+ Args:
577
+ run_id_a: First run ID
578
+ run_id_b: Second run ID
579
+ storage_dir: Storage directory
580
+
581
+ Returns:
582
+ ConfigDiff object with differences
583
+ """
584
+ storage_dir = Path(storage_dir)
585
+
586
+ # Load config files
587
+ config_a_path = storage_dir / run_id_a / "config.json"
588
+ config_b_path = storage_dir / run_id_b / "config.json"
589
+
590
+ if not config_a_path.exists():
591
+ raise FileNotFoundError(f"Config not found for run '{run_id_a}'")
592
+ if not config_b_path.exists():
593
+ raise FileNotFoundError(f"Config not found for run '{run_id_b}'")
594
+
595
+ with config_a_path.open("r", encoding="utf-8") as f:
596
+ config_a = json.load(f)
597
+ with config_b_path.open("r", encoding="utf-8") as f:
598
+ config_b = json.load(f)
599
+
600
+ # Compute differences
601
+ changed: dict[str, tuple[Any, Any]] = {}
602
+ added: dict[str, Any] = {}
603
+ removed: dict[str, Any] = {}
604
+
605
+ all_keys = set(config_a.keys()) | set(config_b.keys())
606
+
607
+ for key in all_keys:
608
+ if key in config_a and key in config_b:
609
+ if config_a[key] != config_b[key]:
610
+ changed[key] = (config_a[key], config_b[key])
611
+ elif key in config_a:
612
+ removed[key] = config_a[key]
613
+ else:
614
+ added[key] = config_b[key]
615
+
616
+ return ConfigDiff(
617
+ run_id_a=run_id_a,
618
+ run_id_b=run_id_b,
619
+ changed_fields=changed,
620
+ added_fields=added,
621
+ removed_fields=removed,
622
+ )
623
+
624
+
625
+ __all__ = [
626
+ "ComparisonRow",
627
+ "ConfigDiff",
628
+ "MultiExperimentComparison",
629
+ "compare_experiments",
630
+ "diff_configs",
631
+ ]