wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task-agnostic interface for benchmark integration.
|
|
3
|
+
|
|
4
|
+
This module provides a unified interface for integrating different benchmarks
|
|
5
|
+
without depending on lm-evaluation-harness.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, Dict, List, Optional, Type
|
|
10
|
+
|
|
11
|
+
from .benchmark_extractors import BenchmarkExtractor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TaskInterface(ABC):
|
|
15
|
+
"""Abstract interface for benchmark tasks."""
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
19
|
+
"""Load task data."""
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def get_extractor(self) -> BenchmarkExtractor:
|
|
23
|
+
"""Get the benchmark extractor for this task."""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_name(self) -> str:
|
|
27
|
+
"""Get the task name."""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def get_description(self) -> str:
|
|
31
|
+
"""Get the task description."""
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def get_categories(self) -> List[str]:
|
|
35
|
+
"""Get the task categories (e.g., ['coding', 'reasoning'])."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TaskRegistry:
|
|
39
|
+
"""Registry for managing available tasks."""
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
self._tasks: Dict[str, Type[TaskInterface]] = {}
|
|
43
|
+
|
|
44
|
+
def register_task(self, name: str, task_class: Type[TaskInterface]):
|
|
45
|
+
"""Register a new task."""
|
|
46
|
+
self._tasks[name] = task_class
|
|
47
|
+
|
|
48
|
+
def get_task(self, name: str, limit: Optional[int] = None) -> TaskInterface:
|
|
49
|
+
"""Get a task instance by name."""
|
|
50
|
+
if name not in self._tasks:
|
|
51
|
+
raise ValueError(f"Task '{name}' not found. Available tasks: {list(self._tasks.keys())}")
|
|
52
|
+
|
|
53
|
+
task_factory = self._tasks[name]
|
|
54
|
+
|
|
55
|
+
# Handle different task factory types
|
|
56
|
+
if callable(task_factory):
|
|
57
|
+
# Try calling with limit parameter
|
|
58
|
+
try:
|
|
59
|
+
return task_factory(limit=limit)
|
|
60
|
+
except TypeError:
|
|
61
|
+
# Fallback for factories that don't accept limit
|
|
62
|
+
return task_factory()
|
|
63
|
+
else:
|
|
64
|
+
# Direct class instantiation
|
|
65
|
+
return task_factory()
|
|
66
|
+
|
|
67
|
+
def list_tasks(self) -> List[str]:
|
|
68
|
+
"""List all available task names."""
|
|
69
|
+
return list(self._tasks.keys())
|
|
70
|
+
|
|
71
|
+
def get_task_info(self, name: str) -> Dict[str, Any]:
|
|
72
|
+
"""Get information about a specific task."""
|
|
73
|
+
task = self.get_task(name)
|
|
74
|
+
return {"name": task.get_name(), "description": task.get_description(), "categories": task.get_categories()}
|
|
75
|
+
|
|
76
|
+
def list_task_info(self) -> List[Dict[str, Any]]:
|
|
77
|
+
"""List information about all available tasks."""
|
|
78
|
+
return [self.get_task_info(name) for name in self.list_tasks()]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Global task registry instance
|
|
82
|
+
_task_registry = TaskRegistry()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def register_task(name: str, task_class: Type[TaskInterface]):
|
|
86
|
+
"""Register a new task globally."""
|
|
87
|
+
_task_registry.register_task(name, task_class)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_task(name: str, limit: Optional[int] = None) -> TaskInterface:
|
|
91
|
+
"""Get a task instance by name."""
|
|
92
|
+
# Ensure tasks are registered before attempting to get a task
|
|
93
|
+
_ensure_tasks_registered()
|
|
94
|
+
|
|
95
|
+
# Check if this is a file path (contains / or \\ or ends with .json)
|
|
96
|
+
if "/" in name or "\\" in name or name.endswith(".json"):
|
|
97
|
+
# Treat as file path and load directly
|
|
98
|
+
from .tasks.file_task import FileTask
|
|
99
|
+
|
|
100
|
+
return FileTask(name, limit=limit)
|
|
101
|
+
|
|
102
|
+
# Otherwise, try to get from registry
|
|
103
|
+
try:
|
|
104
|
+
return _task_registry.get_task(name, limit=limit)
|
|
105
|
+
except ValueError:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"Task '{name}' not found in registry. Available tasks: {list(_task_registry._tasks.keys())}. To load a custom dataset, provide a file path ending with .json"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def list_tasks() -> List[str]:
|
|
112
|
+
"""List all available task names."""
|
|
113
|
+
_ensure_tasks_registered()
|
|
114
|
+
return _task_registry.list_tasks()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_task_info(name: str) -> Dict[str, Any]:
|
|
118
|
+
"""Get information about a specific task."""
|
|
119
|
+
return _task_registry.get_task_info(name)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def list_task_info() -> List[Dict[str, Any]]:
|
|
123
|
+
"""List information about all available tasks."""
|
|
124
|
+
return _task_registry.list_task_info()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _ensure_tasks_registered():
|
|
128
|
+
"""Ensure all tasks are registered in the global registry."""
|
|
129
|
+
if len(_task_registry._tasks) == 0: # Only register if not already done
|
|
130
|
+
# Import tasks module to trigger registration
|
|
131
|
+
# This is crucial for CLI usage where tasks module isn't imported elsewhere
|
|
132
|
+
from . import tasks # noqa: F401 # This imports __init__.py which calls register_all_tasks()
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task selector for choosing tasks based on skills and risks tags.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import random
|
|
8
|
+
import logging
|
|
9
|
+
from typing import List, Dict, Any, Optional, Set
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TaskSelector:
|
|
16
|
+
"""Select tasks based on skills and risks criteria."""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
"""Initialize the task selector by loading metadata."""
|
|
20
|
+
self.base_path = Path(__file__).parent.parent / "parameters" / "tasks"
|
|
21
|
+
self.skills = self._load_json("skills.json")
|
|
22
|
+
self.risks = self._load_json("risks.json")
|
|
23
|
+
self.tasks_data = self._load_json("tasks.json")
|
|
24
|
+
self.tasks = self.tasks_data.get("tasks", {})
|
|
25
|
+
|
|
26
|
+
def _load_json(self, filename: str) -> Any:
|
|
27
|
+
"""Load JSON file from parameters/tasks directory."""
|
|
28
|
+
filepath = self.base_path / filename
|
|
29
|
+
try:
|
|
30
|
+
with open(filepath, 'r') as f:
|
|
31
|
+
return json.load(f)
|
|
32
|
+
except Exception as e:
|
|
33
|
+
logger.error(f"Failed to load {filename}: {e}")
|
|
34
|
+
return {} if filename == "tasks.json" else []
|
|
35
|
+
|
|
36
|
+
def get_available_skills(self) -> List[str]:
|
|
37
|
+
"""Get list of available skills."""
|
|
38
|
+
return self.skills
|
|
39
|
+
|
|
40
|
+
def get_available_risks(self) -> List[str]:
|
|
41
|
+
"""Get list of available risks."""
|
|
42
|
+
return self.risks
|
|
43
|
+
|
|
44
|
+
def find_tasks_by_tags(
|
|
45
|
+
self,
|
|
46
|
+
skills: Optional[List[str]] = None,
|
|
47
|
+
risks: Optional[List[str]] = None,
|
|
48
|
+
min_quality_score: int = 2
|
|
49
|
+
) -> List[str]:
|
|
50
|
+
"""
|
|
51
|
+
Find tasks that match the given skills and/or risks.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
skills: List of skill tags to match
|
|
55
|
+
risks: List of risk tags to match
|
|
56
|
+
min_quality_score: Minimum quality score for tasks (default: 2)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of task names that match the criteria
|
|
60
|
+
"""
|
|
61
|
+
if not skills and not risks:
|
|
62
|
+
# Return all tasks if no criteria specified
|
|
63
|
+
return [
|
|
64
|
+
task_name for task_name, task_data in self.tasks.items()
|
|
65
|
+
if task_data.get("quality_score", 0) >= min_quality_score
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
# Convert to sets for efficient lookup
|
|
69
|
+
required_tags = set()
|
|
70
|
+
if skills:
|
|
71
|
+
required_tags.update(skills)
|
|
72
|
+
if risks:
|
|
73
|
+
required_tags.update(risks)
|
|
74
|
+
|
|
75
|
+
matched_tasks = []
|
|
76
|
+
for task_name, task_data in self.tasks.items():
|
|
77
|
+
# Check quality score
|
|
78
|
+
if task_data.get("quality_score", 0) < min_quality_score:
|
|
79
|
+
continue
|
|
80
|
+
|
|
81
|
+
# Check if task has any of the required tags
|
|
82
|
+
task_tags = set(task_data.get("tags", []))
|
|
83
|
+
if task_tags.intersection(required_tags):
|
|
84
|
+
matched_tasks.append(task_name)
|
|
85
|
+
|
|
86
|
+
return matched_tasks
|
|
87
|
+
|
|
88
|
+
def select_random_tasks(
|
|
89
|
+
self,
|
|
90
|
+
skills: Optional[List[str]] = None,
|
|
91
|
+
risks: Optional[List[str]] = None,
|
|
92
|
+
num_tasks: Optional[int] = None,
|
|
93
|
+
min_quality_score: int = 2,
|
|
94
|
+
seed: Optional[int] = None
|
|
95
|
+
) -> List[str]:
|
|
96
|
+
"""
|
|
97
|
+
Select random tasks based on skills/risks criteria.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
skills: List of skill tags to match
|
|
101
|
+
risks: List of risk tags to match
|
|
102
|
+
num_tasks: Number of tasks to select (None = all matching tasks)
|
|
103
|
+
min_quality_score: Minimum quality score for tasks
|
|
104
|
+
seed: Random seed for reproducibility
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
List of randomly selected task names
|
|
108
|
+
"""
|
|
109
|
+
# Find matching tasks
|
|
110
|
+
matched_tasks = self.find_tasks_by_tags(skills, risks, min_quality_score)
|
|
111
|
+
|
|
112
|
+
if not matched_tasks:
|
|
113
|
+
logger.warning(f"No tasks found matching skills={skills}, risks={risks}")
|
|
114
|
+
return []
|
|
115
|
+
|
|
116
|
+
# Set random seed if provided
|
|
117
|
+
if seed is not None:
|
|
118
|
+
random.seed(seed)
|
|
119
|
+
|
|
120
|
+
# Select tasks
|
|
121
|
+
if num_tasks is None or num_tasks >= len(matched_tasks):
|
|
122
|
+
selected = matched_tasks
|
|
123
|
+
else:
|
|
124
|
+
selected = random.sample(matched_tasks, num_tasks)
|
|
125
|
+
|
|
126
|
+
logger.info(f"Selected {len(selected)} tasks from {len(matched_tasks)} matching tasks")
|
|
127
|
+
return selected
|
|
128
|
+
|
|
129
|
+
def validate_skills_and_risks(
|
|
130
|
+
self,
|
|
131
|
+
skills: Optional[List[str]] = None,
|
|
132
|
+
risks: Optional[List[str]] = None
|
|
133
|
+
) -> Dict[str, List[str]]:
|
|
134
|
+
"""
|
|
135
|
+
Validate provided skills and risks against available options.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Dictionary with 'invalid_skills' and 'invalid_risks' lists
|
|
139
|
+
"""
|
|
140
|
+
invalid = {"invalid_skills": [], "invalid_risks": []}
|
|
141
|
+
|
|
142
|
+
if skills:
|
|
143
|
+
valid_skills = set(self.skills)
|
|
144
|
+
invalid["invalid_skills"] = [s for s in skills if s not in valid_skills]
|
|
145
|
+
|
|
146
|
+
if risks:
|
|
147
|
+
valid_risks = set(self.risks)
|
|
148
|
+
invalid["invalid_risks"] = [r for r in risks if r not in valid_risks]
|
|
149
|
+
|
|
150
|
+
return invalid
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_tasks_for_skills_and_risks(
|
|
154
|
+
skills: Optional[List[str]] = None,
|
|
155
|
+
risks: Optional[List[str]] = None,
|
|
156
|
+
num_tasks: Optional[int] = None,
|
|
157
|
+
min_quality_score: int = 2,
|
|
158
|
+
seed: Optional[int] = None
|
|
159
|
+
) -> List[str]:
|
|
160
|
+
"""
|
|
161
|
+
Convenience function to get tasks matching skills/risks criteria.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
skills: List of skill tags to match
|
|
165
|
+
risks: List of risk tags to match
|
|
166
|
+
num_tasks: Number of tasks to select (None = all)
|
|
167
|
+
min_quality_score: Minimum quality score
|
|
168
|
+
seed: Random seed
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
List of task names
|
|
172
|
+
"""
|
|
173
|
+
selector = TaskSelector()
|
|
174
|
+
|
|
175
|
+
# Validate inputs
|
|
176
|
+
invalid = selector.validate_skills_and_risks(skills, risks)
|
|
177
|
+
if invalid["invalid_skills"]:
|
|
178
|
+
logger.warning(f"Invalid skills: {invalid['invalid_skills']}")
|
|
179
|
+
if invalid["invalid_risks"]:
|
|
180
|
+
logger.warning(f"Invalid risks: {invalid['invalid_risks']}")
|
|
181
|
+
|
|
182
|
+
# Select tasks
|
|
183
|
+
return selector.select_random_tasks(
|
|
184
|
+
skills=skills,
|
|
185
|
+
risks=risks,
|
|
186
|
+
num_tasks=num_tasks,
|
|
187
|
+
min_quality_score=min_quality_score,
|
|
188
|
+
seed=seed
|
|
189
|
+
)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task implementations for wisent-guard.
|
|
3
|
+
|
|
4
|
+
This package contains task-agnostic implementations for various benchmarks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ..task_interface import register_task
|
|
8
|
+
from .aime_task import AIMETask
|
|
9
|
+
from .hle_task import HLEExactMatchTask, HLEMultipleChoiceTask, HLETask
|
|
10
|
+
from .hmmt_task import HMMTTask
|
|
11
|
+
from .livecodebench_task import LiveCodeBenchTask
|
|
12
|
+
from .livemathbench_task import LiveMathBenchTask
|
|
13
|
+
from .lm_eval_task import (
|
|
14
|
+
AppsTask,
|
|
15
|
+
CodexglueCodeToTextGoTask,
|
|
16
|
+
CodexglueCodeToTextJavascriptTask,
|
|
17
|
+
CodexglueCodeToTextJavaTask,
|
|
18
|
+
CodexglueCodeToTextPhpTask,
|
|
19
|
+
CodexglueCodeToTextPythonTask,
|
|
20
|
+
CodexglueCodeToTextRubyTask,
|
|
21
|
+
ConalaTask,
|
|
22
|
+
ConcodeTask,
|
|
23
|
+
DS1000Task,
|
|
24
|
+
GSM8KTask,
|
|
25
|
+
HumanEvalPlusTask,
|
|
26
|
+
HumanEvalTask,
|
|
27
|
+
InstructHumanEvalTask,
|
|
28
|
+
MBPPPlusTask,
|
|
29
|
+
MBPPTask,
|
|
30
|
+
MercuryTask,
|
|
31
|
+
MMLUTask,
|
|
32
|
+
MultipleCppTask,
|
|
33
|
+
MultipleGoTask,
|
|
34
|
+
MultipleJavaTask,
|
|
35
|
+
MultipleJsTask,
|
|
36
|
+
MultiplePyTask,
|
|
37
|
+
MultipleRsTask,
|
|
38
|
+
RecodeTask,
|
|
39
|
+
Squad2Task,
|
|
40
|
+
TruthfulQATask,
|
|
41
|
+
)
|
|
42
|
+
from .math500_task import Math500Task
|
|
43
|
+
from .polymath_task import PolyMathTask
|
|
44
|
+
from .supergpqa_task import SuperGPQABiologyTask, SuperGPQAChemistryTask, SuperGPQAPhysicsTask, SuperGPQATask
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def register_all_tasks():
|
|
48
|
+
"""Register all available tasks."""
|
|
49
|
+
# Register LiveCodeBench task
|
|
50
|
+
register_task("livecodebench", lambda limit=None: LiveCodeBenchTask(release_version="release_v1", limit=limit))
|
|
51
|
+
|
|
52
|
+
# Register common lm-eval tasks
|
|
53
|
+
register_task("gsm8k", GSM8KTask)
|
|
54
|
+
register_task("truthfulqa_mc1", TruthfulQATask)
|
|
55
|
+
register_task("mmlu", MMLUTask)
|
|
56
|
+
|
|
57
|
+
# Register all coding tasks
|
|
58
|
+
register_task("mbpp", MBPPTask)
|
|
59
|
+
register_task("humaneval", HumanEvalTask)
|
|
60
|
+
register_task("mbpp_plus", MBPPPlusTask)
|
|
61
|
+
register_task("instructhumaneval", InstructHumanEvalTask)
|
|
62
|
+
register_task("humaneval_plus", HumanEvalPlusTask)
|
|
63
|
+
register_task("conala", ConalaTask)
|
|
64
|
+
register_task("concode", ConcodeTask)
|
|
65
|
+
register_task("mercury", MercuryTask)
|
|
66
|
+
register_task("apps", AppsTask)
|
|
67
|
+
register_task("ds1000", DS1000Task)
|
|
68
|
+
register_task("multiple_py", MultiplePyTask)
|
|
69
|
+
register_task("multiple_js", MultipleJsTask)
|
|
70
|
+
register_task("multiple_java", MultipleJavaTask)
|
|
71
|
+
register_task("multiple_cpp", MultipleCppTask)
|
|
72
|
+
register_task("multiple_rs", MultipleRsTask)
|
|
73
|
+
register_task("multiple_go", MultipleGoTask)
|
|
74
|
+
register_task("codexglue_code_to_text_python", CodexglueCodeToTextPythonTask)
|
|
75
|
+
register_task("codexglue_code_to_text_go", CodexglueCodeToTextGoTask)
|
|
76
|
+
register_task("codexglue_code_to_text_ruby", CodexglueCodeToTextRubyTask)
|
|
77
|
+
register_task("codexglue_code_to_text_java", CodexglueCodeToTextJavaTask)
|
|
78
|
+
register_task("codexglue_code_to_text_javascript", CodexglueCodeToTextJavascriptTask)
|
|
79
|
+
register_task("codexglue_code_to_text_php", CodexglueCodeToTextPhpTask)
|
|
80
|
+
register_task("recode", RecodeTask)
|
|
81
|
+
register_task("squad2", Squad2Task)
|
|
82
|
+
|
|
83
|
+
# Register HLE tasks
|
|
84
|
+
register_task("hle", lambda limit=None: HLETask(limit=limit))
|
|
85
|
+
register_task("hle_exact_match", lambda limit=None: HLEExactMatchTask(limit=limit))
|
|
86
|
+
register_task("hle_multiple_choice", lambda limit=None: HLEMultipleChoiceTask(limit=limit))
|
|
87
|
+
|
|
88
|
+
# Register MATH-500 tasks
|
|
89
|
+
register_task("math500", lambda limit=None: Math500Task(limit=limit))
|
|
90
|
+
register_task("math", lambda limit=None: Math500Task(limit=limit))
|
|
91
|
+
register_task("hendrycks_math", lambda limit=None: Math500Task(limit=limit))
|
|
92
|
+
|
|
93
|
+
# Register AIME tasks (general + year-specific)
|
|
94
|
+
register_task("aime", lambda limit=None: AIMETask(year="2025", limit=limit)) # Default: latest year (2025)
|
|
95
|
+
register_task("aime2025", lambda limit=None: AIMETask(year="2025", limit=limit))
|
|
96
|
+
register_task("aime2024", lambda limit=None: AIMETask(year="2024", limit=limit))
|
|
97
|
+
|
|
98
|
+
# Register HMMT tasks (general + competition-specific)
|
|
99
|
+
register_task(
|
|
100
|
+
"hmmt", lambda limit=None: HMMTTask(competition="feb_2025", limit=limit)
|
|
101
|
+
) # Default: latest competition
|
|
102
|
+
register_task("hmmt_feb_2025", lambda limit=None: HMMTTask(competition="feb_2025", limit=limit))
|
|
103
|
+
|
|
104
|
+
# Register PolyMath tasks (Chinese and English, medium difficulty)
|
|
105
|
+
register_task(
|
|
106
|
+
"polymath", lambda limit=None: PolyMathTask(language="en", difficulty="medium", limit=limit)
|
|
107
|
+
) # Default: English medium
|
|
108
|
+
register_task(
|
|
109
|
+
"polymath_en_medium", lambda limit=None: PolyMathTask(language="en", difficulty="medium", limit=limit)
|
|
110
|
+
)
|
|
111
|
+
register_task(
|
|
112
|
+
"polymath_zh_medium", lambda limit=None: PolyMathTask(language="zh", difficulty="medium", limit=limit)
|
|
113
|
+
)
|
|
114
|
+
register_task("polymath_en_high", lambda limit=None: PolyMathTask(language="en", difficulty="high", limit=limit))
|
|
115
|
+
register_task("polymath_zh_high", lambda limit=None: PolyMathTask(language="zh", difficulty="high", limit=limit))
|
|
116
|
+
|
|
117
|
+
# Register LiveMathBench tasks (CNMO 2024 Chinese and English)
|
|
118
|
+
register_task("livemathbench", lambda limit=None: LiveMathBenchTask(language="en", limit=limit)) # Default: English
|
|
119
|
+
register_task("livemathbench_cnmo_en", lambda limit=None: LiveMathBenchTask(language="en", limit=limit))
|
|
120
|
+
register_task("livemathbench_cnmo_zh", lambda limit=None: LiveMathBenchTask(language="zh", limit=limit))
|
|
121
|
+
|
|
122
|
+
# Register SuperGPQA tasks (scientific reasoning)
|
|
123
|
+
register_task("supergpqa", lambda limit=None: SuperGPQATask(limit=limit)) # Default: all subjects
|
|
124
|
+
register_task("supergpqa_physics", lambda limit=None: SuperGPQAPhysicsTask(limit=limit))
|
|
125
|
+
register_task("supergpqa_chemistry", lambda limit=None: SuperGPQAChemistryTask(limit=limit))
|
|
126
|
+
register_task("supergpqa_biology", lambda limit=None: SuperGPQABiologyTask(limit=limit))
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# Auto-register tasks when the module is imported
|
|
130
|
+
register_all_tasks()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
__all__ = [
|
|
134
|
+
"AIMETask",
|
|
135
|
+
"AppsTask",
|
|
136
|
+
"CodexglueCodeToTextGoTask",
|
|
137
|
+
"CodexglueCodeToTextJavaTask",
|
|
138
|
+
"CodexglueCodeToTextJavascriptTask",
|
|
139
|
+
"CodexglueCodeToTextPhpTask",
|
|
140
|
+
"CodexglueCodeToTextPythonTask",
|
|
141
|
+
"CodexglueCodeToTextRubyTask",
|
|
142
|
+
"ConalaTask",
|
|
143
|
+
"ConcodeTask",
|
|
144
|
+
"DS1000Task",
|
|
145
|
+
"GSM8KTask",
|
|
146
|
+
"HLEExactMatchTask",
|
|
147
|
+
"HLEMultipleChoiceTask",
|
|
148
|
+
"HLETask",
|
|
149
|
+
"HMMTTask",
|
|
150
|
+
"HumanEvalPlusTask",
|
|
151
|
+
"HumanEvalTask",
|
|
152
|
+
"InstructHumanEvalTask",
|
|
153
|
+
"LiveCodeBenchTask",
|
|
154
|
+
"LiveMathBenchTask",
|
|
155
|
+
"MBPPPlusTask",
|
|
156
|
+
"MBPPTask",
|
|
157
|
+
"MMLUTask",
|
|
158
|
+
"Math500Task",
|
|
159
|
+
"MercuryTask",
|
|
160
|
+
"MultipleCppTask",
|
|
161
|
+
"MultipleGoTask",
|
|
162
|
+
"MultipleJavaTask",
|
|
163
|
+
"MultipleJsTask",
|
|
164
|
+
"MultiplePyTask",
|
|
165
|
+
"MultipleRsTask",
|
|
166
|
+
"PolyMathTask",
|
|
167
|
+
"RecodeTask",
|
|
168
|
+
"Squad2Task",
|
|
169
|
+
"SuperGPQABiologyTask",
|
|
170
|
+
"SuperGPQAChemistryTask",
|
|
171
|
+
"SuperGPQAPhysicsTask",
|
|
172
|
+
"SuperGPQATask",
|
|
173
|
+
"TruthfulQATask",
|
|
174
|
+
"register_all_tasks",
|
|
175
|
+
]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AIME task implementation for task-agnostic architecture.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Any, List, Optional
|
|
6
|
+
from ..task_interface import TaskInterface
|
|
7
|
+
from ..benchmark_extractors import GSM8KExtractor
|
|
8
|
+
import datasets
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AIMETask(TaskInterface):
|
|
12
|
+
"""General AIME mathematical contest task implementation."""
|
|
13
|
+
|
|
14
|
+
# Dataset configurations for different years
|
|
15
|
+
DATASET_CONFIGS = {
|
|
16
|
+
"2024": {
|
|
17
|
+
"source": "Maxwell-Jia/AIME_2024",
|
|
18
|
+
"split": "train",
|
|
19
|
+
"fields": {"problem": "Problem", "answer": "Answer"},
|
|
20
|
+
"description": "30 high-difficulty AIME contest problems from 2024"
|
|
21
|
+
},
|
|
22
|
+
"2025": {
|
|
23
|
+
"source": "MathArena/aime_2025",
|
|
24
|
+
"split": "train",
|
|
25
|
+
"fields": {"problem": "problem", "answer": "answer"},
|
|
26
|
+
"description": "30 high-difficulty AIME contest problems from 2025 (MathArena)"
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def __init__(self, year: str = "2025", limit: Optional[int] = None):
|
|
31
|
+
"""
|
|
32
|
+
Initialize AIME task for specified year.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
year: AIME year to load ("2024", "2025"). Default: "2025" (latest)
|
|
36
|
+
limit: Maximum number of samples to load
|
|
37
|
+
"""
|
|
38
|
+
if year not in self.DATASET_CONFIGS:
|
|
39
|
+
available = list(self.DATASET_CONFIGS.keys())
|
|
40
|
+
raise ValueError(f"AIME year '{year}' not supported. Available: {available}")
|
|
41
|
+
|
|
42
|
+
self.year = year
|
|
43
|
+
self.config = self.DATASET_CONFIGS[year]
|
|
44
|
+
self._limit = limit
|
|
45
|
+
self._data = None # Cache for loaded data
|
|
46
|
+
self._extractor = GSM8KExtractor() # Reuse enhanced GSM8K extractor
|
|
47
|
+
|
|
48
|
+
def load_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
49
|
+
"""Load AIME data from HuggingFace for specified year."""
|
|
50
|
+
# Load dataset based on year configuration
|
|
51
|
+
dataset = datasets.load_dataset(
|
|
52
|
+
self.config["source"],
|
|
53
|
+
split=self.config["split"]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Apply limit
|
|
57
|
+
effective_limit = limit or self._limit
|
|
58
|
+
if effective_limit:
|
|
59
|
+
dataset = dataset.select(range(min(effective_limit, len(dataset))))
|
|
60
|
+
|
|
61
|
+
# Convert to list and normalize field names
|
|
62
|
+
data = [dict(item) for item in dataset]
|
|
63
|
+
|
|
64
|
+
# Normalize field names for consistent processing
|
|
65
|
+
normalized_data = []
|
|
66
|
+
problem_field = self.config["fields"]["problem"]
|
|
67
|
+
answer_field = self.config["fields"]["answer"]
|
|
68
|
+
|
|
69
|
+
for item in data:
|
|
70
|
+
normalized_item = dict(item) # Keep all original fields
|
|
71
|
+
|
|
72
|
+
# Ensure consistent field names for extractor
|
|
73
|
+
if problem_field in item:
|
|
74
|
+
normalized_item["Problem"] = item[problem_field]
|
|
75
|
+
normalized_item["question"] = item[problem_field] # For question/answer format
|
|
76
|
+
|
|
77
|
+
if answer_field in item:
|
|
78
|
+
normalized_item["Answer"] = item[answer_field]
|
|
79
|
+
normalized_item["answer"] = item[answer_field] # For question/answer format
|
|
80
|
+
|
|
81
|
+
normalized_data.append(normalized_item)
|
|
82
|
+
|
|
83
|
+
return normalized_data
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_task_info(self) -> Dict[str, Any]:
|
|
87
|
+
"""Get information about the AIME task."""
|
|
88
|
+
return {
|
|
89
|
+
"task_name": f"aime{self.year}" if self.year != "2025" else "aime",
|
|
90
|
+
"year": self.year,
|
|
91
|
+
"description": self.config["description"],
|
|
92
|
+
"source": self.config["source"],
|
|
93
|
+
"task_type": "text_generation",
|
|
94
|
+
"evaluation_method": "mathematical_equivalence"
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def validate_sample(self, sample: Dict[str, Any]) -> bool:
|
|
98
|
+
"""Validate that a sample has required AIME fields."""
|
|
99
|
+
problem_field = self.config["fields"]["problem"]
|
|
100
|
+
answer_field = self.config["fields"]["answer"]
|
|
101
|
+
|
|
102
|
+
return all(field in sample for field in [problem_field, answer_field])
|
|
103
|
+
|
|
104
|
+
def get_extractor(self) -> GSM8KExtractor:
|
|
105
|
+
"""Get the benchmark extractor for this task."""
|
|
106
|
+
return self._extractor
|
|
107
|
+
|
|
108
|
+
def get_name(self) -> str:
|
|
109
|
+
"""Get the task name."""
|
|
110
|
+
return f"aime{self.year}" if self.year != "2025" else "aime"
|
|
111
|
+
|
|
112
|
+
def get_description(self) -> str:
|
|
113
|
+
"""Get the task description."""
|
|
114
|
+
return f"AIME {self.year} contest problems requiring advanced mathematical reasoning"
|
|
115
|
+
|
|
116
|
+
def get_categories(self) -> List[str]:
|
|
117
|
+
"""Get the task categories."""
|
|
118
|
+
return ["mathematics", "reasoning", "contest", "text_generation"]
|
|
119
|
+
|
|
120
|
+
# Methods to match lm-eval interface
|
|
121
|
+
def has_validation_docs(self) -> bool:
|
|
122
|
+
"""Check if task has validation documents."""
|
|
123
|
+
return False # AIME doesn't have separate validation sets
|
|
124
|
+
|
|
125
|
+
def has_test_docs(self) -> bool:
|
|
126
|
+
"""Check if task has test documents."""
|
|
127
|
+
return True # All samples are considered test docs
|
|
128
|
+
|
|
129
|
+
def test_docs(self) -> List[Dict[str, Any]]:
|
|
130
|
+
"""Get test documents."""
|
|
131
|
+
if self._data is None:
|
|
132
|
+
self._data = self.load_data()
|
|
133
|
+
return self._data
|
|
134
|
+
|
|
135
|
+
def validation_docs(self) -> List[Dict[str, Any]]:
|
|
136
|
+
"""Get validation documents."""
|
|
137
|
+
return [] # No separate validation set
|
|
138
|
+
|
|
139
|
+
def doc_to_text(self, doc: Dict[str, Any]) -> str:
|
|
140
|
+
"""Convert document to text prompt."""
|
|
141
|
+
return doc.get('Problem', '')
|