themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,159 @@
1
+ """CSV export functionality for experiment reports."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ from pathlib import Path
7
+ from typing import MutableMapping, Sequence
8
+
9
+ from themis.core import entities as core_entities
10
+ from themis.experiment import orchestrator
11
+
12
+
13
+ def export_report_csv(
14
+ report: orchestrator.ExperimentReport,
15
+ path: str | Path,
16
+ *,
17
+ include_failures: bool = True,
18
+ ) -> Path:
19
+ """Write per-sample metrics to a CSV file for offline analysis.
20
+
21
+ Args:
22
+ report: Experiment report to export
23
+ path: Output path for CSV file
24
+ include_failures: Whether to include failures column
25
+
26
+ Returns:
27
+ Path to created CSV file
28
+ """
29
+ path = Path(path)
30
+ path.parent.mkdir(parents=True, exist_ok=True)
31
+ metadata_by_condition, metadata_fields = _collect_sample_metadata(
32
+ report.generation_results
33
+ )
34
+
35
+ # Create a proper index mapping generation records to their metadata
36
+ gen_record_index = {}
37
+ for gen_record in report.generation_results:
38
+ sample_id = gen_record.task.metadata.get(
39
+ "dataset_id"
40
+ ) or gen_record.task.metadata.get("sample_id")
41
+ prompt_template = gen_record.task.prompt.spec.name
42
+ model_identifier = gen_record.task.model.identifier
43
+ sampling_temp = gen_record.task.sampling.temperature
44
+ sampling_max_tokens = gen_record.task.sampling.max_tokens
45
+ condition_id = f"{sample_id}_{prompt_template}_{model_identifier}_{sampling_temp}_{sampling_max_tokens}"
46
+ gen_record_index[condition_id] = gen_record
47
+
48
+ metric_names = sorted(report.evaluation_report.metrics.keys())
49
+ fieldnames = (
50
+ ["sample_id"] + metadata_fields + [f"metric:{name}" for name in metric_names]
51
+ )
52
+ if include_failures:
53
+ fieldnames.append("failures")
54
+
55
+ with path.open("w", encoding="utf-8", newline="") as handle:
56
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
57
+ writer.writeheader()
58
+
59
+ # Process evaluation records in the same order as generation records
60
+ for i, eval_record in enumerate(report.evaluation_report.records):
61
+ # Find the corresponding generation record by index
62
+ if i < len(report.generation_results):
63
+ gen_record = report.generation_results[i]
64
+ sample_id = gen_record.task.metadata.get(
65
+ "dataset_id"
66
+ ) or gen_record.task.metadata.get("sample_id")
67
+ prompt_template = gen_record.task.prompt.spec.name
68
+ model_identifier = gen_record.task.model.identifier
69
+ sampling_temp = gen_record.task.sampling.temperature
70
+ sampling_max_tokens = gen_record.task.sampling.max_tokens
71
+ condition_id = f"{sample_id}_{prompt_template}_{model_identifier}_{sampling_temp}_{sampling_max_tokens}"
72
+ metadata = metadata_by_condition.get(condition_id, {})
73
+ else:
74
+ # Fallback for extra evaluation records
75
+ sample_id = eval_record.sample_id or ""
76
+ metadata = {}
77
+
78
+ row: dict[str, object] = {"sample_id": sample_id}
79
+ for field in metadata_fields:
80
+ row[field] = metadata.get(field, "")
81
+ score_by_name = {
82
+ score.metric_name: score.value for score in eval_record.scores
83
+ }
84
+ for name in metric_names:
85
+ row[f"metric:{name}"] = score_by_name.get(name, "")
86
+ if include_failures:
87
+ row["failures"] = "; ".join(eval_record.failures)
88
+ writer.writerow(row)
89
+ return path
90
+
91
+
92
+ def _collect_sample_metadata(
93
+ records: Sequence[core_entities.GenerationRecord],
94
+ ) -> tuple[dict[str, MutableMapping[str, object]], list[str]]:
95
+ """Collect metadata from generation records.
96
+
97
+ Args:
98
+ records: Generation records
99
+
100
+ Returns:
101
+ Tuple of (metadata by condition ID, list of metadata fields)
102
+ """
103
+ metadata: dict[str, MutableMapping[str, object]] = {}
104
+ for index, record in enumerate(records):
105
+ sample_id = _extract_sample_id(record.task.metadata)
106
+ if sample_id is None:
107
+ sample_id = f"sample-{index}"
108
+
109
+ # Create unique identifier for each experimental condition
110
+ prompt_template = record.task.prompt.spec.name
111
+ model_identifier = record.task.model.identifier
112
+ sampling_temp = record.task.sampling.temperature
113
+ sampling_max_tokens = record.task.sampling.max_tokens
114
+
115
+ # Create unique condition key
116
+ condition_id = f"{sample_id}_{prompt_template}_{model_identifier}_{sampling_temp}_{sampling_max_tokens}"
117
+
118
+ # Store metadata with unique condition ID
119
+ condition_metadata = _metadata_from_task(record)
120
+ metadata[condition_id] = condition_metadata
121
+
122
+ # Collect all field names from all conditions
123
+ fields = sorted({field for meta in metadata.values() for field in meta.keys()})
124
+
125
+ return metadata, fields
126
+
127
+
128
+ def _extract_sample_id(metadata: dict[str, object]) -> str | None:
129
+ """Extract sample ID from metadata.
130
+
131
+ Args:
132
+ metadata: Task metadata
133
+
134
+ Returns:
135
+ Sample ID or None
136
+ """
137
+ value = metadata.get("dataset_id") or metadata.get("sample_id")
138
+ if value is None:
139
+ return None
140
+ return str(value)
141
+
142
+
143
+ def _metadata_from_task(record: core_entities.GenerationRecord) -> dict[str, object]:
144
+ """Build metadata dict from generation record.
145
+
146
+ Args:
147
+ record: Generation record
148
+
149
+ Returns:
150
+ Metadata dictionary
151
+ """
152
+ metadata = dict(record.task.metadata)
153
+ metadata.setdefault("model_identifier", record.task.model.identifier)
154
+ metadata.setdefault("model_provider", record.task.model.provider)
155
+ metadata.setdefault("prompt_template", record.task.prompt.spec.name)
156
+ metadata.setdefault("sampling_temperature", record.task.sampling.temperature)
157
+ metadata.setdefault("sampling_top_p", record.task.sampling.top_p)
158
+ metadata.setdefault("sampling_max_tokens", record.task.sampling.max_tokens)
159
+ return metadata
@@ -0,0 +1,104 @@
1
+ """Integration management for external services (WandB, HuggingFace Hub)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from themis.config.schema import IntegrationsConfig
9
+ from themis.core.entities import ExperimentReport
10
+ from themis.integrations.huggingface import HuggingFaceHubUploader
11
+ from themis.integrations.wandb import WandbTracker
12
+
13
+
14
+ class IntegrationManager:
15
+ """Manages external integrations (WandB, HuggingFace Hub).
16
+
17
+ This class handles all integration-related operations including:
18
+ - Initializing integrations based on configuration
19
+ - Logging experiment results to WandB
20
+ - Uploading results to HuggingFace Hub
21
+ - Finalizing integrations on completion
22
+
23
+ Single Responsibility: External integration management
24
+ """
25
+
26
+ def __init__(self, config: IntegrationsConfig | None = None) -> None:
27
+ """Initialize integration manager.
28
+
29
+ Args:
30
+ config: Integration configuration (None disables all integrations)
31
+ """
32
+ self._config = config or IntegrationsConfig()
33
+
34
+ # Initialize WandB tracker if enabled
35
+ self._wandb_tracker = (
36
+ WandbTracker(self._config.wandb) if self._config.wandb.enable else None
37
+ )
38
+
39
+ # Initialize HuggingFace Hub uploader if enabled
40
+ self._hf_uploader = (
41
+ HuggingFaceHubUploader(self._config.huggingface_hub)
42
+ if self._config.huggingface_hub.enable
43
+ else None
44
+ )
45
+
46
+ @property
47
+ def has_wandb(self) -> bool:
48
+ """Check if WandB integration is enabled."""
49
+ return self._wandb_tracker is not None
50
+
51
+ @property
52
+ def has_huggingface(self) -> bool:
53
+ """Check if HuggingFace Hub integration is enabled."""
54
+ return self._hf_uploader is not None
55
+
56
+ def initialize_run(self, run_config: dict[str, Any]) -> None:
57
+ """Initialize integrations for a new run.
58
+
59
+ Args:
60
+ run_config: Configuration dictionary for the run
61
+ Common keys: max_samples, run_id, resume
62
+ """
63
+ if self._wandb_tracker:
64
+ self._wandb_tracker.init(run_config)
65
+
66
+ def log_results(self, report: ExperimentReport) -> None:
67
+ """Log experiment results to integrations.
68
+
69
+ Args:
70
+ report: Completed experiment report with all results
71
+ """
72
+ if self._wandb_tracker:
73
+ self._wandb_tracker.log_results(report)
74
+
75
+ def upload_results(
76
+ self,
77
+ report: ExperimentReport,
78
+ run_path: str | Path | None,
79
+ ) -> None:
80
+ """Upload results to HuggingFace Hub.
81
+
82
+ Args:
83
+ report: Completed experiment report
84
+ run_path: Path to run directory with cached results
85
+ """
86
+ if self._hf_uploader and run_path is not None:
87
+ self._hf_uploader.upload_results(report, run_path)
88
+
89
+ def finalize(self) -> None:
90
+ """Finalize all integrations.
91
+
92
+ This should be called after experiment completion to properly
93
+ close connections and clean up resources.
94
+ """
95
+ if self._wandb_tracker:
96
+ # WandB tracker handles finalization in log_results
97
+ pass
98
+
99
+ if self._hf_uploader:
100
+ # HuggingFace uploader is stateless, no finalization needed
101
+ pass
102
+
103
+
104
+ __all__ = ["IntegrationManager"]
@@ -0,0 +1,192 @@
1
+ """High-level helpers for math-focused experiments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from textwrap import dedent
6
+ from typing import Sequence
7
+
8
+ from themis.core import entities as core_entities
9
+ from themis.evaluation import extractors, math_verify_utils, metrics, pipeline
10
+ from themis.experiment import orchestrator
11
+ from themis.experiment import storage as experiment_storage
12
+ from themis.generation import clients, plan, runner, templates
13
+ from themis.interfaces import ModelProvider
14
+
15
+
16
+ def build_math500_zero_shot_experiment(
17
+ *,
18
+ model_client: ModelProvider | None = None,
19
+ model_name: str = "fake-math-llm",
20
+ provider_name: str = "fake",
21
+ temperature: float | None = None,
22
+ sampling: core_entities.SamplingConfig | None = None,
23
+ storage: experiment_storage.ExperimentStorage | None = None,
24
+ runner_options: dict[str, object] | None = None,
25
+ task_name: str = "math500",
26
+ ) -> orchestrator.ExperimentOrchestrator:
27
+ """Create an experiment orchestrator tailored for competition math benchmarks."""
28
+
29
+ prompt_template = templates.PromptTemplate(
30
+ name=f"{task_name}-zero-shot-json",
31
+ template=dedent(
32
+ """
33
+ You are an expert competition mathematician. Solve the following problem in a zero-shot
34
+ manner. Think carefully and provide a short reasoning paragraph followed by a line of the
35
+ form `Final Answer: \\boxed{{value}}` where `value` is the final numeric result.
36
+
37
+ Problem:
38
+ {problem}
39
+ """
40
+ ).strip(),
41
+ metadata={"task": task_name, "expect_boxed": True},
42
+ )
43
+
44
+ sampling = sampling or core_entities.SamplingConfig(
45
+ temperature=temperature if temperature is not None else 0.0,
46
+ top_p=0.95,
47
+ max_tokens=512,
48
+ )
49
+ model_spec = core_entities.ModelSpec(
50
+ identifier=model_name, provider=provider_name, default_sampling=sampling
51
+ )
52
+ math_plan = plan.GenerationPlan(
53
+ templates=[prompt_template],
54
+ models=[model_spec],
55
+ sampling_parameters=[sampling],
56
+ dataset_id_field="unique_id",
57
+ reference_field="answer",
58
+ metadata_fields=("subject", "level"),
59
+ context_builder=lambda row: {"problem": row.get("problem", "")},
60
+ )
61
+
62
+ # Extract runner options with proper type conversion
63
+ runner_kwargs = {}
64
+ if runner_options:
65
+ # Convert values to appropriate types with type checking
66
+ if (
67
+ "max_parallel" in runner_options
68
+ and runner_options["max_parallel"] is not None
69
+ ):
70
+ runner_kwargs["max_parallel"] = int(str(runner_options["max_parallel"]))
71
+ if (
72
+ "max_retries" in runner_options
73
+ and runner_options["max_retries"] is not None
74
+ ):
75
+ runner_kwargs["max_retries"] = int(str(runner_options["max_retries"]))
76
+ if (
77
+ "retry_initial_delay" in runner_options
78
+ and runner_options["retry_initial_delay"] is not None
79
+ ):
80
+ runner_kwargs["retry_initial_delay"] = float(
81
+ str(runner_options["retry_initial_delay"])
82
+ )
83
+ if (
84
+ "retry_backoff_multiplier" in runner_options
85
+ and runner_options["retry_backoff_multiplier"] is not None
86
+ ):
87
+ runner_kwargs["retry_backoff_multiplier"] = float(
88
+ str(runner_options["retry_backoff_multiplier"])
89
+ )
90
+ if "retry_max_delay" in runner_options:
91
+ retry_max_delay = runner_options["retry_max_delay"]
92
+ runner_kwargs["retry_max_delay"] = (
93
+ float(str(retry_max_delay)) if retry_max_delay is not None else None
94
+ )
95
+
96
+ math_runner = runner.GenerationRunner(
97
+ provider=model_client or clients.FakeMathModelClient(),
98
+ **runner_kwargs,
99
+ )
100
+ if math_verify_utils.math_verify_available():
101
+ extractor = extractors.MathVerifyExtractor()
102
+ metric_list = [
103
+ metrics.MathVerifyAccuracy(),
104
+ metrics.ExactMatch(case_sensitive=False, strip_whitespace=True),
105
+ ]
106
+ else:
107
+ extractor = extractors.JsonFieldExtractor(field_path="answer")
108
+ metric_list = [
109
+ metrics.ExactMatch(case_sensitive=False, strip_whitespace=True),
110
+ ]
111
+ eval_pipeline = pipeline.EvaluationPipeline(
112
+ extractor=extractor,
113
+ metrics=metric_list,
114
+ )
115
+
116
+ return orchestrator.ExperimentOrchestrator(
117
+ generation_plan=math_plan,
118
+ generation_runner=math_runner,
119
+ evaluation_pipeline=eval_pipeline,
120
+ storage=storage,
121
+ )
122
+
123
+
124
+ def run_math500_zero_shot(
125
+ dataset: Sequence[dict[str, object]],
126
+ *,
127
+ model_client: clients.FakeMathModelClient | None = None,
128
+ max_samples: int | None = None,
129
+ storage: experiment_storage.ExperimentStorage | None = None,
130
+ run_id: str | None = None,
131
+ resume: bool = True,
132
+ ) -> orchestrator.ExperimentReport:
133
+ """Run the zero-shot math experiment against a prepared dataset."""
134
+
135
+ experiment = build_math500_zero_shot_experiment(
136
+ model_client=model_client, storage=storage
137
+ )
138
+ return experiment.run(
139
+ dataset, max_samples=max_samples, run_id=run_id, resume=resume
140
+ )
141
+
142
+
143
+ def summarize_report(report: orchestrator.ExperimentReport) -> str:
144
+ # Get exact match metric
145
+ exact = report.evaluation_report.metrics.get("ExactMatch")
146
+ exact_mean = exact.mean if exact else 0.0
147
+ exact_count = exact.count if exact else 0
148
+
149
+ # Get MathVerify metric if available
150
+ math_verify = report.evaluation_report.metrics.get("MathVerifyAccuracy")
151
+ math_verify_mean = math_verify.mean if math_verify else None
152
+ math_verify_count = math_verify.count if math_verify else 0
153
+
154
+ # Get failure counts
155
+ generation_failures = len(report.failures)
156
+ evaluation_failures = len(report.evaluation_report.failures)
157
+ total_failures = generation_failures + evaluation_failures
158
+
159
+ # Get metadata
160
+ total_samples = report.metadata.get("total_samples", 0)
161
+ successful_generations = report.metadata.get("successful_generations", 0)
162
+ failed_generations = report.metadata.get("failed_generations", 0)
163
+
164
+ # Build summary string
165
+ summary_parts = [
166
+ f"Evaluated {total_samples} samples",
167
+ f"Successful generations: {successful_generations}/{total_samples}",
168
+ f"Exact match: {exact_mean:.3f} ({exact_count} evaluated)",
169
+ ]
170
+
171
+ # Add MathVerify accuracy if available
172
+ if math_verify_mean is not None:
173
+ summary_parts.append(
174
+ f"MathVerify accuracy: {math_verify_mean:.3f} ({math_verify_count} evaluated)"
175
+ )
176
+
177
+ # Add failure information
178
+ if total_failures > 0:
179
+ summary_parts.append(
180
+ f"Failures: {total_failures} (gen: {failed_generations}, eval: {evaluation_failures})"
181
+ )
182
+ else:
183
+ summary_parts.append("No failures")
184
+
185
+ return " | ".join(summary_parts)
186
+
187
+
188
+ __all__ = [
189
+ "build_math500_zero_shot_experiment",
190
+ "run_math500_zero_shot",
191
+ "summarize_report",
192
+ ]
@@ -0,0 +1,169 @@
1
+ """Experiment builders for multiple-choice benchmarks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from textwrap import dedent
6
+ from typing import Callable, Sequence
7
+
8
+ from themis.core import entities as core_entities
9
+ from themis.evaluation import extractors, metrics, pipeline
10
+ from themis.experiment import orchestrator
11
+ from themis.experiment import storage as experiment_storage
12
+ from themis.generation import clients, plan, runner, templates
13
+ from themis.interfaces import ModelProvider
14
+
15
+
16
+ def build_multiple_choice_json_experiment(
17
+ *,
18
+ dataset_name: str,
19
+ task_id: str | None = None,
20
+ model_client: ModelProvider | None = None,
21
+ model_name: str = "fake-math-llm",
22
+ provider_name: str = "fake",
23
+ temperature: float | None = None,
24
+ sampling: core_entities.SamplingConfig | None = None,
25
+ storage: experiment_storage.ExperimentStorage | None = None,
26
+ runner_options: dict[str, object] | None = None,
27
+ metadata_fields: Sequence[str] = ("subject",),
28
+ context_builder: Callable[[dict[str, object]], dict[str, object]] | None = None,
29
+ ) -> orchestrator.ExperimentOrchestrator:
30
+ """Create an experiment orchestrator for multiple-choice QA benchmarks."""
31
+
32
+ task_id = task_id or dataset_name
33
+ prompt_template = templates.PromptTemplate(
34
+ name=f"{dataset_name}-multiple-choice-json",
35
+ template=dedent(
36
+ """
37
+ You are an expert test taker. Select the single best answer to the following
38
+ multiple-choice question.
39
+
40
+ Question:
41
+ {question}
42
+
43
+ Choices:
44
+ {choices_block}
45
+
46
+ Respond with a JSON object containing two keys:
47
+ "answer" - the capital letter of the chosen option (e.g. "A")
48
+ "explanation" - one or two sentences explaining your reasoning
49
+
50
+ Example response:
51
+ {{"answer": "A", "explanation": "Reasoning..."}}
52
+ """
53
+ ).strip(),
54
+ metadata={"task": task_id, "response_format": "json"},
55
+ )
56
+
57
+ sampling = sampling or core_entities.SamplingConfig(
58
+ temperature=temperature if temperature is not None else 0.0,
59
+ top_p=0.95,
60
+ max_tokens=512,
61
+ )
62
+ model_spec = core_entities.ModelSpec(
63
+ identifier=model_name, provider=provider_name, default_sampling=sampling
64
+ )
65
+
66
+ def _default_context_builder(row: dict[str, object]) -> dict[str, object]:
67
+ labels: Sequence[str] = tuple(
68
+ str(label) for label in row.get("choice_labels", [])
69
+ ) or tuple("ABCD")
70
+ choices: Sequence[str] = tuple(str(choice) for choice in row.get("choices", []))
71
+ choice_lines = []
72
+ for label, choice in zip(labels, choices, strict=False):
73
+ choice_lines.append(f"{label}. {choice}")
74
+ choices_block = "\n".join(choice_lines)
75
+ return {
76
+ "question": str(row.get("question", "")),
77
+ "choices_block": choices_block,
78
+ }
79
+
80
+ mcq_plan = plan.GenerationPlan(
81
+ templates=[prompt_template],
82
+ models=[model_spec],
83
+ sampling_parameters=[sampling],
84
+ dataset_id_field="unique_id",
85
+ reference_field="answer",
86
+ metadata_fields=tuple(metadata_fields),
87
+ context_builder=context_builder or _default_context_builder,
88
+ )
89
+
90
+ runner_kwargs = {}
91
+ if runner_options:
92
+ if (
93
+ "max_parallel" in runner_options
94
+ and runner_options["max_parallel"] is not None
95
+ ):
96
+ runner_kwargs["max_parallel"] = int(str(runner_options["max_parallel"]))
97
+ if (
98
+ "max_retries" in runner_options
99
+ and runner_options["max_retries"] is not None
100
+ ):
101
+ runner_kwargs["max_retries"] = int(str(runner_options["max_retries"]))
102
+ if (
103
+ "retry_initial_delay" in runner_options
104
+ and runner_options["retry_initial_delay"] is not None
105
+ ):
106
+ runner_kwargs["retry_initial_delay"] = float(
107
+ str(runner_options["retry_initial_delay"])
108
+ )
109
+ if (
110
+ "retry_backoff_multiplier" in runner_options
111
+ and runner_options["retry_backoff_multiplier"] is not None
112
+ ):
113
+ runner_kwargs["retry_backoff_multiplier"] = float(
114
+ str(runner_options["retry_backoff_multiplier"])
115
+ )
116
+ if "retry_max_delay" in runner_options:
117
+ retry_max_delay = runner_options["retry_max_delay"]
118
+ runner_kwargs["retry_max_delay"] = (
119
+ float(str(retry_max_delay)) if retry_max_delay is not None else None
120
+ )
121
+
122
+ mcq_runner = runner.GenerationRunner(
123
+ provider=model_client or clients.FakeMathModelClient(),
124
+ **runner_kwargs,
125
+ )
126
+
127
+ extractor = extractors.JsonFieldExtractor(field_path="answer")
128
+ metric_list = [
129
+ metrics.ExactMatch(case_sensitive=False, strip_whitespace=True),
130
+ ]
131
+ eval_pipeline = pipeline.EvaluationPipeline(
132
+ extractor=extractor,
133
+ metrics=metric_list,
134
+ )
135
+
136
+ return orchestrator.ExperimentOrchestrator(
137
+ generation_plan=mcq_plan,
138
+ generation_runner=mcq_runner,
139
+ evaluation_pipeline=eval_pipeline,
140
+ storage=storage,
141
+ )
142
+
143
+
144
+ def summarize_report(report: orchestrator.ExperimentReport) -> str:
145
+ exact = report.evaluation_report.metrics.get("ExactMatch")
146
+ accuracy = exact.mean if exact else 0.0
147
+ evaluated = exact.count if exact else 0
148
+
149
+ total_samples = report.metadata.get("total_samples", evaluated)
150
+ successful_generations = report.metadata.get("successful_generations", evaluated)
151
+ failed_generations = report.metadata.get("failed_generations", 0)
152
+ evaluation_failures = len(report.evaluation_report.failures)
153
+ total_failures = failed_generations + evaluation_failures
154
+
155
+ summary_parts = [
156
+ f"Evaluated {total_samples} samples",
157
+ f"Successful generations: {successful_generations}/{total_samples}",
158
+ f"Accuracy: {accuracy:.3f} ({evaluated} evaluated)",
159
+ ]
160
+ if total_failures:
161
+ summary_parts.append(
162
+ f"Failures: {total_failures} (gen: {failed_generations}, eval: {evaluation_failures})"
163
+ )
164
+ else:
165
+ summary_parts.append("No failures")
166
+ return " | ".join(summary_parts)
167
+
168
+
169
+ __all__ = ["build_multiple_choice_json_experiment", "summarize_report"]