wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
1
+ """
2
+ Task-agnostic interface for benchmark integration.
3
+
4
+ This module provides a unified interface for integrating different benchmarks
5
+ without depending on lm-evaluation-harness.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Dict, List, Optional, Type
10
+
11
+ from .benchmark_extractors import BenchmarkExtractor
12
+
13
+
14
+ class TaskInterface(ABC):
15
+ """Abstract interface for benchmark tasks."""
16
+
17
+ @abstractmethod
18
+ def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
19
+ """Load task data."""
20
+
21
+ @abstractmethod
22
+ def get_extractor(self) -> BenchmarkExtractor:
23
+ """Get the benchmark extractor for this task."""
24
+
25
+ @abstractmethod
26
+ def get_name(self) -> str:
27
+ """Get the task name."""
28
+
29
+ @abstractmethod
30
+ def get_description(self) -> str:
31
+ """Get the task description."""
32
+
33
+ @abstractmethod
34
+ def get_categories(self) -> List[str]:
35
+ """Get the task categories (e.g., ['coding', 'reasoning'])."""
36
+
37
+
38
+ class TaskRegistry:
39
+ """Registry for managing available tasks."""
40
+
41
+ def __init__(self):
42
+ self._tasks: Dict[str, Type[TaskInterface]] = {}
43
+
44
+ def register_task(self, name: str, task_class: Type[TaskInterface]):
45
+ """Register a new task."""
46
+ self._tasks[name] = task_class
47
+
48
+ def get_task(self, name: str, limit: Optional[int] = None) -> TaskInterface:
49
+ """Get a task instance by name."""
50
+ if name not in self._tasks:
51
+ raise ValueError(f"Task '{name}' not found. Available tasks: {list(self._tasks.keys())}")
52
+
53
+ task_factory = self._tasks[name]
54
+
55
+ # Handle different task factory types
56
+ if callable(task_factory):
57
+ # Try calling with limit parameter
58
+ try:
59
+ return task_factory(limit=limit)
60
+ except TypeError:
61
+ # Fallback for factories that don't accept limit
62
+ return task_factory()
63
+ else:
64
+ # Direct class instantiation
65
+ return task_factory()
66
+
67
+ def list_tasks(self) -> List[str]:
68
+ """List all available task names."""
69
+ return list(self._tasks.keys())
70
+
71
+ def get_task_info(self, name: str) -> Dict[str, Any]:
72
+ """Get information about a specific task."""
73
+ task = self.get_task(name)
74
+ return {"name": task.get_name(), "description": task.get_description(), "categories": task.get_categories()}
75
+
76
+ def list_task_info(self) -> List[Dict[str, Any]]:
77
+ """List information about all available tasks."""
78
+ return [self.get_task_info(name) for name in self.list_tasks()]
79
+
80
+
81
+ # Global task registry instance
82
+ _task_registry = TaskRegistry()
83
+
84
+
85
+ def register_task(name: str, task_class: Type[TaskInterface]):
86
+ """Register a new task globally."""
87
+ _task_registry.register_task(name, task_class)
88
+
89
+
90
+ def get_task(name: str, limit: Optional[int] = None) -> TaskInterface:
91
+ """Get a task instance by name."""
92
+ # Ensure tasks are registered before attempting to get a task
93
+ _ensure_tasks_registered()
94
+
95
+ # Check if this is a file path (contains / or \\ or ends with .json)
96
+ if "/" in name or "\\" in name or name.endswith(".json"):
97
+ # Treat as file path and load directly
98
+ from .tasks.file_task import FileTask
99
+
100
+ return FileTask(name, limit=limit)
101
+
102
+ # Otherwise, try to get from registry
103
+ try:
104
+ return _task_registry.get_task(name, limit=limit)
105
+ except ValueError:
106
+ raise ValueError(
107
+ f"Task '{name}' not found in registry. Available tasks: {list(_task_registry._tasks.keys())}. To load a custom dataset, provide a file path ending with .json"
108
+ )
109
+
110
+
111
+ def list_tasks() -> List[str]:
112
+ """List all available task names."""
113
+ _ensure_tasks_registered()
114
+ return _task_registry.list_tasks()
115
+
116
+
117
+ def get_task_info(name: str) -> Dict[str, Any]:
118
+ """Get information about a specific task."""
119
+ return _task_registry.get_task_info(name)
120
+
121
+
122
+ def list_task_info() -> List[Dict[str, Any]]:
123
+ """List information about all available tasks."""
124
+ return _task_registry.list_task_info()
125
+
126
+
127
+ def _ensure_tasks_registered():
128
+ """Ensure all tasks are registered in the global registry."""
129
+ if len(_task_registry._tasks) == 0: # Only register if not already done
130
+ # Import tasks module to trigger registration
131
+ # This is crucial for CLI usage where tasks module isn't imported elsewhere
132
+ from . import tasks # noqa: F401 # This imports __init__.py which calls register_all_tasks()
@@ -0,0 +1,189 @@
1
+ """
2
+ Task selector for choosing tasks based on skills and risks tags.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ import random
8
+ import logging
9
+ from typing import List, Dict, Any, Optional, Set
10
+ from pathlib import Path
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class TaskSelector:
16
+ """Select tasks based on skills and risks criteria."""
17
+
18
+ def __init__(self):
19
+ """Initialize the task selector by loading metadata."""
20
+ self.base_path = Path(__file__).parent.parent / "parameters" / "tasks"
21
+ self.skills = self._load_json("skills.json")
22
+ self.risks = self._load_json("risks.json")
23
+ self.tasks_data = self._load_json("tasks.json")
24
+ self.tasks = self.tasks_data.get("tasks", {})
25
+
26
+ def _load_json(self, filename: str) -> Any:
27
+ """Load JSON file from parameters/tasks directory."""
28
+ filepath = self.base_path / filename
29
+ try:
30
+ with open(filepath, 'r') as f:
31
+ return json.load(f)
32
+ except Exception as e:
33
+ logger.error(f"Failed to load {filename}: {e}")
34
+ return {} if filename == "tasks.json" else []
35
+
36
+ def get_available_skills(self) -> List[str]:
37
+ """Get list of available skills."""
38
+ return self.skills
39
+
40
+ def get_available_risks(self) -> List[str]:
41
+ """Get list of available risks."""
42
+ return self.risks
43
+
44
+ def find_tasks_by_tags(
45
+ self,
46
+ skills: Optional[List[str]] = None,
47
+ risks: Optional[List[str]] = None,
48
+ min_quality_score: int = 2
49
+ ) -> List[str]:
50
+ """
51
+ Find tasks that match the given skills and/or risks.
52
+
53
+ Args:
54
+ skills: List of skill tags to match
55
+ risks: List of risk tags to match
56
+ min_quality_score: Minimum quality score for tasks (default: 2)
57
+
58
+ Returns:
59
+ List of task names that match the criteria
60
+ """
61
+ if not skills and not risks:
62
+ # Return all tasks if no criteria specified
63
+ return [
64
+ task_name for task_name, task_data in self.tasks.items()
65
+ if task_data.get("quality_score", 0) >= min_quality_score
66
+ ]
67
+
68
+ # Convert to sets for efficient lookup
69
+ required_tags = set()
70
+ if skills:
71
+ required_tags.update(skills)
72
+ if risks:
73
+ required_tags.update(risks)
74
+
75
+ matched_tasks = []
76
+ for task_name, task_data in self.tasks.items():
77
+ # Check quality score
78
+ if task_data.get("quality_score", 0) < min_quality_score:
79
+ continue
80
+
81
+ # Check if task has any of the required tags
82
+ task_tags = set(task_data.get("tags", []))
83
+ if task_tags.intersection(required_tags):
84
+ matched_tasks.append(task_name)
85
+
86
+ return matched_tasks
87
+
88
+ def select_random_tasks(
89
+ self,
90
+ skills: Optional[List[str]] = None,
91
+ risks: Optional[List[str]] = None,
92
+ num_tasks: Optional[int] = None,
93
+ min_quality_score: int = 2,
94
+ seed: Optional[int] = None
95
+ ) -> List[str]:
96
+ """
97
+ Select random tasks based on skills/risks criteria.
98
+
99
+ Args:
100
+ skills: List of skill tags to match
101
+ risks: List of risk tags to match
102
+ num_tasks: Number of tasks to select (None = all matching tasks)
103
+ min_quality_score: Minimum quality score for tasks
104
+ seed: Random seed for reproducibility
105
+
106
+ Returns:
107
+ List of randomly selected task names
108
+ """
109
+ # Find matching tasks
110
+ matched_tasks = self.find_tasks_by_tags(skills, risks, min_quality_score)
111
+
112
+ if not matched_tasks:
113
+ logger.warning(f"No tasks found matching skills={skills}, risks={risks}")
114
+ return []
115
+
116
+ # Set random seed if provided
117
+ if seed is not None:
118
+ random.seed(seed)
119
+
120
+ # Select tasks
121
+ if num_tasks is None or num_tasks >= len(matched_tasks):
122
+ selected = matched_tasks
123
+ else:
124
+ selected = random.sample(matched_tasks, num_tasks)
125
+
126
+ logger.info(f"Selected {len(selected)} tasks from {len(matched_tasks)} matching tasks")
127
+ return selected
128
+
129
+ def validate_skills_and_risks(
130
+ self,
131
+ skills: Optional[List[str]] = None,
132
+ risks: Optional[List[str]] = None
133
+ ) -> Dict[str, List[str]]:
134
+ """
135
+ Validate provided skills and risks against available options.
136
+
137
+ Returns:
138
+ Dictionary with 'invalid_skills' and 'invalid_risks' lists
139
+ """
140
+ invalid = {"invalid_skills": [], "invalid_risks": []}
141
+
142
+ if skills:
143
+ valid_skills = set(self.skills)
144
+ invalid["invalid_skills"] = [s for s in skills if s not in valid_skills]
145
+
146
+ if risks:
147
+ valid_risks = set(self.risks)
148
+ invalid["invalid_risks"] = [r for r in risks if r not in valid_risks]
149
+
150
+ return invalid
151
+
152
+
153
+ def get_tasks_for_skills_and_risks(
154
+ skills: Optional[List[str]] = None,
155
+ risks: Optional[List[str]] = None,
156
+ num_tasks: Optional[int] = None,
157
+ min_quality_score: int = 2,
158
+ seed: Optional[int] = None
159
+ ) -> List[str]:
160
+ """
161
+ Convenience function to get tasks matching skills/risks criteria.
162
+
163
+ Args:
164
+ skills: List of skill tags to match
165
+ risks: List of risk tags to match
166
+ num_tasks: Number of tasks to select (None = all)
167
+ min_quality_score: Minimum quality score
168
+ seed: Random seed
169
+
170
+ Returns:
171
+ List of task names
172
+ """
173
+ selector = TaskSelector()
174
+
175
+ # Validate inputs
176
+ invalid = selector.validate_skills_and_risks(skills, risks)
177
+ if invalid["invalid_skills"]:
178
+ logger.warning(f"Invalid skills: {invalid['invalid_skills']}")
179
+ if invalid["invalid_risks"]:
180
+ logger.warning(f"Invalid risks: {invalid['invalid_risks']}")
181
+
182
+ # Select tasks
183
+ return selector.select_random_tasks(
184
+ skills=skills,
185
+ risks=risks,
186
+ num_tasks=num_tasks,
187
+ min_quality_score=min_quality_score,
188
+ seed=seed
189
+ )
@@ -0,0 +1,175 @@
1
+ """
2
+ Task implementations for wisent-guard.
3
+
4
+ This package contains task-agnostic implementations for various benchmarks.
5
+ """
6
+
7
+ from ..task_interface import register_task
8
+ from .aime_task import AIMETask
9
+ from .hle_task import HLEExactMatchTask, HLEMultipleChoiceTask, HLETask
10
+ from .hmmt_task import HMMTTask
11
+ from .livecodebench_task import LiveCodeBenchTask
12
+ from .livemathbench_task import LiveMathBenchTask
13
+ from .lm_eval_task import (
14
+ AppsTask,
15
+ CodexglueCodeToTextGoTask,
16
+ CodexglueCodeToTextJavascriptTask,
17
+ CodexglueCodeToTextJavaTask,
18
+ CodexglueCodeToTextPhpTask,
19
+ CodexglueCodeToTextPythonTask,
20
+ CodexglueCodeToTextRubyTask,
21
+ ConalaTask,
22
+ ConcodeTask,
23
+ DS1000Task,
24
+ GSM8KTask,
25
+ HumanEvalPlusTask,
26
+ HumanEvalTask,
27
+ InstructHumanEvalTask,
28
+ MBPPPlusTask,
29
+ MBPPTask,
30
+ MercuryTask,
31
+ MMLUTask,
32
+ MultipleCppTask,
33
+ MultipleGoTask,
34
+ MultipleJavaTask,
35
+ MultipleJsTask,
36
+ MultiplePyTask,
37
+ MultipleRsTask,
38
+ RecodeTask,
39
+ Squad2Task,
40
+ TruthfulQATask,
41
+ )
42
+ from .math500_task import Math500Task
43
+ from .polymath_task import PolyMathTask
44
+ from .supergpqa_task import SuperGPQABiologyTask, SuperGPQAChemistryTask, SuperGPQAPhysicsTask, SuperGPQATask
45
+
46
+
47
+ def register_all_tasks():
48
+ """Register all available tasks."""
49
+ # Register LiveCodeBench task
50
+ register_task("livecodebench", lambda limit=None: LiveCodeBenchTask(release_version="release_v1", limit=limit))
51
+
52
+ # Register common lm-eval tasks
53
+ register_task("gsm8k", GSM8KTask)
54
+ register_task("truthfulqa_mc1", TruthfulQATask)
55
+ register_task("mmlu", MMLUTask)
56
+
57
+ # Register all coding tasks
58
+ register_task("mbpp", MBPPTask)
59
+ register_task("humaneval", HumanEvalTask)
60
+ register_task("mbpp_plus", MBPPPlusTask)
61
+ register_task("instructhumaneval", InstructHumanEvalTask)
62
+ register_task("humaneval_plus", HumanEvalPlusTask)
63
+ register_task("conala", ConalaTask)
64
+ register_task("concode", ConcodeTask)
65
+ register_task("mercury", MercuryTask)
66
+ register_task("apps", AppsTask)
67
+ register_task("ds1000", DS1000Task)
68
+ register_task("multiple_py", MultiplePyTask)
69
+ register_task("multiple_js", MultipleJsTask)
70
+ register_task("multiple_java", MultipleJavaTask)
71
+ register_task("multiple_cpp", MultipleCppTask)
72
+ register_task("multiple_rs", MultipleRsTask)
73
+ register_task("multiple_go", MultipleGoTask)
74
+ register_task("codexglue_code_to_text_python", CodexglueCodeToTextPythonTask)
75
+ register_task("codexglue_code_to_text_go", CodexglueCodeToTextGoTask)
76
+ register_task("codexglue_code_to_text_ruby", CodexglueCodeToTextRubyTask)
77
+ register_task("codexglue_code_to_text_java", CodexglueCodeToTextJavaTask)
78
+ register_task("codexglue_code_to_text_javascript", CodexglueCodeToTextJavascriptTask)
79
+ register_task("codexglue_code_to_text_php", CodexglueCodeToTextPhpTask)
80
+ register_task("recode", RecodeTask)
81
+ register_task("squad2", Squad2Task)
82
+
83
+ # Register HLE tasks
84
+ register_task("hle", lambda limit=None: HLETask(limit=limit))
85
+ register_task("hle_exact_match", lambda limit=None: HLEExactMatchTask(limit=limit))
86
+ register_task("hle_multiple_choice", lambda limit=None: HLEMultipleChoiceTask(limit=limit))
87
+
88
+ # Register MATH-500 tasks
89
+ register_task("math500", lambda limit=None: Math500Task(limit=limit))
90
+ register_task("math", lambda limit=None: Math500Task(limit=limit))
91
+ register_task("hendrycks_math", lambda limit=None: Math500Task(limit=limit))
92
+
93
+ # Register AIME tasks (general + year-specific)
94
+ register_task("aime", lambda limit=None: AIMETask(year="2025", limit=limit)) # Default: latest year (2025)
95
+ register_task("aime2025", lambda limit=None: AIMETask(year="2025", limit=limit))
96
+ register_task("aime2024", lambda limit=None: AIMETask(year="2024", limit=limit))
97
+
98
+ # Register HMMT tasks (general + competition-specific)
99
+ register_task(
100
+ "hmmt", lambda limit=None: HMMTTask(competition="feb_2025", limit=limit)
101
+ ) # Default: latest competition
102
+ register_task("hmmt_feb_2025", lambda limit=None: HMMTTask(competition="feb_2025", limit=limit))
103
+
104
+ # Register PolyMath tasks (Chinese and English, medium difficulty)
105
+ register_task(
106
+ "polymath", lambda limit=None: PolyMathTask(language="en", difficulty="medium", limit=limit)
107
+ ) # Default: English medium
108
+ register_task(
109
+ "polymath_en_medium", lambda limit=None: PolyMathTask(language="en", difficulty="medium", limit=limit)
110
+ )
111
+ register_task(
112
+ "polymath_zh_medium", lambda limit=None: PolyMathTask(language="zh", difficulty="medium", limit=limit)
113
+ )
114
+ register_task("polymath_en_high", lambda limit=None: PolyMathTask(language="en", difficulty="high", limit=limit))
115
+ register_task("polymath_zh_high", lambda limit=None: PolyMathTask(language="zh", difficulty="high", limit=limit))
116
+
117
+ # Register LiveMathBench tasks (CNMO 2024 Chinese and English)
118
+ register_task("livemathbench", lambda limit=None: LiveMathBenchTask(language="en", limit=limit)) # Default: English
119
+ register_task("livemathbench_cnmo_en", lambda limit=None: LiveMathBenchTask(language="en", limit=limit))
120
+ register_task("livemathbench_cnmo_zh", lambda limit=None: LiveMathBenchTask(language="zh", limit=limit))
121
+
122
+ # Register SuperGPQA tasks (scientific reasoning)
123
+ register_task("supergpqa", lambda limit=None: SuperGPQATask(limit=limit)) # Default: all subjects
124
+ register_task("supergpqa_physics", lambda limit=None: SuperGPQAPhysicsTask(limit=limit))
125
+ register_task("supergpqa_chemistry", lambda limit=None: SuperGPQAChemistryTask(limit=limit))
126
+ register_task("supergpqa_biology", lambda limit=None: SuperGPQABiologyTask(limit=limit))
127
+
128
+
129
+ # Auto-register tasks when the module is imported
130
+ register_all_tasks()
131
+
132
+
133
+ __all__ = [
134
+ "AIMETask",
135
+ "AppsTask",
136
+ "CodexglueCodeToTextGoTask",
137
+ "CodexglueCodeToTextJavaTask",
138
+ "CodexglueCodeToTextJavascriptTask",
139
+ "CodexglueCodeToTextPhpTask",
140
+ "CodexglueCodeToTextPythonTask",
141
+ "CodexglueCodeToTextRubyTask",
142
+ "ConalaTask",
143
+ "ConcodeTask",
144
+ "DS1000Task",
145
+ "GSM8KTask",
146
+ "HLEExactMatchTask",
147
+ "HLEMultipleChoiceTask",
148
+ "HLETask",
149
+ "HMMTTask",
150
+ "HumanEvalPlusTask",
151
+ "HumanEvalTask",
152
+ "InstructHumanEvalTask",
153
+ "LiveCodeBenchTask",
154
+ "LiveMathBenchTask",
155
+ "MBPPPlusTask",
156
+ "MBPPTask",
157
+ "MMLUTask",
158
+ "Math500Task",
159
+ "MercuryTask",
160
+ "MultipleCppTask",
161
+ "MultipleGoTask",
162
+ "MultipleJavaTask",
163
+ "MultipleJsTask",
164
+ "MultiplePyTask",
165
+ "MultipleRsTask",
166
+ "PolyMathTask",
167
+ "RecodeTask",
168
+ "Squad2Task",
169
+ "SuperGPQABiologyTask",
170
+ "SuperGPQAChemistryTask",
171
+ "SuperGPQAPhysicsTask",
172
+ "SuperGPQATask",
173
+ "TruthfulQATask",
174
+ "register_all_tasks",
175
+ ]
@@ -0,0 +1,141 @@
1
+ """
2
+ AIME task implementation for task-agnostic architecture.
3
+ """
4
+
5
+ from typing import Dict, Any, List, Optional
6
+ from ..task_interface import TaskInterface
7
+ from ..benchmark_extractors import GSM8KExtractor
8
+ import datasets
9
+
10
+
11
+ class AIMETask(TaskInterface):
12
+ """General AIME mathematical contest task implementation."""
13
+
14
+ # Dataset configurations for different years
15
+ DATASET_CONFIGS = {
16
+ "2024": {
17
+ "source": "Maxwell-Jia/AIME_2024",
18
+ "split": "train",
19
+ "fields": {"problem": "Problem", "answer": "Answer"},
20
+ "description": "30 high-difficulty AIME contest problems from 2024"
21
+ },
22
+ "2025": {
23
+ "source": "MathArena/aime_2025",
24
+ "split": "train",
25
+ "fields": {"problem": "problem", "answer": "answer"},
26
+ "description": "30 high-difficulty AIME contest problems from 2025 (MathArena)"
27
+ }
28
+ }
29
+
30
+ def __init__(self, year: str = "2025", limit: Optional[int] = None):
31
+ """
32
+ Initialize AIME task for specified year.
33
+
34
+ Args:
35
+ year: AIME year to load ("2024", "2025"). Default: "2025" (latest)
36
+ limit: Maximum number of samples to load
37
+ """
38
+ if year not in self.DATASET_CONFIGS:
39
+ available = list(self.DATASET_CONFIGS.keys())
40
+ raise ValueError(f"AIME year '{year}' not supported. Available: {available}")
41
+
42
+ self.year = year
43
+ self.config = self.DATASET_CONFIGS[year]
44
+ self._limit = limit
45
+ self._data = None # Cache for loaded data
46
+ self._extractor = GSM8KExtractor() # Reuse enhanced GSM8K extractor
47
+
48
+ def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
49
+ """Load AIME data from HuggingFace for specified year."""
50
+ # Load dataset based on year configuration
51
+ dataset = datasets.load_dataset(
52
+ self.config["source"],
53
+ split=self.config["split"]
54
+ )
55
+
56
+ # Apply limit
57
+ effective_limit = limit or self._limit
58
+ if effective_limit:
59
+ dataset = dataset.select(range(min(effective_limit, len(dataset))))
60
+
61
+ # Convert to list and normalize field names
62
+ data = [dict(item) for item in dataset]
63
+
64
+ # Normalize field names for consistent processing
65
+ normalized_data = []
66
+ problem_field = self.config["fields"]["problem"]
67
+ answer_field = self.config["fields"]["answer"]
68
+
69
+ for item in data:
70
+ normalized_item = dict(item) # Keep all original fields
71
+
72
+ # Ensure consistent field names for extractor
73
+ if problem_field in item:
74
+ normalized_item["Problem"] = item[problem_field]
75
+ normalized_item["question"] = item[problem_field] # For question/answer format
76
+
77
+ if answer_field in item:
78
+ normalized_item["Answer"] = item[answer_field]
79
+ normalized_item["answer"] = item[answer_field] # For question/answer format
80
+
81
+ normalized_data.append(normalized_item)
82
+
83
+ return normalized_data
84
+
85
+
86
+ def get_task_info(self) -> Dict[str, Any]:
87
+ """Get information about the AIME task."""
88
+ return {
89
+ "task_name": f"aime{self.year}" if self.year != "2025" else "aime",
90
+ "year": self.year,
91
+ "description": self.config["description"],
92
+ "source": self.config["source"],
93
+ "task_type": "text_generation",
94
+ "evaluation_method": "mathematical_equivalence"
95
+ }
96
+
97
+ def validate_sample(self, sample: Dict[str, Any]) -> bool:
98
+ """Validate that a sample has required AIME fields."""
99
+ problem_field = self.config["fields"]["problem"]
100
+ answer_field = self.config["fields"]["answer"]
101
+
102
+ return all(field in sample for field in [problem_field, answer_field])
103
+
104
+ def get_extractor(self) -> GSM8KExtractor:
105
+ """Get the benchmark extractor for this task."""
106
+ return self._extractor
107
+
108
+ def get_name(self) -> str:
109
+ """Get the task name."""
110
+ return f"aime{self.year}" if self.year != "2025" else "aime"
111
+
112
+ def get_description(self) -> str:
113
+ """Get the task description."""
114
+ return f"AIME {self.year} contest problems requiring advanced mathematical reasoning"
115
+
116
+ def get_categories(self) -> List[str]:
117
+ """Get the task categories."""
118
+ return ["mathematics", "reasoning", "contest", "text_generation"]
119
+
120
+ # Methods to match lm-eval interface
121
+ def has_validation_docs(self) -> bool:
122
+ """Check if task has validation documents."""
123
+ return False # AIME doesn't have separate validation sets
124
+
125
+ def has_test_docs(self) -> bool:
126
+ """Check if task has test documents."""
127
+ return True # All samples are considered test docs
128
+
129
+ def test_docs(self) -> List[Dict[str, Any]]:
130
+ """Get test documents."""
131
+ if self._data is None:
132
+ self._data = self.load_data()
133
+ return self._data
134
+
135
+ def validation_docs(self) -> List[Dict[str, Any]]:
136
+ """Get validation documents."""
137
+ return [] # No separate validation set
138
+
139
+ def doc_to_text(self, doc: Dict[str, Any]) -> str:
140
+ """Convert document to text prompt."""
141
+ return doc.get('Problem', '')