wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,606 @@
1
+ """
2
+ Optuna-based classifier optimization for efficient hyperparameter search.
3
+
4
+ This module provides a modern, efficient optimization system that pre-generates
5
+ activations once and uses intelligent caching to avoid redundant training.
6
+ """
7
+
8
+ import logging
9
+ import time
10
+ from dataclasses import dataclass
11
+ from typing import Any, Optional
12
+
13
+ import numpy as np
14
+ import optuna
15
+ import torch
16
+ from optuna.pruners import MedianPruner
17
+ from optuna.samplers import TPESampler
18
+
19
+ from wisent_guard.core.classifier.classifier import Classifier
20
+ from wisent_guard.core.utils.device import resolve_default_device
21
+
22
+ from .activation_generator import ActivationData, ActivationGenerator, GenerationConfig
23
+ from .classifier_cache import CacheConfig, ClassifierCache
24
+
25
+
26
+ def get_model_dtype(model) -> torch.dtype:
27
+ """
28
+ Extract model's native dtype from parameters.
29
+
30
+ Args:
31
+ model: PyTorch model or wisent_guard Model wrapper
32
+
33
+ Returns:
34
+ The model's native dtype
35
+ """
36
+ # Handle wisent_guard Model wrapper
37
+ if hasattr(model, "hf_model"):
38
+ model_params = model.hf_model.parameters()
39
+ else:
40
+ model_params = model.parameters()
41
+
42
+ try:
43
+ return next(model_params).dtype
44
+ except StopIteration:
45
+ # Fallback if no parameters found
46
+ return torch.float32
47
+
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ @dataclass
53
+ class ClassifierOptimizationConfig:
54
+ """Configuration for Optuna classifier optimization."""
55
+
56
+ # Model configuration
57
+ model_name: str = "Qwen/Qwen3-0.6B"
58
+ device: str = "auto" # "auto", "cuda", "cpu", "mps"
59
+ model_dtype: Optional[torch.dtype] = None # Auto-detect if None
60
+
61
+ # Optuna settings
62
+ n_trials: int = 100
63
+ timeout: Optional[float] = None
64
+ n_jobs: int = 1
65
+ sampler_seed: int = 42
66
+
67
+ # Model type search space
68
+ model_types: list[str] = None
69
+
70
+ # Hyperparameter ranges
71
+ hidden_dim_range: tuple[int, int] = (32, 512)
72
+ threshold_range: tuple[float, float] = (0.3, 0.9)
73
+
74
+ # Training settings
75
+ num_epochs_range: tuple[int, int] = (20, 100)
76
+ learning_rate_range: tuple[float, float] = (1e-4, 1e-2)
77
+ batch_size_options: list[int] = None
78
+
79
+ # Evaluation settings
80
+ cv_folds: int = 3
81
+ test_size: float = 0.2
82
+ random_state: int = 42
83
+
84
+ # Optimization objective
85
+ primary_metric: str = "f1" # "accuracy", "f1", "auc", "precision", "recall"
86
+
87
+ # Pruning settings
88
+ enable_pruning: bool = True
89
+ pruning_patience: int = 10
90
+
91
+ def __post_init__(self):
92
+ if self.model_types is None:
93
+ self.model_types = ["logistic", "mlp"]
94
+ if self.batch_size_options is None:
95
+ self.batch_size_options = [16, 32, 64]
96
+
97
+ # Auto-detect device if needed
98
+ if self.device == "auto":
99
+ self.device = resolve_default_device()
100
+
101
+
102
+ @dataclass
103
+ class OptimizationResult:
104
+ """Result from Optuna optimization."""
105
+
106
+ best_params: dict[str, Any]
107
+ best_value: float
108
+ best_classifier: Classifier
109
+ study: optuna.Study
110
+ trial_results: list[dict[str, Any]]
111
+ optimization_time: float
112
+ cache_hits: int
113
+ cache_misses: int
114
+
115
+ def get_best_config(self) -> dict[str, Any]:
116
+ """Get the best configuration found."""
117
+ if not self.best_params:
118
+ return {
119
+ "model_type": "unknown",
120
+ "layer": -1,
121
+ "aggregation": "unknown",
122
+ "threshold": 0.0,
123
+ "hyperparameters": {},
124
+ }
125
+
126
+ return {
127
+ "model_type": self.best_params["model_type"],
128
+ "layer": self.best_params["layer"],
129
+ "aggregation": self.best_params["aggregation"],
130
+ "threshold": self.best_params["threshold"],
131
+ "hyperparameters": {
132
+ k: v
133
+ for k, v in self.best_params.items()
134
+ if k not in ["model_type", "layer", "aggregation", "threshold"]
135
+ },
136
+ }
137
+
138
+
139
+ class OptunaClassifierOptimizer:
140
+ """
141
+ Optuna-based classifier optimizer with efficient caching and pre-generation.
142
+
143
+ Key features:
144
+ - Pre-generates activations once for all trials
145
+ - Uses intelligent model caching to avoid retraining
146
+ - Supports both logistic and MLP classifiers
147
+ - Multi-objective optimization with pruning
148
+ - Cross-validation for robust evaluation
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ optimization_config: ClassifierOptimizationConfig,
154
+ generation_config: GenerationConfig,
155
+ cache_config: CacheConfig,
156
+ ):
157
+ self.opt_config = optimization_config
158
+ self.gen_config = generation_config
159
+ self.cache_config = cache_config
160
+
161
+ self.activation_generator = ActivationGenerator(generation_config)
162
+ self.classifier_cache = ClassifierCache(cache_config)
163
+
164
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
165
+
166
+ # Statistics tracking
167
+ self.cache_hits = 0
168
+ self.cache_misses = 0
169
+ self.activation_data: dict[str, ActivationData] = {}
170
+
171
+ def optimize(
172
+ self, model, contrastive_pairs: list, task_name: str, model_name: str, limit: int
173
+ ) -> OptimizationResult:
174
+ """
175
+ Run Optuna-based classifier optimization.
176
+
177
+ Args:
178
+ model: Language model
179
+ contrastive_pairs: Training contrastive pairs
180
+ task_name: Name of the task
181
+ model_name: Name of the model
182
+ limit: Data limit used
183
+
184
+ Returns:
185
+ OptimizationResult with best configuration and classifier
186
+ """
187
+ self.logger.info(f"Starting Optuna classifier optimization for {task_name}")
188
+ layer_range = self.gen_config.layer_search_range[1] - self.gen_config.layer_search_range[0] + 1
189
+ self.logger.info(
190
+ f"Configuration: {self.opt_config.n_trials} trials, layers {self.gen_config.layer_search_range[0]}-{self.gen_config.layer_search_range[1]} ({layer_range} layers)"
191
+ )
192
+
193
+ # Detect or use configured model dtype
194
+ detected_dtype = get_model_dtype(model)
195
+ self.model_dtype = self.opt_config.model_dtype if self.opt_config.model_dtype is not None else detected_dtype
196
+ self.logger.info(f"Using model dtype: {self.model_dtype} (detected: {detected_dtype})")
197
+
198
+ start_time = time.time()
199
+
200
+ # Step 1: Pre-generate all activations
201
+ self.logger.info("Pre-generating activations for all layers and aggregation methods...")
202
+ self.activation_data = self.activation_generator.generate_from_contrastive_pairs(
203
+ model=model, contrastive_pairs=contrastive_pairs, task_name=task_name, model_name=model_name, limit=limit
204
+ )
205
+
206
+ if not self.activation_data:
207
+ raise ValueError("No activation data generated - cannot proceed with optimization")
208
+
209
+ self.logger.info(f"Generated {len(self.activation_data)} activation datasets")
210
+
211
+ # Step 2: Set up Optuna study
212
+ sampler = TPESampler(seed=self.opt_config.sampler_seed)
213
+ pruner = (
214
+ MedianPruner(n_startup_trials=5, n_warmup_steps=self.opt_config.pruning_patience)
215
+ if self.opt_config.enable_pruning
216
+ else None
217
+ )
218
+
219
+ study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
220
+
221
+ # Step 3: Run optimization
222
+ self.logger.info("Starting Optuna trials...")
223
+
224
+ def objective(trial):
225
+ return self._objective_function(trial, task_name, model_name)
226
+
227
+ study.optimize(
228
+ objective,
229
+ n_trials=self.opt_config.n_trials,
230
+ timeout=self.opt_config.timeout,
231
+ n_jobs=self.opt_config.n_jobs,
232
+ show_progress_bar=True,
233
+ )
234
+
235
+ # Step 4: Get best results
236
+ completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
237
+
238
+ if not completed_trials:
239
+ self.logger.warning("No trials completed successfully - all trials were pruned or failed")
240
+ # Show trial states for debugging
241
+ trial_states = {}
242
+ for trial in study.trials:
243
+ state = trial.state.name
244
+ trial_states[state] = trial_states.get(state, 0) + 1
245
+ self.logger.warning(f"Trial states: {trial_states}")
246
+
247
+ # Return a dummy result for debugging
248
+ dummy_result = OptimizationResult(
249
+ best_params={},
250
+ best_value=0.0,
251
+ best_classifier=None,
252
+ study=study,
253
+ trial_results=[],
254
+ optimization_time=time.time() - start_time,
255
+ cache_hits=self.cache_hits,
256
+ cache_misses=self.cache_misses,
257
+ )
258
+ return dummy_result
259
+
260
+ best_params = study.best_params
261
+ best_value = study.best_value
262
+
263
+ self.logger.info(f"Best trial: {best_params} -> {self.opt_config.primary_metric}={best_value:.4f}")
264
+
265
+ # Step 5: Train final model with best parameters
266
+ best_classifier = self._train_final_classifier(best_params, task_name, model_name)
267
+
268
+ optimization_time = time.time() - start_time
269
+
270
+ # Step 6: Collect trial results
271
+ trial_results = []
272
+ for trial in study.trials:
273
+ if trial.state == optuna.trial.TrialState.COMPLETE:
274
+ trial_results.append(
275
+ {
276
+ "trial_number": trial.number,
277
+ "params": trial.params,
278
+ "value": trial.value,
279
+ "duration": trial.duration.total_seconds() if trial.duration else None,
280
+ }
281
+ )
282
+
283
+ result = OptimizationResult(
284
+ best_params=best_params,
285
+ best_value=best_value,
286
+ best_classifier=best_classifier,
287
+ study=study,
288
+ trial_results=trial_results,
289
+ optimization_time=optimization_time,
290
+ cache_hits=self.cache_hits,
291
+ cache_misses=self.cache_misses,
292
+ )
293
+
294
+ self.logger.info(
295
+ f"Optimization completed in {optimization_time:.1f}s "
296
+ f"({self.cache_hits} cache hits, {self.cache_misses} cache misses)"
297
+ )
298
+
299
+ return result
300
+
301
+ def _objective_function(self, trial: optuna.Trial, task_name: str, model_name: str) -> float:
302
+ """
303
+ Optuna objective function for a single trial.
304
+
305
+ Args:
306
+ trial: Optuna trial object
307
+ task_name: Task name
308
+ model_name: Model name
309
+
310
+ Returns:
311
+ Objective value to maximize
312
+ """
313
+ # Sample hyperparameters directly (following steering pattern)
314
+ model_type = trial.suggest_categorical("model_type", self.opt_config.model_types)
315
+
316
+ # Layer and aggregation from pre-generated activation data
317
+ available_layers = set()
318
+ available_aggregations = set()
319
+
320
+ for key in self.activation_data.keys():
321
+ parts = key.split("_")
322
+ if len(parts) >= 4: # layer_X_agg_Y
323
+ layer = int(parts[1])
324
+ agg = parts[3]
325
+ available_layers.add(layer)
326
+ available_aggregations.add(agg)
327
+
328
+ layer = trial.suggest_categorical("layer", sorted(available_layers))
329
+ aggregation = trial.suggest_categorical("aggregation", sorted(available_aggregations))
330
+
331
+ # Classification threshold
332
+ threshold = trial.suggest_float(
333
+ "threshold", self.opt_config.threshold_range[0], self.opt_config.threshold_range[1]
334
+ )
335
+
336
+ # Training hyperparameters
337
+ num_epochs = trial.suggest_int(
338
+ "num_epochs", self.opt_config.num_epochs_range[0], self.opt_config.num_epochs_range[1]
339
+ )
340
+
341
+ learning_rate = trial.suggest_float(
342
+ "learning_rate", self.opt_config.learning_rate_range[0], self.opt_config.learning_rate_range[1], log=True
343
+ )
344
+
345
+ batch_size = trial.suggest_categorical("batch_size", self.opt_config.batch_size_options)
346
+
347
+ # Model-specific hyperparameters (conditional logic like steering)
348
+ hyperparams = {"num_epochs": num_epochs, "learning_rate": learning_rate, "batch_size": batch_size}
349
+
350
+ if model_type == "mlp":
351
+ # MLP-specific parameters
352
+ hyperparams["hidden_dim"] = trial.suggest_int(
353
+ "hidden_dim", self.opt_config.hidden_dim_range[0], self.opt_config.hidden_dim_range[1], step=32
354
+ )
355
+
356
+ # Combine all parameters
357
+ params = {
358
+ "model_type": model_type,
359
+ "layer": layer,
360
+ "aggregation": aggregation,
361
+ "threshold": threshold,
362
+ **hyperparams,
363
+ }
364
+
365
+ # Get activation data for this configuration
366
+ activation_key = f"layer_{params['layer']}_agg_{params['aggregation']}"
367
+
368
+ if activation_key not in self.activation_data:
369
+ self.logger.warning(f"No activation data for {activation_key}")
370
+ raise optuna.TrialPruned()
371
+
372
+ activation_data = self.activation_data[activation_key]
373
+ X, y = activation_data.to_tensors(device=self.gen_config.device, dtype=self.model_dtype)
374
+ print(f"DEBUG: Training data shape: X.shape={X.shape}, y.shape={y.shape}, dtype={X.dtype}")
375
+
376
+ # Generate cache key
377
+ data_hash = self.classifier_cache.compute_data_hash(X, y)
378
+ cache_key = self.classifier_cache.get_cache_key(
379
+ model_name=model_name,
380
+ task_name=task_name,
381
+ model_type=params["model_type"],
382
+ layer=params["layer"],
383
+ aggregation=params["aggregation"],
384
+ threshold=params["threshold"],
385
+ hyperparameters={
386
+ k: v for k, v in params.items() if k not in ["model_type", "layer", "aggregation", "threshold"]
387
+ },
388
+ data_hash=data_hash,
389
+ )
390
+
391
+ # Try to load from cache
392
+ cached_classifier = self.classifier_cache.load_classifier(cache_key)
393
+ if cached_classifier is not None:
394
+ self.cache_hits += 1
395
+ # Evaluate cached classifier
396
+ return self._evaluate_classifier(cached_classifier, X, y, params["threshold"])
397
+
398
+ self.cache_misses += 1
399
+
400
+ # Train new classifier
401
+ classifier = self._train_classifier(params, X, y, trial)
402
+
403
+ if classifier is None:
404
+ raise optuna.TrialPruned()
405
+
406
+ # Evaluate classifier
407
+ score = self._evaluate_classifier(classifier, X, y, params["threshold"])
408
+
409
+ # Save to cache if training was successful
410
+ if score > 0:
411
+ try:
412
+ performance_metrics = {self.opt_config.primary_metric: score}
413
+
414
+ self.classifier_cache.save_classifier(
415
+ cache_key=cache_key,
416
+ classifier=classifier,
417
+ model_name=model_name,
418
+ task_name=task_name,
419
+ layer=params["layer"],
420
+ aggregation=params["aggregation"],
421
+ threshold=params["threshold"],
422
+ hyperparameters={
423
+ k: v for k, v in params.items() if k not in ["model_type", "layer", "aggregation", "threshold"]
424
+ },
425
+ performance_metrics=performance_metrics,
426
+ training_samples=len(X),
427
+ data_hash=data_hash,
428
+ )
429
+ except Exception as e:
430
+ self.logger.warning(f"Failed to cache classifier: {e}")
431
+
432
+ return score
433
+
434
+ def _train_classifier(
435
+ self, params: dict[str, Any], X: np.ndarray, y: np.ndarray, trial: Optional[optuna.Trial] = None
436
+ ) -> Optional[Classifier]:
437
+ """
438
+ Train a classifier with the given parameters.
439
+
440
+ Args:
441
+ params: Hyperparameters
442
+ X: Training features
443
+ y: Training labels
444
+ trial: Optuna trial for pruning
445
+
446
+ Returns:
447
+ Trained classifier or None if training failed
448
+ """
449
+ try:
450
+ # Create classifier (don't pass hidden_dim to constructor)
451
+ classifier_kwargs = {
452
+ "model_type": params["model_type"],
453
+ "threshold": params["threshold"],
454
+ "device": self.gen_config.device if self.gen_config.device else "auto",
455
+ "dtype": self.model_dtype,
456
+ }
457
+
458
+ print(
459
+ f"Preparing to train {params['model_type']} classifier with {len(X)} samples (dtype: {self.model_dtype})"
460
+ )
461
+ classifier = Classifier(**classifier_kwargs)
462
+
463
+ # Train classifier
464
+ training_kwargs = {
465
+ "num_epochs": params["num_epochs"],
466
+ "learning_rate": params["learning_rate"],
467
+ "batch_size": params["batch_size"],
468
+ "test_size": self.opt_config.test_size,
469
+ "random_state": self.opt_config.random_state,
470
+ }
471
+
472
+ if params["model_type"] == "mlp":
473
+ training_kwargs["hidden_dim"] = params["hidden_dim"]
474
+
475
+ # Add pruning callback if trial is provided
476
+ if trial and self.opt_config.enable_pruning:
477
+ # TODO: Implement pruning callback for early stopping
478
+ pass
479
+
480
+ print(f"About to fit classifier with kwargs: {training_kwargs}")
481
+ results = classifier.fit(X, y, **training_kwargs)
482
+ print(f"Training results: {results}")
483
+
484
+ accuracy = results.get("accuracy", 0)
485
+ if accuracy <= 0.35: # More permissive threshold - only prune very poor performance
486
+ self.logger.debug(f"Classifier performance too low ({accuracy:.3f}), pruning")
487
+ print(f"Classifier pruned - accuracy too low: {accuracy:.3f}")
488
+ return None
489
+
490
+ self.logger.debug(f"Classifier training successful - accuracy: {accuracy:.3f}")
491
+ print(f"Classifier training successful - accuracy: {accuracy:.3f}")
492
+
493
+ return classifier
494
+
495
+ except Exception as e:
496
+ print(f"EXCEPTION during classifier training: {e}")
497
+ import traceback
498
+
499
+ traceback.print_exc()
500
+ self.logger.debug(f"Training failed with params {params}: {e}")
501
+ return None
502
+
503
+ def _evaluate_classifier(self, classifier: Classifier, X: np.ndarray, y: np.ndarray, threshold: float) -> float:
504
+ """
505
+ Evaluate classifier performance.
506
+
507
+ Args:
508
+ classifier: Trained classifier
509
+ X: Features
510
+ y: Labels
511
+ threshold: Classification threshold
512
+
513
+ Returns:
514
+ Performance score based on primary metric
515
+ """
516
+ try:
517
+ print(f"DEBUG: Evaluation data shape: X.shape={X.shape}, y.shape={y.shape}, dtype={X.dtype}")
518
+
519
+ # Set threshold
520
+ classifier.set_threshold(threshold)
521
+
522
+ # Get predictions
523
+ results = classifier.evaluate(X, y)
524
+ print(f"Evaluation results: {results}")
525
+ print(f"Looking for primary metric '{self.opt_config.primary_metric}' in results")
526
+
527
+ # Return primary metric
528
+ score = results.get(self.opt_config.primary_metric, 0.0)
529
+ print(f"Score extracted: {score}")
530
+ return float(score)
531
+
532
+ except Exception as e:
533
+ print(f"EXCEPTION during evaluation: {e}")
534
+ import traceback
535
+
536
+ traceback.print_exc()
537
+ self.logger.debug(f"Evaluation failed: {e}")
538
+ return 0.0
539
+
540
+ def _train_final_classifier(self, best_params: dict[str, Any], task_name: str, model_name: str) -> Classifier:
541
+ """Train the final classifier with best parameters."""
542
+ # Get activation data
543
+ activation_key = f"layer_{best_params['layer']}_agg_{best_params['aggregation']}"
544
+ activation_data = self.activation_data[activation_key]
545
+ X, y = activation_data.to_tensors(device=self.gen_config.device, dtype=self.model_dtype)
546
+
547
+ # Try cache first
548
+ data_hash = self.classifier_cache.compute_data_hash(X, y)
549
+ cache_key = self.classifier_cache.get_cache_key(
550
+ model_name=model_name,
551
+ task_name=task_name,
552
+ model_type=best_params["model_type"],
553
+ layer=best_params["layer"],
554
+ aggregation=best_params["aggregation"],
555
+ threshold=best_params["threshold"],
556
+ hyperparameters={
557
+ k: v for k, v in best_params.items() if k not in ["model_type", "layer", "aggregation", "threshold"]
558
+ },
559
+ data_hash=data_hash,
560
+ )
561
+
562
+ cached_classifier = self.classifier_cache.load_classifier(cache_key)
563
+ if cached_classifier is not None:
564
+ self.logger.info("Using cached classifier for final model")
565
+ return cached_classifier
566
+
567
+ # Train new classifier
568
+ self.logger.info("Training final classifier with best parameters")
569
+ classifier = self._train_classifier(best_params, X, y)
570
+
571
+ if classifier is None:
572
+ raise ValueError("Failed to train final classifier")
573
+
574
+ return classifier
575
+
576
+ def get_optimization_summary(self, result: OptimizationResult) -> dict[str, Any]:
577
+ """Get a comprehensive optimization summary."""
578
+ return {
579
+ "best_configuration": result.get_best_config(),
580
+ "best_score": result.best_value,
581
+ "optimization_time_seconds": result.optimization_time,
582
+ "total_trials": len(result.trial_results),
583
+ "cache_efficiency": {
584
+ "hits": result.cache_hits,
585
+ "misses": result.cache_misses,
586
+ "hit_rate": result.cache_hits / (result.cache_hits + result.cache_misses)
587
+ if (result.cache_hits + result.cache_misses) > 0
588
+ else 0,
589
+ },
590
+ "activation_data_info": {
591
+ key: {
592
+ "samples": data.activations.shape[0],
593
+ "features": data.activations.shape[1]
594
+ if len(data.activations.shape) > 1
595
+ else data.activations.shape[0],
596
+ "layer": data.layer,
597
+ "aggregation": data.aggregation,
598
+ }
599
+ for key, data in self.activation_data.items()
600
+ },
601
+ "study_info": {
602
+ "n_trials": len(result.study.trials),
603
+ "best_trial": result.study.best_trial.number,
604
+ "pruned_trials": len([t for t in result.study.trials if t.state == optuna.trial.TrialState.PRUNED]),
605
+ },
606
+ }
File without changes