wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
1
+ """
2
+ File-based task implementation for loading custom datasets from JSON files.
3
+
4
+ This allows users to easily test the optimization pipeline with their own datasets
5
+ without needing to implement complex task classes or modify the core system.
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ from ..benchmark_extractors import GSM8KExtractor
13
+ from ..task_interface import TaskInterface
14
+
15
+
16
+ class FileTask(TaskInterface):
17
+ """Task that loads data from a JSON file."""
18
+
19
+ def __init__(self, file_path: str, task_name: Optional[str] = None, limit: Optional[int] = None):
20
+ """
21
+ Initialize a file-based task.
22
+
23
+ Args:
24
+ file_path: Path to JSON file containing the dataset
25
+ task_name: Optional custom name for the task (defaults to filename)
26
+ limit: Optional limit on number of samples to load
27
+ """
28
+ self.file_path = Path(file_path)
29
+ self._limit = limit
30
+ self._data = None # Cache for loaded data
31
+ self._extractor = GSM8KExtractor() # Reuse GSM8K extractor for QA format
32
+
33
+ # Set task name
34
+ if task_name:
35
+ self._task_name = task_name
36
+ else:
37
+ self._task_name = self.file_path.stem.lower()
38
+
39
+ def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
40
+ """Load data from the JSON file."""
41
+ if not self.file_path.exists():
42
+ raise FileNotFoundError(f"Dataset file not found: {self.file_path}")
43
+
44
+ try:
45
+ with open(self.file_path, encoding="utf-8") as f:
46
+ data = json.load(f)
47
+ except json.JSONDecodeError as e:
48
+ raise ValueError(f"Invalid JSON in file {self.file_path}: {e}")
49
+ except Exception as e:
50
+ raise RuntimeError(f"Failed to load file {self.file_path}: {e}")
51
+
52
+ # Ensure data is a list
53
+ if not isinstance(data, list):
54
+ raise ValueError(f"JSON file must contain a list of objects, got {type(data).__name__}")
55
+
56
+ # Validate samples
57
+ for i, sample in enumerate(data):
58
+ if not self.validate_sample(sample):
59
+ raise ValueError(f"Invalid sample at index {i}: {sample}")
60
+
61
+ # Apply limit
62
+ effective_limit = limit or self._limit
63
+ if effective_limit:
64
+ data = data[: min(effective_limit, len(data))]
65
+
66
+ return data
67
+
68
+ def get_extractor(self) -> GSM8KExtractor:
69
+ """Get the benchmark extractor for this task."""
70
+ return self._extractor
71
+
72
+ def get_name(self) -> str:
73
+ """Get the task name."""
74
+ return self._task_name
75
+
76
+ def get_description(self) -> str:
77
+ """Get the task description."""
78
+ return f"Custom dataset loaded from {self.file_path.name}"
79
+
80
+ def get_categories(self) -> List[str]:
81
+ """Get the task categories."""
82
+ return ["custom", "file_based", "text_generation"]
83
+
84
+ def validate_sample(self, sample: Dict[str, Any]) -> bool:
85
+ """
86
+ Validate that a sample has the required format.
87
+
88
+ Expected format:
89
+ {
90
+ "question": "Question text",
91
+ "answer": "Expected answer"
92
+ }
93
+
94
+ Optional fields:
95
+ - "problem": Alternative to "question"
96
+ - Any other fields will be preserved but ignored
97
+ """
98
+ if not isinstance(sample, dict):
99
+ return False
100
+
101
+ # Check for question field (or alternative names)
102
+ question = sample.get("question") or sample.get("problem")
103
+ if not question or not isinstance(question, str):
104
+ return False
105
+
106
+ # Check for answer field
107
+ answer = sample.get("answer")
108
+ if answer is None:
109
+ return False
110
+
111
+ return True
112
+
113
+ # Methods to match lm-eval interface
114
+ def has_validation_docs(self) -> bool:
115
+ """Check if task has validation documents."""
116
+ return False # File tasks don't have separate validation sets
117
+
118
+ def has_test_docs(self) -> bool:
119
+ """Check if task has test documents."""
120
+ return True # All samples are considered test docs
121
+
122
+ def test_docs(self) -> List[Dict[str, Any]]:
123
+ """Get test documents."""
124
+ if self._data is None:
125
+ self._data = self.load_data()
126
+ return self._data
127
+
128
+ def validation_docs(self) -> List[Dict[str, Any]]:
129
+ """Get validation documents."""
130
+ return [] # No separate validation set
131
+
132
+ def doc_to_text(self, doc: Dict[str, Any]) -> str:
133
+ """Convert document to text prompt."""
134
+ question = doc.get("question") or doc.get("problem", "")
135
+ return f"Question: {question}\nAnswer:"
136
+
137
+ def get_task_info(self) -> Dict[str, Any]:
138
+ """Get information about the file task."""
139
+ return {
140
+ "task_name": self._task_name,
141
+ "description": self.get_description(),
142
+ "source": str(self.file_path),
143
+ "task_type": "text_generation",
144
+ "evaluation_method": "exact_match",
145
+ "num_samples": len(self.test_docs()) if self._data else "unknown",
146
+ }
147
+
148
+
149
+ def create_file_task(file_path: str, task_name: Optional[str] = None) -> callable:
150
+ """
151
+ Create a task factory function for a file-based task.
152
+
153
+ This is the recommended way to create file tasks for registration.
154
+
155
+ Args:
156
+ file_path: Path to the JSON dataset file
157
+ task_name: Optional custom name for the task
158
+
159
+ Returns:
160
+ A factory function that creates FileTask instances
161
+ """
162
+
163
+ def task_factory(limit: Optional[int] = None) -> FileTask:
164
+ return FileTask(file_path=file_path, task_name=task_name, limit=limit)
165
+
166
+ return task_factory
167
+
168
+
169
+ def register_file_task(task_name: str, file_path: str, registry=None):
170
+ """
171
+ Register a file-based task with the global task registry.
172
+
173
+ Args:
174
+ task_name: Name to register the task under
175
+ file_path: Path to the JSON dataset file
176
+ registry: Optional registry to use (defaults to global registry)
177
+ """
178
+ from ..task_interface import register_task
179
+
180
+ task_factory = create_file_task(file_path, task_name)
181
+ register_task(task_name, task_factory)
182
+
183
+
184
+ def load_tasks_from_directory(directory: str, pattern: str = "*.json", prefix: str = ""):
185
+ """
186
+ Load all JSON files in a directory as tasks.
187
+
188
+ Args:
189
+ directory: Directory to search for JSON files
190
+ pattern: File pattern to match (default: "*.json")
191
+ prefix: Optional prefix to add to task names
192
+ """
193
+ directory_path = Path(directory)
194
+
195
+ if not directory_path.exists():
196
+ raise FileNotFoundError(f"Directory not found: {directory}")
197
+
198
+ if not directory_path.is_dir():
199
+ raise ValueError(f"Path is not a directory: {directory}")
200
+
201
+ loaded_tasks = []
202
+
203
+ for json_file in directory_path.glob(pattern):
204
+ try:
205
+ task_name = f"{prefix}{json_file.stem}".lower()
206
+ register_file_task(task_name, str(json_file))
207
+ loaded_tasks.append(task_name)
208
+ except Exception as e:
209
+ print(f"Warning: Failed to load task from {json_file}: {e}")
210
+
211
+ return loaded_tasks
@@ -0,0 +1,180 @@
1
+ """
2
+ HLE (Human-Level Evaluation) task implementation for task-agnostic architecture.
3
+ """
4
+
5
+ from typing import Dict, Any, List, Optional
6
+ from datasets import load_dataset
7
+ from ..task_interface import TaskInterface
8
+ from ..benchmark_extractors import HLEExtractor
9
+
10
+
11
+ class HLETask(TaskInterface):
12
+ """HLE (Human-Level Evaluation) task implementation."""
13
+
14
+ def __init__(self, category_filter: Optional[str] = None, answer_type_filter: Optional[str] = None,
15
+ limit: Optional[int] = None):
16
+ """Initialize HLE task.
17
+
18
+ Args:
19
+ category_filter: Filter by category (Math, Physics, CS, etc.)
20
+ answer_type_filter: Filter by answer type ('exactMatch' or 'multipleChoice')
21
+ limit: Maximum number of examples to load
22
+ """
23
+ self.dataset_name = "cais/hle"
24
+ self.category_filter = category_filter
25
+ self.answer_type_filter = answer_type_filter
26
+ self.limit = limit
27
+ self._extractor = HLEExtractor()
28
+ self._data = None # Cache for loaded data
29
+
30
+ def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
31
+ """Load HLE data from HuggingFace datasets."""
32
+ dataset = load_dataset(self.dataset_name, split="test")
33
+
34
+ # Filter out multimodal examples for initial implementation
35
+ text_only_data = [
36
+ item for item in dataset
37
+ if not item.get('image') and not item.get('image_1') and not item.get('image_2')
38
+ ]
39
+
40
+ # Apply additional filters
41
+ filtered_data = self._filter_and_process(text_only_data)
42
+
43
+ # Apply limit
44
+ effective_limit = limit or self.limit
45
+ if effective_limit:
46
+ filtered_data = filtered_data[:effective_limit]
47
+
48
+ return filtered_data
49
+
50
+ def _filter_and_process(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
51
+ """Filter data by category and answer type, and convert to internal format."""
52
+ filtered_data = []
53
+
54
+ for item in data:
55
+ # Apply category filter
56
+ if self.category_filter and item.get('category') != self.category_filter:
57
+ continue
58
+
59
+ # Apply answer type filter
60
+ if self.answer_type_filter and item.get('answer_type') != self.answer_type_filter:
61
+ continue
62
+
63
+ # Convert to internal format
64
+ processed_item = {
65
+ 'question': item.get('question', ''),
66
+ 'answer': item.get('answer', ''),
67
+ 'answer_type': item.get('answer_type', ''),
68
+ 'category': item.get('category', ''),
69
+ 'raw_subject': item.get('raw_subject', ''),
70
+ 'rationale': item.get('rationale', ''),
71
+ 'author_name': item.get('author_name', ''),
72
+ 'id': item.get('id', ''),
73
+ 'metadata': {
74
+ 'canary': item.get('canary', ''),
75
+ 'dataset': self.dataset_name
76
+ }
77
+ }
78
+
79
+ # For multiple choice, parse choices from question text if needed
80
+ if item.get('answer_type') == 'multipleChoice':
81
+ processed_item['parsed_choices'] = self._parse_choices_from_question(item.get('question', ''))
82
+
83
+ filtered_data.append(processed_item)
84
+
85
+ return filtered_data
86
+
87
+ def _parse_choices_from_question(self, question: str) -> List[str]:
88
+ """Parse multiple choice options from question text."""
89
+ # Look for patterns like "A. ", "B. ", etc.
90
+ import re
91
+ choices = []
92
+ patterns = [
93
+ r'([A-E])\.\s+(.+?)(?=\n[A-E]\.|$)', # "A. option" format
94
+ r'([A-E])\)\s+(.+?)(?=\n[A-E]\)|$)', # "A) option" format
95
+ ]
96
+
97
+ for pattern in patterns:
98
+ matches = re.findall(pattern, question, re.MULTILINE | re.DOTALL)
99
+ if matches:
100
+ choices = [f"{letter}. {text.strip()}" for letter, text in matches]
101
+ break
102
+
103
+ return choices
104
+
105
+ def get_extractor(self) -> HLEExtractor:
106
+ """Get the HLE benchmark extractor."""
107
+ return self._extractor
108
+
109
+ def get_name(self) -> str:
110
+ """Get the task name."""
111
+ return "hle"
112
+
113
+ def get_description(self) -> str:
114
+ """Get the task description."""
115
+ desc = "HLE (Human-Level Evaluation): Multimodal benchmark for human-level reasoning across multiple domains"
116
+ if self.category_filter:
117
+ desc += f" (filtered to {self.category_filter})"
118
+ if self.answer_type_filter:
119
+ desc += f" ({self.answer_type_filter} questions only)"
120
+ return desc
121
+
122
+ def get_categories(self) -> List[str]:
123
+ """Get the task categories."""
124
+ return ["reasoning", "knowledge", "multimodal", "evaluation"]
125
+
126
+ # Methods to match lm-eval interface
127
+ def has_validation_docs(self) -> bool:
128
+ """Check if task has validation documents."""
129
+ return False # HLE doesn't have separate validation sets
130
+
131
+ def has_test_docs(self) -> bool:
132
+ """Check if task has test documents."""
133
+ return True # All samples are considered test docs
134
+
135
+ def test_docs(self) -> List[Dict[str, Any]]:
136
+ """Get test documents."""
137
+ if self._data is None:
138
+ self._data = self.load_data()
139
+ return self._data
140
+
141
+ def validation_docs(self) -> List[Dict[str, Any]]:
142
+ """Get validation documents."""
143
+ return [] # No separate validation set
144
+
145
+ def doc_to_text(self, doc: Dict[str, Any]) -> str:
146
+ """Convert document to text prompt."""
147
+ # For HLE, the question already contains the choices for multiple choice
148
+ return doc.get('question', '')
149
+
150
+
151
+ class HLEExactMatchTask(HLETask):
152
+ """HLE task filtered to exact match questions only."""
153
+
154
+ def __init__(self, category_filter: Optional[str] = None, limit: Optional[int] = None):
155
+ super().__init__(category_filter=category_filter, answer_type_filter='exactMatch', limit=limit)
156
+
157
+ def get_name(self) -> str:
158
+ return "hle_exact_match"
159
+
160
+ def get_description(self) -> str:
161
+ desc = "HLE Exact Match: Text-based questions requiring exact string matching"
162
+ if self.category_filter:
163
+ desc += f" (filtered to {self.category_filter})"
164
+ return desc
165
+
166
+
167
+ class HLEMultipleChoiceTask(HLETask):
168
+ """HLE task filtered to multiple choice questions only."""
169
+
170
+ def __init__(self, category_filter: Optional[str] = None, limit: Optional[int] = None):
171
+ super().__init__(category_filter=category_filter, answer_type_filter='multipleChoice', limit=limit)
172
+
173
+ def get_name(self) -> str:
174
+ return "hle_multiple_choice"
175
+
176
+ def get_description(self) -> str:
177
+ desc = "HLE Multiple Choice: Questions with A/B/C/D/E answer options"
178
+ if self.category_filter:
179
+ desc += f" (filtered to {self.category_filter})"
180
+ return desc
@@ -0,0 +1,119 @@
1
+ """
2
+ HMMT (Harvard-MIT Math Tournament) task implementation for task-agnostic architecture.
3
+ """
4
+
5
+ from typing import Dict, Any, List, Optional
6
+ from ..task_interface import TaskInterface
7
+ from ..benchmark_extractors import GSM8KExtractor
8
+ import datasets
9
+
10
+
11
+ class HMMTTask(TaskInterface):
12
+ """HMMT (Harvard-MIT Math Tournament) mathematical contest task implementation."""
13
+
14
+ # Dataset configurations for different HMMT competitions
15
+ DATASET_CONFIGS = {
16
+ "feb_2025": {
17
+ "source": "MathArena/hmmt_feb_2025",
18
+ "split": "train",
19
+ "fields": {"problem": "problem", "answer": "answer"},
20
+ "description": "30 high-difficulty HMMT February 2025 contest problems"
21
+ }
22
+ }
23
+
24
+ def __init__(self, competition: str = "feb_2025", limit: Optional[int] = None):
25
+ """
26
+ Initialize HMMT task for specified competition.
27
+
28
+ Args:
29
+ competition: HMMT competition to load ("feb_2025"). Default: "feb_2025" (latest)
30
+ limit: Maximum number of samples to load
31
+ """
32
+ if competition not in self.DATASET_CONFIGS:
33
+ available = list(self.DATASET_CONFIGS.keys())
34
+ raise ValueError(f"HMMT competition '{competition}' not supported. Available: {available}")
35
+
36
+ self.competition = competition
37
+ self.config = self.DATASET_CONFIGS[competition]
38
+ self._limit = limit
39
+ self._data = None # Cache for loaded data
40
+ self._extractor = GSM8KExtractor() # Reuse enhanced GSM8K extractor
41
+
42
+ def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
43
+ """Load HMMT data from HuggingFace for specified competition."""
44
+ # Load dataset based on competition configuration
45
+ dataset = datasets.load_dataset(
46
+ self.config["source"],
47
+ split=self.config["split"]
48
+ )
49
+
50
+ # Apply limit
51
+ effective_limit = limit or self._limit
52
+ if effective_limit:
53
+ dataset = dataset.select(range(min(effective_limit, len(dataset))))
54
+
55
+ # Convert to list and normalize field names
56
+ data = [dict(item) for item in dataset]
57
+
58
+ # Normalize field names for consistent processing
59
+ normalized_data = []
60
+ problem_field = self.config["fields"]["problem"]
61
+ answer_field = self.config["fields"]["answer"]
62
+
63
+ for item in data:
64
+ normalized_item = dict(item) # Keep all original fields
65
+
66
+ # Ensure consistent field names for extractor
67
+ if problem_field in item:
68
+ normalized_item["Problem"] = item[problem_field]
69
+ normalized_item["question"] = item[problem_field] # For question/answer format
70
+
71
+ if answer_field in item:
72
+ normalized_item["Answer"] = item[answer_field]
73
+ normalized_item["answer"] = item[answer_field] # For question/answer format
74
+
75
+ normalized_data.append(normalized_item)
76
+
77
+ return normalized_data
78
+
79
+
80
+ def get_task_info(self) -> Dict[str, Any]:
81
+ """Get information about the HMMT task."""
82
+ return {
83
+ "task_name": f"hmmt_{self.competition}" if self.competition != "feb_2025" else "hmmt",
84
+ "competition": self.competition,
85
+ "description": self.config["description"],
86
+ "source": self.config["source"],
87
+ "task_type": "text_generation",
88
+ "evaluation_method": "mathematical_equivalence"
89
+ }
90
+
91
+ def validate_sample(self, sample: Dict[str, Any]) -> bool:
92
+ """Validate that a sample has required HMMT fields."""
93
+ problem_field = self.config["fields"]["problem"]
94
+ answer_field = self.config["fields"]["answer"]
95
+
96
+ return all(field in sample for field in [problem_field, answer_field])
97
+
98
+ def get_extractor(self) -> GSM8KExtractor:
99
+ """Get the benchmark extractor for this task."""
100
+ return self._extractor
101
+
102
+ def get_name(self) -> str:
103
+ """Get the task name."""
104
+ return f"hmmt_{self.competition}" if self.competition != "feb_2025" else "hmmt"
105
+
106
+ def get_description(self) -> str:
107
+ """Get the task description."""
108
+ return f"HMMT {self.competition.replace('_', ' ').title()} contest problems requiring advanced mathematical reasoning"
109
+
110
+ def get_categories(self) -> List[str]:
111
+ """Get the task categories."""
112
+ return ["mathematics", "reasoning", "contest", "text_generation"]
113
+
114
+ @classmethod
115
+ def get_supported_competitions(cls) -> List[str]:
116
+ """Get list of supported HMMT competitions."""
117
+ return list(cls.DATASET_CONFIGS.keys())
118
+
119
+