wisent 0.1.1__py3-none-any.whl ā 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info ā wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info ā wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info ā wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1386 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Full Benchmark Downloader
|
|
4
|
+
|
|
5
|
+
Downloads complete benchmarks from lm-eval-harness and saves them in a structured format.
|
|
6
|
+
Downloads the ENTIRE benchmark datasets, not just samples.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python download_full_benchmarks.py --benchmarks glue mmlu --force
|
|
10
|
+
python download_full_benchmarks.py --all # Download all benchmarks
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import json
|
|
15
|
+
import pickle
|
|
16
|
+
import random
|
|
17
|
+
import sys
|
|
18
|
+
import time
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any, Dict, List, Optional
|
|
22
|
+
|
|
23
|
+
# Add current directory to path to import local modules
|
|
24
|
+
current_dir = Path(__file__).parent
|
|
25
|
+
sys.path.insert(0, str(current_dir.parent / "lm-harness-integration"))
|
|
26
|
+
|
|
27
|
+
# Import the benchmark list
|
|
28
|
+
from only_benchmarks import CORE_BENCHMARKS
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FullBenchmarkDownloader:
|
|
32
|
+
"""Downloads complete benchmarks and saves them to disk."""
|
|
33
|
+
|
|
34
|
+
# Benchmarks that are known to be unavailable or problematic
|
|
35
|
+
UNAVAILABLE_BENCHMARKS = {
|
|
36
|
+
# Empty set - let all benchmarks be attempted and skip dynamically if they fail
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
def __init__(self, download_dir: str = "full_benchmarks"):
|
|
40
|
+
"""
|
|
41
|
+
Initialize the benchmark downloader.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
download_dir: Directory to save downloaded benchmarks
|
|
45
|
+
"""
|
|
46
|
+
self.download_dir = Path(download_dir)
|
|
47
|
+
self.download_dir.mkdir(exist_ok=True)
|
|
48
|
+
|
|
49
|
+
# Create subdirectories
|
|
50
|
+
self.data_dir = self.download_dir / "data"
|
|
51
|
+
self.metadata_dir = self.download_dir / "metadata"
|
|
52
|
+
self.data_dir.mkdir(exist_ok=True)
|
|
53
|
+
self.metadata_dir.mkdir(exist_ok=True)
|
|
54
|
+
|
|
55
|
+
print("š Full Benchmark Downloader")
|
|
56
|
+
print(f"š Download directory: {self.download_dir.absolute()}")
|
|
57
|
+
|
|
58
|
+
def download_complete_benchmark(
|
|
59
|
+
self, benchmark_name: str, benchmark_config: dict, force: bool = False
|
|
60
|
+
) -> Optional[str]:
|
|
61
|
+
"""
|
|
62
|
+
Download a complete benchmark dataset.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
benchmark_name: Display name of the benchmark
|
|
66
|
+
benchmark_config: Config dict with 'task' and 'tags' keys
|
|
67
|
+
force: Force redownload even if exists
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Path to saved benchmark file, or None if failed
|
|
71
|
+
"""
|
|
72
|
+
task_name = benchmark_config["task"]
|
|
73
|
+
tags = benchmark_config.get("tags", [])
|
|
74
|
+
|
|
75
|
+
# Check if already exists
|
|
76
|
+
data_file = self.data_dir / f"{benchmark_name}.pkl"
|
|
77
|
+
metadata_file = self.metadata_dir / f"{benchmark_name}_metadata.json"
|
|
78
|
+
|
|
79
|
+
if data_file.exists() and metadata_file.exists() and not force:
|
|
80
|
+
print(f" ā© Skipping {benchmark_name} (already exists)")
|
|
81
|
+
return str(data_file)
|
|
82
|
+
|
|
83
|
+
print(f" š„ Downloading complete benchmark: {benchmark_name}")
|
|
84
|
+
print(f" š Loading full dataset for task: {task_name}")
|
|
85
|
+
|
|
86
|
+
start_time = time.time()
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
# Import lm_eval to download complete datasets
|
|
90
|
+
from lm_eval import tasks
|
|
91
|
+
|
|
92
|
+
# Get the task
|
|
93
|
+
task_dict = tasks.get_task_dict([task_name])
|
|
94
|
+
if task_name not in task_dict:
|
|
95
|
+
print(f" ā Task {task_name} not found in lm_eval")
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
task = task_dict[task_name]
|
|
99
|
+
|
|
100
|
+
# Download complete dataset - combine all splits into one unified dataset
|
|
101
|
+
complete_data = {
|
|
102
|
+
"benchmark_name": benchmark_name,
|
|
103
|
+
"task_name": task_name,
|
|
104
|
+
"config": benchmark_config,
|
|
105
|
+
"all_samples": [],
|
|
106
|
+
"total_samples": 0,
|
|
107
|
+
"splits_found": [],
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# Get all available document splits
|
|
111
|
+
splits_to_try = ["test", "validation", "train", "dev"]
|
|
112
|
+
|
|
113
|
+
for split in splits_to_try:
|
|
114
|
+
try:
|
|
115
|
+
if hasattr(task, f"{split}_docs"):
|
|
116
|
+
docs_method = getattr(task, f"{split}_docs")
|
|
117
|
+
docs = list(docs_method())
|
|
118
|
+
|
|
119
|
+
if docs:
|
|
120
|
+
print(f" š Found {len(docs)} samples in {split} split")
|
|
121
|
+
complete_data["splits_found"].append(split)
|
|
122
|
+
|
|
123
|
+
# Convert documents to serializable format and add to unified list
|
|
124
|
+
for i, doc in enumerate(docs):
|
|
125
|
+
if i % 1000 == 0 and i > 0:
|
|
126
|
+
print(f" Processing {split} {i}/{len(docs)}...")
|
|
127
|
+
|
|
128
|
+
# Convert doc to dict, handling different doc types
|
|
129
|
+
if hasattr(doc, "__dict__"):
|
|
130
|
+
doc_dict = doc.__dict__.copy()
|
|
131
|
+
elif isinstance(doc, dict):
|
|
132
|
+
doc_dict = doc.copy()
|
|
133
|
+
else:
|
|
134
|
+
doc_dict = {"content": str(doc)}
|
|
135
|
+
|
|
136
|
+
# Add split origin info
|
|
137
|
+
doc_dict["_split_origin"] = split
|
|
138
|
+
|
|
139
|
+
# Ensure all values are serializable
|
|
140
|
+
serializable_doc = {}
|
|
141
|
+
for key, value in doc_dict.items():
|
|
142
|
+
try:
|
|
143
|
+
json.dumps(value) # Test if serializable
|
|
144
|
+
serializable_doc[key] = value
|
|
145
|
+
except (TypeError, ValueError):
|
|
146
|
+
serializable_doc[key] = str(value)
|
|
147
|
+
|
|
148
|
+
complete_data["all_samples"].append(serializable_doc)
|
|
149
|
+
|
|
150
|
+
complete_data["total_samples"] += len(docs)
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
print(f" ā ļø Could not load {split} split: {e}")
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
if complete_data["total_samples"] == 0:
|
|
157
|
+
print(f" ā No data found for {benchmark_name}")
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
processing_time = time.time() - start_time
|
|
161
|
+
|
|
162
|
+
# Add metadata
|
|
163
|
+
metadata = {
|
|
164
|
+
"benchmark_name": benchmark_name,
|
|
165
|
+
"task_name": task_name,
|
|
166
|
+
"config": benchmark_config,
|
|
167
|
+
"download_timestamp": datetime.now().isoformat(),
|
|
168
|
+
"processing_time_seconds": processing_time,
|
|
169
|
+
"total_samples": complete_data["total_samples"],
|
|
170
|
+
"splits_found": complete_data["splits_found"],
|
|
171
|
+
"task_info": {
|
|
172
|
+
"description": getattr(task, "DESCRIPTION", "No description available"),
|
|
173
|
+
"citation": getattr(task, "CITATION", "No citation available"),
|
|
174
|
+
"homepage": getattr(task, "HOMEPAGE", "No homepage available"),
|
|
175
|
+
},
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
# Convert to contrastive pairs
|
|
179
|
+
contrastive_data = self.convert_to_contrastive_pairs(benchmark_name, complete_data)
|
|
180
|
+
|
|
181
|
+
# Save only the contrastive pairs
|
|
182
|
+
data_file = self.data_dir / f"{benchmark_name}.pkl"
|
|
183
|
+
with open(data_file, "wb") as f:
|
|
184
|
+
pickle.dump(contrastive_data["contrastive_pairs"], f)
|
|
185
|
+
|
|
186
|
+
# Save metadata as JSON
|
|
187
|
+
with open(metadata_file, "w") as f:
|
|
188
|
+
json.dump(metadata, f, indent=2)
|
|
189
|
+
|
|
190
|
+
print(f" ā
Saved benchmark: {benchmark_name}")
|
|
191
|
+
print(f" š Contrastive pairs: {len(contrastive_data['contrastive_pairs'])}")
|
|
192
|
+
print(f" ā±ļø Time: {processing_time:.1f}s")
|
|
193
|
+
|
|
194
|
+
return str(data_file)
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
processing_time = time.time() - start_time
|
|
198
|
+
print(f" ā Failed to download {benchmark_name}: {e}")
|
|
199
|
+
print(f" ā±ļø Time: {processing_time:.1f}s")
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
def download_all_benchmarks(self, benchmarks: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]:
|
|
203
|
+
"""
|
|
204
|
+
Download multiple complete benchmarks.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
benchmarks: List of benchmark names to download, or None for all
|
|
208
|
+
force: Force redownload even if exists
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Dictionary with download results
|
|
212
|
+
"""
|
|
213
|
+
if benchmarks is None:
|
|
214
|
+
# Filter out known unavailable benchmarks when downloading all
|
|
215
|
+
available_benchmarks = {
|
|
216
|
+
name: config for name, config in CORE_BENCHMARKS.items() if name not in self.UNAVAILABLE_BENCHMARKS
|
|
217
|
+
}
|
|
218
|
+
benchmarks_to_download = available_benchmarks
|
|
219
|
+
|
|
220
|
+
# Report excluded benchmarks
|
|
221
|
+
excluded_count = len(CORE_BENCHMARKS) - len(available_benchmarks)
|
|
222
|
+
if excluded_count > 0:
|
|
223
|
+
print(f"ā© Excluding {excluded_count} known unavailable benchmarks")
|
|
224
|
+
print(f" š Available benchmarks: {len(available_benchmarks)}/{len(CORE_BENCHMARKS)}")
|
|
225
|
+
else:
|
|
226
|
+
benchmarks_to_download = {name: CORE_BENCHMARKS[name] for name in benchmarks if name in CORE_BENCHMARKS}
|
|
227
|
+
|
|
228
|
+
# Check for invalid benchmarks
|
|
229
|
+
invalid = [name for name in benchmarks if name not in CORE_BENCHMARKS]
|
|
230
|
+
if invalid:
|
|
231
|
+
print(f"ā ļø Invalid benchmarks (skipping): {invalid}")
|
|
232
|
+
|
|
233
|
+
# Warn about unavailable benchmarks that were explicitly requested
|
|
234
|
+
unavailable_requested = [name for name in benchmarks if name in self.UNAVAILABLE_BENCHMARKS]
|
|
235
|
+
if unavailable_requested:
|
|
236
|
+
print(f"ā ļø Requested benchmarks are known to be unavailable: {unavailable_requested}")
|
|
237
|
+
print(" š§ These will likely fail. Remove from list to avoid delays.")
|
|
238
|
+
|
|
239
|
+
print(f"\nšļø Downloading {len(benchmarks_to_download)} complete benchmarks")
|
|
240
|
+
print(f" Force redownload: {force}")
|
|
241
|
+
|
|
242
|
+
results = {
|
|
243
|
+
"successful": [],
|
|
244
|
+
"failed": [],
|
|
245
|
+
"skipped": [],
|
|
246
|
+
"excluded": list(self.UNAVAILABLE_BENCHMARKS) if benchmarks is None else [],
|
|
247
|
+
"total_time": 0,
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
total_start_time = time.time()
|
|
251
|
+
|
|
252
|
+
for i, (benchmark_name, benchmark_config) in enumerate(benchmarks_to_download.items(), 1):
|
|
253
|
+
print(f"\n[{i:2d}/{len(benchmarks_to_download)}] šÆ Processing benchmark: {benchmark_name}")
|
|
254
|
+
print(f" Task: {benchmark_config['task']}")
|
|
255
|
+
print(f" Tags: {benchmark_config.get('tags', [])}")
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
result_path = self.download_complete_benchmark(benchmark_name, benchmark_config, force)
|
|
259
|
+
|
|
260
|
+
if result_path:
|
|
261
|
+
results["successful"].append(benchmark_name)
|
|
262
|
+
else:
|
|
263
|
+
results["failed"].append(benchmark_name)
|
|
264
|
+
|
|
265
|
+
except Exception as e:
|
|
266
|
+
print(f" ā Exception downloading {benchmark_name}: {e}")
|
|
267
|
+
results["failed"].append(benchmark_name)
|
|
268
|
+
|
|
269
|
+
# Progress update
|
|
270
|
+
elapsed = time.time() - total_start_time
|
|
271
|
+
if i < len(benchmarks_to_download):
|
|
272
|
+
eta = elapsed * (len(benchmarks_to_download) - i) / i
|
|
273
|
+
print(f"\nš Progress: {i}/{len(benchmarks_to_download)} benchmarks completed")
|
|
274
|
+
print(f" ā±ļø Elapsed: {elapsed / 60:.1f}min, ETA: {eta / 60:.1f}min")
|
|
275
|
+
|
|
276
|
+
results["total_time"] = time.time() - total_start_time
|
|
277
|
+
return results
|
|
278
|
+
|
|
279
|
+
def convert_to_contrastive_pairs(self, benchmark_name: str, complete_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
280
|
+
"""
|
|
281
|
+
Convert benchmark data to contrastive pair format.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
benchmark_name: Name of the benchmark
|
|
285
|
+
complete_data: Raw benchmark data
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Dictionary with contrastive pairs
|
|
289
|
+
"""
|
|
290
|
+
print(" š Converting to contrastive pairs...")
|
|
291
|
+
|
|
292
|
+
contrastive_pairs = []
|
|
293
|
+
|
|
294
|
+
for i, sample in enumerate(complete_data["all_samples"]):
|
|
295
|
+
try:
|
|
296
|
+
pairs = self._convert_sample_to_pairs(sample, benchmark_name)
|
|
297
|
+
if pairs:
|
|
298
|
+
contrastive_pairs.extend(pairs)
|
|
299
|
+
except Exception as e:
|
|
300
|
+
print(f" ā ļø Conversion error for sample {i}: {e}")
|
|
301
|
+
|
|
302
|
+
return {"contrastive_pairs": contrastive_pairs}
|
|
303
|
+
|
|
304
|
+
def _convert_sample_to_pairs(self, sample: Dict[str, Any], benchmark_name: str) -> List[Dict[str, Any]]:
|
|
305
|
+
"""Convert a single sample to contrastive pairs based on benchmark type."""
|
|
306
|
+
|
|
307
|
+
# MMMLU format (instruction, option_a, option_b, option_c, option_d, answer)
|
|
308
|
+
if "instruction" in sample and "option_a" in sample and "answer" in sample:
|
|
309
|
+
return self._convert_mmmlu_format(sample)
|
|
310
|
+
|
|
311
|
+
# Multiple Choice with explicit choices and numeric label (HellaSwag, SWAG, etc.)
|
|
312
|
+
if ("endings" in sample and "label" in sample) or ("ending0" in sample and "label" in sample):
|
|
313
|
+
return self._convert_multiple_choice_numeric(sample)
|
|
314
|
+
|
|
315
|
+
# Multiple Choice with choices dict and answerKey (ARC, OpenBookQA, etc.)
|
|
316
|
+
if "choices" in sample and "answerKey" in sample:
|
|
317
|
+
return self._convert_multiple_choice_letter(sample)
|
|
318
|
+
|
|
319
|
+
# TruthfulQA MC1 format
|
|
320
|
+
if "mc1_targets" in sample:
|
|
321
|
+
return self._convert_truthfulqa_mc1(sample)
|
|
322
|
+
|
|
323
|
+
# TruthfulQA MC2 format
|
|
324
|
+
if "mc2_targets" in sample:
|
|
325
|
+
return self._convert_truthfulqa_mc2(sample)
|
|
326
|
+
|
|
327
|
+
# SQuAD2 format (id, title, context, question, answers)
|
|
328
|
+
if "context" in sample and "question" in sample and "answers" in sample:
|
|
329
|
+
return self._convert_squad2_format(sample)
|
|
330
|
+
|
|
331
|
+
# Textual entailment (premise/hypothesis format like CB, RTE)
|
|
332
|
+
if "premise" in sample and "hypothesis" in sample:
|
|
333
|
+
return self._convert_textual_entailment(sample)
|
|
334
|
+
|
|
335
|
+
# Boolean questions (BoolQ)
|
|
336
|
+
if "label" in sample and str(sample["label"]).lower() in ["true", "false", "0", "1"]:
|
|
337
|
+
return self._convert_boolean_question(sample)
|
|
338
|
+
|
|
339
|
+
# MBPP format (programming problems with code)
|
|
340
|
+
if "task_id" in sample and "text" in sample and "code" in sample:
|
|
341
|
+
return self._convert_mbpp_format(sample)
|
|
342
|
+
|
|
343
|
+
# MATH-500 format (problem, solution, answer, subject, level)
|
|
344
|
+
if (
|
|
345
|
+
"problem" in sample
|
|
346
|
+
and "solution" in sample
|
|
347
|
+
and "answer" in sample
|
|
348
|
+
and "subject" in sample
|
|
349
|
+
and "level" in sample
|
|
350
|
+
):
|
|
351
|
+
return self._convert_math500_format(sample)
|
|
352
|
+
|
|
353
|
+
# WebQS format (question, answers list)
|
|
354
|
+
if "question" in sample and "answers" in sample and isinstance(sample.get("answers"), list):
|
|
355
|
+
return self._convert_webqs_format(sample)
|
|
356
|
+
|
|
357
|
+
# NaturalQS format (question, answer as list)
|
|
358
|
+
if "question" in sample and "answer" in sample and isinstance(sample.get("answer"), list):
|
|
359
|
+
return self._convert_naturalqs_format(sample)
|
|
360
|
+
|
|
361
|
+
# TriviaQA format (question, answer as dict with aliases)
|
|
362
|
+
if "question" in sample and "answer" in sample and isinstance(sample.get("answer"), dict):
|
|
363
|
+
return self._convert_triviaqa_format(sample)
|
|
364
|
+
|
|
365
|
+
# Text generation with question/answer (GSM8K, math problems)
|
|
366
|
+
if "question" in sample and "answer" in sample:
|
|
367
|
+
return self._convert_text_generation(sample)
|
|
368
|
+
|
|
369
|
+
# Reading comprehension (CoQA, SQuAD)
|
|
370
|
+
if "story" in sample or "passage" in sample:
|
|
371
|
+
return self._convert_reading_comprehension(sample)
|
|
372
|
+
|
|
373
|
+
# SQuAD2 format (id, title, context, question, answers)
|
|
374
|
+
if (
|
|
375
|
+
"id" in sample
|
|
376
|
+
and "title" in sample
|
|
377
|
+
and "context" in sample
|
|
378
|
+
and "question" in sample
|
|
379
|
+
and "answers" in sample
|
|
380
|
+
):
|
|
381
|
+
return self._convert_squad2_format(sample)
|
|
382
|
+
|
|
383
|
+
# Winogrande format (sentence, option1, option2, answer)
|
|
384
|
+
if "sentence" in sample and "option1" in sample and "option2" in sample and "answer" in sample:
|
|
385
|
+
return self._convert_winogrande_format(sample)
|
|
386
|
+
|
|
387
|
+
# WikiText format (page)
|
|
388
|
+
if "page" in sample:
|
|
389
|
+
return self._convert_wikitext_format(sample)
|
|
390
|
+
|
|
391
|
+
# GPQA format (Question, choice1-4, answer, plus rich metadata)
|
|
392
|
+
if (
|
|
393
|
+
"Question" in sample
|
|
394
|
+
and "choice1" in sample
|
|
395
|
+
and "choice2" in sample
|
|
396
|
+
and "choice3" in sample
|
|
397
|
+
and "choice4" in sample
|
|
398
|
+
and "answer" in sample
|
|
399
|
+
):
|
|
400
|
+
return self._convert_gpqa_format(sample)
|
|
401
|
+
|
|
402
|
+
# HLE format (question, answer, answer_type, category)
|
|
403
|
+
if "question" in sample and "answer" in sample and "answer_type" in sample and "category" in sample:
|
|
404
|
+
return self._convert_hle_format(sample)
|
|
405
|
+
|
|
406
|
+
# HumanEval code generation format (task_id, canonical_solution, prompt, test, entry_point)
|
|
407
|
+
if "task_id" in sample and "canonical_solution" in sample and "prompt" in sample and "test" in sample:
|
|
408
|
+
return self._convert_humaneval_format(sample)
|
|
409
|
+
|
|
410
|
+
# MBPP code generation format (task_id, code, prompt, test)
|
|
411
|
+
if "task_id" in sample and "code" in sample and "prompt" in sample and "test" in sample:
|
|
412
|
+
return self._convert_mbpp_format(sample)
|
|
413
|
+
|
|
414
|
+
# Generic multiple choice fallback
|
|
415
|
+
if "choices" in sample:
|
|
416
|
+
return self._convert_generic_multiple_choice(sample)
|
|
417
|
+
|
|
418
|
+
print(f" ā ļø Unknown sample format: {list(sample.keys())}")
|
|
419
|
+
return []
|
|
420
|
+
|
|
421
|
+
def _convert_mmmlu_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
422
|
+
"""Convert MMMLU format (instruction, option_a/b/c/d, answer)."""
|
|
423
|
+
instruction = sample.get("instruction", "")
|
|
424
|
+
option_a = sample.get("option_a", "")
|
|
425
|
+
option_b = sample.get("option_b", "")
|
|
426
|
+
option_c = sample.get("option_c", "")
|
|
427
|
+
option_d = sample.get("option_d", "")
|
|
428
|
+
answer = sample.get("answer", "")
|
|
429
|
+
|
|
430
|
+
# Map answer letter to option
|
|
431
|
+
options = {"A": option_a, "B": option_b, "C": option_c, "D": option_d}
|
|
432
|
+
|
|
433
|
+
correct_answer = options.get(answer, option_a) # Default to A if answer not found
|
|
434
|
+
|
|
435
|
+
# Create pairs with each incorrect option
|
|
436
|
+
pairs = []
|
|
437
|
+
for letter, option in options.items():
|
|
438
|
+
if letter != answer and option:
|
|
439
|
+
pairs.append(
|
|
440
|
+
{
|
|
441
|
+
"context": instruction,
|
|
442
|
+
"good_response": correct_answer,
|
|
443
|
+
"bad_response": option,
|
|
444
|
+
"metadata": {
|
|
445
|
+
"answer_key": answer,
|
|
446
|
+
"sample_id": sample.get("id", ""),
|
|
447
|
+
"benchmark_type": "mmmlu",
|
|
448
|
+
},
|
|
449
|
+
}
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
return pairs
|
|
453
|
+
|
|
454
|
+
def _convert_multiple_choice_numeric(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
455
|
+
"""Convert multiple choice with numeric label (HellaSwag, SWAG)."""
|
|
456
|
+
context = sample.get("ctx", sample.get("query", ""))
|
|
457
|
+
|
|
458
|
+
# Handle different choice formats
|
|
459
|
+
if "endings" in sample:
|
|
460
|
+
# HellaSwag format: choices in "endings" list
|
|
461
|
+
choices = sample.get("endings", [])
|
|
462
|
+
elif "ending0" in sample:
|
|
463
|
+
# SWAG format: choices in separate ending0, ending1, etc. fields
|
|
464
|
+
choices = []
|
|
465
|
+
for i in range(4): # SWAG typically has 4 choices
|
|
466
|
+
ending_key = f"ending{i}"
|
|
467
|
+
if ending_key in sample:
|
|
468
|
+
choices.append(sample[ending_key])
|
|
469
|
+
# Build context from sent1, sent2, etc.
|
|
470
|
+
sent1 = sample.get("sent1", "")
|
|
471
|
+
sent2 = sample.get("sent2", "")
|
|
472
|
+
context = f"{sent1} {sent2}".strip()
|
|
473
|
+
else:
|
|
474
|
+
choices = sample.get("choices", [])
|
|
475
|
+
|
|
476
|
+
correct_idx = int(sample["label"])
|
|
477
|
+
|
|
478
|
+
if not choices or correct_idx >= len(choices):
|
|
479
|
+
return []
|
|
480
|
+
|
|
481
|
+
correct_answer = choices[correct_idx]
|
|
482
|
+
incorrect_answers = [choices[i] for i in range(len(choices)) if i != correct_idx]
|
|
483
|
+
|
|
484
|
+
pairs = []
|
|
485
|
+
for incorrect in incorrect_answers:
|
|
486
|
+
pairs.append(
|
|
487
|
+
{
|
|
488
|
+
"context": context,
|
|
489
|
+
"good_response": correct_answer,
|
|
490
|
+
"bad_response": incorrect,
|
|
491
|
+
"metadata": {
|
|
492
|
+
"correct_index": correct_idx,
|
|
493
|
+
"sample_id": sample.get("id", sample.get("ind", "")),
|
|
494
|
+
"source": sample.get("source", ""),
|
|
495
|
+
},
|
|
496
|
+
}
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
return pairs
|
|
500
|
+
|
|
501
|
+
def _convert_multiple_choice_letter(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
502
|
+
"""Convert multiple choice with letter answerKey (ARC, OpenBookQA)."""
|
|
503
|
+
question = sample.get("question", "")
|
|
504
|
+
choices_text = sample["choices"]["text"]
|
|
505
|
+
choices_labels = sample["choices"]["label"]
|
|
506
|
+
answer_key = sample["answerKey"]
|
|
507
|
+
|
|
508
|
+
# Find correct answer
|
|
509
|
+
correct_idx = None
|
|
510
|
+
for i, label in enumerate(choices_labels):
|
|
511
|
+
if label == answer_key:
|
|
512
|
+
correct_idx = i
|
|
513
|
+
break
|
|
514
|
+
|
|
515
|
+
if correct_idx is None:
|
|
516
|
+
return []
|
|
517
|
+
|
|
518
|
+
correct_answer = choices_text[correct_idx]
|
|
519
|
+
incorrect_answers = [choices_text[i] for i in range(len(choices_text)) if i != correct_idx]
|
|
520
|
+
|
|
521
|
+
pairs = []
|
|
522
|
+
for incorrect in incorrect_answers:
|
|
523
|
+
pairs.append(
|
|
524
|
+
{
|
|
525
|
+
"context": question,
|
|
526
|
+
"good_response": correct_answer,
|
|
527
|
+
"bad_response": incorrect,
|
|
528
|
+
"metadata": {
|
|
529
|
+
"answer_key": answer_key,
|
|
530
|
+
"sample_id": sample.get("id", ""),
|
|
531
|
+
"source": sample.get("source", ""),
|
|
532
|
+
},
|
|
533
|
+
}
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
return pairs
|
|
537
|
+
|
|
538
|
+
def _convert_truthfulqa_mc1(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
539
|
+
"""Convert TruthfulQA MC1 format."""
|
|
540
|
+
question = sample["question"]
|
|
541
|
+
choices = sample["mc1_targets"]["choices"]
|
|
542
|
+
labels = sample["mc1_targets"]["labels"]
|
|
543
|
+
|
|
544
|
+
# Find correct and incorrect answers
|
|
545
|
+
correct_answers = [choices[i] for i, label in enumerate(labels) if label == 1]
|
|
546
|
+
incorrect_answers = [choices[i] for i, label in enumerate(labels) if label == 0]
|
|
547
|
+
|
|
548
|
+
if not correct_answers or not incorrect_answers:
|
|
549
|
+
return []
|
|
550
|
+
|
|
551
|
+
pairs = []
|
|
552
|
+
for correct in correct_answers:
|
|
553
|
+
for incorrect in incorrect_answers[:3]: # Limit to 3 incorrect per correct
|
|
554
|
+
pairs.append(
|
|
555
|
+
{
|
|
556
|
+
"context": question,
|
|
557
|
+
"good_response": correct,
|
|
558
|
+
"bad_response": incorrect,
|
|
559
|
+
"metadata": {"sample_id": sample.get("id", ""), "benchmark_type": "truthfulqa_mc1"},
|
|
560
|
+
}
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
return pairs
|
|
564
|
+
|
|
565
|
+
def _convert_truthfulqa_mc2(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
566
|
+
"""Convert TruthfulQA MC2 format."""
|
|
567
|
+
question = sample["question"]
|
|
568
|
+
choices = sample["mc2_targets"]["choices"]
|
|
569
|
+
labels = sample["mc2_targets"]["labels"]
|
|
570
|
+
|
|
571
|
+
correct_answers = [choices[i] for i, label in enumerate(labels) if label == 1]
|
|
572
|
+
incorrect_answers = [choices[i] for i, label in enumerate(labels) if label == 0]
|
|
573
|
+
|
|
574
|
+
if not correct_answers or not incorrect_answers:
|
|
575
|
+
return []
|
|
576
|
+
|
|
577
|
+
pairs = []
|
|
578
|
+
for correct in correct_answers:
|
|
579
|
+
for incorrect in incorrect_answers[:2]: # Limit to 2 incorrect per correct
|
|
580
|
+
pairs.append(
|
|
581
|
+
{
|
|
582
|
+
"context": question,
|
|
583
|
+
"good_response": correct,
|
|
584
|
+
"bad_response": incorrect,
|
|
585
|
+
"metadata": {"sample_id": sample.get("id", ""), "benchmark_type": "truthfulqa_mc2"},
|
|
586
|
+
}
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
return pairs
|
|
590
|
+
|
|
591
|
+
def _convert_textual_entailment(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
592
|
+
"""Convert textual entailment tasks (CB, RTE)."""
|
|
593
|
+
premise = sample["premise"]
|
|
594
|
+
hypothesis = sample["hypothesis"]
|
|
595
|
+
label = sample["label"]
|
|
596
|
+
|
|
597
|
+
# Map different label formats
|
|
598
|
+
if isinstance(label, str):
|
|
599
|
+
if label.lower() in ["entailment", "true", "1"]:
|
|
600
|
+
correct_answer = "Yes, this follows logically."
|
|
601
|
+
incorrect_answer = "No, this does not follow logically."
|
|
602
|
+
elif label.lower() in ["contradiction", "false", "0"]:
|
|
603
|
+
correct_answer = "No, this contradicts the premise."
|
|
604
|
+
incorrect_answer = "Yes, this follows logically."
|
|
605
|
+
else: # neutral
|
|
606
|
+
correct_answer = "This is neither supported nor contradicted."
|
|
607
|
+
incorrect_answer = "Yes, this follows logically."
|
|
608
|
+
else:
|
|
609
|
+
# Numeric labels: typically 0=entailment, 1=neutral, 2=contradiction
|
|
610
|
+
if label == 0:
|
|
611
|
+
correct_answer = "Yes, this follows logically."
|
|
612
|
+
incorrect_answer = "No, this does not follow logically."
|
|
613
|
+
elif label == 2:
|
|
614
|
+
correct_answer = "No, this contradicts the premise."
|
|
615
|
+
incorrect_answer = "Yes, this follows logically."
|
|
616
|
+
else: # neutral
|
|
617
|
+
correct_answer = "This is neither supported nor contradicted."
|
|
618
|
+
incorrect_answer = "Yes, this follows logically."
|
|
619
|
+
|
|
620
|
+
context = f"Premise: {premise}\nHypothesis: {hypothesis}\nDoes the hypothesis follow from the premise?"
|
|
621
|
+
|
|
622
|
+
return [
|
|
623
|
+
{
|
|
624
|
+
"context": context,
|
|
625
|
+
"good_response": correct_answer,
|
|
626
|
+
"bad_response": incorrect_answer,
|
|
627
|
+
"metadata": {
|
|
628
|
+
"sample_id": sample.get("idx", ""),
|
|
629
|
+
"original_label": label,
|
|
630
|
+
"benchmark_type": "textual_entailment",
|
|
631
|
+
},
|
|
632
|
+
}
|
|
633
|
+
]
|
|
634
|
+
|
|
635
|
+
def _convert_boolean_question(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
636
|
+
"""Convert boolean questions (BoolQ)."""
|
|
637
|
+
question = sample.get("question", "")
|
|
638
|
+
passage = sample.get("passage", "")
|
|
639
|
+
label = sample["label"]
|
|
640
|
+
|
|
641
|
+
# Determine correct answer
|
|
642
|
+
if str(label).lower() in ["true", "1"]:
|
|
643
|
+
correct_answer = "Yes"
|
|
644
|
+
incorrect_answer = "No"
|
|
645
|
+
else:
|
|
646
|
+
correct_answer = "No"
|
|
647
|
+
incorrect_answer = "Yes"
|
|
648
|
+
|
|
649
|
+
context = f"{passage}\n\nQuestion: {question}" if passage else question
|
|
650
|
+
|
|
651
|
+
return [
|
|
652
|
+
{
|
|
653
|
+
"context": context,
|
|
654
|
+
"good_response": correct_answer,
|
|
655
|
+
"bad_response": incorrect_answer,
|
|
656
|
+
"metadata": {"sample_id": sample.get("id", ""), "original_label": label, "benchmark_type": "boolean"},
|
|
657
|
+
}
|
|
658
|
+
]
|
|
659
|
+
|
|
660
|
+
def _convert_text_generation(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
661
|
+
"""Convert text generation tasks (GSM8K, math problems)."""
|
|
662
|
+
question = sample["question"]
|
|
663
|
+
correct_answer = sample["answer"]
|
|
664
|
+
|
|
665
|
+
# Generate plausible incorrect answers for math problems
|
|
666
|
+
if any(
|
|
667
|
+
math_keyword in question.lower() for math_keyword in ["dollars", "cost", "price", "how much", "how many"]
|
|
668
|
+
):
|
|
669
|
+
incorrect_answers = self._generate_math_distractors(correct_answer)
|
|
670
|
+
else:
|
|
671
|
+
# For non-math, create generic incorrect responses
|
|
672
|
+
incorrect_answers = [
|
|
673
|
+
"I don't know the answer to this question.",
|
|
674
|
+
"This question cannot be answered with the given information.",
|
|
675
|
+
"The answer is unclear from the problem statement.",
|
|
676
|
+
]
|
|
677
|
+
|
|
678
|
+
pairs = []
|
|
679
|
+
for incorrect in incorrect_answers:
|
|
680
|
+
pairs.append(
|
|
681
|
+
{
|
|
682
|
+
"context": question,
|
|
683
|
+
"good_response": correct_answer,
|
|
684
|
+
"bad_response": incorrect,
|
|
685
|
+
"metadata": {"sample_id": sample.get("id", ""), "benchmark_type": "text_generation"},
|
|
686
|
+
}
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
return pairs
|
|
690
|
+
|
|
691
|
+
def _convert_math500_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
692
|
+
"""Convert MATH-500 format (problem, solution, answer, subject, level)."""
|
|
693
|
+
problem = sample.get("problem", "")
|
|
694
|
+
correct_answer = sample.get("answer", "")
|
|
695
|
+
solution = sample.get("solution", "")
|
|
696
|
+
subject = sample.get("subject", "")
|
|
697
|
+
level = sample.get("level", 0)
|
|
698
|
+
unique_id = sample.get("unique_id", "")
|
|
699
|
+
|
|
700
|
+
# Generate mathematical incorrect answers based on correct answer
|
|
701
|
+
incorrect_answers = self._generate_math_distractors(correct_answer)
|
|
702
|
+
|
|
703
|
+
pairs = []
|
|
704
|
+
for incorrect in incorrect_answers:
|
|
705
|
+
pairs.append(
|
|
706
|
+
{
|
|
707
|
+
"context": problem,
|
|
708
|
+
"good_response": correct_answer,
|
|
709
|
+
"bad_response": incorrect,
|
|
710
|
+
"metadata": {
|
|
711
|
+
"benchmark_type": "math500",
|
|
712
|
+
"subject": subject,
|
|
713
|
+
"level": level,
|
|
714
|
+
"sample_id": unique_id,
|
|
715
|
+
"has_solution": bool(solution.strip()), # Track if step-by-step solution available
|
|
716
|
+
},
|
|
717
|
+
}
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
return pairs
|
|
721
|
+
|
|
722
|
+
def _convert_reading_comprehension(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
723
|
+
"""Convert reading comprehension tasks (CoQA, SQuAD)."""
|
|
724
|
+
# This is complex as these often have multiple Q&A pairs
|
|
725
|
+
# For now, create a basic conversion
|
|
726
|
+
story = sample.get("story", sample.get("passage", ""))
|
|
727
|
+
|
|
728
|
+
pairs = []
|
|
729
|
+
|
|
730
|
+
# Handle CoQA format with multiple questions
|
|
731
|
+
if "questions" in sample and "answers" in sample:
|
|
732
|
+
questions_data = sample["questions"]
|
|
733
|
+
answers_data = sample["answers"]
|
|
734
|
+
|
|
735
|
+
# CoQA format has questions and answers as dicts with lists
|
|
736
|
+
if isinstance(questions_data, dict) and isinstance(answers_data, dict):
|
|
737
|
+
question_texts = questions_data.get("input_text", [])
|
|
738
|
+
answer_texts = answers_data.get("input_text", [])
|
|
739
|
+
|
|
740
|
+
for i, (q_text, a_text) in enumerate(zip(question_texts, answer_texts)):
|
|
741
|
+
context = f"{story}\n\nQuestion: {q_text}"
|
|
742
|
+
|
|
743
|
+
# Generate incorrect answer
|
|
744
|
+
incorrect_answer = "I cannot find this information in the passage."
|
|
745
|
+
|
|
746
|
+
pairs.append(
|
|
747
|
+
{
|
|
748
|
+
"context": context,
|
|
749
|
+
"good_response": a_text,
|
|
750
|
+
"bad_response": incorrect_answer,
|
|
751
|
+
"metadata": {
|
|
752
|
+
"sample_id": sample.get("id", ""),
|
|
753
|
+
"question_index": i,
|
|
754
|
+
"benchmark_type": "reading_comprehension",
|
|
755
|
+
},
|
|
756
|
+
}
|
|
757
|
+
)
|
|
758
|
+
# Handle other formats where questions/answers might be lists directly
|
|
759
|
+
elif isinstance(questions_data, list) and isinstance(answers_data, list):
|
|
760
|
+
for i, (q, a) in enumerate(zip(questions_data, answers_data)):
|
|
761
|
+
question_text = q.get("input_text", q.get("text", "")) if isinstance(q, dict) else str(q)
|
|
762
|
+
answer_text = a.get("input_text", a.get("text", "")) if isinstance(a, dict) else str(a)
|
|
763
|
+
|
|
764
|
+
context = f"{story}\n\nQuestion: {question_text}"
|
|
765
|
+
|
|
766
|
+
# Generate incorrect answer
|
|
767
|
+
incorrect_answer = "I cannot find this information in the passage."
|
|
768
|
+
|
|
769
|
+
pairs.append(
|
|
770
|
+
{
|
|
771
|
+
"context": context,
|
|
772
|
+
"good_response": answer_text,
|
|
773
|
+
"bad_response": incorrect_answer,
|
|
774
|
+
"metadata": {
|
|
775
|
+
"sample_id": sample.get("id", ""),
|
|
776
|
+
"question_index": i,
|
|
777
|
+
"benchmark_type": "reading_comprehension",
|
|
778
|
+
},
|
|
779
|
+
}
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
return pairs
|
|
783
|
+
|
|
784
|
+
def _convert_squad2_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
785
|
+
"""Convert SQuAD2 format (id, title, context, question, answers)."""
|
|
786
|
+
context = sample.get("context", "")
|
|
787
|
+
question = sample.get("question", "")
|
|
788
|
+
answers_data = sample.get("answers", {})
|
|
789
|
+
|
|
790
|
+
# Extract answer texts from answers dict
|
|
791
|
+
answer_texts = answers_data.get("text", [])
|
|
792
|
+
if not answer_texts:
|
|
793
|
+
# Handle empty answers (SQuAD2 has "no answer" questions)
|
|
794
|
+
correct_answer = "There is no answer to this question in the given context."
|
|
795
|
+
else:
|
|
796
|
+
# Use the first answer as the correct one
|
|
797
|
+
correct_answer = answer_texts[0]
|
|
798
|
+
|
|
799
|
+
# Generate plausible incorrect answers for reading comprehension
|
|
800
|
+
incorrect_answers = [
|
|
801
|
+
"I cannot find this information in the passage.",
|
|
802
|
+
"The question cannot be answered based on the given context.",
|
|
803
|
+
"This information is not provided in the text.",
|
|
804
|
+
]
|
|
805
|
+
|
|
806
|
+
# Format the context for the contrastive pair
|
|
807
|
+
full_context = f"Context: {context}\n\nQuestion: {question}"
|
|
808
|
+
|
|
809
|
+
pairs = []
|
|
810
|
+
for incorrect in incorrect_answers:
|
|
811
|
+
pairs.append(
|
|
812
|
+
{
|
|
813
|
+
"context": full_context,
|
|
814
|
+
"good_response": correct_answer,
|
|
815
|
+
"bad_response": incorrect,
|
|
816
|
+
"metadata": {
|
|
817
|
+
"sample_id": sample.get("id", ""),
|
|
818
|
+
"title": sample.get("title", ""),
|
|
819
|
+
"benchmark_type": "squad2",
|
|
820
|
+
"has_answer": bool(answer_texts), # Track if this question has an answer
|
|
821
|
+
},
|
|
822
|
+
}
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
return pairs
|
|
826
|
+
|
|
827
|
+
def _convert_generic_multiple_choice(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
828
|
+
"""Generic fallback for multiple choice formats."""
|
|
829
|
+
question = sample.get("question", sample.get("query", ""))
|
|
830
|
+
choices = sample.get("choices", [])
|
|
831
|
+
|
|
832
|
+
if len(choices) < 2:
|
|
833
|
+
return []
|
|
834
|
+
|
|
835
|
+
# Assume first choice is correct (this is a fallback)
|
|
836
|
+
correct_answer = choices[0]
|
|
837
|
+
incorrect_answers = choices[1:]
|
|
838
|
+
|
|
839
|
+
pairs = []
|
|
840
|
+
for incorrect in incorrect_answers:
|
|
841
|
+
pairs.append(
|
|
842
|
+
{
|
|
843
|
+
"context": question,
|
|
844
|
+
"good_response": correct_answer,
|
|
845
|
+
"bad_response": incorrect,
|
|
846
|
+
"metadata": {
|
|
847
|
+
"sample_id": sample.get("id", ""),
|
|
848
|
+
"benchmark_type": "generic_multiple_choice",
|
|
849
|
+
"warning": "Assumed first choice is correct",
|
|
850
|
+
},
|
|
851
|
+
}
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
return pairs
|
|
855
|
+
|
|
856
|
+
def _generate_math_distractors(self, correct_answer: str) -> List[str]:
|
|
857
|
+
"""Generate plausible incorrect answers for math problems."""
|
|
858
|
+
import re
|
|
859
|
+
|
|
860
|
+
# Extract final number from answer
|
|
861
|
+
numbers = re.findall(r"\d+(?:\.\d+)?", correct_answer)
|
|
862
|
+
if not numbers:
|
|
863
|
+
return ["42", "0", "Cannot be determined"]
|
|
864
|
+
|
|
865
|
+
final_number = float(numbers[-1])
|
|
866
|
+
|
|
867
|
+
# Generate distractors
|
|
868
|
+
distractors = []
|
|
869
|
+
|
|
870
|
+
# Off-by-one errors
|
|
871
|
+
distractors.append(str(int(final_number + 1)))
|
|
872
|
+
distractors.append(str(int(final_number - 1)))
|
|
873
|
+
|
|
874
|
+
# Calculation errors (common mistakes)
|
|
875
|
+
distractors.append(str(int(final_number * 2)))
|
|
876
|
+
distractors.append(str(int(final_number / 2)))
|
|
877
|
+
|
|
878
|
+
# Random nearby numbers
|
|
879
|
+
distractors.append(str(int(final_number + random.randint(2, 10))))
|
|
880
|
+
|
|
881
|
+
return distractors[:3] # Return top 3
|
|
882
|
+
|
|
883
|
+
def _convert_humaneval_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
884
|
+
"""Convert HumanEval code generation format."""
|
|
885
|
+
task_id = sample.get("task_id", "unknown")
|
|
886
|
+
prompt = sample.get("prompt", "")
|
|
887
|
+
canonical_solution = sample.get("canonical_solution", "")
|
|
888
|
+
test = sample.get("test", "")
|
|
889
|
+
entry_point = sample.get("entry_point", "")
|
|
890
|
+
|
|
891
|
+
pairs = []
|
|
892
|
+
|
|
893
|
+
# Create a contrastive pair with the coding prompt
|
|
894
|
+
pairs.append(
|
|
895
|
+
{
|
|
896
|
+
"question": f"Complete this Python function:\n\n{prompt}",
|
|
897
|
+
"correct_answer": canonical_solution,
|
|
898
|
+
"incorrect_answer": "# Incorrect or incomplete implementation\npass",
|
|
899
|
+
"metadata": {
|
|
900
|
+
"task_id": task_id,
|
|
901
|
+
"test_cases": test,
|
|
902
|
+
"entry_point": entry_point,
|
|
903
|
+
"benchmark_type": "humaneval",
|
|
904
|
+
"task_type": "code_completion",
|
|
905
|
+
},
|
|
906
|
+
}
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
return pairs
|
|
910
|
+
|
|
911
|
+
def _convert_mbpp_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
912
|
+
"""Convert MBPP format (programming problems with code)."""
|
|
913
|
+
# Use the benchmark extractor to get contrastive pairs
|
|
914
|
+
from wisent_guard.core.benchmark_extractors import extract_contrastive_pair
|
|
915
|
+
|
|
916
|
+
try:
|
|
917
|
+
contrastive_data = extract_contrastive_pair("mbpp", sample, None)
|
|
918
|
+
|
|
919
|
+
if contrastive_data:
|
|
920
|
+
return [
|
|
921
|
+
{
|
|
922
|
+
"context": contrastive_data["question"],
|
|
923
|
+
"good_response": contrastive_data["correct_answer"],
|
|
924
|
+
"bad_response": contrastive_data["incorrect_answer"],
|
|
925
|
+
"metadata": {"task_id": sample.get("task_id", ""), "benchmark_type": "mbpp"},
|
|
926
|
+
}
|
|
927
|
+
]
|
|
928
|
+
return []
|
|
929
|
+
except Exception as e:
|
|
930
|
+
print(f" ā ļø Error converting MBPP sample: {e}")
|
|
931
|
+
return []
|
|
932
|
+
|
|
933
|
+
def _convert_gpqa_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
934
|
+
"""Convert GPQA format (Question, choice1-4, answer, plus rich metadata)."""
|
|
935
|
+
question = sample.get("Question", "")
|
|
936
|
+
choice1 = sample.get("choice1", "")
|
|
937
|
+
choice2 = sample.get("choice2", "")
|
|
938
|
+
choice3 = sample.get("choice3", "")
|
|
939
|
+
choice4 = sample.get("choice4", "")
|
|
940
|
+
answer = sample.get("answer", "")
|
|
941
|
+
|
|
942
|
+
# Extract letter from answer format like "(A)" or "A"
|
|
943
|
+
import re
|
|
944
|
+
|
|
945
|
+
answer_match = re.search(r"[ABCD]", answer.upper())
|
|
946
|
+
if not answer_match:
|
|
947
|
+
return []
|
|
948
|
+
|
|
949
|
+
answer_letter = answer_match.group()
|
|
950
|
+
|
|
951
|
+
# Map answer letter to choice
|
|
952
|
+
choices_map = {"A": choice1, "B": choice2, "C": choice3, "D": choice4}
|
|
953
|
+
|
|
954
|
+
correct_answer = choices_map.get(answer_letter, "")
|
|
955
|
+
if not correct_answer:
|
|
956
|
+
return []
|
|
957
|
+
|
|
958
|
+
# Create pairs with each incorrect option
|
|
959
|
+
pairs = []
|
|
960
|
+
for letter, choice in choices_map.items():
|
|
961
|
+
if letter != answer_letter and choice:
|
|
962
|
+
pairs.append(
|
|
963
|
+
{
|
|
964
|
+
"context": question,
|
|
965
|
+
"good_response": correct_answer,
|
|
966
|
+
"bad_response": choice,
|
|
967
|
+
"metadata": {
|
|
968
|
+
"answer_key": answer_letter,
|
|
969
|
+
"raw_answer": answer,
|
|
970
|
+
"benchmark_type": "gpqa",
|
|
971
|
+
"subdomain": sample.get("Subdomain", ""),
|
|
972
|
+
"high_level_domain": sample.get("High-level domain", ""),
|
|
973
|
+
"difficulty_estimate": sample.get("Writer's Difficulty Estimate", ""),
|
|
974
|
+
"expert_accuracy": sample.get("Expert Validator Accuracy", ""),
|
|
975
|
+
"explanation": sample.get("Explanation", "")[:200]
|
|
976
|
+
if sample.get("Explanation")
|
|
977
|
+
else "", # Truncate long explanations
|
|
978
|
+
},
|
|
979
|
+
}
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
return pairs
|
|
983
|
+
|
|
984
|
+
def _convert_hle_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
985
|
+
"""Convert HLE format (question, answer, answer_type, category)."""
|
|
986
|
+
question = sample.get("question", "")
|
|
987
|
+
answer = sample.get("answer", "")
|
|
988
|
+
answer_type = sample.get("answer_type", "")
|
|
989
|
+
category = sample.get("category", "")
|
|
990
|
+
|
|
991
|
+
if not question or not answer:
|
|
992
|
+
return []
|
|
993
|
+
|
|
994
|
+
# Use the HLE extractor to get contrastive pairs
|
|
995
|
+
from wisent_guard.core.benchmark_extractors import HLEExtractor
|
|
996
|
+
|
|
997
|
+
try:
|
|
998
|
+
extractor = HLEExtractor()
|
|
999
|
+
contrastive_pair = extractor.extract_contrastive_pair(sample)
|
|
1000
|
+
|
|
1001
|
+
if contrastive_pair:
|
|
1002
|
+
return [
|
|
1003
|
+
{
|
|
1004
|
+
"question": contrastive_pair["question"],
|
|
1005
|
+
"good_response": contrastive_pair["correct_answer"],
|
|
1006
|
+
"bad_response": contrastive_pair["incorrect_answer"],
|
|
1007
|
+
"metadata": {
|
|
1008
|
+
"answer_type": answer_type,
|
|
1009
|
+
"category": category,
|
|
1010
|
+
"raw_subject": sample.get("raw_subject", ""),
|
|
1011
|
+
"benchmark_type": "hle",
|
|
1012
|
+
},
|
|
1013
|
+
}
|
|
1014
|
+
]
|
|
1015
|
+
return []
|
|
1016
|
+
except Exception as e:
|
|
1017
|
+
print(f" ā ļø Error converting HLE sample: {e}")
|
|
1018
|
+
return []
|
|
1019
|
+
|
|
1020
|
+
def _convert_squad2_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
1021
|
+
"""Convert SQuAD2 format (id, title, context, question, answers)."""
|
|
1022
|
+
context = sample.get("context", "")
|
|
1023
|
+
question = sample.get("question", "")
|
|
1024
|
+
answers = sample.get("answers", {})
|
|
1025
|
+
|
|
1026
|
+
if not context or not question:
|
|
1027
|
+
return []
|
|
1028
|
+
|
|
1029
|
+
# Handle SQuAD2 answer format
|
|
1030
|
+
answer_text = ""
|
|
1031
|
+
if isinstance(answers, dict):
|
|
1032
|
+
answer_texts = answers.get("text", [])
|
|
1033
|
+
if answer_texts and len(answer_texts) > 0:
|
|
1034
|
+
answer_text = answer_texts[0]
|
|
1035
|
+
elif isinstance(answers, list) and len(answers) > 0:
|
|
1036
|
+
if isinstance(answers[0], dict):
|
|
1037
|
+
answer_text = answers[0].get("text", "")
|
|
1038
|
+
else:
|
|
1039
|
+
answer_text = str(answers[0])
|
|
1040
|
+
|
|
1041
|
+
if not answer_text:
|
|
1042
|
+
# For unanswerable questions in SQuAD2, create a pair with empty answer
|
|
1043
|
+
answer_text = "[No answer available]"
|
|
1044
|
+
|
|
1045
|
+
# Create a contrastive pair using question-answering format
|
|
1046
|
+
return [
|
|
1047
|
+
{
|
|
1048
|
+
"question": f"Context: {context}\n\nQuestion: {question}",
|
|
1049
|
+
"good_response": answer_text,
|
|
1050
|
+
"bad_response": "[Incorrect answer]", # Generic bad response for SQuAD2
|
|
1051
|
+
"metadata": {
|
|
1052
|
+
"id": sample.get("id", ""),
|
|
1053
|
+
"title": sample.get("title", ""),
|
|
1054
|
+
"benchmark_type": "squad2",
|
|
1055
|
+
"task_type": "reading_comprehension",
|
|
1056
|
+
},
|
|
1057
|
+
}
|
|
1058
|
+
]
|
|
1059
|
+
|
|
1060
|
+
def _convert_winogrande_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
1061
|
+
"""Convert Winogrande format (sentence, option1, option2, answer)."""
|
|
1062
|
+
sentence = sample.get("sentence", "")
|
|
1063
|
+
option1 = sample.get("option1", "")
|
|
1064
|
+
option2 = sample.get("option2", "")
|
|
1065
|
+
answer = sample.get("answer", "")
|
|
1066
|
+
|
|
1067
|
+
if not sentence or not option1 or not option2 or not answer:
|
|
1068
|
+
return []
|
|
1069
|
+
|
|
1070
|
+
# Determine correct and incorrect answers
|
|
1071
|
+
if answer == "1":
|
|
1072
|
+
correct_answer = option1
|
|
1073
|
+
incorrect_answer = option2
|
|
1074
|
+
elif answer == "2":
|
|
1075
|
+
correct_answer = option2
|
|
1076
|
+
incorrect_answer = option1
|
|
1077
|
+
else:
|
|
1078
|
+
# If answer format is unexpected, default to option1 as correct
|
|
1079
|
+
correct_answer = option1
|
|
1080
|
+
incorrect_answer = option2
|
|
1081
|
+
|
|
1082
|
+
# Create contrastive pair
|
|
1083
|
+
return [
|
|
1084
|
+
{
|
|
1085
|
+
"question": sentence, # The sentence with blank to fill
|
|
1086
|
+
"good_response": correct_answer,
|
|
1087
|
+
"bad_response": incorrect_answer,
|
|
1088
|
+
"metadata": {
|
|
1089
|
+
"option1": option1,
|
|
1090
|
+
"option2": option2,
|
|
1091
|
+
"answer": answer,
|
|
1092
|
+
"benchmark_type": "winogrande",
|
|
1093
|
+
"task_type": "coreference_resolution",
|
|
1094
|
+
},
|
|
1095
|
+
}
|
|
1096
|
+
]
|
|
1097
|
+
|
|
1098
|
+
def _convert_wikitext_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
1099
|
+
"""Convert WikiText format (page)."""
|
|
1100
|
+
page = sample.get("page", "")
|
|
1101
|
+
|
|
1102
|
+
if not page or len(page.strip()) < 50: # Skip very short pages
|
|
1103
|
+
return []
|
|
1104
|
+
|
|
1105
|
+
# For WikiText, we create language modeling pairs
|
|
1106
|
+
# Split the page into sentences and create good/corrupted pairs
|
|
1107
|
+
sentences = page.split(". ")
|
|
1108
|
+
if len(sentences) < 2:
|
|
1109
|
+
return []
|
|
1110
|
+
|
|
1111
|
+
pairs = []
|
|
1112
|
+
for i, sentence in enumerate(sentences):
|
|
1113
|
+
if len(sentence.strip()) > 20: # Only use substantial sentences
|
|
1114
|
+
# Create a corrupted version by replacing some words
|
|
1115
|
+
words = sentence.split()
|
|
1116
|
+
if len(words) > 3:
|
|
1117
|
+
# Simple corruption: duplicate a word in the middle
|
|
1118
|
+
mid_idx = len(words) // 2
|
|
1119
|
+
corrupted_words = words.copy()
|
|
1120
|
+
corrupted_words.insert(mid_idx, words[mid_idx])
|
|
1121
|
+
corrupted_sentence = " ".join(corrupted_words)
|
|
1122
|
+
|
|
1123
|
+
pairs.append(
|
|
1124
|
+
{
|
|
1125
|
+
"question": "Complete the text naturally:",
|
|
1126
|
+
"good_response": sentence.strip(),
|
|
1127
|
+
"bad_response": corrupted_sentence,
|
|
1128
|
+
"metadata": {
|
|
1129
|
+
"benchmark_type": "wikitext",
|
|
1130
|
+
"task_type": "language_modeling",
|
|
1131
|
+
"sentence_index": i,
|
|
1132
|
+
},
|
|
1133
|
+
}
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
# Limit to 3 pairs per page to avoid too many
|
|
1137
|
+
if len(pairs) >= 3:
|
|
1138
|
+
break
|
|
1139
|
+
|
|
1140
|
+
return pairs
|
|
1141
|
+
|
|
1142
|
+
def _convert_webqs_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
1143
|
+
"""Convert WebQS format (question, answers list)."""
|
|
1144
|
+
question = sample.get("question", "")
|
|
1145
|
+
answers = sample.get("answers", [])
|
|
1146
|
+
|
|
1147
|
+
if not question or not answers:
|
|
1148
|
+
return []
|
|
1149
|
+
|
|
1150
|
+
# Take the first answer as the correct one
|
|
1151
|
+
correct_answer = answers[0] if answers else ""
|
|
1152
|
+
|
|
1153
|
+
if not correct_answer:
|
|
1154
|
+
return []
|
|
1155
|
+
|
|
1156
|
+
# Generate incorrect answers (simple approach)
|
|
1157
|
+
incorrect_answers = []
|
|
1158
|
+
|
|
1159
|
+
# Strategy 1: Use other answers from the same dataset if available
|
|
1160
|
+
if len(answers) > 1:
|
|
1161
|
+
incorrect_answers.extend(answers[1:3]) # Take up to 2 more answers as distractors
|
|
1162
|
+
|
|
1163
|
+
# Strategy 2: Generate simple incorrect answers
|
|
1164
|
+
if len(incorrect_answers) < 2:
|
|
1165
|
+
# Simple factual distractors
|
|
1166
|
+
incorrect_answers.append("Unknown")
|
|
1167
|
+
incorrect_answers.append("No information available")
|
|
1168
|
+
|
|
1169
|
+
# Create contrastive pairs
|
|
1170
|
+
pairs = []
|
|
1171
|
+
for incorrect in incorrect_answers[:2]: # Limit to 2 pairs
|
|
1172
|
+
pairs.append(
|
|
1173
|
+
{
|
|
1174
|
+
"question": question,
|
|
1175
|
+
"good_response": correct_answer,
|
|
1176
|
+
"bad_response": incorrect,
|
|
1177
|
+
"metadata": {"benchmark_type": "webqs", "task_type": "factual_qa", "url": sample.get("url", "")},
|
|
1178
|
+
}
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
return pairs
|
|
1182
|
+
|
|
1183
|
+
def _convert_naturalqs_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
1184
|
+
"""Convert NaturalQS format (question, answer as list)."""
|
|
1185
|
+
question = sample.get("question", "")
|
|
1186
|
+
answer_list = sample.get("answer", [])
|
|
1187
|
+
|
|
1188
|
+
if not question or not answer_list:
|
|
1189
|
+
return []
|
|
1190
|
+
|
|
1191
|
+
# Take the first answer as the correct one (shortest/most direct)
|
|
1192
|
+
correct_answer = answer_list[0] if answer_list else ""
|
|
1193
|
+
|
|
1194
|
+
if not correct_answer:
|
|
1195
|
+
return []
|
|
1196
|
+
|
|
1197
|
+
# Generate incorrect answers
|
|
1198
|
+
incorrect_answers = []
|
|
1199
|
+
|
|
1200
|
+
# Strategy 1: Use other answers from the list as distractors if available
|
|
1201
|
+
if len(answer_list) > 1:
|
|
1202
|
+
incorrect_answers.extend(answer_list[1:3]) # Take up to 2 more answers
|
|
1203
|
+
|
|
1204
|
+
# Strategy 2: Generate generic incorrect answers
|
|
1205
|
+
if len(incorrect_answers) < 2:
|
|
1206
|
+
incorrect_answers.append("I don't know the answer to this question.")
|
|
1207
|
+
incorrect_answers.append("This information is not available.")
|
|
1208
|
+
|
|
1209
|
+
# Create contrastive pairs
|
|
1210
|
+
pairs = []
|
|
1211
|
+
for incorrect in incorrect_answers[:2]: # Limit to 2 pairs
|
|
1212
|
+
pairs.append(
|
|
1213
|
+
{
|
|
1214
|
+
"context": question,
|
|
1215
|
+
"good_response": correct_answer,
|
|
1216
|
+
"bad_response": incorrect,
|
|
1217
|
+
"metadata": {
|
|
1218
|
+
"benchmark_type": "naturalqs",
|
|
1219
|
+
"task_type": "factual_qa",
|
|
1220
|
+
"total_answers": len(answer_list),
|
|
1221
|
+
},
|
|
1222
|
+
}
|
|
1223
|
+
)
|
|
1224
|
+
|
|
1225
|
+
return pairs
|
|
1226
|
+
|
|
1227
|
+
def _convert_triviaqa_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
1228
|
+
"""Convert TriviaQA format (question, answer as dict with aliases)."""
|
|
1229
|
+
question = sample.get("question", "")
|
|
1230
|
+
answer_dict = sample.get("answer", {})
|
|
1231
|
+
|
|
1232
|
+
if not question or not answer_dict:
|
|
1233
|
+
return []
|
|
1234
|
+
|
|
1235
|
+
# Extract the correct answer from aliases
|
|
1236
|
+
aliases = answer_dict.get("aliases", [])
|
|
1237
|
+
if not aliases:
|
|
1238
|
+
# Fallback to other fields
|
|
1239
|
+
correct_answer = (
|
|
1240
|
+
answer_dict.get("value", "") or answer_dict.get("normalized_value", "") or str(answer_dict)
|
|
1241
|
+
)[:100] # Truncate if too long
|
|
1242
|
+
else:
|
|
1243
|
+
correct_answer = aliases[0] # Use first alias as primary answer
|
|
1244
|
+
|
|
1245
|
+
if not correct_answer:
|
|
1246
|
+
return []
|
|
1247
|
+
|
|
1248
|
+
# Generate incorrect answers
|
|
1249
|
+
incorrect_answers = []
|
|
1250
|
+
|
|
1251
|
+
# Strategy 1: Use other aliases as distractors if available
|
|
1252
|
+
if len(aliases) > 1:
|
|
1253
|
+
incorrect_answers.extend(aliases[1:3]) # Take up to 2 more aliases
|
|
1254
|
+
|
|
1255
|
+
# Strategy 2: Generate generic incorrect answers for trivia
|
|
1256
|
+
if len(incorrect_answers) < 2:
|
|
1257
|
+
incorrect_answers.append("Unknown")
|
|
1258
|
+
incorrect_answers.append("I don't know")
|
|
1259
|
+
|
|
1260
|
+
# Create contrastive pairs
|
|
1261
|
+
pairs = []
|
|
1262
|
+
for incorrect in incorrect_answers[:2]: # Limit to 2 pairs
|
|
1263
|
+
pairs.append(
|
|
1264
|
+
{
|
|
1265
|
+
"context": question,
|
|
1266
|
+
"good_response": correct_answer,
|
|
1267
|
+
"bad_response": incorrect,
|
|
1268
|
+
"metadata": {
|
|
1269
|
+
"benchmark_type": "triviaqa",
|
|
1270
|
+
"task_type": "trivia_qa",
|
|
1271
|
+
"total_aliases": len(aliases),
|
|
1272
|
+
"entity_name": answer_dict.get("matched_wiki_entity_name", ""),
|
|
1273
|
+
},
|
|
1274
|
+
}
|
|
1275
|
+
)
|
|
1276
|
+
|
|
1277
|
+
return pairs
|
|
1278
|
+
|
|
1279
|
+
def _convert_mbpp_format(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
1280
|
+
"""Convert MBPP/HumanEval code generation format (task_id, code, prompt, test)."""
|
|
1281
|
+
task_id = sample.get("task_id", "")
|
|
1282
|
+
code = sample.get("code", "")
|
|
1283
|
+
prompt = sample.get("prompt", "")
|
|
1284
|
+
test = sample.get("test", "")
|
|
1285
|
+
|
|
1286
|
+
# For code generation tasks, we create contrastive pairs based on:
|
|
1287
|
+
# Correct: The reference code solution
|
|
1288
|
+
# Incorrect: A placeholder for incorrect/buggy code (since we don't have real incorrect solutions)
|
|
1289
|
+
|
|
1290
|
+
pairs = []
|
|
1291
|
+
|
|
1292
|
+
# Create a contrastive pair with the coding prompt
|
|
1293
|
+
pairs.append(
|
|
1294
|
+
{
|
|
1295
|
+
"question": f"Write Python code to solve this problem:\n\n{prompt}",
|
|
1296
|
+
"correct_answer": code,
|
|
1297
|
+
"incorrect_answer": "# This is a placeholder for incorrect code\n# In practice, this would be buggy or incomplete code\npass", # TODO
|
|
1298
|
+
"metadata": {
|
|
1299
|
+
"task_id": task_id,
|
|
1300
|
+
"test_cases": test,
|
|
1301
|
+
"source_file": sample.get("source_file", ""),
|
|
1302
|
+
"test_imports": sample.get("test_imports", ""),
|
|
1303
|
+
"test_list": sample.get("test_list", []),
|
|
1304
|
+
"benchmark_type": "mbpp",
|
|
1305
|
+
"task_type": "code_generation",
|
|
1306
|
+
"programming_language": "python",
|
|
1307
|
+
},
|
|
1308
|
+
}
|
|
1309
|
+
)
|
|
1310
|
+
|
|
1311
|
+
return pairs
|
|
1312
|
+
|
|
1313
|
+
|
|
1314
|
+
def main():
|
|
1315
|
+
"""Main function for CLI usage."""
|
|
1316
|
+
parser = argparse.ArgumentParser(description="Download complete benchmarks from lm-eval-harness")
|
|
1317
|
+
|
|
1318
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
1319
|
+
group.add_argument("--benchmarks", nargs="+", help="Specific benchmarks to download")
|
|
1320
|
+
group.add_argument("--all", action="store_true", help="Download all available benchmarks")
|
|
1321
|
+
|
|
1322
|
+
parser.add_argument("--force", action="store_true", help="Force redownload even if exists")
|
|
1323
|
+
parser.add_argument("--download-dir", default="full_benchmarks", help="Directory to save downloads")
|
|
1324
|
+
|
|
1325
|
+
args = parser.parse_args()
|
|
1326
|
+
|
|
1327
|
+
print("š Full Benchmark Downloader")
|
|
1328
|
+
print("=" * 60)
|
|
1329
|
+
|
|
1330
|
+
# Create downloader
|
|
1331
|
+
downloader = FullBenchmarkDownloader(download_dir=args.download_dir)
|
|
1332
|
+
|
|
1333
|
+
# Download benchmarks
|
|
1334
|
+
try:
|
|
1335
|
+
if args.all:
|
|
1336
|
+
benchmarks_to_download = None
|
|
1337
|
+
print(f"š Downloading ALL {len(CORE_BENCHMARKS)} available benchmarks")
|
|
1338
|
+
else:
|
|
1339
|
+
benchmarks_to_download = args.benchmarks
|
|
1340
|
+
print(f"š Downloading {len(args.benchmarks)} specified benchmarks: {args.benchmarks}")
|
|
1341
|
+
|
|
1342
|
+
results = downloader.download_all_benchmarks(benchmarks=benchmarks_to_download, force=args.force)
|
|
1343
|
+
|
|
1344
|
+
# Print summary
|
|
1345
|
+
print("\n" + "=" * 80)
|
|
1346
|
+
print("š FULL BENCHMARK DOWNLOAD SUMMARY")
|
|
1347
|
+
print("=" * 80)
|
|
1348
|
+
print(f"ā
Successful: {len(results['successful'])}")
|
|
1349
|
+
print(f"ā© Skipped (already exist): {len(results['skipped'])}")
|
|
1350
|
+
print(f"ā Failed: {len(results['failed'])}")
|
|
1351
|
+
if results["excluded"]:
|
|
1352
|
+
print(f"š« Excluded (known unavailable): {len(results['excluded'])}")
|
|
1353
|
+
print(f"ā±ļø Total time: {results['total_time'] / 60:.1f} minutes")
|
|
1354
|
+
print(f"š Download directory: {downloader.download_dir.absolute()}")
|
|
1355
|
+
|
|
1356
|
+
if results["successful"]:
|
|
1357
|
+
print("\nšÆ Successfully downloaded:")
|
|
1358
|
+
for benchmark in results["successful"]:
|
|
1359
|
+
print(f" ā
{benchmark}")
|
|
1360
|
+
|
|
1361
|
+
if results["failed"]:
|
|
1362
|
+
print("\nā Failed downloads:")
|
|
1363
|
+
for benchmark in results["failed"]:
|
|
1364
|
+
print(f" ā {benchmark}")
|
|
1365
|
+
|
|
1366
|
+
if results["excluded"]:
|
|
1367
|
+
print("\nš« Excluded (known unavailable):")
|
|
1368
|
+
excluded_list = sorted(results["excluded"])
|
|
1369
|
+
for i in range(0, len(excluded_list), 4): # Show 4 per line
|
|
1370
|
+
line_items = excluded_list[i : i + 4]
|
|
1371
|
+
print(f" š« {', '.join(line_items)}")
|
|
1372
|
+
|
|
1373
|
+
print("\nš Complete benchmark data saved in:")
|
|
1374
|
+
print(f" š Data: {downloader.data_dir}")
|
|
1375
|
+
print(f" š Metadata: {downloader.metadata_dir}")
|
|
1376
|
+
|
|
1377
|
+
if results["successful"]:
|
|
1378
|
+
print(f"\nš SUCCESS! Downloaded {len(results['successful'])} complete benchmarks!")
|
|
1379
|
+
|
|
1380
|
+
except Exception as e:
|
|
1381
|
+
print(f"\nā Error: {e}")
|
|
1382
|
+
return 1
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
if __name__ == "__main__":
|
|
1386
|
+
exit(main())
|