wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""
|
|
2
|
+
File-based task implementation for loading custom datasets from JSON files.
|
|
3
|
+
|
|
4
|
+
This allows users to easily test the optimization pipeline with their own datasets
|
|
5
|
+
without needing to implement complex task classes or modify the core system.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
|
|
12
|
+
from ..benchmark_extractors import GSM8KExtractor
|
|
13
|
+
from ..task_interface import TaskInterface
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FileTask(TaskInterface):
|
|
17
|
+
"""Task that loads data from a JSON file."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, file_path: str, task_name: Optional[str] = None, limit: Optional[int] = None):
|
|
20
|
+
"""
|
|
21
|
+
Initialize a file-based task.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
file_path: Path to JSON file containing the dataset
|
|
25
|
+
task_name: Optional custom name for the task (defaults to filename)
|
|
26
|
+
limit: Optional limit on number of samples to load
|
|
27
|
+
"""
|
|
28
|
+
self.file_path = Path(file_path)
|
|
29
|
+
self._limit = limit
|
|
30
|
+
self._data = None # Cache for loaded data
|
|
31
|
+
self._extractor = GSM8KExtractor() # Reuse GSM8K extractor for QA format
|
|
32
|
+
|
|
33
|
+
# Set task name
|
|
34
|
+
if task_name:
|
|
35
|
+
self._task_name = task_name
|
|
36
|
+
else:
|
|
37
|
+
self._task_name = self.file_path.stem.lower()
|
|
38
|
+
|
|
39
|
+
def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
40
|
+
"""Load data from the JSON file."""
|
|
41
|
+
if not self.file_path.exists():
|
|
42
|
+
raise FileNotFoundError(f"Dataset file not found: {self.file_path}")
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
with open(self.file_path, encoding="utf-8") as f:
|
|
46
|
+
data = json.load(f)
|
|
47
|
+
except json.JSONDecodeError as e:
|
|
48
|
+
raise ValueError(f"Invalid JSON in file {self.file_path}: {e}")
|
|
49
|
+
except Exception as e:
|
|
50
|
+
raise RuntimeError(f"Failed to load file {self.file_path}: {e}")
|
|
51
|
+
|
|
52
|
+
# Ensure data is a list
|
|
53
|
+
if not isinstance(data, list):
|
|
54
|
+
raise ValueError(f"JSON file must contain a list of objects, got {type(data).__name__}")
|
|
55
|
+
|
|
56
|
+
# Validate samples
|
|
57
|
+
for i, sample in enumerate(data):
|
|
58
|
+
if not self.validate_sample(sample):
|
|
59
|
+
raise ValueError(f"Invalid sample at index {i}: {sample}")
|
|
60
|
+
|
|
61
|
+
# Apply limit
|
|
62
|
+
effective_limit = limit or self._limit
|
|
63
|
+
if effective_limit:
|
|
64
|
+
data = data[: min(effective_limit, len(data))]
|
|
65
|
+
|
|
66
|
+
return data
|
|
67
|
+
|
|
68
|
+
def get_extractor(self) -> GSM8KExtractor:
|
|
69
|
+
"""Get the benchmark extractor for this task."""
|
|
70
|
+
return self._extractor
|
|
71
|
+
|
|
72
|
+
def get_name(self) -> str:
|
|
73
|
+
"""Get the task name."""
|
|
74
|
+
return self._task_name
|
|
75
|
+
|
|
76
|
+
def get_description(self) -> str:
|
|
77
|
+
"""Get the task description."""
|
|
78
|
+
return f"Custom dataset loaded from {self.file_path.name}"
|
|
79
|
+
|
|
80
|
+
def get_categories(self) -> List[str]:
|
|
81
|
+
"""Get the task categories."""
|
|
82
|
+
return ["custom", "file_based", "text_generation"]
|
|
83
|
+
|
|
84
|
+
def validate_sample(self, sample: Dict[str, Any]) -> bool:
|
|
85
|
+
"""
|
|
86
|
+
Validate that a sample has the required format.
|
|
87
|
+
|
|
88
|
+
Expected format:
|
|
89
|
+
{
|
|
90
|
+
"question": "Question text",
|
|
91
|
+
"answer": "Expected answer"
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
Optional fields:
|
|
95
|
+
- "problem": Alternative to "question"
|
|
96
|
+
- Any other fields will be preserved but ignored
|
|
97
|
+
"""
|
|
98
|
+
if not isinstance(sample, dict):
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
# Check for question field (or alternative names)
|
|
102
|
+
question = sample.get("question") or sample.get("problem")
|
|
103
|
+
if not question or not isinstance(question, str):
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
# Check for answer field
|
|
107
|
+
answer = sample.get("answer")
|
|
108
|
+
if answer is None:
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
return True
|
|
112
|
+
|
|
113
|
+
# Methods to match lm-eval interface
|
|
114
|
+
def has_validation_docs(self) -> bool:
|
|
115
|
+
"""Check if task has validation documents."""
|
|
116
|
+
return False # File tasks don't have separate validation sets
|
|
117
|
+
|
|
118
|
+
def has_test_docs(self) -> bool:
|
|
119
|
+
"""Check if task has test documents."""
|
|
120
|
+
return True # All samples are considered test docs
|
|
121
|
+
|
|
122
|
+
def test_docs(self) -> List[Dict[str, Any]]:
|
|
123
|
+
"""Get test documents."""
|
|
124
|
+
if self._data is None:
|
|
125
|
+
self._data = self.load_data()
|
|
126
|
+
return self._data
|
|
127
|
+
|
|
128
|
+
def validation_docs(self) -> List[Dict[str, Any]]:
|
|
129
|
+
"""Get validation documents."""
|
|
130
|
+
return [] # No separate validation set
|
|
131
|
+
|
|
132
|
+
def doc_to_text(self, doc: Dict[str, Any]) -> str:
|
|
133
|
+
"""Convert document to text prompt."""
|
|
134
|
+
question = doc.get("question") or doc.get("problem", "")
|
|
135
|
+
return f"Question: {question}\nAnswer:"
|
|
136
|
+
|
|
137
|
+
def get_task_info(self) -> Dict[str, Any]:
|
|
138
|
+
"""Get information about the file task."""
|
|
139
|
+
return {
|
|
140
|
+
"task_name": self._task_name,
|
|
141
|
+
"description": self.get_description(),
|
|
142
|
+
"source": str(self.file_path),
|
|
143
|
+
"task_type": "text_generation",
|
|
144
|
+
"evaluation_method": "exact_match",
|
|
145
|
+
"num_samples": len(self.test_docs()) if self._data else "unknown",
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def create_file_task(file_path: str, task_name: Optional[str] = None) -> callable:
|
|
150
|
+
"""
|
|
151
|
+
Create a task factory function for a file-based task.
|
|
152
|
+
|
|
153
|
+
This is the recommended way to create file tasks for registration.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
file_path: Path to the JSON dataset file
|
|
157
|
+
task_name: Optional custom name for the task
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
A factory function that creates FileTask instances
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def task_factory(limit: Optional[int] = None) -> FileTask:
|
|
164
|
+
return FileTask(file_path=file_path, task_name=task_name, limit=limit)
|
|
165
|
+
|
|
166
|
+
return task_factory
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def register_file_task(task_name: str, file_path: str, registry=None):
|
|
170
|
+
"""
|
|
171
|
+
Register a file-based task with the global task registry.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
task_name: Name to register the task under
|
|
175
|
+
file_path: Path to the JSON dataset file
|
|
176
|
+
registry: Optional registry to use (defaults to global registry)
|
|
177
|
+
"""
|
|
178
|
+
from ..task_interface import register_task
|
|
179
|
+
|
|
180
|
+
task_factory = create_file_task(file_path, task_name)
|
|
181
|
+
register_task(task_name, task_factory)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def load_tasks_from_directory(directory: str, pattern: str = "*.json", prefix: str = ""):
|
|
185
|
+
"""
|
|
186
|
+
Load all JSON files in a directory as tasks.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
directory: Directory to search for JSON files
|
|
190
|
+
pattern: File pattern to match (default: "*.json")
|
|
191
|
+
prefix: Optional prefix to add to task names
|
|
192
|
+
"""
|
|
193
|
+
directory_path = Path(directory)
|
|
194
|
+
|
|
195
|
+
if not directory_path.exists():
|
|
196
|
+
raise FileNotFoundError(f"Directory not found: {directory}")
|
|
197
|
+
|
|
198
|
+
if not directory_path.is_dir():
|
|
199
|
+
raise ValueError(f"Path is not a directory: {directory}")
|
|
200
|
+
|
|
201
|
+
loaded_tasks = []
|
|
202
|
+
|
|
203
|
+
for json_file in directory_path.glob(pattern):
|
|
204
|
+
try:
|
|
205
|
+
task_name = f"{prefix}{json_file.stem}".lower()
|
|
206
|
+
register_file_task(task_name, str(json_file))
|
|
207
|
+
loaded_tasks.append(task_name)
|
|
208
|
+
except Exception as e:
|
|
209
|
+
print(f"Warning: Failed to load task from {json_file}: {e}")
|
|
210
|
+
|
|
211
|
+
return loaded_tasks
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HLE (Human-Level Evaluation) task implementation for task-agnostic architecture.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Any, List, Optional
|
|
6
|
+
from datasets import load_dataset
|
|
7
|
+
from ..task_interface import TaskInterface
|
|
8
|
+
from ..benchmark_extractors import HLEExtractor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HLETask(TaskInterface):
|
|
12
|
+
"""HLE (Human-Level Evaluation) task implementation."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, category_filter: Optional[str] = None, answer_type_filter: Optional[str] = None,
|
|
15
|
+
limit: Optional[int] = None):
|
|
16
|
+
"""Initialize HLE task.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
category_filter: Filter by category (Math, Physics, CS, etc.)
|
|
20
|
+
answer_type_filter: Filter by answer type ('exactMatch' or 'multipleChoice')
|
|
21
|
+
limit: Maximum number of examples to load
|
|
22
|
+
"""
|
|
23
|
+
self.dataset_name = "cais/hle"
|
|
24
|
+
self.category_filter = category_filter
|
|
25
|
+
self.answer_type_filter = answer_type_filter
|
|
26
|
+
self.limit = limit
|
|
27
|
+
self._extractor = HLEExtractor()
|
|
28
|
+
self._data = None # Cache for loaded data
|
|
29
|
+
|
|
30
|
+
def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
31
|
+
"""Load HLE data from HuggingFace datasets."""
|
|
32
|
+
dataset = load_dataset(self.dataset_name, split="test")
|
|
33
|
+
|
|
34
|
+
# Filter out multimodal examples for initial implementation
|
|
35
|
+
text_only_data = [
|
|
36
|
+
item for item in dataset
|
|
37
|
+
if not item.get('image') and not item.get('image_1') and not item.get('image_2')
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
# Apply additional filters
|
|
41
|
+
filtered_data = self._filter_and_process(text_only_data)
|
|
42
|
+
|
|
43
|
+
# Apply limit
|
|
44
|
+
effective_limit = limit or self.limit
|
|
45
|
+
if effective_limit:
|
|
46
|
+
filtered_data = filtered_data[:effective_limit]
|
|
47
|
+
|
|
48
|
+
return filtered_data
|
|
49
|
+
|
|
50
|
+
def _filter_and_process(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
51
|
+
"""Filter data by category and answer type, and convert to internal format."""
|
|
52
|
+
filtered_data = []
|
|
53
|
+
|
|
54
|
+
for item in data:
|
|
55
|
+
# Apply category filter
|
|
56
|
+
if self.category_filter and item.get('category') != self.category_filter:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
# Apply answer type filter
|
|
60
|
+
if self.answer_type_filter and item.get('answer_type') != self.answer_type_filter:
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
# Convert to internal format
|
|
64
|
+
processed_item = {
|
|
65
|
+
'question': item.get('question', ''),
|
|
66
|
+
'answer': item.get('answer', ''),
|
|
67
|
+
'answer_type': item.get('answer_type', ''),
|
|
68
|
+
'category': item.get('category', ''),
|
|
69
|
+
'raw_subject': item.get('raw_subject', ''),
|
|
70
|
+
'rationale': item.get('rationale', ''),
|
|
71
|
+
'author_name': item.get('author_name', ''),
|
|
72
|
+
'id': item.get('id', ''),
|
|
73
|
+
'metadata': {
|
|
74
|
+
'canary': item.get('canary', ''),
|
|
75
|
+
'dataset': self.dataset_name
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# For multiple choice, parse choices from question text if needed
|
|
80
|
+
if item.get('answer_type') == 'multipleChoice':
|
|
81
|
+
processed_item['parsed_choices'] = self._parse_choices_from_question(item.get('question', ''))
|
|
82
|
+
|
|
83
|
+
filtered_data.append(processed_item)
|
|
84
|
+
|
|
85
|
+
return filtered_data
|
|
86
|
+
|
|
87
|
+
def _parse_choices_from_question(self, question: str) -> List[str]:
|
|
88
|
+
"""Parse multiple choice options from question text."""
|
|
89
|
+
# Look for patterns like "A. ", "B. ", etc.
|
|
90
|
+
import re
|
|
91
|
+
choices = []
|
|
92
|
+
patterns = [
|
|
93
|
+
r'([A-E])\.\s+(.+?)(?=\n[A-E]\.|$)', # "A. option" format
|
|
94
|
+
r'([A-E])\)\s+(.+?)(?=\n[A-E]\)|$)', # "A) option" format
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
for pattern in patterns:
|
|
98
|
+
matches = re.findall(pattern, question, re.MULTILINE | re.DOTALL)
|
|
99
|
+
if matches:
|
|
100
|
+
choices = [f"{letter}. {text.strip()}" for letter, text in matches]
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
return choices
|
|
104
|
+
|
|
105
|
+
def get_extractor(self) -> HLEExtractor:
|
|
106
|
+
"""Get the HLE benchmark extractor."""
|
|
107
|
+
return self._extractor
|
|
108
|
+
|
|
109
|
+
def get_name(self) -> str:
|
|
110
|
+
"""Get the task name."""
|
|
111
|
+
return "hle"
|
|
112
|
+
|
|
113
|
+
def get_description(self) -> str:
|
|
114
|
+
"""Get the task description."""
|
|
115
|
+
desc = "HLE (Human-Level Evaluation): Multimodal benchmark for human-level reasoning across multiple domains"
|
|
116
|
+
if self.category_filter:
|
|
117
|
+
desc += f" (filtered to {self.category_filter})"
|
|
118
|
+
if self.answer_type_filter:
|
|
119
|
+
desc += f" ({self.answer_type_filter} questions only)"
|
|
120
|
+
return desc
|
|
121
|
+
|
|
122
|
+
def get_categories(self) -> List[str]:
|
|
123
|
+
"""Get the task categories."""
|
|
124
|
+
return ["reasoning", "knowledge", "multimodal", "evaluation"]
|
|
125
|
+
|
|
126
|
+
# Methods to match lm-eval interface
|
|
127
|
+
def has_validation_docs(self) -> bool:
|
|
128
|
+
"""Check if task has validation documents."""
|
|
129
|
+
return False # HLE doesn't have separate validation sets
|
|
130
|
+
|
|
131
|
+
def has_test_docs(self) -> bool:
|
|
132
|
+
"""Check if task has test documents."""
|
|
133
|
+
return True # All samples are considered test docs
|
|
134
|
+
|
|
135
|
+
def test_docs(self) -> List[Dict[str, Any]]:
|
|
136
|
+
"""Get test documents."""
|
|
137
|
+
if self._data is None:
|
|
138
|
+
self._data = self.load_data()
|
|
139
|
+
return self._data
|
|
140
|
+
|
|
141
|
+
def validation_docs(self) -> List[Dict[str, Any]]:
|
|
142
|
+
"""Get validation documents."""
|
|
143
|
+
return [] # No separate validation set
|
|
144
|
+
|
|
145
|
+
def doc_to_text(self, doc: Dict[str, Any]) -> str:
|
|
146
|
+
"""Convert document to text prompt."""
|
|
147
|
+
# For HLE, the question already contains the choices for multiple choice
|
|
148
|
+
return doc.get('question', '')
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class HLEExactMatchTask(HLETask):
|
|
152
|
+
"""HLE task filtered to exact match questions only."""
|
|
153
|
+
|
|
154
|
+
def __init__(self, category_filter: Optional[str] = None, limit: Optional[int] = None):
|
|
155
|
+
super().__init__(category_filter=category_filter, answer_type_filter='exactMatch', limit=limit)
|
|
156
|
+
|
|
157
|
+
def get_name(self) -> str:
|
|
158
|
+
return "hle_exact_match"
|
|
159
|
+
|
|
160
|
+
def get_description(self) -> str:
|
|
161
|
+
desc = "HLE Exact Match: Text-based questions requiring exact string matching"
|
|
162
|
+
if self.category_filter:
|
|
163
|
+
desc += f" (filtered to {self.category_filter})"
|
|
164
|
+
return desc
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class HLEMultipleChoiceTask(HLETask):
|
|
168
|
+
"""HLE task filtered to multiple choice questions only."""
|
|
169
|
+
|
|
170
|
+
def __init__(self, category_filter: Optional[str] = None, limit: Optional[int] = None):
|
|
171
|
+
super().__init__(category_filter=category_filter, answer_type_filter='multipleChoice', limit=limit)
|
|
172
|
+
|
|
173
|
+
def get_name(self) -> str:
|
|
174
|
+
return "hle_multiple_choice"
|
|
175
|
+
|
|
176
|
+
def get_description(self) -> str:
|
|
177
|
+
desc = "HLE Multiple Choice: Questions with A/B/C/D/E answer options"
|
|
178
|
+
if self.category_filter:
|
|
179
|
+
desc += f" (filtered to {self.category_filter})"
|
|
180
|
+
return desc
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HMMT (Harvard-MIT Math Tournament) task implementation for task-agnostic architecture.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Any, List, Optional
|
|
6
|
+
from ..task_interface import TaskInterface
|
|
7
|
+
from ..benchmark_extractors import GSM8KExtractor
|
|
8
|
+
import datasets
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HMMTTask(TaskInterface):
|
|
12
|
+
"""HMMT (Harvard-MIT Math Tournament) mathematical contest task implementation."""
|
|
13
|
+
|
|
14
|
+
# Dataset configurations for different HMMT competitions
|
|
15
|
+
DATASET_CONFIGS = {
|
|
16
|
+
"feb_2025": {
|
|
17
|
+
"source": "MathArena/hmmt_feb_2025",
|
|
18
|
+
"split": "train",
|
|
19
|
+
"fields": {"problem": "problem", "answer": "answer"},
|
|
20
|
+
"description": "30 high-difficulty HMMT February 2025 contest problems"
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
def __init__(self, competition: str = "feb_2025", limit: Optional[int] = None):
|
|
25
|
+
"""
|
|
26
|
+
Initialize HMMT task for specified competition.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
competition: HMMT competition to load ("feb_2025"). Default: "feb_2025" (latest)
|
|
30
|
+
limit: Maximum number of samples to load
|
|
31
|
+
"""
|
|
32
|
+
if competition not in self.DATASET_CONFIGS:
|
|
33
|
+
available = list(self.DATASET_CONFIGS.keys())
|
|
34
|
+
raise ValueError(f"HMMT competition '{competition}' not supported. Available: {available}")
|
|
35
|
+
|
|
36
|
+
self.competition = competition
|
|
37
|
+
self.config = self.DATASET_CONFIGS[competition]
|
|
38
|
+
self._limit = limit
|
|
39
|
+
self._data = None # Cache for loaded data
|
|
40
|
+
self._extractor = GSM8KExtractor() # Reuse enhanced GSM8K extractor
|
|
41
|
+
|
|
42
|
+
def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
43
|
+
"""Load HMMT data from HuggingFace for specified competition."""
|
|
44
|
+
# Load dataset based on competition configuration
|
|
45
|
+
dataset = datasets.load_dataset(
|
|
46
|
+
self.config["source"],
|
|
47
|
+
split=self.config["split"]
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Apply limit
|
|
51
|
+
effective_limit = limit or self._limit
|
|
52
|
+
if effective_limit:
|
|
53
|
+
dataset = dataset.select(range(min(effective_limit, len(dataset))))
|
|
54
|
+
|
|
55
|
+
# Convert to list and normalize field names
|
|
56
|
+
data = [dict(item) for item in dataset]
|
|
57
|
+
|
|
58
|
+
# Normalize field names for consistent processing
|
|
59
|
+
normalized_data = []
|
|
60
|
+
problem_field = self.config["fields"]["problem"]
|
|
61
|
+
answer_field = self.config["fields"]["answer"]
|
|
62
|
+
|
|
63
|
+
for item in data:
|
|
64
|
+
normalized_item = dict(item) # Keep all original fields
|
|
65
|
+
|
|
66
|
+
# Ensure consistent field names for extractor
|
|
67
|
+
if problem_field in item:
|
|
68
|
+
normalized_item["Problem"] = item[problem_field]
|
|
69
|
+
normalized_item["question"] = item[problem_field] # For question/answer format
|
|
70
|
+
|
|
71
|
+
if answer_field in item:
|
|
72
|
+
normalized_item["Answer"] = item[answer_field]
|
|
73
|
+
normalized_item["answer"] = item[answer_field] # For question/answer format
|
|
74
|
+
|
|
75
|
+
normalized_data.append(normalized_item)
|
|
76
|
+
|
|
77
|
+
return normalized_data
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_task_info(self) -> Dict[str, Any]:
|
|
81
|
+
"""Get information about the HMMT task."""
|
|
82
|
+
return {
|
|
83
|
+
"task_name": f"hmmt_{self.competition}" if self.competition != "feb_2025" else "hmmt",
|
|
84
|
+
"competition": self.competition,
|
|
85
|
+
"description": self.config["description"],
|
|
86
|
+
"source": self.config["source"],
|
|
87
|
+
"task_type": "text_generation",
|
|
88
|
+
"evaluation_method": "mathematical_equivalence"
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
def validate_sample(self, sample: Dict[str, Any]) -> bool:
|
|
92
|
+
"""Validate that a sample has required HMMT fields."""
|
|
93
|
+
problem_field = self.config["fields"]["problem"]
|
|
94
|
+
answer_field = self.config["fields"]["answer"]
|
|
95
|
+
|
|
96
|
+
return all(field in sample for field in [problem_field, answer_field])
|
|
97
|
+
|
|
98
|
+
def get_extractor(self) -> GSM8KExtractor:
|
|
99
|
+
"""Get the benchmark extractor for this task."""
|
|
100
|
+
return self._extractor
|
|
101
|
+
|
|
102
|
+
def get_name(self) -> str:
|
|
103
|
+
"""Get the task name."""
|
|
104
|
+
return f"hmmt_{self.competition}" if self.competition != "feb_2025" else "hmmt"
|
|
105
|
+
|
|
106
|
+
def get_description(self) -> str:
|
|
107
|
+
"""Get the task description."""
|
|
108
|
+
return f"HMMT {self.competition.replace('_', ' ').title()} contest problems requiring advanced mathematical reasoning"
|
|
109
|
+
|
|
110
|
+
def get_categories(self) -> List[str]:
|
|
111
|
+
"""Get the task categories."""
|
|
112
|
+
return ["mathematics", "reasoning", "contest", "text_generation"]
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def get_supported_competitions(cls) -> List[str]:
|
|
116
|
+
"""Get list of supported HMMT competitions."""
|
|
117
|
+
return list(cls.DATASET_CONFIGS.keys())
|
|
118
|
+
|
|
119
|
+
|