wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Managed Cached Benchmarks service for intelligent dataset downloading and caching.
|
|
3
|
+
|
|
4
|
+
This service controls how much of each benchmark is downloaded based on the limit parameter:
|
|
5
|
+
- If limit=5, download only 5 samples
|
|
6
|
+
- If limit=3 and we have 5 cached, reuse cached samples
|
|
7
|
+
- If limit=10 and we have 5 cached, download 5 more
|
|
8
|
+
- Hard errors for unsupported benchmarks, no fallbacks
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from collections.abc import Iterator
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from datetime import datetime, timedelta
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Dict, List, Optional
|
|
19
|
+
|
|
20
|
+
from .benchmark_extractors import EXTRACTORS, get_extractor
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BenchmarkError(Exception):
|
|
26
|
+
"""Base exception for benchmark-related errors."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class UnsupportedBenchmarkError(BenchmarkError):
|
|
30
|
+
"""Raised when benchmark has no adapter."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SampleNormalizationError(BenchmarkError):
|
|
34
|
+
"""Raised when sample normalization fails."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class InsufficientSamplesError(BenchmarkError):
|
|
38
|
+
"""Raised when benchmark doesn't have enough samples."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CacheCorruptionError(BenchmarkError):
|
|
42
|
+
"""Raised when cache data is corrupted."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class CacheInfo:
|
|
47
|
+
"""Information about cached benchmark data."""
|
|
48
|
+
|
|
49
|
+
task_name: str
|
|
50
|
+
samples_count: int
|
|
51
|
+
last_updated: datetime
|
|
52
|
+
cache_version: str
|
|
53
|
+
chunks: List[str] # List of chunk filenames
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class CacheMetadata:
|
|
58
|
+
"""Global cache metadata."""
|
|
59
|
+
|
|
60
|
+
version: str
|
|
61
|
+
created_at: datetime
|
|
62
|
+
last_cleanup: datetime
|
|
63
|
+
tasks: Dict[str, CacheInfo]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ManagedCachedBenchmarks:
|
|
67
|
+
"""
|
|
68
|
+
Service for intelligent benchmark downloading and caching.
|
|
69
|
+
|
|
70
|
+
Features:
|
|
71
|
+
- Downloads only what's needed based on limit parameter
|
|
72
|
+
- Reuses cached data when possible
|
|
73
|
+
- Incremental downloads for growing limits
|
|
74
|
+
- Hard errors for unsupported benchmarks
|
|
75
|
+
- Chunk-based storage for efficiency
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
CACHE_VERSION = "1.0"
|
|
79
|
+
CHUNK_SIZE = 25 # Samples per chunk
|
|
80
|
+
MAX_CACHE_AGE_DAYS = 30
|
|
81
|
+
SUPPORTED_BENCHMARKS = None # Will be initialized in __init__
|
|
82
|
+
|
|
83
|
+
def __init__(self, cache_dir: str = "./benchmark_cache"):
|
|
84
|
+
"""
|
|
85
|
+
Initialize the managed cache service.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
cache_dir: Directory to store cached benchmark data
|
|
89
|
+
"""
|
|
90
|
+
self.cache_dir = Path(cache_dir)
|
|
91
|
+
self.cache_dir.mkdir(exist_ok=True)
|
|
92
|
+
|
|
93
|
+
self.metadata_file = self.cache_dir / "metadata.json"
|
|
94
|
+
self._metadata = self._load_metadata()
|
|
95
|
+
|
|
96
|
+
# Initialize supported benchmarks including BigCode tasks
|
|
97
|
+
if ManagedCachedBenchmarks.SUPPORTED_BENCHMARKS is None:
|
|
98
|
+
supported = set(EXTRACTORS.keys())
|
|
99
|
+
try:
|
|
100
|
+
from .bigcode_integration import BigCodeTaskLoader
|
|
101
|
+
|
|
102
|
+
loader = BigCodeTaskLoader()
|
|
103
|
+
supported.update(loader.TASK_MAPPING.keys())
|
|
104
|
+
except ImportError:
|
|
105
|
+
pass
|
|
106
|
+
ManagedCachedBenchmarks.SUPPORTED_BENCHMARKS = supported
|
|
107
|
+
|
|
108
|
+
# Validate all supported benchmarks have extractors
|
|
109
|
+
self._validate_extractor_registry()
|
|
110
|
+
|
|
111
|
+
def _validate_extractor_registry(self):
|
|
112
|
+
"""Ensure every supported benchmark has a working extractor."""
|
|
113
|
+
for benchmark in self.SUPPORTED_BENCHMARKS:
|
|
114
|
+
try:
|
|
115
|
+
extractor = get_extractor(benchmark)
|
|
116
|
+
if not hasattr(extractor, "extract_qa_pair"):
|
|
117
|
+
raise AttributeError(f"Extractor for {benchmark} missing extract_qa_pair method")
|
|
118
|
+
except Exception as e:
|
|
119
|
+
raise BenchmarkError(f"Invalid extractor for supported benchmark '{benchmark}': {e}")
|
|
120
|
+
|
|
121
|
+
def get_task_samples(self, task_name: str, limit: int, force_fresh: bool = False) -> List[Dict[str, Any]]:
|
|
122
|
+
"""
|
|
123
|
+
Get samples for a task, using intelligent caching.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
task_name: Name of the benchmark task
|
|
127
|
+
limit: Number of samples needed
|
|
128
|
+
force_fresh: Force fresh download even if cached
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
List of normalized sample dictionaries
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
UnsupportedBenchmarkError: If task has no extractor
|
|
135
|
+
InsufficientSamplesError: If benchmark doesn't have enough samples
|
|
136
|
+
SampleNormalizationError: If sample extraction fails
|
|
137
|
+
"""
|
|
138
|
+
# Hard error for unsupported benchmarks
|
|
139
|
+
if task_name not in self.SUPPORTED_BENCHMARKS:
|
|
140
|
+
raise UnsupportedBenchmarkError(
|
|
141
|
+
f"No extractor found for benchmark '{task_name}'. "
|
|
142
|
+
f"Supported benchmarks: {sorted(self.SUPPORTED_BENCHMARKS)}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if limit <= 0:
|
|
146
|
+
return []
|
|
147
|
+
|
|
148
|
+
logger.info(f"Getting {limit} samples for task '{task_name}'")
|
|
149
|
+
|
|
150
|
+
# Check cache status
|
|
151
|
+
cached_count = self._get_cached_sample_count(task_name)
|
|
152
|
+
logger.info(f"Found {cached_count} cached samples for '{task_name}'")
|
|
153
|
+
|
|
154
|
+
if force_fresh:
|
|
155
|
+
logger.info(f"Force fresh download requested for '{task_name}'")
|
|
156
|
+
self._clear_task_cache(task_name)
|
|
157
|
+
cached_count = 0
|
|
158
|
+
|
|
159
|
+
# Decision logic
|
|
160
|
+
if cached_count >= limit:
|
|
161
|
+
# Case 1: We have enough - load from cache
|
|
162
|
+
logger.info(f"Loading {limit} samples from cache for '{task_name}'")
|
|
163
|
+
return self._load_cached_samples(task_name, limit)
|
|
164
|
+
|
|
165
|
+
if cached_count > 0 and limit <= cached_count * 2:
|
|
166
|
+
# Case 2: We have some, need a bit more - incremental download
|
|
167
|
+
needed = limit - cached_count
|
|
168
|
+
logger.info(f"Incremental download: need {needed} more samples for '{task_name}'")
|
|
169
|
+
|
|
170
|
+
new_samples = self._download_samples(task_name, needed, start_offset=cached_count)
|
|
171
|
+
self._append_to_cache(task_name, new_samples)
|
|
172
|
+
|
|
173
|
+
return self._load_cached_samples(task_name, limit)
|
|
174
|
+
|
|
175
|
+
# Case 3: Major mismatch - fresh download
|
|
176
|
+
logger.info(f"Fresh download: getting {limit} samples for '{task_name}'")
|
|
177
|
+
self._clear_task_cache(task_name)
|
|
178
|
+
|
|
179
|
+
new_samples = self._download_samples(task_name, limit, start_offset=0)
|
|
180
|
+
self._save_to_cache(task_name, new_samples)
|
|
181
|
+
|
|
182
|
+
return new_samples
|
|
183
|
+
|
|
184
|
+
def _get_cached_sample_count(self, task_name: str) -> int:
|
|
185
|
+
"""Get number of cached samples for a task."""
|
|
186
|
+
if task_name not in self._metadata.tasks:
|
|
187
|
+
return 0
|
|
188
|
+
return self._metadata.tasks[task_name].samples_count
|
|
189
|
+
|
|
190
|
+
def _download_samples(self, task_name: str, limit: int, start_offset: int = 0) -> List[Dict[str, Any]]:
|
|
191
|
+
"""
|
|
192
|
+
Download samples from a benchmark task.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
task_name: Name of the benchmark
|
|
196
|
+
limit: Number of samples to download
|
|
197
|
+
start_offset: Offset to start downloading from
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
List of normalized samples
|
|
201
|
+
"""
|
|
202
|
+
logger.info(f"Downloading {limit} samples for '{task_name}' (offset: {start_offset})")
|
|
203
|
+
|
|
204
|
+
# Get extractor (hard error if not found)
|
|
205
|
+
extractor = get_extractor(task_name)
|
|
206
|
+
|
|
207
|
+
# Load raw task from lm-eval, BigCode, or TaskInterface
|
|
208
|
+
try:
|
|
209
|
+
task = self._load_lm_eval_task(task_name)
|
|
210
|
+
except Exception as e:
|
|
211
|
+
# Check if it's a BigCode task
|
|
212
|
+
from .bigcode_integration import BigCodeTaskLoader
|
|
213
|
+
|
|
214
|
+
loader = BigCodeTaskLoader()
|
|
215
|
+
if loader.is_bigcode_task(task_name):
|
|
216
|
+
task = self._load_bigcode_task(task_name)
|
|
217
|
+
# Check if it's a TaskInterface task (like AIME, HLE, etc.)
|
|
218
|
+
elif self._is_taskinterface_task(task_name):
|
|
219
|
+
task = self._load_taskinterface_task(task_name, limit=start_offset + limit)
|
|
220
|
+
else:
|
|
221
|
+
raise BenchmarkError(f"Failed to load task '{task_name}' from lm-eval: {e}")
|
|
222
|
+
|
|
223
|
+
# Get sample iterator
|
|
224
|
+
try:
|
|
225
|
+
sample_iterator = self._get_task_sample_iterator(task, start_offset + limit)
|
|
226
|
+
except Exception as e:
|
|
227
|
+
raise BenchmarkError(f"Failed to get samples from task '{task_name}': {e}")
|
|
228
|
+
|
|
229
|
+
# Skip to start offset
|
|
230
|
+
for _ in range(start_offset):
|
|
231
|
+
try:
|
|
232
|
+
next(sample_iterator)
|
|
233
|
+
except StopIteration:
|
|
234
|
+
raise InsufficientSamplesError(
|
|
235
|
+
f"Task '{task_name}' only has {start_offset} samples, cannot skip to offset {start_offset}"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Extract samples
|
|
239
|
+
samples = []
|
|
240
|
+
for i in range(limit):
|
|
241
|
+
try:
|
|
242
|
+
raw_sample = next(sample_iterator)
|
|
243
|
+
except StopIteration:
|
|
244
|
+
raise InsufficientSamplesError(
|
|
245
|
+
f"Task '{task_name}' only has {start_offset + i} samples, but {start_offset + limit} were requested"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Extract contrastive pair using extractor (includes both correct and incorrect answers)
|
|
249
|
+
try:
|
|
250
|
+
qa_pair = extractor.extract_contrastive_pair(raw_sample, task)
|
|
251
|
+
if qa_pair is None:
|
|
252
|
+
raise ValueError("Extractor returned None")
|
|
253
|
+
except Exception as e:
|
|
254
|
+
raise SampleNormalizationError(f"Failed to normalize sample {start_offset + i} from '{task_name}': {e}")
|
|
255
|
+
|
|
256
|
+
samples.append(
|
|
257
|
+
{
|
|
258
|
+
"id": f"sample_{start_offset + i:03d}",
|
|
259
|
+
"raw_data": raw_sample,
|
|
260
|
+
"normalized": qa_pair,
|
|
261
|
+
"extracted_at": datetime.now().isoformat(),
|
|
262
|
+
}
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
logger.info(f"Successfully downloaded {len(samples)} samples for '{task_name}'")
|
|
266
|
+
return samples
|
|
267
|
+
|
|
268
|
+
def _load_bigcode_task(self, task_name: str):
|
|
269
|
+
"""Load task from bigcode-evaluation-harness."""
|
|
270
|
+
from .bigcode_integration import BigCodeTaskLoader
|
|
271
|
+
|
|
272
|
+
loader = BigCodeTaskLoader()
|
|
273
|
+
|
|
274
|
+
# For APPS, we need to check if HF_ALLOW_CODE_EVAL is set
|
|
275
|
+
if task_name == "apps" and os.environ.get("HF_ALLOW_CODE_EVAL") != "1":
|
|
276
|
+
print(f"\n⚠️ Task '{task_name}' requires code evaluation permission.")
|
|
277
|
+
print("This task will execute model-generated code which could be potentially harmful.")
|
|
278
|
+
print("Please review the safety information at: https://arxiv.org/abs/2107.03374")
|
|
279
|
+
response = input("\nDo you want to enable code evaluation? (yes/no): ").strip().lower()
|
|
280
|
+
|
|
281
|
+
if response == "yes":
|
|
282
|
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
|
283
|
+
print("✅ Code evaluation enabled for this session.")
|
|
284
|
+
else:
|
|
285
|
+
raise BenchmarkError(f"Code evaluation permission denied for task '{task_name}'")
|
|
286
|
+
|
|
287
|
+
return loader.load_task(task_name)
|
|
288
|
+
|
|
289
|
+
def _load_lm_eval_task(self, task_name: str):
|
|
290
|
+
"""Load task from lm-eval-harness."""
|
|
291
|
+
try:
|
|
292
|
+
from lm_eval.tasks import get_task_dict
|
|
293
|
+
|
|
294
|
+
# First check if it's a BigCode task before trying lm-eval
|
|
295
|
+
from .bigcode_integration import BigCodeTaskLoader
|
|
296
|
+
|
|
297
|
+
loader = BigCodeTaskLoader()
|
|
298
|
+
if loader.is_bigcode_task(task_name):
|
|
299
|
+
raise ValueError(f"Task '{task_name}' is a BigCode task. Use --bigcode flag or BigCodeTaskLoader")
|
|
300
|
+
|
|
301
|
+
# Check if we need HF_ALLOW_CODE_EVAL for code evaluation tasks
|
|
302
|
+
code_eval_tasks = ["mbpp", "mbpp_plus", "humaneval", "humaneval_plus"]
|
|
303
|
+
if task_name in code_eval_tasks and os.environ.get("HF_ALLOW_CODE_EVAL") != "1":
|
|
304
|
+
print(f"\n⚠️ Task '{task_name}' requires code evaluation permission.")
|
|
305
|
+
print("This task will execute model-generated code which could be potentially harmful.")
|
|
306
|
+
print("Please review the safety information at: https://arxiv.org/abs/2107.03374")
|
|
307
|
+
response = input("\nDo you want to enable code evaluation? (yes/no): ").strip().lower()
|
|
308
|
+
|
|
309
|
+
if response == "yes":
|
|
310
|
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
|
311
|
+
print("✅ Code evaluation enabled for this session.")
|
|
312
|
+
else:
|
|
313
|
+
raise BenchmarkError(f"Code evaluation permission denied for task '{task_name}'")
|
|
314
|
+
|
|
315
|
+
task_dict = get_task_dict([task_name])
|
|
316
|
+
if task_name not in task_dict:
|
|
317
|
+
raise ValueError(f"Task '{task_name}' not found in lm-eval")
|
|
318
|
+
|
|
319
|
+
return task_dict[task_name]
|
|
320
|
+
except ImportError as e:
|
|
321
|
+
raise BenchmarkError("lm-evaluation-harness not available") from e
|
|
322
|
+
|
|
323
|
+
def _is_taskinterface_task(self, task_name: str) -> bool:
|
|
324
|
+
"""Check if task is a TaskInterface-based task by checking the task registry."""
|
|
325
|
+
from .task_interface import list_tasks
|
|
326
|
+
|
|
327
|
+
return task_name in list_tasks()
|
|
328
|
+
|
|
329
|
+
def _load_taskinterface_task(self, task_name: str, limit: Optional[int] = None):
|
|
330
|
+
"""Load TaskInterface task using the central task registry."""
|
|
331
|
+
from .task_interface import get_task
|
|
332
|
+
|
|
333
|
+
try:
|
|
334
|
+
return get_task(task_name, limit=limit)
|
|
335
|
+
except Exception as e:
|
|
336
|
+
raise BenchmarkError(f"Failed to load TaskInterface task '{task_name}': {e}")
|
|
337
|
+
|
|
338
|
+
def _get_task_sample_iterator(self, task, limit: int) -> Iterator[Dict[str, Any]]:
|
|
339
|
+
"""Get iterator over task samples."""
|
|
340
|
+
# Try different document sources in order of preference
|
|
341
|
+
if hasattr(task, "validation_docs") and task.has_validation_docs():
|
|
342
|
+
docs = task.validation_docs()
|
|
343
|
+
elif hasattr(task, "test_docs") and task.has_test_docs():
|
|
344
|
+
docs = task.test_docs()
|
|
345
|
+
elif hasattr(task, "training_docs") and task.has_training_docs():
|
|
346
|
+
docs = task.training_docs()
|
|
347
|
+
else:
|
|
348
|
+
raise BenchmarkError("No document source available for task")
|
|
349
|
+
|
|
350
|
+
# Convert to iterator and limit
|
|
351
|
+
doc_iter = iter(docs)
|
|
352
|
+
for i, doc in enumerate(doc_iter):
|
|
353
|
+
if i >= limit:
|
|
354
|
+
break
|
|
355
|
+
yield doc
|
|
356
|
+
|
|
357
|
+
def _save_to_cache(self, task_name: str, samples: List[Dict[str, Any]]):
|
|
358
|
+
"""Save samples to cache in chunks."""
|
|
359
|
+
task_dir = self.cache_dir / task_name
|
|
360
|
+
task_dir.mkdir(exist_ok=True)
|
|
361
|
+
|
|
362
|
+
# Save in chunks
|
|
363
|
+
chunks = []
|
|
364
|
+
for i in range(0, len(samples), self.CHUNK_SIZE):
|
|
365
|
+
chunk_samples = samples[i : i + self.CHUNK_SIZE]
|
|
366
|
+
chunk_filename = f"samples_{i + 1}_to_{i + len(chunk_samples)}.json"
|
|
367
|
+
chunk_path = task_dir / chunk_filename
|
|
368
|
+
|
|
369
|
+
with open(chunk_path, "w") as f:
|
|
370
|
+
json.dump(chunk_samples, f, indent=2)
|
|
371
|
+
|
|
372
|
+
chunks.append(chunk_filename)
|
|
373
|
+
|
|
374
|
+
# Update metadata
|
|
375
|
+
cache_info = CacheInfo(
|
|
376
|
+
task_name=task_name,
|
|
377
|
+
samples_count=len(samples),
|
|
378
|
+
last_updated=datetime.now(),
|
|
379
|
+
cache_version=self.CACHE_VERSION,
|
|
380
|
+
chunks=chunks,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
self._metadata.tasks[task_name] = cache_info
|
|
384
|
+
self._save_metadata()
|
|
385
|
+
|
|
386
|
+
logger.info(f"Saved {len(samples)} samples to cache for '{task_name}' in {len(chunks)} chunks")
|
|
387
|
+
|
|
388
|
+
def _append_to_cache(self, task_name: str, new_samples: List[Dict[str, Any]]):
|
|
389
|
+
"""Append new samples to existing cache."""
|
|
390
|
+
if task_name not in self._metadata.tasks:
|
|
391
|
+
return self._save_to_cache(task_name, new_samples)
|
|
392
|
+
|
|
393
|
+
task_dir = self.cache_dir / task_name
|
|
394
|
+
cache_info = self._metadata.tasks[task_name]
|
|
395
|
+
|
|
396
|
+
# Load existing samples
|
|
397
|
+
existing_samples = self._load_all_cached_samples(task_name)
|
|
398
|
+
|
|
399
|
+
# Combine and re-save in chunks
|
|
400
|
+
all_samples = existing_samples + new_samples
|
|
401
|
+
self._save_to_cache(task_name, all_samples)
|
|
402
|
+
|
|
403
|
+
def _load_cached_samples(self, task_name: str, limit: int) -> List[Dict[str, Any]]:
|
|
404
|
+
"""Load cached samples up to limit."""
|
|
405
|
+
if task_name not in self._metadata.tasks:
|
|
406
|
+
return []
|
|
407
|
+
|
|
408
|
+
cache_info = self._metadata.tasks[task_name]
|
|
409
|
+
task_dir = self.cache_dir / task_name
|
|
410
|
+
|
|
411
|
+
samples = []
|
|
412
|
+
samples_loaded = 0
|
|
413
|
+
|
|
414
|
+
for chunk_filename in cache_info.chunks:
|
|
415
|
+
if samples_loaded >= limit:
|
|
416
|
+
break
|
|
417
|
+
|
|
418
|
+
chunk_path = task_dir / chunk_filename
|
|
419
|
+
if not chunk_path.exists():
|
|
420
|
+
raise CacheCorruptionError(f"Missing chunk file: {chunk_path}")
|
|
421
|
+
|
|
422
|
+
try:
|
|
423
|
+
with open(chunk_path) as f:
|
|
424
|
+
chunk_samples = json.load(f)
|
|
425
|
+
except Exception as e:
|
|
426
|
+
raise CacheCorruptionError(f"Corrupted chunk file {chunk_path}: {e}")
|
|
427
|
+
|
|
428
|
+
# Add samples until we reach the limit
|
|
429
|
+
for sample in chunk_samples:
|
|
430
|
+
if samples_loaded >= limit:
|
|
431
|
+
break
|
|
432
|
+
samples.append(sample)
|
|
433
|
+
samples_loaded += 1
|
|
434
|
+
|
|
435
|
+
logger.info(f"Loaded {len(samples)} cached samples for '{task_name}'")
|
|
436
|
+
return samples
|
|
437
|
+
|
|
438
|
+
def _load_all_cached_samples(self, task_name: str) -> List[Dict[str, Any]]:
|
|
439
|
+
"""Load all cached samples for a task."""
|
|
440
|
+
if task_name not in self._metadata.tasks:
|
|
441
|
+
return []
|
|
442
|
+
|
|
443
|
+
cache_info = self._metadata.tasks[task_name]
|
|
444
|
+
return self._load_cached_samples(task_name, cache_info.samples_count)
|
|
445
|
+
|
|
446
|
+
def _clear_task_cache(self, task_name: str):
|
|
447
|
+
"""Clear cache for a specific task."""
|
|
448
|
+
task_dir = self.cache_dir / task_name
|
|
449
|
+
|
|
450
|
+
if task_dir.exists():
|
|
451
|
+
import shutil
|
|
452
|
+
|
|
453
|
+
shutil.rmtree(task_dir)
|
|
454
|
+
|
|
455
|
+
if task_name in self._metadata.tasks:
|
|
456
|
+
del self._metadata.tasks[task_name]
|
|
457
|
+
self._save_metadata()
|
|
458
|
+
|
|
459
|
+
logger.info(f"Cleared cache for task '{task_name}'")
|
|
460
|
+
|
|
461
|
+
def _load_metadata(self) -> CacheMetadata:
|
|
462
|
+
"""Load cache metadata."""
|
|
463
|
+
if not self.metadata_file.exists():
|
|
464
|
+
return CacheMetadata(
|
|
465
|
+
version=self.CACHE_VERSION, created_at=datetime.now(), last_cleanup=datetime.now(), tasks={}
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
with open(self.metadata_file) as f:
|
|
470
|
+
data = json.load(f)
|
|
471
|
+
|
|
472
|
+
# Convert datetime strings back to datetime objects
|
|
473
|
+
tasks = {}
|
|
474
|
+
for task_name, task_data in data.get("tasks", {}).items():
|
|
475
|
+
tasks[task_name] = CacheInfo(
|
|
476
|
+
task_name=task_data["task_name"],
|
|
477
|
+
samples_count=task_data["samples_count"],
|
|
478
|
+
last_updated=datetime.fromisoformat(task_data["last_updated"]),
|
|
479
|
+
cache_version=task_data["cache_version"],
|
|
480
|
+
chunks=task_data["chunks"],
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
return CacheMetadata(
|
|
484
|
+
version=data.get("version", self.CACHE_VERSION),
|
|
485
|
+
created_at=datetime.fromisoformat(data["created_at"]),
|
|
486
|
+
last_cleanup=datetime.fromisoformat(data["last_cleanup"]),
|
|
487
|
+
tasks=tasks,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
except Exception as e:
|
|
491
|
+
logger.warning(f"Failed to load cache metadata: {e}")
|
|
492
|
+
return CacheMetadata(
|
|
493
|
+
version=self.CACHE_VERSION, created_at=datetime.now(), last_cleanup=datetime.now(), tasks={}
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def _save_metadata(self):
|
|
497
|
+
"""Save cache metadata."""
|
|
498
|
+
# Convert to serializable format
|
|
499
|
+
tasks_data = {}
|
|
500
|
+
for task_name, cache_info in self._metadata.tasks.items():
|
|
501
|
+
tasks_data[task_name] = {
|
|
502
|
+
"task_name": cache_info.task_name,
|
|
503
|
+
"samples_count": cache_info.samples_count,
|
|
504
|
+
"last_updated": cache_info.last_updated.isoformat(),
|
|
505
|
+
"cache_version": cache_info.cache_version,
|
|
506
|
+
"chunks": cache_info.chunks,
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
data = {
|
|
510
|
+
"version": self._metadata.version,
|
|
511
|
+
"created_at": self._metadata.created_at.isoformat(),
|
|
512
|
+
"last_cleanup": self._metadata.last_cleanup.isoformat(),
|
|
513
|
+
"tasks": tasks_data,
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
with open(self.metadata_file, "w") as f:
|
|
517
|
+
json.dump(data, f, indent=2)
|
|
518
|
+
|
|
519
|
+
def cache_status(self) -> Dict[str, Any]:
|
|
520
|
+
"""Get comprehensive cache status."""
|
|
521
|
+
total_samples = sum(info.samples_count for info in self._metadata.tasks.values())
|
|
522
|
+
total_size = sum(
|
|
523
|
+
sum(
|
|
524
|
+
(self.cache_dir / task_name / chunk).stat().st_size
|
|
525
|
+
for chunk in info.chunks
|
|
526
|
+
if (self.cache_dir / task_name / chunk).exists()
|
|
527
|
+
)
|
|
528
|
+
for task_name, info in self._metadata.tasks.items()
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
return {
|
|
532
|
+
"cache_version": self._metadata.version,
|
|
533
|
+
"total_tasks": len(self._metadata.tasks),
|
|
534
|
+
"total_samples": total_samples,
|
|
535
|
+
"total_size_bytes": total_size,
|
|
536
|
+
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
|
537
|
+
"created_at": self._metadata.created_at.isoformat(),
|
|
538
|
+
"last_cleanup": self._metadata.last_cleanup.isoformat(),
|
|
539
|
+
"tasks": {
|
|
540
|
+
task_name: {
|
|
541
|
+
"samples_count": info.samples_count,
|
|
542
|
+
"last_updated": info.last_updated.isoformat(),
|
|
543
|
+
"chunks": len(info.chunks),
|
|
544
|
+
}
|
|
545
|
+
for task_name, info in self._metadata.tasks.items()
|
|
546
|
+
},
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
def cleanup_cache(self, max_age_days: int = None):
|
|
550
|
+
"""Clean up old cache entries."""
|
|
551
|
+
if max_age_days is None:
|
|
552
|
+
max_age_days = self.MAX_CACHE_AGE_DAYS
|
|
553
|
+
|
|
554
|
+
cutoff_date = datetime.now() - timedelta(days=max_age_days)
|
|
555
|
+
tasks_to_remove = []
|
|
556
|
+
|
|
557
|
+
for task_name, cache_info in self._metadata.tasks.items():
|
|
558
|
+
if cache_info.last_updated < cutoff_date:
|
|
559
|
+
tasks_to_remove.append(task_name)
|
|
560
|
+
|
|
561
|
+
for task_name in tasks_to_remove:
|
|
562
|
+
self._clear_task_cache(task_name)
|
|
563
|
+
|
|
564
|
+
self._metadata.last_cleanup = datetime.now()
|
|
565
|
+
self._save_metadata()
|
|
566
|
+
|
|
567
|
+
logger.info(f"Cleaned up {len(tasks_to_remove)} old cache entries")
|
|
568
|
+
return len(tasks_to_remove)
|
|
569
|
+
|
|
570
|
+
def preload_tasks(self, task_limits: Dict[str, int]):
|
|
571
|
+
"""Preload multiple tasks with specified limits."""
|
|
572
|
+
results = {}
|
|
573
|
+
|
|
574
|
+
for task_name, limit in task_limits.items():
|
|
575
|
+
try:
|
|
576
|
+
samples = self.get_task_samples(task_name, limit)
|
|
577
|
+
results[task_name] = {"status": "success", "samples_loaded": len(samples), "requested_limit": limit}
|
|
578
|
+
logger.info(f"Preloaded {len(samples)} samples for '{task_name}'")
|
|
579
|
+
except Exception as e:
|
|
580
|
+
results[task_name] = {"status": "error", "error": str(e), "requested_limit": limit}
|
|
581
|
+
logger.error(f"Failed to preload '{task_name}': {e}")
|
|
582
|
+
|
|
583
|
+
return results
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
# Global instance
|
|
587
|
+
_managed_cache = None
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def get_managed_cache(cache_dir: str = "./benchmark_cache") -> ManagedCachedBenchmarks:
|
|
591
|
+
"""Get the global managed cache instance."""
|
|
592
|
+
global _managed_cache
|
|
593
|
+
if _managed_cache is None:
|
|
594
|
+
_managed_cache = ManagedCachedBenchmarks(cache_dir)
|
|
595
|
+
return _managed_cache
|