wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mixed Benchmark Sampler for tag-based random sampling across multiple benchmarks.
|
|
3
|
+
|
|
4
|
+
This module enables training and evaluation on random samples from multiple benchmarks
|
|
5
|
+
that share common tags (e.g., 'coding', 'reasoning', 'math').
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import random
|
|
9
|
+
import logging
|
|
10
|
+
from typing import List, Dict, Any, Optional, Set, Tuple
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
|
|
14
|
+
# Suppress BigCode debug output
|
|
15
|
+
import builtins
|
|
16
|
+
_original_print = getattr(builtins, '_original_print', builtins.print)
|
|
17
|
+
|
|
18
|
+
def _quiet_print(*args, **kwargs):
|
|
19
|
+
"""Filter out BigCode debug messages."""
|
|
20
|
+
message = ' '.join(str(arg) for arg in args)
|
|
21
|
+
if any(x in message for x in ['DEBUG', 'Available tasks:', 'ERROR extracting', 'bigcode_eval']):
|
|
22
|
+
return
|
|
23
|
+
_original_print(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
# Store original print and patch
|
|
26
|
+
builtins._original_print = builtins.print
|
|
27
|
+
builtins.print = _quiet_print
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from .lm_harness_integration.only_benchmarks import CORE_BENCHMARKS
|
|
31
|
+
except ImportError:
|
|
32
|
+
# Try alternative import path
|
|
33
|
+
import sys
|
|
34
|
+
import os
|
|
35
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
36
|
+
sys.path.insert(0, os.path.join(current_dir, "lm-harness-integration"))
|
|
37
|
+
from only_benchmarks import CORE_BENCHMARKS
|
|
38
|
+
|
|
39
|
+
from .contrastive_pairs import ContrastivePairSet
|
|
40
|
+
from .managed_cached_benchmarks import ManagedCachedBenchmarks, get_managed_cache
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class BenchmarkSample:
|
|
47
|
+
"""A single sample from a benchmark."""
|
|
48
|
+
benchmark_name: str
|
|
49
|
+
sample_data: Dict[str, Any]
|
|
50
|
+
tags: List[str]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MixedBenchmarkSampler:
|
|
54
|
+
"""
|
|
55
|
+
Samples randomly from multiple benchmarks based on tags.
|
|
56
|
+
|
|
57
|
+
This creates more robust classifiers by training on diverse data
|
|
58
|
+
from multiple sources rather than a single benchmark.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, cache_dir: str = "./benchmark_cache"):
|
|
62
|
+
"""
|
|
63
|
+
Initialize the mixed benchmark sampler.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
cache_dir: Directory for cached benchmark data
|
|
67
|
+
"""
|
|
68
|
+
self.cache_dir = cache_dir
|
|
69
|
+
self.managed_cache = get_managed_cache(cache_dir)
|
|
70
|
+
self._benchmark_registry = self._build_benchmark_registry()
|
|
71
|
+
|
|
72
|
+
def _build_benchmark_registry(self) -> Dict[str, List[str]]:
|
|
73
|
+
"""Build a registry mapping tags to benchmark names."""
|
|
74
|
+
tag_to_benchmarks = defaultdict(list)
|
|
75
|
+
|
|
76
|
+
for benchmark_name, config in CORE_BENCHMARKS.items():
|
|
77
|
+
tags = config.get("tags", [])
|
|
78
|
+
for tag in tags:
|
|
79
|
+
tag_to_benchmarks[tag].append(benchmark_name)
|
|
80
|
+
|
|
81
|
+
return dict(tag_to_benchmarks)
|
|
82
|
+
|
|
83
|
+
def get_benchmarks_by_tag(self, tag: str) -> List[str]:
|
|
84
|
+
"""Get all benchmarks that have a specific tag."""
|
|
85
|
+
return self._benchmark_registry.get(tag, [])
|
|
86
|
+
|
|
87
|
+
def get_benchmarks_by_tags(self, tags: List[str], mode: str = "any") -> List[str]:
|
|
88
|
+
"""
|
|
89
|
+
Get benchmarks that match the given tags.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
tags: List of tags to match
|
|
93
|
+
mode: "any" (benchmark has at least one tag) or "all" (benchmark has all tags)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
List of benchmark names matching the criteria
|
|
97
|
+
"""
|
|
98
|
+
if mode == "any":
|
|
99
|
+
# Get benchmarks that have ANY of the specified tags
|
|
100
|
+
matching_benchmarks = set()
|
|
101
|
+
for tag in tags:
|
|
102
|
+
matching_benchmarks.update(self.get_benchmarks_by_tag(tag))
|
|
103
|
+
return list(matching_benchmarks)
|
|
104
|
+
|
|
105
|
+
elif mode == "all":
|
|
106
|
+
# Get benchmarks that have ALL of the specified tags
|
|
107
|
+
if not tags:
|
|
108
|
+
return []
|
|
109
|
+
|
|
110
|
+
# Start with benchmarks that have the first tag
|
|
111
|
+
matching_benchmarks = set(self.get_benchmarks_by_tag(tags[0]))
|
|
112
|
+
|
|
113
|
+
# Intersect with benchmarks for each additional tag
|
|
114
|
+
for tag in tags[1:]:
|
|
115
|
+
matching_benchmarks &= set(self.get_benchmarks_by_tag(tag))
|
|
116
|
+
|
|
117
|
+
return list(matching_benchmarks)
|
|
118
|
+
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError(f"Invalid mode: {mode}. Use 'any' or 'all'")
|
|
121
|
+
|
|
122
|
+
def sample_mixed_dataset(
|
|
123
|
+
self,
|
|
124
|
+
tags: List[str],
|
|
125
|
+
total_samples: int,
|
|
126
|
+
split_ratio: float = 0.8,
|
|
127
|
+
random_seed: Optional[int] = None,
|
|
128
|
+
tag_mode: str = "any",
|
|
129
|
+
benchmark_weights: Optional[Dict[str, float]] = None
|
|
130
|
+
) -> Tuple[List[BenchmarkSample], List[BenchmarkSample]]:
|
|
131
|
+
"""
|
|
132
|
+
Sample a mixed dataset from benchmarks matching the given tags.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
tags: Tags to filter benchmarks (e.g., ["coding", "python"])
|
|
136
|
+
total_samples: Total number of samples to collect
|
|
137
|
+
split_ratio: Train/test split ratio
|
|
138
|
+
random_seed: Random seed for reproducibility
|
|
139
|
+
tag_mode: "any" or "all" for tag matching
|
|
140
|
+
benchmark_weights: Optional weights for sampling probability per benchmark
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Tuple of (train_samples, test_samples)
|
|
144
|
+
"""
|
|
145
|
+
if random_seed is not None:
|
|
146
|
+
random.seed(random_seed)
|
|
147
|
+
|
|
148
|
+
# Get matching benchmarks
|
|
149
|
+
matching_benchmarks = self.get_benchmarks_by_tags(tags, mode=tag_mode)
|
|
150
|
+
|
|
151
|
+
if not matching_benchmarks:
|
|
152
|
+
raise ValueError(f"No benchmarks found with tags {tags} (mode={tag_mode})")
|
|
153
|
+
|
|
154
|
+
logger.info(f"Found {len(matching_benchmarks)} benchmarks matching tags {tags}")
|
|
155
|
+
logger.info(f"Matching benchmarks: {matching_benchmarks[:10]}...") # Show first 10
|
|
156
|
+
|
|
157
|
+
# Collect all available samples from matching benchmarks
|
|
158
|
+
all_samples = []
|
|
159
|
+
benchmark_sample_counts = {}
|
|
160
|
+
|
|
161
|
+
# Skip benchmarks that require code execution permission
|
|
162
|
+
code_execution_benchmarks = {"apps", "ds1000", "mercury"}
|
|
163
|
+
|
|
164
|
+
for benchmark_name in matching_benchmarks:
|
|
165
|
+
# Skip benchmarks that require code execution for safety
|
|
166
|
+
if benchmark_name in code_execution_benchmarks:
|
|
167
|
+
logger.info(f"Skipping {benchmark_name} (requires code execution permission)")
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
# Get samples from this benchmark
|
|
172
|
+
samples_per_benchmark = max(10, total_samples // len(matching_benchmarks))
|
|
173
|
+
|
|
174
|
+
cached_samples = self.managed_cache.get_task_samples(
|
|
175
|
+
task_name=benchmark_name,
|
|
176
|
+
limit=samples_per_benchmark,
|
|
177
|
+
force_fresh=False
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Convert to BenchmarkSample objects
|
|
181
|
+
for sample in cached_samples:
|
|
182
|
+
benchmark_sample = BenchmarkSample(
|
|
183
|
+
benchmark_name=benchmark_name,
|
|
184
|
+
sample_data=sample,
|
|
185
|
+
tags=CORE_BENCHMARKS[benchmark_name].get("tags", [])
|
|
186
|
+
)
|
|
187
|
+
all_samples.append(benchmark_sample)
|
|
188
|
+
|
|
189
|
+
benchmark_sample_counts[benchmark_name] = len(cached_samples)
|
|
190
|
+
|
|
191
|
+
except Exception as e:
|
|
192
|
+
logger.warning(f"Failed to load samples from {benchmark_name}: {e}")
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
if not all_samples:
|
|
196
|
+
raise ValueError(f"No samples could be loaded from any benchmark with tags {tags}")
|
|
197
|
+
|
|
198
|
+
logger.info(f"Collected {len(all_samples)} total samples from {len(benchmark_sample_counts)} benchmarks")
|
|
199
|
+
for benchmark, count in benchmark_sample_counts.items():
|
|
200
|
+
logger.debug(f" {benchmark}: {count} samples")
|
|
201
|
+
|
|
202
|
+
# Apply benchmark weights if provided
|
|
203
|
+
if benchmark_weights:
|
|
204
|
+
weighted_samples = []
|
|
205
|
+
for sample in all_samples:
|
|
206
|
+
weight = benchmark_weights.get(sample.benchmark_name, 1.0)
|
|
207
|
+
# Duplicate samples based on weight (simple approach)
|
|
208
|
+
weighted_samples.extend([sample] * int(weight))
|
|
209
|
+
all_samples = weighted_samples
|
|
210
|
+
|
|
211
|
+
# Randomly sample and shuffle
|
|
212
|
+
if len(all_samples) > total_samples:
|
|
213
|
+
all_samples = random.sample(all_samples, total_samples)
|
|
214
|
+
else:
|
|
215
|
+
# If we have fewer samples than requested, use all and log warning
|
|
216
|
+
logger.warning(f"Only {len(all_samples)} samples available, requested {total_samples}")
|
|
217
|
+
|
|
218
|
+
random.shuffle(all_samples)
|
|
219
|
+
|
|
220
|
+
# Split into train/test
|
|
221
|
+
split_point = int(len(all_samples) * split_ratio)
|
|
222
|
+
train_samples = all_samples[:split_point]
|
|
223
|
+
test_samples = all_samples[split_point:]
|
|
224
|
+
|
|
225
|
+
# Log distribution
|
|
226
|
+
train_dist = defaultdict(int)
|
|
227
|
+
test_dist = defaultdict(int)
|
|
228
|
+
|
|
229
|
+
for sample in train_samples:
|
|
230
|
+
train_dist[sample.benchmark_name] += 1
|
|
231
|
+
|
|
232
|
+
for sample in test_samples:
|
|
233
|
+
test_dist[sample.benchmark_name] += 1
|
|
234
|
+
|
|
235
|
+
logger.info(f"Train set: {len(train_samples)} samples from {len(train_dist)} benchmarks")
|
|
236
|
+
logger.info(f"Test set: {len(test_samples)} samples from {len(test_dist)} benchmarks")
|
|
237
|
+
|
|
238
|
+
return train_samples, test_samples
|
|
239
|
+
|
|
240
|
+
def extract_contrastive_pairs_from_mixed_samples(
|
|
241
|
+
self,
|
|
242
|
+
samples: List[BenchmarkSample]
|
|
243
|
+
) -> List[Dict[str, Any]]:
|
|
244
|
+
"""
|
|
245
|
+
Extract contrastive pairs from mixed benchmark samples.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
samples: List of BenchmarkSample objects
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
List of contrastive pairs with question, correct_answer, incorrect_answer
|
|
252
|
+
"""
|
|
253
|
+
contrastive_pairs = []
|
|
254
|
+
|
|
255
|
+
for sample in samples:
|
|
256
|
+
try:
|
|
257
|
+
# Each sample already has normalized QA pair from managed cache
|
|
258
|
+
qa_pair = sample.sample_data.get("normalized", {})
|
|
259
|
+
|
|
260
|
+
if qa_pair and all(k in qa_pair for k in ["question", "correct_answer", "incorrect_answer"]):
|
|
261
|
+
# Add benchmark source info
|
|
262
|
+
qa_pair["source_benchmark"] = sample.benchmark_name
|
|
263
|
+
qa_pair["tags"] = sample.tags
|
|
264
|
+
contrastive_pairs.append(qa_pair)
|
|
265
|
+
else:
|
|
266
|
+
logger.warning(f"Invalid QA pair from {sample.benchmark_name}")
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
logger.warning(f"Failed to extract pair from {sample.benchmark_name}: {e}")
|
|
270
|
+
continue
|
|
271
|
+
|
|
272
|
+
logger.info(f"Extracted {len(contrastive_pairs)} contrastive pairs from mixed samples")
|
|
273
|
+
|
|
274
|
+
return contrastive_pairs
|
|
275
|
+
|
|
276
|
+
def create_mixed_contrastive_pair_set(
|
|
277
|
+
self,
|
|
278
|
+
tags: List[str],
|
|
279
|
+
total_samples: int,
|
|
280
|
+
name: Optional[str] = None,
|
|
281
|
+
**kwargs
|
|
282
|
+
) -> ContrastivePairSet:
|
|
283
|
+
"""
|
|
284
|
+
Create a ContrastivePairSet from mixed benchmark samples.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
tags: Tags to filter benchmarks
|
|
288
|
+
total_samples: Number of samples to include
|
|
289
|
+
name: Name for the pair set (auto-generated if not provided)
|
|
290
|
+
**kwargs: Additional arguments for sample_mixed_dataset
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
ContrastivePairSet ready for training
|
|
294
|
+
"""
|
|
295
|
+
# Sample mixed dataset
|
|
296
|
+
train_samples, test_samples = self.sample_mixed_dataset(
|
|
297
|
+
tags=tags,
|
|
298
|
+
total_samples=total_samples,
|
|
299
|
+
**kwargs
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Extract contrastive pairs
|
|
303
|
+
all_samples = train_samples + test_samples
|
|
304
|
+
contrastive_pairs = self.extract_contrastive_pairs_from_mixed_samples(all_samples)
|
|
305
|
+
|
|
306
|
+
# Create name if not provided
|
|
307
|
+
if name is None:
|
|
308
|
+
name = f"mixed_{'_'.join(tags)}_{total_samples}_samples"
|
|
309
|
+
|
|
310
|
+
# Create ContrastivePairSet
|
|
311
|
+
return ContrastivePairSet.from_contrastive_pairs(
|
|
312
|
+
name=name,
|
|
313
|
+
contrastive_pairs=contrastive_pairs,
|
|
314
|
+
task_type="mixed_benchmark"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def sample_benchmarks_by_tag(
|
|
319
|
+
tag: str,
|
|
320
|
+
samples_per_benchmark: int = 10,
|
|
321
|
+
max_benchmarks: Optional[int] = None,
|
|
322
|
+
random_seed: Optional[int] = None
|
|
323
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
324
|
+
"""
|
|
325
|
+
Convenience function to sample from all benchmarks with a specific tag.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
tag: Tag to filter benchmarks (e.g., "coding")
|
|
329
|
+
samples_per_benchmark: Number of samples from each benchmark
|
|
330
|
+
max_benchmarks: Maximum number of benchmarks to sample from
|
|
331
|
+
random_seed: Random seed for reproducibility
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Dictionary mapping benchmark names to their samples
|
|
335
|
+
"""
|
|
336
|
+
sampler = MixedBenchmarkSampler()
|
|
337
|
+
|
|
338
|
+
# Get all benchmarks with the tag
|
|
339
|
+
benchmarks = sampler.get_benchmarks_by_tag(tag)
|
|
340
|
+
|
|
341
|
+
if max_benchmarks and len(benchmarks) > max_benchmarks:
|
|
342
|
+
if random_seed is not None:
|
|
343
|
+
random.seed(random_seed)
|
|
344
|
+
benchmarks = random.sample(benchmarks, max_benchmarks)
|
|
345
|
+
|
|
346
|
+
# Sample from each benchmark
|
|
347
|
+
results = {}
|
|
348
|
+
cache = get_managed_cache()
|
|
349
|
+
|
|
350
|
+
for benchmark_name in benchmarks:
|
|
351
|
+
try:
|
|
352
|
+
samples = cache.get_task_samples(
|
|
353
|
+
task_name=benchmark_name,
|
|
354
|
+
limit=samples_per_benchmark,
|
|
355
|
+
force_fresh=False
|
|
356
|
+
)
|
|
357
|
+
results[benchmark_name] = samples
|
|
358
|
+
logger.info(f"Sampled {len(samples)} from {benchmark_name}")
|
|
359
|
+
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.warning(f"Failed to sample from {benchmark_name}: {e}")
|
|
362
|
+
continue
|
|
363
|
+
|
|
364
|
+
return results
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Configuration Manager for storing and retrieving optimal parameters per model.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Dict, Any, Optional, List
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
import hashlib
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NumpyEncoder(json.JSONEncoder):
|
|
15
|
+
"""Custom JSON encoder to handle numpy types."""
|
|
16
|
+
def default(self, obj):
|
|
17
|
+
if isinstance(obj, (np.integer, np.int64)):
|
|
18
|
+
return int(obj)
|
|
19
|
+
if isinstance(obj, (np.floating, np.float64)):
|
|
20
|
+
return float(obj)
|
|
21
|
+
if isinstance(obj, np.ndarray):
|
|
22
|
+
return obj.tolist()
|
|
23
|
+
return super().default(obj)
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelConfigManager:
|
|
29
|
+
"""Manages model-specific configuration files for optimal parameters."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, config_dir: Optional[str] = None):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the ModelConfigManager.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config_dir: Directory to store config files. If None, uses default location.
|
|
37
|
+
"""
|
|
38
|
+
if config_dir is None:
|
|
39
|
+
# Use ~/.wisent-guard/model_configs/ as default
|
|
40
|
+
home_dir = os.path.expanduser("~")
|
|
41
|
+
self.config_dir = os.path.join(home_dir, ".wisent-guard", "model_configs")
|
|
42
|
+
else:
|
|
43
|
+
self.config_dir = config_dir
|
|
44
|
+
|
|
45
|
+
# Create directory if it doesn't exist
|
|
46
|
+
os.makedirs(self.config_dir, exist_ok=True)
|
|
47
|
+
|
|
48
|
+
def _sanitize_model_name(self, model_name: str) -> str:
|
|
49
|
+
"""
|
|
50
|
+
Convert model name to a safe filename.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model_name: Original model name (e.g., "meta-llama/Llama-3.1-8B-Instruct")
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Sanitized filename (e.g., "meta-llama_Llama-3.1-8B-Instruct")
|
|
57
|
+
"""
|
|
58
|
+
# Replace problematic characters
|
|
59
|
+
sanitized = model_name.replace("/", "_").replace("\\", "_").replace(":", "_")
|
|
60
|
+
# Remove any other problematic characters
|
|
61
|
+
sanitized = "".join(c for c in sanitized if c.isalnum() or c in "._-")
|
|
62
|
+
return sanitized
|
|
63
|
+
|
|
64
|
+
def _get_config_path(self, model_name: str) -> str:
|
|
65
|
+
"""Get the full path to the config file for a model."""
|
|
66
|
+
sanitized_name = self._sanitize_model_name(model_name)
|
|
67
|
+
return os.path.join(self.config_dir, f"{sanitized_name}.json")
|
|
68
|
+
|
|
69
|
+
def save_model_config(
|
|
70
|
+
self,
|
|
71
|
+
model_name: str,
|
|
72
|
+
classification_layer: int,
|
|
73
|
+
steering_layer: Optional[int] = None,
|
|
74
|
+
token_aggregation: str = "average",
|
|
75
|
+
detection_threshold: float = 0.6,
|
|
76
|
+
optimization_method: str = "manual",
|
|
77
|
+
optimization_metrics: Optional[Dict[str, Any]] = None,
|
|
78
|
+
task_specific_overrides: Optional[Dict[str, Dict[str, Any]]] = None
|
|
79
|
+
) -> str:
|
|
80
|
+
"""
|
|
81
|
+
Save optimal parameters for a model.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
model_name: Name/path of the model
|
|
85
|
+
classification_layer: Optimal layer for classification
|
|
86
|
+
steering_layer: Optimal layer for steering (defaults to classification_layer)
|
|
87
|
+
token_aggregation: Token aggregation method
|
|
88
|
+
detection_threshold: Detection threshold
|
|
89
|
+
optimization_method: How these parameters were determined
|
|
90
|
+
optimization_metrics: Metrics from optimization process
|
|
91
|
+
task_specific_overrides: Task-specific parameter overrides
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Path to the saved config file
|
|
95
|
+
"""
|
|
96
|
+
if steering_layer is None:
|
|
97
|
+
steering_layer = classification_layer
|
|
98
|
+
|
|
99
|
+
config_data = {
|
|
100
|
+
"model_name": model_name,
|
|
101
|
+
"created_date": datetime.now().isoformat(),
|
|
102
|
+
"optimization_method": optimization_method,
|
|
103
|
+
"optimal_parameters": {
|
|
104
|
+
"classification_layer": classification_layer,
|
|
105
|
+
"steering_layer": steering_layer,
|
|
106
|
+
"token_aggregation": token_aggregation,
|
|
107
|
+
"detection_threshold": detection_threshold
|
|
108
|
+
},
|
|
109
|
+
"task_specific_overrides": task_specific_overrides or {},
|
|
110
|
+
"optimization_metrics": optimization_metrics or {},
|
|
111
|
+
"config_version": "1.0"
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
config_path = self._get_config_path(model_name)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
with open(config_path, 'w') as f:
|
|
118
|
+
json.dump(config_data, f, indent=2, cls=NumpyEncoder)
|
|
119
|
+
|
|
120
|
+
logger.info(f"✅ Model configuration saved: {config_path}")
|
|
121
|
+
logger.info(f" • Classification layer: {classification_layer}")
|
|
122
|
+
logger.info(f" • Steering layer: {steering_layer}")
|
|
123
|
+
logger.info(f" • Token aggregation: {token_aggregation}")
|
|
124
|
+
logger.info(f" • Detection threshold: {detection_threshold}")
|
|
125
|
+
|
|
126
|
+
return config_path
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"❌ Failed to save model configuration: {e}")
|
|
130
|
+
raise
|
|
131
|
+
|
|
132
|
+
def load_model_config(self, model_name: str) -> Optional[Dict[str, Any]]:
|
|
133
|
+
"""
|
|
134
|
+
Load optimal parameters for a model.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
model_name: Name/path of the model
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Configuration dictionary if found, None otherwise
|
|
141
|
+
"""
|
|
142
|
+
config_path = self._get_config_path(model_name)
|
|
143
|
+
|
|
144
|
+
if not os.path.exists(config_path):
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
with open(config_path, 'r') as f:
|
|
149
|
+
config_data = json.load(f)
|
|
150
|
+
|
|
151
|
+
logger.debug(f"📄 Loaded model configuration: {config_path}")
|
|
152
|
+
return config_data
|
|
153
|
+
|
|
154
|
+
except Exception as e:
|
|
155
|
+
logger.warning(f"⚠️ Failed to load model configuration: {e}")
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
def has_model_config(self, model_name: str) -> bool:
|
|
159
|
+
"""Check if a model has a saved configuration."""
|
|
160
|
+
config_path = self._get_config_path(model_name)
|
|
161
|
+
return os.path.exists(config_path)
|
|
162
|
+
|
|
163
|
+
def update_model_config(self, model_name: str, config_data: Dict[str, Any]) -> str:
|
|
164
|
+
"""
|
|
165
|
+
Update an existing model configuration.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
model_name: Name/path of the model
|
|
169
|
+
config_data: Updated configuration dictionary
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Path to the saved config file
|
|
173
|
+
"""
|
|
174
|
+
config_path = self._get_config_path(model_name)
|
|
175
|
+
|
|
176
|
+
# Update timestamp
|
|
177
|
+
config_data["updated_date"] = datetime.now().isoformat()
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
with open(config_path, 'w') as f:
|
|
181
|
+
json.dump(config_data, f, indent=2, cls=NumpyEncoder)
|
|
182
|
+
|
|
183
|
+
logger.info(f"✅ Model configuration updated: {config_path}")
|
|
184
|
+
return config_path
|
|
185
|
+
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.error(f"❌ Failed to update model configuration: {e}")
|
|
188
|
+
raise
|
|
189
|
+
|
|
190
|
+
def get_optimal_parameters(
|
|
191
|
+
self,
|
|
192
|
+
model_name: str,
|
|
193
|
+
task_name: Optional[str] = None
|
|
194
|
+
) -> Optional[Dict[str, Any]]:
|
|
195
|
+
"""
|
|
196
|
+
Get optimal parameters for a model, with optional task-specific overrides.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
model_name: Name/path of the model
|
|
200
|
+
task_name: Specific task name for overrides
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
Dictionary of optimal parameters or None if no config exists
|
|
204
|
+
"""
|
|
205
|
+
config = self.load_model_config(model_name)
|
|
206
|
+
if not config:
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
# Start with base optimal parameters
|
|
210
|
+
optimal_params = config.get("optimal_parameters", {}).copy()
|
|
211
|
+
|
|
212
|
+
# Apply task-specific overrides if available
|
|
213
|
+
if task_name and "task_specific_overrides" in config:
|
|
214
|
+
task_overrides = config["task_specific_overrides"].get(task_name, {})
|
|
215
|
+
optimal_params.update(task_overrides)
|
|
216
|
+
|
|
217
|
+
return optimal_params
|
|
218
|
+
|
|
219
|
+
def get_optimal_sample_size(
|
|
220
|
+
self,
|
|
221
|
+
model_name: str,
|
|
222
|
+
task_name: str,
|
|
223
|
+
layer: int
|
|
224
|
+
) -> Optional[int]:
|
|
225
|
+
"""
|
|
226
|
+
Get optimal sample size for a specific task and layer.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
model_name: Name/path of the model
|
|
230
|
+
task_name: Task name
|
|
231
|
+
layer: Layer index
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Optimal sample size or None if not found
|
|
235
|
+
"""
|
|
236
|
+
config = self.load_model_config(model_name)
|
|
237
|
+
if not config:
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
# Check if optimal_sample_sizes exists
|
|
241
|
+
if "optimal_sample_sizes" not in config:
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
# Navigate the nested structure: optimal_sample_sizes[task][layer]
|
|
245
|
+
task_sizes = config["optimal_sample_sizes"].get(task_name, {})
|
|
246
|
+
sample_size = task_sizes.get(str(layer), None)
|
|
247
|
+
|
|
248
|
+
return sample_size
|
|
249
|
+
|
|
250
|
+
def list_model_configs(self) -> List[Dict[str, Any]]:
|
|
251
|
+
"""
|
|
252
|
+
List all available model configurations.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
List of configuration summaries
|
|
256
|
+
"""
|
|
257
|
+
configs = []
|
|
258
|
+
|
|
259
|
+
if not os.path.exists(self.config_dir):
|
|
260
|
+
return configs
|
|
261
|
+
|
|
262
|
+
for filename in os.listdir(self.config_dir):
|
|
263
|
+
if filename.endswith('.json'):
|
|
264
|
+
try:
|
|
265
|
+
config_path = os.path.join(self.config_dir, filename)
|
|
266
|
+
with open(config_path, 'r') as f:
|
|
267
|
+
config_data = json.load(f)
|
|
268
|
+
|
|
269
|
+
summary = {
|
|
270
|
+
"model_name": config_data.get("model_name", "unknown"),
|
|
271
|
+
"created_date": config_data.get("created_date", "unknown"),
|
|
272
|
+
"optimization_method": config_data.get("optimization_method", "unknown"),
|
|
273
|
+
"classification_layer": config_data.get("optimal_parameters", {}).get("classification_layer"),
|
|
274
|
+
"steering_layer": config_data.get("optimal_parameters", {}).get("steering_layer"),
|
|
275
|
+
"config_file": filename
|
|
276
|
+
}
|
|
277
|
+
configs.append(summary)
|
|
278
|
+
|
|
279
|
+
except Exception as e:
|
|
280
|
+
logger.warning(f"⚠️ Failed to read config file {filename}: {e}")
|
|
281
|
+
|
|
282
|
+
return configs
|
|
283
|
+
|
|
284
|
+
def remove_model_config(self, model_name: str) -> bool:
|
|
285
|
+
"""
|
|
286
|
+
Remove a model configuration.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
model_name: Name/path of the model
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
True if removed successfully, False otherwise
|
|
293
|
+
"""
|
|
294
|
+
config_path = self._get_config_path(model_name)
|
|
295
|
+
|
|
296
|
+
if not os.path.exists(config_path):
|
|
297
|
+
logger.warning(f"⚠️ No configuration found for model: {model_name}")
|
|
298
|
+
return False
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
os.remove(config_path)
|
|
302
|
+
logger.info(f"✅ Removed model configuration: {config_path}")
|
|
303
|
+
return True
|
|
304
|
+
|
|
305
|
+
except Exception as e:
|
|
306
|
+
logger.error(f"❌ Failed to remove model configuration: {e}")
|
|
307
|
+
return False
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# Convenience functions for easy access
|
|
311
|
+
_default_manager = None
|
|
312
|
+
|
|
313
|
+
def get_default_manager() -> ModelConfigManager:
|
|
314
|
+
"""Get the default ModelConfigManager instance."""
|
|
315
|
+
global _default_manager
|
|
316
|
+
if _default_manager is None:
|
|
317
|
+
_default_manager = ModelConfigManager()
|
|
318
|
+
return _default_manager
|
|
319
|
+
|
|
320
|
+
def save_model_config(model_name: str, **kwargs) -> str:
|
|
321
|
+
"""Save model configuration using default manager."""
|
|
322
|
+
return get_default_manager().save_model_config(model_name, **kwargs)
|
|
323
|
+
|
|
324
|
+
def load_model_config(model_name: str) -> Optional[Dict[str, Any]]:
|
|
325
|
+
"""Load model configuration using default manager."""
|
|
326
|
+
return get_default_manager().load_model_config(model_name)
|
|
327
|
+
|
|
328
|
+
def get_optimal_parameters(model_name: str, task_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
329
|
+
"""Get optimal parameters using default manager."""
|
|
330
|
+
return get_default_manager().get_optimal_parameters(model_name, task_name)
|