themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""CSV export functionality for experiment reports."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import csv
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import MutableMapping, Sequence
|
|
8
|
+
|
|
9
|
+
from themis.core import entities as core_entities
|
|
10
|
+
from themis.experiment import orchestrator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def export_report_csv(
|
|
14
|
+
report: orchestrator.ExperimentReport,
|
|
15
|
+
path: str | Path,
|
|
16
|
+
*,
|
|
17
|
+
include_failures: bool = True,
|
|
18
|
+
) -> Path:
|
|
19
|
+
"""Write per-sample metrics to a CSV file for offline analysis.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
report: Experiment report to export
|
|
23
|
+
path: Output path for CSV file
|
|
24
|
+
include_failures: Whether to include failures column
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Path to created CSV file
|
|
28
|
+
"""
|
|
29
|
+
path = Path(path)
|
|
30
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
31
|
+
metadata_by_condition, metadata_fields = _collect_sample_metadata(
|
|
32
|
+
report.generation_results
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Create a proper index mapping generation records to their metadata
|
|
36
|
+
gen_record_index = {}
|
|
37
|
+
for gen_record in report.generation_results:
|
|
38
|
+
sample_id = gen_record.task.metadata.get(
|
|
39
|
+
"dataset_id"
|
|
40
|
+
) or gen_record.task.metadata.get("sample_id")
|
|
41
|
+
prompt_template = gen_record.task.prompt.spec.name
|
|
42
|
+
model_identifier = gen_record.task.model.identifier
|
|
43
|
+
sampling_temp = gen_record.task.sampling.temperature
|
|
44
|
+
sampling_max_tokens = gen_record.task.sampling.max_tokens
|
|
45
|
+
condition_id = f"{sample_id}_{prompt_template}_{model_identifier}_{sampling_temp}_{sampling_max_tokens}"
|
|
46
|
+
gen_record_index[condition_id] = gen_record
|
|
47
|
+
|
|
48
|
+
metric_names = sorted(report.evaluation_report.metrics.keys())
|
|
49
|
+
fieldnames = (
|
|
50
|
+
["sample_id"] + metadata_fields + [f"metric:{name}" for name in metric_names]
|
|
51
|
+
)
|
|
52
|
+
if include_failures:
|
|
53
|
+
fieldnames.append("failures")
|
|
54
|
+
|
|
55
|
+
with path.open("w", encoding="utf-8", newline="") as handle:
|
|
56
|
+
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
|
57
|
+
writer.writeheader()
|
|
58
|
+
|
|
59
|
+
# Process evaluation records in the same order as generation records
|
|
60
|
+
for i, eval_record in enumerate(report.evaluation_report.records):
|
|
61
|
+
# Find the corresponding generation record by index
|
|
62
|
+
if i < len(report.generation_results):
|
|
63
|
+
gen_record = report.generation_results[i]
|
|
64
|
+
sample_id = gen_record.task.metadata.get(
|
|
65
|
+
"dataset_id"
|
|
66
|
+
) or gen_record.task.metadata.get("sample_id")
|
|
67
|
+
prompt_template = gen_record.task.prompt.spec.name
|
|
68
|
+
model_identifier = gen_record.task.model.identifier
|
|
69
|
+
sampling_temp = gen_record.task.sampling.temperature
|
|
70
|
+
sampling_max_tokens = gen_record.task.sampling.max_tokens
|
|
71
|
+
condition_id = f"{sample_id}_{prompt_template}_{model_identifier}_{sampling_temp}_{sampling_max_tokens}"
|
|
72
|
+
metadata = metadata_by_condition.get(condition_id, {})
|
|
73
|
+
else:
|
|
74
|
+
# Fallback for extra evaluation records
|
|
75
|
+
sample_id = eval_record.sample_id or ""
|
|
76
|
+
metadata = {}
|
|
77
|
+
|
|
78
|
+
row: dict[str, object] = {"sample_id": sample_id}
|
|
79
|
+
for field in metadata_fields:
|
|
80
|
+
row[field] = metadata.get(field, "")
|
|
81
|
+
score_by_name = {
|
|
82
|
+
score.metric_name: score.value for score in eval_record.scores
|
|
83
|
+
}
|
|
84
|
+
for name in metric_names:
|
|
85
|
+
row[f"metric:{name}"] = score_by_name.get(name, "")
|
|
86
|
+
if include_failures:
|
|
87
|
+
row["failures"] = "; ".join(eval_record.failures)
|
|
88
|
+
writer.writerow(row)
|
|
89
|
+
return path
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _collect_sample_metadata(
|
|
93
|
+
records: Sequence[core_entities.GenerationRecord],
|
|
94
|
+
) -> tuple[dict[str, MutableMapping[str, object]], list[str]]:
|
|
95
|
+
"""Collect metadata from generation records.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
records: Generation records
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Tuple of (metadata by condition ID, list of metadata fields)
|
|
102
|
+
"""
|
|
103
|
+
metadata: dict[str, MutableMapping[str, object]] = {}
|
|
104
|
+
for index, record in enumerate(records):
|
|
105
|
+
sample_id = _extract_sample_id(record.task.metadata)
|
|
106
|
+
if sample_id is None:
|
|
107
|
+
sample_id = f"sample-{index}"
|
|
108
|
+
|
|
109
|
+
# Create unique identifier for each experimental condition
|
|
110
|
+
prompt_template = record.task.prompt.spec.name
|
|
111
|
+
model_identifier = record.task.model.identifier
|
|
112
|
+
sampling_temp = record.task.sampling.temperature
|
|
113
|
+
sampling_max_tokens = record.task.sampling.max_tokens
|
|
114
|
+
|
|
115
|
+
# Create unique condition key
|
|
116
|
+
condition_id = f"{sample_id}_{prompt_template}_{model_identifier}_{sampling_temp}_{sampling_max_tokens}"
|
|
117
|
+
|
|
118
|
+
# Store metadata with unique condition ID
|
|
119
|
+
condition_metadata = _metadata_from_task(record)
|
|
120
|
+
metadata[condition_id] = condition_metadata
|
|
121
|
+
|
|
122
|
+
# Collect all field names from all conditions
|
|
123
|
+
fields = sorted({field for meta in metadata.values() for field in meta.keys()})
|
|
124
|
+
|
|
125
|
+
return metadata, fields
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _extract_sample_id(metadata: dict[str, object]) -> str | None:
|
|
129
|
+
"""Extract sample ID from metadata.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
metadata: Task metadata
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Sample ID or None
|
|
136
|
+
"""
|
|
137
|
+
value = metadata.get("dataset_id") or metadata.get("sample_id")
|
|
138
|
+
if value is None:
|
|
139
|
+
return None
|
|
140
|
+
return str(value)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _metadata_from_task(record: core_entities.GenerationRecord) -> dict[str, object]:
|
|
144
|
+
"""Build metadata dict from generation record.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
record: Generation record
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Metadata dictionary
|
|
151
|
+
"""
|
|
152
|
+
metadata = dict(record.task.metadata)
|
|
153
|
+
metadata.setdefault("model_identifier", record.task.model.identifier)
|
|
154
|
+
metadata.setdefault("model_provider", record.task.model.provider)
|
|
155
|
+
metadata.setdefault("prompt_template", record.task.prompt.spec.name)
|
|
156
|
+
metadata.setdefault("sampling_temperature", record.task.sampling.temperature)
|
|
157
|
+
metadata.setdefault("sampling_top_p", record.task.sampling.top_p)
|
|
158
|
+
metadata.setdefault("sampling_max_tokens", record.task.sampling.max_tokens)
|
|
159
|
+
return metadata
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Integration management for external services (WandB, HuggingFace Hub)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from themis.config.schema import IntegrationsConfig
|
|
9
|
+
from themis.core.entities import ExperimentReport
|
|
10
|
+
from themis.integrations.huggingface import HuggingFaceHubUploader
|
|
11
|
+
from themis.integrations.wandb import WandbTracker
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class IntegrationManager:
|
|
15
|
+
"""Manages external integrations (WandB, HuggingFace Hub).
|
|
16
|
+
|
|
17
|
+
This class handles all integration-related operations including:
|
|
18
|
+
- Initializing integrations based on configuration
|
|
19
|
+
- Logging experiment results to WandB
|
|
20
|
+
- Uploading results to HuggingFace Hub
|
|
21
|
+
- Finalizing integrations on completion
|
|
22
|
+
|
|
23
|
+
Single Responsibility: External integration management
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: IntegrationsConfig | None = None) -> None:
|
|
27
|
+
"""Initialize integration manager.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config: Integration configuration (None disables all integrations)
|
|
31
|
+
"""
|
|
32
|
+
self._config = config or IntegrationsConfig()
|
|
33
|
+
|
|
34
|
+
# Initialize WandB tracker if enabled
|
|
35
|
+
self._wandb_tracker = (
|
|
36
|
+
WandbTracker(self._config.wandb) if self._config.wandb.enable else None
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Initialize HuggingFace Hub uploader if enabled
|
|
40
|
+
self._hf_uploader = (
|
|
41
|
+
HuggingFaceHubUploader(self._config.huggingface_hub)
|
|
42
|
+
if self._config.huggingface_hub.enable
|
|
43
|
+
else None
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def has_wandb(self) -> bool:
|
|
48
|
+
"""Check if WandB integration is enabled."""
|
|
49
|
+
return self._wandb_tracker is not None
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def has_huggingface(self) -> bool:
|
|
53
|
+
"""Check if HuggingFace Hub integration is enabled."""
|
|
54
|
+
return self._hf_uploader is not None
|
|
55
|
+
|
|
56
|
+
def initialize_run(self, run_config: dict[str, Any]) -> None:
|
|
57
|
+
"""Initialize integrations for a new run.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
run_config: Configuration dictionary for the run
|
|
61
|
+
Common keys: max_samples, run_id, resume
|
|
62
|
+
"""
|
|
63
|
+
if self._wandb_tracker:
|
|
64
|
+
self._wandb_tracker.init(run_config)
|
|
65
|
+
|
|
66
|
+
def log_results(self, report: ExperimentReport) -> None:
|
|
67
|
+
"""Log experiment results to integrations.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
report: Completed experiment report with all results
|
|
71
|
+
"""
|
|
72
|
+
if self._wandb_tracker:
|
|
73
|
+
self._wandb_tracker.log_results(report)
|
|
74
|
+
|
|
75
|
+
def upload_results(
|
|
76
|
+
self,
|
|
77
|
+
report: ExperimentReport,
|
|
78
|
+
run_path: str | Path | None,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Upload results to HuggingFace Hub.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
report: Completed experiment report
|
|
84
|
+
run_path: Path to run directory with cached results
|
|
85
|
+
"""
|
|
86
|
+
if self._hf_uploader and run_path is not None:
|
|
87
|
+
self._hf_uploader.upload_results(report, run_path)
|
|
88
|
+
|
|
89
|
+
def finalize(self) -> None:
|
|
90
|
+
"""Finalize all integrations.
|
|
91
|
+
|
|
92
|
+
This should be called after experiment completion to properly
|
|
93
|
+
close connections and clean up resources.
|
|
94
|
+
"""
|
|
95
|
+
if self._wandb_tracker:
|
|
96
|
+
# WandB tracker handles finalization in log_results
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
if self._hf_uploader:
|
|
100
|
+
# HuggingFace uploader is stateless, no finalization needed
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
__all__ = ["IntegrationManager"]
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""High-level helpers for math-focused experiments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from textwrap import dedent
|
|
6
|
+
from typing import Sequence
|
|
7
|
+
|
|
8
|
+
from themis.core import entities as core_entities
|
|
9
|
+
from themis.evaluation import extractors, math_verify_utils, metrics, pipeline
|
|
10
|
+
from themis.experiment import orchestrator
|
|
11
|
+
from themis.experiment import storage as experiment_storage
|
|
12
|
+
from themis.generation import clients, plan, runner, templates
|
|
13
|
+
from themis.interfaces import ModelProvider
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def build_math500_zero_shot_experiment(
|
|
17
|
+
*,
|
|
18
|
+
model_client: ModelProvider | None = None,
|
|
19
|
+
model_name: str = "fake-math-llm",
|
|
20
|
+
provider_name: str = "fake",
|
|
21
|
+
temperature: float | None = None,
|
|
22
|
+
sampling: core_entities.SamplingConfig | None = None,
|
|
23
|
+
storage: experiment_storage.ExperimentStorage | None = None,
|
|
24
|
+
runner_options: dict[str, object] | None = None,
|
|
25
|
+
task_name: str = "math500",
|
|
26
|
+
) -> orchestrator.ExperimentOrchestrator:
|
|
27
|
+
"""Create an experiment orchestrator tailored for competition math benchmarks."""
|
|
28
|
+
|
|
29
|
+
prompt_template = templates.PromptTemplate(
|
|
30
|
+
name=f"{task_name}-zero-shot-json",
|
|
31
|
+
template=dedent(
|
|
32
|
+
"""
|
|
33
|
+
You are an expert competition mathematician. Solve the following problem in a zero-shot
|
|
34
|
+
manner. Think carefully and provide a short reasoning paragraph followed by a line of the
|
|
35
|
+
form `Final Answer: \\boxed{{value}}` where `value` is the final numeric result.
|
|
36
|
+
|
|
37
|
+
Problem:
|
|
38
|
+
{problem}
|
|
39
|
+
"""
|
|
40
|
+
).strip(),
|
|
41
|
+
metadata={"task": task_name, "expect_boxed": True},
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
sampling = sampling or core_entities.SamplingConfig(
|
|
45
|
+
temperature=temperature if temperature is not None else 0.0,
|
|
46
|
+
top_p=0.95,
|
|
47
|
+
max_tokens=512,
|
|
48
|
+
)
|
|
49
|
+
model_spec = core_entities.ModelSpec(
|
|
50
|
+
identifier=model_name, provider=provider_name, default_sampling=sampling
|
|
51
|
+
)
|
|
52
|
+
math_plan = plan.GenerationPlan(
|
|
53
|
+
templates=[prompt_template],
|
|
54
|
+
models=[model_spec],
|
|
55
|
+
sampling_parameters=[sampling],
|
|
56
|
+
dataset_id_field="unique_id",
|
|
57
|
+
reference_field="answer",
|
|
58
|
+
metadata_fields=("subject", "level"),
|
|
59
|
+
context_builder=lambda row: {"problem": row.get("problem", "")},
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Extract runner options with proper type conversion
|
|
63
|
+
runner_kwargs = {}
|
|
64
|
+
if runner_options:
|
|
65
|
+
# Convert values to appropriate types with type checking
|
|
66
|
+
if (
|
|
67
|
+
"max_parallel" in runner_options
|
|
68
|
+
and runner_options["max_parallel"] is not None
|
|
69
|
+
):
|
|
70
|
+
runner_kwargs["max_parallel"] = int(str(runner_options["max_parallel"]))
|
|
71
|
+
if (
|
|
72
|
+
"max_retries" in runner_options
|
|
73
|
+
and runner_options["max_retries"] is not None
|
|
74
|
+
):
|
|
75
|
+
runner_kwargs["max_retries"] = int(str(runner_options["max_retries"]))
|
|
76
|
+
if (
|
|
77
|
+
"retry_initial_delay" in runner_options
|
|
78
|
+
and runner_options["retry_initial_delay"] is not None
|
|
79
|
+
):
|
|
80
|
+
runner_kwargs["retry_initial_delay"] = float(
|
|
81
|
+
str(runner_options["retry_initial_delay"])
|
|
82
|
+
)
|
|
83
|
+
if (
|
|
84
|
+
"retry_backoff_multiplier" in runner_options
|
|
85
|
+
and runner_options["retry_backoff_multiplier"] is not None
|
|
86
|
+
):
|
|
87
|
+
runner_kwargs["retry_backoff_multiplier"] = float(
|
|
88
|
+
str(runner_options["retry_backoff_multiplier"])
|
|
89
|
+
)
|
|
90
|
+
if "retry_max_delay" in runner_options:
|
|
91
|
+
retry_max_delay = runner_options["retry_max_delay"]
|
|
92
|
+
runner_kwargs["retry_max_delay"] = (
|
|
93
|
+
float(str(retry_max_delay)) if retry_max_delay is not None else None
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
math_runner = runner.GenerationRunner(
|
|
97
|
+
provider=model_client or clients.FakeMathModelClient(),
|
|
98
|
+
**runner_kwargs,
|
|
99
|
+
)
|
|
100
|
+
if math_verify_utils.math_verify_available():
|
|
101
|
+
extractor = extractors.MathVerifyExtractor()
|
|
102
|
+
metric_list = [
|
|
103
|
+
metrics.MathVerifyAccuracy(),
|
|
104
|
+
metrics.ExactMatch(case_sensitive=False, strip_whitespace=True),
|
|
105
|
+
]
|
|
106
|
+
else:
|
|
107
|
+
extractor = extractors.JsonFieldExtractor(field_path="answer")
|
|
108
|
+
metric_list = [
|
|
109
|
+
metrics.ExactMatch(case_sensitive=False, strip_whitespace=True),
|
|
110
|
+
]
|
|
111
|
+
eval_pipeline = pipeline.EvaluationPipeline(
|
|
112
|
+
extractor=extractor,
|
|
113
|
+
metrics=metric_list,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return orchestrator.ExperimentOrchestrator(
|
|
117
|
+
generation_plan=math_plan,
|
|
118
|
+
generation_runner=math_runner,
|
|
119
|
+
evaluation_pipeline=eval_pipeline,
|
|
120
|
+
storage=storage,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def run_math500_zero_shot(
|
|
125
|
+
dataset: Sequence[dict[str, object]],
|
|
126
|
+
*,
|
|
127
|
+
model_client: clients.FakeMathModelClient | None = None,
|
|
128
|
+
max_samples: int | None = None,
|
|
129
|
+
storage: experiment_storage.ExperimentStorage | None = None,
|
|
130
|
+
run_id: str | None = None,
|
|
131
|
+
resume: bool = True,
|
|
132
|
+
) -> orchestrator.ExperimentReport:
|
|
133
|
+
"""Run the zero-shot math experiment against a prepared dataset."""
|
|
134
|
+
|
|
135
|
+
experiment = build_math500_zero_shot_experiment(
|
|
136
|
+
model_client=model_client, storage=storage
|
|
137
|
+
)
|
|
138
|
+
return experiment.run(
|
|
139
|
+
dataset, max_samples=max_samples, run_id=run_id, resume=resume
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def summarize_report(report: orchestrator.ExperimentReport) -> str:
|
|
144
|
+
# Get exact match metric
|
|
145
|
+
exact = report.evaluation_report.metrics.get("ExactMatch")
|
|
146
|
+
exact_mean = exact.mean if exact else 0.0
|
|
147
|
+
exact_count = exact.count if exact else 0
|
|
148
|
+
|
|
149
|
+
# Get MathVerify metric if available
|
|
150
|
+
math_verify = report.evaluation_report.metrics.get("MathVerifyAccuracy")
|
|
151
|
+
math_verify_mean = math_verify.mean if math_verify else None
|
|
152
|
+
math_verify_count = math_verify.count if math_verify else 0
|
|
153
|
+
|
|
154
|
+
# Get failure counts
|
|
155
|
+
generation_failures = len(report.failures)
|
|
156
|
+
evaluation_failures = len(report.evaluation_report.failures)
|
|
157
|
+
total_failures = generation_failures + evaluation_failures
|
|
158
|
+
|
|
159
|
+
# Get metadata
|
|
160
|
+
total_samples = report.metadata.get("total_samples", 0)
|
|
161
|
+
successful_generations = report.metadata.get("successful_generations", 0)
|
|
162
|
+
failed_generations = report.metadata.get("failed_generations", 0)
|
|
163
|
+
|
|
164
|
+
# Build summary string
|
|
165
|
+
summary_parts = [
|
|
166
|
+
f"Evaluated {total_samples} samples",
|
|
167
|
+
f"Successful generations: {successful_generations}/{total_samples}",
|
|
168
|
+
f"Exact match: {exact_mean:.3f} ({exact_count} evaluated)",
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
# Add MathVerify accuracy if available
|
|
172
|
+
if math_verify_mean is not None:
|
|
173
|
+
summary_parts.append(
|
|
174
|
+
f"MathVerify accuracy: {math_verify_mean:.3f} ({math_verify_count} evaluated)"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Add failure information
|
|
178
|
+
if total_failures > 0:
|
|
179
|
+
summary_parts.append(
|
|
180
|
+
f"Failures: {total_failures} (gen: {failed_generations}, eval: {evaluation_failures})"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
summary_parts.append("No failures")
|
|
184
|
+
|
|
185
|
+
return " | ".join(summary_parts)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
__all__ = [
|
|
189
|
+
"build_math500_zero_shot_experiment",
|
|
190
|
+
"run_math500_zero_shot",
|
|
191
|
+
"summarize_report",
|
|
192
|
+
]
|
themis/experiment/mcq.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""Experiment builders for multiple-choice benchmarks."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from textwrap import dedent
|
|
6
|
+
from typing import Callable, Sequence
|
|
7
|
+
|
|
8
|
+
from themis.core import entities as core_entities
|
|
9
|
+
from themis.evaluation import extractors, metrics, pipeline
|
|
10
|
+
from themis.experiment import orchestrator
|
|
11
|
+
from themis.experiment import storage as experiment_storage
|
|
12
|
+
from themis.generation import clients, plan, runner, templates
|
|
13
|
+
from themis.interfaces import ModelProvider
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def build_multiple_choice_json_experiment(
|
|
17
|
+
*,
|
|
18
|
+
dataset_name: str,
|
|
19
|
+
task_id: str | None = None,
|
|
20
|
+
model_client: ModelProvider | None = None,
|
|
21
|
+
model_name: str = "fake-math-llm",
|
|
22
|
+
provider_name: str = "fake",
|
|
23
|
+
temperature: float | None = None,
|
|
24
|
+
sampling: core_entities.SamplingConfig | None = None,
|
|
25
|
+
storage: experiment_storage.ExperimentStorage | None = None,
|
|
26
|
+
runner_options: dict[str, object] | None = None,
|
|
27
|
+
metadata_fields: Sequence[str] = ("subject",),
|
|
28
|
+
context_builder: Callable[[dict[str, object]], dict[str, object]] | None = None,
|
|
29
|
+
) -> orchestrator.ExperimentOrchestrator:
|
|
30
|
+
"""Create an experiment orchestrator for multiple-choice QA benchmarks."""
|
|
31
|
+
|
|
32
|
+
task_id = task_id or dataset_name
|
|
33
|
+
prompt_template = templates.PromptTemplate(
|
|
34
|
+
name=f"{dataset_name}-multiple-choice-json",
|
|
35
|
+
template=dedent(
|
|
36
|
+
"""
|
|
37
|
+
You are an expert test taker. Select the single best answer to the following
|
|
38
|
+
multiple-choice question.
|
|
39
|
+
|
|
40
|
+
Question:
|
|
41
|
+
{question}
|
|
42
|
+
|
|
43
|
+
Choices:
|
|
44
|
+
{choices_block}
|
|
45
|
+
|
|
46
|
+
Respond with a JSON object containing two keys:
|
|
47
|
+
"answer" - the capital letter of the chosen option (e.g. "A")
|
|
48
|
+
"explanation" - one or two sentences explaining your reasoning
|
|
49
|
+
|
|
50
|
+
Example response:
|
|
51
|
+
{{"answer": "A", "explanation": "Reasoning..."}}
|
|
52
|
+
"""
|
|
53
|
+
).strip(),
|
|
54
|
+
metadata={"task": task_id, "response_format": "json"},
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
sampling = sampling or core_entities.SamplingConfig(
|
|
58
|
+
temperature=temperature if temperature is not None else 0.0,
|
|
59
|
+
top_p=0.95,
|
|
60
|
+
max_tokens=512,
|
|
61
|
+
)
|
|
62
|
+
model_spec = core_entities.ModelSpec(
|
|
63
|
+
identifier=model_name, provider=provider_name, default_sampling=sampling
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _default_context_builder(row: dict[str, object]) -> dict[str, object]:
|
|
67
|
+
labels: Sequence[str] = tuple(
|
|
68
|
+
str(label) for label in row.get("choice_labels", [])
|
|
69
|
+
) or tuple("ABCD")
|
|
70
|
+
choices: Sequence[str] = tuple(str(choice) for choice in row.get("choices", []))
|
|
71
|
+
choice_lines = []
|
|
72
|
+
for label, choice in zip(labels, choices, strict=False):
|
|
73
|
+
choice_lines.append(f"{label}. {choice}")
|
|
74
|
+
choices_block = "\n".join(choice_lines)
|
|
75
|
+
return {
|
|
76
|
+
"question": str(row.get("question", "")),
|
|
77
|
+
"choices_block": choices_block,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
mcq_plan = plan.GenerationPlan(
|
|
81
|
+
templates=[prompt_template],
|
|
82
|
+
models=[model_spec],
|
|
83
|
+
sampling_parameters=[sampling],
|
|
84
|
+
dataset_id_field="unique_id",
|
|
85
|
+
reference_field="answer",
|
|
86
|
+
metadata_fields=tuple(metadata_fields),
|
|
87
|
+
context_builder=context_builder or _default_context_builder,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
runner_kwargs = {}
|
|
91
|
+
if runner_options:
|
|
92
|
+
if (
|
|
93
|
+
"max_parallel" in runner_options
|
|
94
|
+
and runner_options["max_parallel"] is not None
|
|
95
|
+
):
|
|
96
|
+
runner_kwargs["max_parallel"] = int(str(runner_options["max_parallel"]))
|
|
97
|
+
if (
|
|
98
|
+
"max_retries" in runner_options
|
|
99
|
+
and runner_options["max_retries"] is not None
|
|
100
|
+
):
|
|
101
|
+
runner_kwargs["max_retries"] = int(str(runner_options["max_retries"]))
|
|
102
|
+
if (
|
|
103
|
+
"retry_initial_delay" in runner_options
|
|
104
|
+
and runner_options["retry_initial_delay"] is not None
|
|
105
|
+
):
|
|
106
|
+
runner_kwargs["retry_initial_delay"] = float(
|
|
107
|
+
str(runner_options["retry_initial_delay"])
|
|
108
|
+
)
|
|
109
|
+
if (
|
|
110
|
+
"retry_backoff_multiplier" in runner_options
|
|
111
|
+
and runner_options["retry_backoff_multiplier"] is not None
|
|
112
|
+
):
|
|
113
|
+
runner_kwargs["retry_backoff_multiplier"] = float(
|
|
114
|
+
str(runner_options["retry_backoff_multiplier"])
|
|
115
|
+
)
|
|
116
|
+
if "retry_max_delay" in runner_options:
|
|
117
|
+
retry_max_delay = runner_options["retry_max_delay"]
|
|
118
|
+
runner_kwargs["retry_max_delay"] = (
|
|
119
|
+
float(str(retry_max_delay)) if retry_max_delay is not None else None
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
mcq_runner = runner.GenerationRunner(
|
|
123
|
+
provider=model_client or clients.FakeMathModelClient(),
|
|
124
|
+
**runner_kwargs,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
extractor = extractors.JsonFieldExtractor(field_path="answer")
|
|
128
|
+
metric_list = [
|
|
129
|
+
metrics.ExactMatch(case_sensitive=False, strip_whitespace=True),
|
|
130
|
+
]
|
|
131
|
+
eval_pipeline = pipeline.EvaluationPipeline(
|
|
132
|
+
extractor=extractor,
|
|
133
|
+
metrics=metric_list,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return orchestrator.ExperimentOrchestrator(
|
|
137
|
+
generation_plan=mcq_plan,
|
|
138
|
+
generation_runner=mcq_runner,
|
|
139
|
+
evaluation_pipeline=eval_pipeline,
|
|
140
|
+
storage=storage,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def summarize_report(report: orchestrator.ExperimentReport) -> str:
|
|
145
|
+
exact = report.evaluation_report.metrics.get("ExactMatch")
|
|
146
|
+
accuracy = exact.mean if exact else 0.0
|
|
147
|
+
evaluated = exact.count if exact else 0
|
|
148
|
+
|
|
149
|
+
total_samples = report.metadata.get("total_samples", evaluated)
|
|
150
|
+
successful_generations = report.metadata.get("successful_generations", evaluated)
|
|
151
|
+
failed_generations = report.metadata.get("failed_generations", 0)
|
|
152
|
+
evaluation_failures = len(report.evaluation_report.failures)
|
|
153
|
+
total_failures = failed_generations + evaluation_failures
|
|
154
|
+
|
|
155
|
+
summary_parts = [
|
|
156
|
+
f"Evaluated {total_samples} samples",
|
|
157
|
+
f"Successful generations: {successful_generations}/{total_samples}",
|
|
158
|
+
f"Accuracy: {accuracy:.3f} ({evaluated} evaluated)",
|
|
159
|
+
]
|
|
160
|
+
if total_failures:
|
|
161
|
+
summary_parts.append(
|
|
162
|
+
f"Failures: {total_failures} (gen: {failed_generations}, eval: {evaluation_failures})"
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
summary_parts.append("No failures")
|
|
166
|
+
return " | ".join(summary_parts)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
__all__ = ["build_multiple_choice_json_experiment", "summarize_report"]
|