wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,606 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Optuna-based classifier optimization for efficient hyperparameter search.
|
|
3
|
+
|
|
4
|
+
This module provides a modern, efficient optimization system that pre-generates
|
|
5
|
+
activations once and uses intelligent caching to avoid redundant training.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import time
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import optuna
|
|
15
|
+
import torch
|
|
16
|
+
from optuna.pruners import MedianPruner
|
|
17
|
+
from optuna.samplers import TPESampler
|
|
18
|
+
|
|
19
|
+
from wisent_guard.core.classifier.classifier import Classifier
|
|
20
|
+
from wisent_guard.core.utils.device import resolve_default_device
|
|
21
|
+
|
|
22
|
+
from .activation_generator import ActivationData, ActivationGenerator, GenerationConfig
|
|
23
|
+
from .classifier_cache import CacheConfig, ClassifierCache
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_model_dtype(model) -> torch.dtype:
|
|
27
|
+
"""
|
|
28
|
+
Extract model's native dtype from parameters.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: PyTorch model or wisent_guard Model wrapper
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The model's native dtype
|
|
35
|
+
"""
|
|
36
|
+
# Handle wisent_guard Model wrapper
|
|
37
|
+
if hasattr(model, "hf_model"):
|
|
38
|
+
model_params = model.hf_model.parameters()
|
|
39
|
+
else:
|
|
40
|
+
model_params = model.parameters()
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
return next(model_params).dtype
|
|
44
|
+
except StopIteration:
|
|
45
|
+
# Fallback if no parameters found
|
|
46
|
+
return torch.float32
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
logger = logging.getLogger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class ClassifierOptimizationConfig:
|
|
54
|
+
"""Configuration for Optuna classifier optimization."""
|
|
55
|
+
|
|
56
|
+
# Model configuration
|
|
57
|
+
model_name: str = "Qwen/Qwen3-0.6B"
|
|
58
|
+
device: str = "auto" # "auto", "cuda", "cpu", "mps"
|
|
59
|
+
model_dtype: Optional[torch.dtype] = None # Auto-detect if None
|
|
60
|
+
|
|
61
|
+
# Optuna settings
|
|
62
|
+
n_trials: int = 100
|
|
63
|
+
timeout: Optional[float] = None
|
|
64
|
+
n_jobs: int = 1
|
|
65
|
+
sampler_seed: int = 42
|
|
66
|
+
|
|
67
|
+
# Model type search space
|
|
68
|
+
model_types: list[str] = None
|
|
69
|
+
|
|
70
|
+
# Hyperparameter ranges
|
|
71
|
+
hidden_dim_range: tuple[int, int] = (32, 512)
|
|
72
|
+
threshold_range: tuple[float, float] = (0.3, 0.9)
|
|
73
|
+
|
|
74
|
+
# Training settings
|
|
75
|
+
num_epochs_range: tuple[int, int] = (20, 100)
|
|
76
|
+
learning_rate_range: tuple[float, float] = (1e-4, 1e-2)
|
|
77
|
+
batch_size_options: list[int] = None
|
|
78
|
+
|
|
79
|
+
# Evaluation settings
|
|
80
|
+
cv_folds: int = 3
|
|
81
|
+
test_size: float = 0.2
|
|
82
|
+
random_state: int = 42
|
|
83
|
+
|
|
84
|
+
# Optimization objective
|
|
85
|
+
primary_metric: str = "f1" # "accuracy", "f1", "auc", "precision", "recall"
|
|
86
|
+
|
|
87
|
+
# Pruning settings
|
|
88
|
+
enable_pruning: bool = True
|
|
89
|
+
pruning_patience: int = 10
|
|
90
|
+
|
|
91
|
+
def __post_init__(self):
|
|
92
|
+
if self.model_types is None:
|
|
93
|
+
self.model_types = ["logistic", "mlp"]
|
|
94
|
+
if self.batch_size_options is None:
|
|
95
|
+
self.batch_size_options = [16, 32, 64]
|
|
96
|
+
|
|
97
|
+
# Auto-detect device if needed
|
|
98
|
+
if self.device == "auto":
|
|
99
|
+
self.device = resolve_default_device()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass
|
|
103
|
+
class OptimizationResult:
|
|
104
|
+
"""Result from Optuna optimization."""
|
|
105
|
+
|
|
106
|
+
best_params: dict[str, Any]
|
|
107
|
+
best_value: float
|
|
108
|
+
best_classifier: Classifier
|
|
109
|
+
study: optuna.Study
|
|
110
|
+
trial_results: list[dict[str, Any]]
|
|
111
|
+
optimization_time: float
|
|
112
|
+
cache_hits: int
|
|
113
|
+
cache_misses: int
|
|
114
|
+
|
|
115
|
+
def get_best_config(self) -> dict[str, Any]:
|
|
116
|
+
"""Get the best configuration found."""
|
|
117
|
+
if not self.best_params:
|
|
118
|
+
return {
|
|
119
|
+
"model_type": "unknown",
|
|
120
|
+
"layer": -1,
|
|
121
|
+
"aggregation": "unknown",
|
|
122
|
+
"threshold": 0.0,
|
|
123
|
+
"hyperparameters": {},
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
return {
|
|
127
|
+
"model_type": self.best_params["model_type"],
|
|
128
|
+
"layer": self.best_params["layer"],
|
|
129
|
+
"aggregation": self.best_params["aggregation"],
|
|
130
|
+
"threshold": self.best_params["threshold"],
|
|
131
|
+
"hyperparameters": {
|
|
132
|
+
k: v
|
|
133
|
+
for k, v in self.best_params.items()
|
|
134
|
+
if k not in ["model_type", "layer", "aggregation", "threshold"]
|
|
135
|
+
},
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class OptunaClassifierOptimizer:
|
|
140
|
+
"""
|
|
141
|
+
Optuna-based classifier optimizer with efficient caching and pre-generation.
|
|
142
|
+
|
|
143
|
+
Key features:
|
|
144
|
+
- Pre-generates activations once for all trials
|
|
145
|
+
- Uses intelligent model caching to avoid retraining
|
|
146
|
+
- Supports both logistic and MLP classifiers
|
|
147
|
+
- Multi-objective optimization with pruning
|
|
148
|
+
- Cross-validation for robust evaluation
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
optimization_config: ClassifierOptimizationConfig,
|
|
154
|
+
generation_config: GenerationConfig,
|
|
155
|
+
cache_config: CacheConfig,
|
|
156
|
+
):
|
|
157
|
+
self.opt_config = optimization_config
|
|
158
|
+
self.gen_config = generation_config
|
|
159
|
+
self.cache_config = cache_config
|
|
160
|
+
|
|
161
|
+
self.activation_generator = ActivationGenerator(generation_config)
|
|
162
|
+
self.classifier_cache = ClassifierCache(cache_config)
|
|
163
|
+
|
|
164
|
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
165
|
+
|
|
166
|
+
# Statistics tracking
|
|
167
|
+
self.cache_hits = 0
|
|
168
|
+
self.cache_misses = 0
|
|
169
|
+
self.activation_data: dict[str, ActivationData] = {}
|
|
170
|
+
|
|
171
|
+
def optimize(
|
|
172
|
+
self, model, contrastive_pairs: list, task_name: str, model_name: str, limit: int
|
|
173
|
+
) -> OptimizationResult:
|
|
174
|
+
"""
|
|
175
|
+
Run Optuna-based classifier optimization.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
model: Language model
|
|
179
|
+
contrastive_pairs: Training contrastive pairs
|
|
180
|
+
task_name: Name of the task
|
|
181
|
+
model_name: Name of the model
|
|
182
|
+
limit: Data limit used
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
OptimizationResult with best configuration and classifier
|
|
186
|
+
"""
|
|
187
|
+
self.logger.info(f"Starting Optuna classifier optimization for {task_name}")
|
|
188
|
+
layer_range = self.gen_config.layer_search_range[1] - self.gen_config.layer_search_range[0] + 1
|
|
189
|
+
self.logger.info(
|
|
190
|
+
f"Configuration: {self.opt_config.n_trials} trials, layers {self.gen_config.layer_search_range[0]}-{self.gen_config.layer_search_range[1]} ({layer_range} layers)"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Detect or use configured model dtype
|
|
194
|
+
detected_dtype = get_model_dtype(model)
|
|
195
|
+
self.model_dtype = self.opt_config.model_dtype if self.opt_config.model_dtype is not None else detected_dtype
|
|
196
|
+
self.logger.info(f"Using model dtype: {self.model_dtype} (detected: {detected_dtype})")
|
|
197
|
+
|
|
198
|
+
start_time = time.time()
|
|
199
|
+
|
|
200
|
+
# Step 1: Pre-generate all activations
|
|
201
|
+
self.logger.info("Pre-generating activations for all layers and aggregation methods...")
|
|
202
|
+
self.activation_data = self.activation_generator.generate_from_contrastive_pairs(
|
|
203
|
+
model=model, contrastive_pairs=contrastive_pairs, task_name=task_name, model_name=model_name, limit=limit
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if not self.activation_data:
|
|
207
|
+
raise ValueError("No activation data generated - cannot proceed with optimization")
|
|
208
|
+
|
|
209
|
+
self.logger.info(f"Generated {len(self.activation_data)} activation datasets")
|
|
210
|
+
|
|
211
|
+
# Step 2: Set up Optuna study
|
|
212
|
+
sampler = TPESampler(seed=self.opt_config.sampler_seed)
|
|
213
|
+
pruner = (
|
|
214
|
+
MedianPruner(n_startup_trials=5, n_warmup_steps=self.opt_config.pruning_patience)
|
|
215
|
+
if self.opt_config.enable_pruning
|
|
216
|
+
else None
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
|
|
220
|
+
|
|
221
|
+
# Step 3: Run optimization
|
|
222
|
+
self.logger.info("Starting Optuna trials...")
|
|
223
|
+
|
|
224
|
+
def objective(trial):
|
|
225
|
+
return self._objective_function(trial, task_name, model_name)
|
|
226
|
+
|
|
227
|
+
study.optimize(
|
|
228
|
+
objective,
|
|
229
|
+
n_trials=self.opt_config.n_trials,
|
|
230
|
+
timeout=self.opt_config.timeout,
|
|
231
|
+
n_jobs=self.opt_config.n_jobs,
|
|
232
|
+
show_progress_bar=True,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Step 4: Get best results
|
|
236
|
+
completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
|
|
237
|
+
|
|
238
|
+
if not completed_trials:
|
|
239
|
+
self.logger.warning("No trials completed successfully - all trials were pruned or failed")
|
|
240
|
+
# Show trial states for debugging
|
|
241
|
+
trial_states = {}
|
|
242
|
+
for trial in study.trials:
|
|
243
|
+
state = trial.state.name
|
|
244
|
+
trial_states[state] = trial_states.get(state, 0) + 1
|
|
245
|
+
self.logger.warning(f"Trial states: {trial_states}")
|
|
246
|
+
|
|
247
|
+
# Return a dummy result for debugging
|
|
248
|
+
dummy_result = OptimizationResult(
|
|
249
|
+
best_params={},
|
|
250
|
+
best_value=0.0,
|
|
251
|
+
best_classifier=None,
|
|
252
|
+
study=study,
|
|
253
|
+
trial_results=[],
|
|
254
|
+
optimization_time=time.time() - start_time,
|
|
255
|
+
cache_hits=self.cache_hits,
|
|
256
|
+
cache_misses=self.cache_misses,
|
|
257
|
+
)
|
|
258
|
+
return dummy_result
|
|
259
|
+
|
|
260
|
+
best_params = study.best_params
|
|
261
|
+
best_value = study.best_value
|
|
262
|
+
|
|
263
|
+
self.logger.info(f"Best trial: {best_params} -> {self.opt_config.primary_metric}={best_value:.4f}")
|
|
264
|
+
|
|
265
|
+
# Step 5: Train final model with best parameters
|
|
266
|
+
best_classifier = self._train_final_classifier(best_params, task_name, model_name)
|
|
267
|
+
|
|
268
|
+
optimization_time = time.time() - start_time
|
|
269
|
+
|
|
270
|
+
# Step 6: Collect trial results
|
|
271
|
+
trial_results = []
|
|
272
|
+
for trial in study.trials:
|
|
273
|
+
if trial.state == optuna.trial.TrialState.COMPLETE:
|
|
274
|
+
trial_results.append(
|
|
275
|
+
{
|
|
276
|
+
"trial_number": trial.number,
|
|
277
|
+
"params": trial.params,
|
|
278
|
+
"value": trial.value,
|
|
279
|
+
"duration": trial.duration.total_seconds() if trial.duration else None,
|
|
280
|
+
}
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
result = OptimizationResult(
|
|
284
|
+
best_params=best_params,
|
|
285
|
+
best_value=best_value,
|
|
286
|
+
best_classifier=best_classifier,
|
|
287
|
+
study=study,
|
|
288
|
+
trial_results=trial_results,
|
|
289
|
+
optimization_time=optimization_time,
|
|
290
|
+
cache_hits=self.cache_hits,
|
|
291
|
+
cache_misses=self.cache_misses,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
self.logger.info(
|
|
295
|
+
f"Optimization completed in {optimization_time:.1f}s "
|
|
296
|
+
f"({self.cache_hits} cache hits, {self.cache_misses} cache misses)"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return result
|
|
300
|
+
|
|
301
|
+
def _objective_function(self, trial: optuna.Trial, task_name: str, model_name: str) -> float:
|
|
302
|
+
"""
|
|
303
|
+
Optuna objective function for a single trial.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
trial: Optuna trial object
|
|
307
|
+
task_name: Task name
|
|
308
|
+
model_name: Model name
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Objective value to maximize
|
|
312
|
+
"""
|
|
313
|
+
# Sample hyperparameters directly (following steering pattern)
|
|
314
|
+
model_type = trial.suggest_categorical("model_type", self.opt_config.model_types)
|
|
315
|
+
|
|
316
|
+
# Layer and aggregation from pre-generated activation data
|
|
317
|
+
available_layers = set()
|
|
318
|
+
available_aggregations = set()
|
|
319
|
+
|
|
320
|
+
for key in self.activation_data.keys():
|
|
321
|
+
parts = key.split("_")
|
|
322
|
+
if len(parts) >= 4: # layer_X_agg_Y
|
|
323
|
+
layer = int(parts[1])
|
|
324
|
+
agg = parts[3]
|
|
325
|
+
available_layers.add(layer)
|
|
326
|
+
available_aggregations.add(agg)
|
|
327
|
+
|
|
328
|
+
layer = trial.suggest_categorical("layer", sorted(available_layers))
|
|
329
|
+
aggregation = trial.suggest_categorical("aggregation", sorted(available_aggregations))
|
|
330
|
+
|
|
331
|
+
# Classification threshold
|
|
332
|
+
threshold = trial.suggest_float(
|
|
333
|
+
"threshold", self.opt_config.threshold_range[0], self.opt_config.threshold_range[1]
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Training hyperparameters
|
|
337
|
+
num_epochs = trial.suggest_int(
|
|
338
|
+
"num_epochs", self.opt_config.num_epochs_range[0], self.opt_config.num_epochs_range[1]
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
learning_rate = trial.suggest_float(
|
|
342
|
+
"learning_rate", self.opt_config.learning_rate_range[0], self.opt_config.learning_rate_range[1], log=True
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
batch_size = trial.suggest_categorical("batch_size", self.opt_config.batch_size_options)
|
|
346
|
+
|
|
347
|
+
# Model-specific hyperparameters (conditional logic like steering)
|
|
348
|
+
hyperparams = {"num_epochs": num_epochs, "learning_rate": learning_rate, "batch_size": batch_size}
|
|
349
|
+
|
|
350
|
+
if model_type == "mlp":
|
|
351
|
+
# MLP-specific parameters
|
|
352
|
+
hyperparams["hidden_dim"] = trial.suggest_int(
|
|
353
|
+
"hidden_dim", self.opt_config.hidden_dim_range[0], self.opt_config.hidden_dim_range[1], step=32
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Combine all parameters
|
|
357
|
+
params = {
|
|
358
|
+
"model_type": model_type,
|
|
359
|
+
"layer": layer,
|
|
360
|
+
"aggregation": aggregation,
|
|
361
|
+
"threshold": threshold,
|
|
362
|
+
**hyperparams,
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
# Get activation data for this configuration
|
|
366
|
+
activation_key = f"layer_{params['layer']}_agg_{params['aggregation']}"
|
|
367
|
+
|
|
368
|
+
if activation_key not in self.activation_data:
|
|
369
|
+
self.logger.warning(f"No activation data for {activation_key}")
|
|
370
|
+
raise optuna.TrialPruned()
|
|
371
|
+
|
|
372
|
+
activation_data = self.activation_data[activation_key]
|
|
373
|
+
X, y = activation_data.to_tensors(device=self.gen_config.device, dtype=self.model_dtype)
|
|
374
|
+
print(f"DEBUG: Training data shape: X.shape={X.shape}, y.shape={y.shape}, dtype={X.dtype}")
|
|
375
|
+
|
|
376
|
+
# Generate cache key
|
|
377
|
+
data_hash = self.classifier_cache.compute_data_hash(X, y)
|
|
378
|
+
cache_key = self.classifier_cache.get_cache_key(
|
|
379
|
+
model_name=model_name,
|
|
380
|
+
task_name=task_name,
|
|
381
|
+
model_type=params["model_type"],
|
|
382
|
+
layer=params["layer"],
|
|
383
|
+
aggregation=params["aggregation"],
|
|
384
|
+
threshold=params["threshold"],
|
|
385
|
+
hyperparameters={
|
|
386
|
+
k: v for k, v in params.items() if k not in ["model_type", "layer", "aggregation", "threshold"]
|
|
387
|
+
},
|
|
388
|
+
data_hash=data_hash,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Try to load from cache
|
|
392
|
+
cached_classifier = self.classifier_cache.load_classifier(cache_key)
|
|
393
|
+
if cached_classifier is not None:
|
|
394
|
+
self.cache_hits += 1
|
|
395
|
+
# Evaluate cached classifier
|
|
396
|
+
return self._evaluate_classifier(cached_classifier, X, y, params["threshold"])
|
|
397
|
+
|
|
398
|
+
self.cache_misses += 1
|
|
399
|
+
|
|
400
|
+
# Train new classifier
|
|
401
|
+
classifier = self._train_classifier(params, X, y, trial)
|
|
402
|
+
|
|
403
|
+
if classifier is None:
|
|
404
|
+
raise optuna.TrialPruned()
|
|
405
|
+
|
|
406
|
+
# Evaluate classifier
|
|
407
|
+
score = self._evaluate_classifier(classifier, X, y, params["threshold"])
|
|
408
|
+
|
|
409
|
+
# Save to cache if training was successful
|
|
410
|
+
if score > 0:
|
|
411
|
+
try:
|
|
412
|
+
performance_metrics = {self.opt_config.primary_metric: score}
|
|
413
|
+
|
|
414
|
+
self.classifier_cache.save_classifier(
|
|
415
|
+
cache_key=cache_key,
|
|
416
|
+
classifier=classifier,
|
|
417
|
+
model_name=model_name,
|
|
418
|
+
task_name=task_name,
|
|
419
|
+
layer=params["layer"],
|
|
420
|
+
aggregation=params["aggregation"],
|
|
421
|
+
threshold=params["threshold"],
|
|
422
|
+
hyperparameters={
|
|
423
|
+
k: v for k, v in params.items() if k not in ["model_type", "layer", "aggregation", "threshold"]
|
|
424
|
+
},
|
|
425
|
+
performance_metrics=performance_metrics,
|
|
426
|
+
training_samples=len(X),
|
|
427
|
+
data_hash=data_hash,
|
|
428
|
+
)
|
|
429
|
+
except Exception as e:
|
|
430
|
+
self.logger.warning(f"Failed to cache classifier: {e}")
|
|
431
|
+
|
|
432
|
+
return score
|
|
433
|
+
|
|
434
|
+
def _train_classifier(
|
|
435
|
+
self, params: dict[str, Any], X: np.ndarray, y: np.ndarray, trial: Optional[optuna.Trial] = None
|
|
436
|
+
) -> Optional[Classifier]:
|
|
437
|
+
"""
|
|
438
|
+
Train a classifier with the given parameters.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
params: Hyperparameters
|
|
442
|
+
X: Training features
|
|
443
|
+
y: Training labels
|
|
444
|
+
trial: Optuna trial for pruning
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Trained classifier or None if training failed
|
|
448
|
+
"""
|
|
449
|
+
try:
|
|
450
|
+
# Create classifier (don't pass hidden_dim to constructor)
|
|
451
|
+
classifier_kwargs = {
|
|
452
|
+
"model_type": params["model_type"],
|
|
453
|
+
"threshold": params["threshold"],
|
|
454
|
+
"device": self.gen_config.device if self.gen_config.device else "auto",
|
|
455
|
+
"dtype": self.model_dtype,
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
print(
|
|
459
|
+
f"Preparing to train {params['model_type']} classifier with {len(X)} samples (dtype: {self.model_dtype})"
|
|
460
|
+
)
|
|
461
|
+
classifier = Classifier(**classifier_kwargs)
|
|
462
|
+
|
|
463
|
+
# Train classifier
|
|
464
|
+
training_kwargs = {
|
|
465
|
+
"num_epochs": params["num_epochs"],
|
|
466
|
+
"learning_rate": params["learning_rate"],
|
|
467
|
+
"batch_size": params["batch_size"],
|
|
468
|
+
"test_size": self.opt_config.test_size,
|
|
469
|
+
"random_state": self.opt_config.random_state,
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
if params["model_type"] == "mlp":
|
|
473
|
+
training_kwargs["hidden_dim"] = params["hidden_dim"]
|
|
474
|
+
|
|
475
|
+
# Add pruning callback if trial is provided
|
|
476
|
+
if trial and self.opt_config.enable_pruning:
|
|
477
|
+
# TODO: Implement pruning callback for early stopping
|
|
478
|
+
pass
|
|
479
|
+
|
|
480
|
+
print(f"About to fit classifier with kwargs: {training_kwargs}")
|
|
481
|
+
results = classifier.fit(X, y, **training_kwargs)
|
|
482
|
+
print(f"Training results: {results}")
|
|
483
|
+
|
|
484
|
+
accuracy = results.get("accuracy", 0)
|
|
485
|
+
if accuracy <= 0.35: # More permissive threshold - only prune very poor performance
|
|
486
|
+
self.logger.debug(f"Classifier performance too low ({accuracy:.3f}), pruning")
|
|
487
|
+
print(f"Classifier pruned - accuracy too low: {accuracy:.3f}")
|
|
488
|
+
return None
|
|
489
|
+
|
|
490
|
+
self.logger.debug(f"Classifier training successful - accuracy: {accuracy:.3f}")
|
|
491
|
+
print(f"Classifier training successful - accuracy: {accuracy:.3f}")
|
|
492
|
+
|
|
493
|
+
return classifier
|
|
494
|
+
|
|
495
|
+
except Exception as e:
|
|
496
|
+
print(f"EXCEPTION during classifier training: {e}")
|
|
497
|
+
import traceback
|
|
498
|
+
|
|
499
|
+
traceback.print_exc()
|
|
500
|
+
self.logger.debug(f"Training failed with params {params}: {e}")
|
|
501
|
+
return None
|
|
502
|
+
|
|
503
|
+
def _evaluate_classifier(self, classifier: Classifier, X: np.ndarray, y: np.ndarray, threshold: float) -> float:
|
|
504
|
+
"""
|
|
505
|
+
Evaluate classifier performance.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
classifier: Trained classifier
|
|
509
|
+
X: Features
|
|
510
|
+
y: Labels
|
|
511
|
+
threshold: Classification threshold
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
Performance score based on primary metric
|
|
515
|
+
"""
|
|
516
|
+
try:
|
|
517
|
+
print(f"DEBUG: Evaluation data shape: X.shape={X.shape}, y.shape={y.shape}, dtype={X.dtype}")
|
|
518
|
+
|
|
519
|
+
# Set threshold
|
|
520
|
+
classifier.set_threshold(threshold)
|
|
521
|
+
|
|
522
|
+
# Get predictions
|
|
523
|
+
results = classifier.evaluate(X, y)
|
|
524
|
+
print(f"Evaluation results: {results}")
|
|
525
|
+
print(f"Looking for primary metric '{self.opt_config.primary_metric}' in results")
|
|
526
|
+
|
|
527
|
+
# Return primary metric
|
|
528
|
+
score = results.get(self.opt_config.primary_metric, 0.0)
|
|
529
|
+
print(f"Score extracted: {score}")
|
|
530
|
+
return float(score)
|
|
531
|
+
|
|
532
|
+
except Exception as e:
|
|
533
|
+
print(f"EXCEPTION during evaluation: {e}")
|
|
534
|
+
import traceback
|
|
535
|
+
|
|
536
|
+
traceback.print_exc()
|
|
537
|
+
self.logger.debug(f"Evaluation failed: {e}")
|
|
538
|
+
return 0.0
|
|
539
|
+
|
|
540
|
+
def _train_final_classifier(self, best_params: dict[str, Any], task_name: str, model_name: str) -> Classifier:
|
|
541
|
+
"""Train the final classifier with best parameters."""
|
|
542
|
+
# Get activation data
|
|
543
|
+
activation_key = f"layer_{best_params['layer']}_agg_{best_params['aggregation']}"
|
|
544
|
+
activation_data = self.activation_data[activation_key]
|
|
545
|
+
X, y = activation_data.to_tensors(device=self.gen_config.device, dtype=self.model_dtype)
|
|
546
|
+
|
|
547
|
+
# Try cache first
|
|
548
|
+
data_hash = self.classifier_cache.compute_data_hash(X, y)
|
|
549
|
+
cache_key = self.classifier_cache.get_cache_key(
|
|
550
|
+
model_name=model_name,
|
|
551
|
+
task_name=task_name,
|
|
552
|
+
model_type=best_params["model_type"],
|
|
553
|
+
layer=best_params["layer"],
|
|
554
|
+
aggregation=best_params["aggregation"],
|
|
555
|
+
threshold=best_params["threshold"],
|
|
556
|
+
hyperparameters={
|
|
557
|
+
k: v for k, v in best_params.items() if k not in ["model_type", "layer", "aggregation", "threshold"]
|
|
558
|
+
},
|
|
559
|
+
data_hash=data_hash,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
cached_classifier = self.classifier_cache.load_classifier(cache_key)
|
|
563
|
+
if cached_classifier is not None:
|
|
564
|
+
self.logger.info("Using cached classifier for final model")
|
|
565
|
+
return cached_classifier
|
|
566
|
+
|
|
567
|
+
# Train new classifier
|
|
568
|
+
self.logger.info("Training final classifier with best parameters")
|
|
569
|
+
classifier = self._train_classifier(best_params, X, y)
|
|
570
|
+
|
|
571
|
+
if classifier is None:
|
|
572
|
+
raise ValueError("Failed to train final classifier")
|
|
573
|
+
|
|
574
|
+
return classifier
|
|
575
|
+
|
|
576
|
+
def get_optimization_summary(self, result: OptimizationResult) -> dict[str, Any]:
|
|
577
|
+
"""Get a comprehensive optimization summary."""
|
|
578
|
+
return {
|
|
579
|
+
"best_configuration": result.get_best_config(),
|
|
580
|
+
"best_score": result.best_value,
|
|
581
|
+
"optimization_time_seconds": result.optimization_time,
|
|
582
|
+
"total_trials": len(result.trial_results),
|
|
583
|
+
"cache_efficiency": {
|
|
584
|
+
"hits": result.cache_hits,
|
|
585
|
+
"misses": result.cache_misses,
|
|
586
|
+
"hit_rate": result.cache_hits / (result.cache_hits + result.cache_misses)
|
|
587
|
+
if (result.cache_hits + result.cache_misses) > 0
|
|
588
|
+
else 0,
|
|
589
|
+
},
|
|
590
|
+
"activation_data_info": {
|
|
591
|
+
key: {
|
|
592
|
+
"samples": data.activations.shape[0],
|
|
593
|
+
"features": data.activations.shape[1]
|
|
594
|
+
if len(data.activations.shape) > 1
|
|
595
|
+
else data.activations.shape[0],
|
|
596
|
+
"layer": data.layer,
|
|
597
|
+
"aggregation": data.aggregation,
|
|
598
|
+
}
|
|
599
|
+
for key, data in self.activation_data.items()
|
|
600
|
+
},
|
|
601
|
+
"study_info": {
|
|
602
|
+
"n_trials": len(result.study.trials),
|
|
603
|
+
"best_trial": result.study.best_trial.number,
|
|
604
|
+
"pruned_trials": len([t for t in result.study.trials if t.state == optuna.trial.TrialState.PRUNED]),
|
|
605
|
+
},
|
|
606
|
+
}
|
|
File without changes
|