wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1738 @@
1
+ """
2
+ Dataset-Agnostic Optimization Pipeline with Optuna
3
+
4
+ This script builds a reproducible pipeline that:
5
+ 1. Trains probes and learns steering vectors on the training split
6
+ 2. Selects the best layer, probe type, steering method, and hyperparameters on validation split via Optuna
7
+ 3. Evaluates once on the test split with the single best configuration determined on validation
8
+
9
+ Key features:
10
+ - Optuna-based hyperparameter optimization with pruners
11
+ - Activation caching for efficiency
12
+ - Configurable datasets for train/val/test splits
13
+ - Steering evaluation with model re-forwarding
14
+ - Reproducibility bundle generation
15
+ """
16
+
17
+ import hashlib
18
+ import json
19
+ import logging
20
+ import os
21
+ import pickle
22
+ from dataclasses import asdict, dataclass, field
23
+ from datetime import datetime
24
+ from pathlib import Path
25
+ from typing import Any, Optional
26
+
27
+ import numpy as np
28
+ import optuna
29
+ import torch
30
+ from optuna.pruners import MedianPruner, SuccessiveHalvingPruner
31
+ from optuna.samplers import TPESampler
32
+ from safetensors.torch import save_file as safetensors_save
33
+ from tqdm import tqdm
34
+
35
+ # Optional WandB integration
36
+ try:
37
+ import wandb
38
+
39
+ WANDB_AVAILABLE = True
40
+ except ImportError:
41
+ WANDB_AVAILABLE = False
42
+ from wisent.core.contrastive_pairs.contrastive_pair import ContrastivePair
43
+ from wisent.core.contrastive_pairs.contrastive_pair_set import ContrastivePairSet
44
+ from wisent.core.optuna.steering import data_utils, metrics
45
+ from wisent.core.response import Response
46
+ from wisent.core.steering_methods.dac import DAC
47
+ from wisent.core.task_interface import get_task
48
+ from wisent.core.utils.device import empty_device_cache, preferred_dtype, resolve_default_device, resolve_device
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ @dataclass
54
+ class OptimizationConfig:
55
+ """Configuration for dataset-agnostic optimization pipeline."""
56
+
57
+ model_name: str = "realtreetune/rho-1b-sft-GSM8K"
58
+ device: str = field(default_factory=resolve_default_device)
59
+
60
+ train_dataset: str = "gsm8k"
61
+ val_dataset: str = "gsm8k"
62
+ test_dataset: str = "gsm8k"
63
+
64
+ # Training configuration
65
+ train_limit: int = 50 # How many training samples to load
66
+ contrastive_pairs_limit: int = 20 # How many contrastive pairs to extract for steering training
67
+
68
+ # Evaluation configuration
69
+ val_limit: int = 50 # How many validation samples to load
70
+ test_limit: int = 100 # How many test samples to load
71
+
72
+ layer_search_range: tuple[int, int] = (15, 20)
73
+ probe_type: str = "logistic_regression" # Fixed probe type
74
+ steering_methods: list[str] = field(default_factory=lambda: ["dac", "caa"]) # TODO add more
75
+
76
+ # Optuna study configuration
77
+ study_name: str = "optimization_pipeline"
78
+ db_url: str = field(
79
+ default_factory=lambda: f"sqlite:///{os.path.dirname(os.path.dirname(__file__))}/optuna_studies.db"
80
+ )
81
+ n_trials: int = 50
82
+ n_startup_trials: int = 10 # Random exploration before TPE kicks in
83
+ sampler: str = "TPE"
84
+ pruner: str = "MedianPruner"
85
+
86
+ # WandB configuration
87
+ wandb_project: str = "wisent-guard-optimization"
88
+ use_wandb: bool = False # TODO
89
+
90
+ batch_size: int = 8
91
+ max_length: int = 512
92
+ max_new_tokens: int = 256
93
+ seed: int = 42
94
+
95
+ temperature: float = 0.0
96
+ do_sample: bool = False
97
+
98
+ output_dir: str = "outputs/optimization_pipeline"
99
+ cache_dir: str = "cache/optimization_pipeline"
100
+
101
+ max_layers_to_search: int = 6
102
+ early_stopping_patience: int = 10
103
+
104
+ def to_dict(self) -> dict[str, Any]:
105
+ """Convert to dictionary for serialization."""
106
+ return asdict(self)
107
+
108
+
109
+ class ActivationCache:
110
+ """Efficient activation caching system with proper cache keys."""
111
+
112
+ def __init__(self, cache_dir: str):
113
+ self.cache_dir = Path(cache_dir)
114
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
115
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
116
+
117
+ def _generate_cache_key(
118
+ self, split: str, layer_id: int, tokenization_config: dict[str, Any], prompt_variant: str = "default"
119
+ ) -> str:
120
+ """Generate unique cache key for activations."""
121
+ config_str = json.dumps(tokenization_config, sort_keys=True)
122
+ key_data = f"{split}_{layer_id}_{config_str}_{prompt_variant}"
123
+ return hashlib.md5(key_data.encode()).hexdigest()
124
+
125
+ def _get_cache_path(self, cache_key: str) -> Path:
126
+ """Get cache file path for key."""
127
+ return self.cache_dir / f"activations_{cache_key}.pkl"
128
+
129
+ def has_cached_activations(
130
+ self, split: str, layer_id: int, tokenization_config: dict[str, Any], prompt_variant: str = "default"
131
+ ) -> bool:
132
+ """Check if activations are cached."""
133
+ cache_key = self._generate_cache_key(split, layer_id, tokenization_config, prompt_variant)
134
+ return self._get_cache_path(cache_key).exists()
135
+
136
+ def save_activations(
137
+ self,
138
+ activations: np.ndarray,
139
+ labels: np.ndarray,
140
+ split: str,
141
+ layer_id: int,
142
+ tokenization_config: dict[str, Any],
143
+ prompt_variant: str = "default",
144
+ ):
145
+ """Save activations to cache."""
146
+ cache_key = self._generate_cache_key(split, layer_id, tokenization_config, prompt_variant)
147
+ cache_path = self._get_cache_path(cache_key)
148
+
149
+ cache_data = {
150
+ "activations": activations,
151
+ "labels": labels,
152
+ "metadata": {
153
+ "split": split,
154
+ "layer_id": layer_id,
155
+ "tokenization_config": tokenization_config,
156
+ "prompt_variant": prompt_variant,
157
+ "timestamp": datetime.now().isoformat(),
158
+ "shape": activations.shape,
159
+ },
160
+ }
161
+
162
+ with open(cache_path, "wb") as f:
163
+ pickle.dump(cache_data, f)
164
+
165
+ self.logger.info(f"Cached activations for {split} layer {layer_id}: {activations.shape}")
166
+
167
+ def load_activations(
168
+ self, split: str, layer_id: int, tokenization_config: dict[str, Any], prompt_variant: str = "default"
169
+ ) -> tuple[np.ndarray, np.ndarray]:
170
+ """Load activations from cache."""
171
+ cache_key = self._generate_cache_key(split, layer_id, tokenization_config, prompt_variant)
172
+ cache_path = self._get_cache_path(cache_key)
173
+
174
+ if not cache_path.exists():
175
+ raise FileNotFoundError(f"No cached activations found for key: {cache_key}")
176
+
177
+ with open(cache_path, "rb") as f:
178
+ cache_data = pickle.load(f)
179
+
180
+ self.logger.info(f"Loaded cached activations for {split} layer {layer_id}: {cache_data['activations'].shape}")
181
+ return cache_data["activations"], cache_data["labels"]
182
+
183
+
184
+ class OptimizationPipeline:
185
+ """Main optimization pipeline using Optuna for hyperparameter search."""
186
+
187
+ def __init__(self, config: OptimizationConfig):
188
+ self.config = config
189
+ self.device = resolve_device(config.device)
190
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
191
+
192
+ # Setup output directories
193
+ self.output_dir = Path(config.output_dir)
194
+ self.output_dir.mkdir(parents=True, exist_ok=True)
195
+
196
+ # Initialize cache
197
+ self.cache = ActivationCache(config.cache_dir)
198
+
199
+ # Initialize WandB if configured
200
+ self.wandb_run = None
201
+ if config.use_wandb:
202
+ if not WANDB_AVAILABLE:
203
+ raise ImportError(
204
+ "WandB integration enabled but wandb is not installed. Install with: pip install wandb"
205
+ )
206
+ self._init_wandb()
207
+
208
+ self.model = None
209
+ self.tokenizer = None
210
+ self.train_samples = None
211
+ self.val_samples = None
212
+ self.test_samples = None
213
+ # Store task documents for BigCode evaluation
214
+ self.train_task_docs = None
215
+ self.val_task_docs = None
216
+ self.test_task_docs = None
217
+ self.tokenization_config = {
218
+ "max_length": config.max_length,
219
+ "padding": True,
220
+ "truncation": True,
221
+ "return_tensors": "pt",
222
+ }
223
+
224
+ @property
225
+ def is_coding_task(self) -> bool:
226
+ """Check if the current task requires code execution evaluation."""
227
+ from ...parameters.task_config import CODING_TASKS
228
+ from ..bigcode_integration import is_bigcode_task
229
+
230
+ val_dataset = getattr(self.config, "val_dataset", None)
231
+ if not val_dataset:
232
+ return False
233
+
234
+ return val_dataset.lower() in CODING_TASKS or is_bigcode_task(val_dataset)
235
+
236
+ def run_optimization(self) -> dict[str, Any]:
237
+ """Run the complete optimization pipeline."""
238
+ self.logger.info("=" * 80)
239
+ self.logger.info("🚀 STARTING OPTIMIZATION PIPELINE WITH OPTUNA")
240
+ self.logger.info("=" * 80)
241
+
242
+ # Create timestamped run directory
243
+ self.run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
244
+ self.run_dir = self.output_dir / f"run_{self.run_timestamp}"
245
+ self.run_dir.mkdir(parents=True, exist_ok=True)
246
+ self.logger.info(f"📁 Run directory: {self.run_dir}")
247
+
248
+ self._setup_experiment()
249
+ study = self._create_optuna_study()
250
+ study.optimize(self._objective_function, n_trials=self.config.n_trials)
251
+ best_trial = study.best_trial
252
+ final_results = self._final_evaluation(best_trial)
253
+ self._save_reproducibility_bundle(study, final_results)
254
+
255
+ # Log final results to WandB
256
+ self._log_final_results_to_wandb(study, final_results)
257
+
258
+ self.logger.info("✅ Optimization completed successfully!")
259
+ return final_results
260
+
261
+ def _setup_experiment(self):
262
+ """Setup model, tokenizer, and load datasets."""
263
+ self.logger.info("📊 Setting up experiment...")
264
+
265
+ # Load model and tokenizer with memory optimizations
266
+ from transformers import AutoModelForCausalLM, AutoTokenizer
267
+
268
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
269
+
270
+ # Load model with memory optimizations (same as comprehensive evaluation)
271
+ self.model = AutoModelForCausalLM.from_pretrained(
272
+ self.config.model_name,
273
+ torch_dtype=preferred_dtype(self.device.type),
274
+ low_cpu_mem_usage=True,
275
+ )
276
+
277
+ self.model.to(self.device)
278
+ self.model.eval() # Set to evaluation mode
279
+
280
+ if self.tokenizer.pad_token is None:
281
+ self.tokenizer.pad_token = self.tokenizer.eos_token
282
+
283
+ # Set left padding for decoder-only models (same as comprehensive evaluation)
284
+ self.tokenizer.padding_side = "left"
285
+
286
+ # Load datasets
287
+ self.train_samples = data_utils.load_dataset_samples(self.config.train_dataset, self.config.train_limit)
288
+ self.val_samples = data_utils.load_dataset_samples(self.config.val_dataset, self.config.val_limit)
289
+ self.test_samples = data_utils.load_dataset_samples(self.config.test_dataset, self.config.test_limit)
290
+
291
+ # Store task documents for BigCode evaluation (coding tasks)
292
+ self.train_task_docs = self.train_samples
293
+ self.val_task_docs = self.val_samples
294
+ self.test_task_docs = self.test_samples
295
+
296
+ self.logger.info(
297
+ f"Loaded {len(self.train_samples)} train, {len(self.val_samples)} val, {len(self.test_samples)} test samples"
298
+ )
299
+
300
+ # Pre-cache activations for all layers on all splits
301
+ self._precache_activations()
302
+
303
+ def _precache_activations(self):
304
+ """Pre-cache activations for all layers and splits to improve efficiency."""
305
+ self.logger.info("🔄 Pre-caching activations for efficiency...")
306
+
307
+ layer_range = range(self.config.layer_search_range[0], self.config.layer_search_range[1] + 1)
308
+
309
+ splits_data = [("train", self.train_samples), ("val", self.val_samples), ("test", self.test_samples)]
310
+
311
+ for split_name, samples in splits_data:
312
+ for layer_id in layer_range:
313
+ if not self.cache.has_cached_activations(split_name, layer_id, self.tokenization_config):
314
+ self.logger.info(f"Caching activations for {split_name} split, layer {layer_id}")
315
+
316
+ dataset_name = {
317
+ "train": self.config.train_dataset,
318
+ "val": self.config.val_dataset,
319
+ "test": self.config.test_dataset,
320
+ }[split_name]
321
+
322
+ activations, labels = self._create_probe_data(samples, layer_id, dataset_name)
323
+
324
+ self.cache.save_activations(activations, labels, split_name, layer_id, self.tokenization_config)
325
+ else:
326
+ self.logger.info(f"Activations already cached for {split_name} split, layer {layer_id}")
327
+
328
+ def _create_probe_data(
329
+ self, samples: list[dict], layer_id: int, dataset_name: str
330
+ ) -> tuple[np.ndarray, np.ndarray]:
331
+ """Create contrastive probe training data for a specific layer."""
332
+ self.logger.info(f"Creating probe data from {len(samples)} samples for {dataset_name} on layer {layer_id}")
333
+
334
+ # Get task for the specified dataset
335
+ task = get_task(dataset_name)
336
+ extractor = task.get_extractor()
337
+ self.logger.debug(f"Using task: {task.__class__.__name__}, extractor: {extractor.__class__.__name__}")
338
+
339
+ texts = []
340
+ labels = []
341
+ success_count = 0
342
+ fail_count = 0
343
+
344
+ for i, sample in enumerate(samples):
345
+ try:
346
+ # Extract QA pair
347
+ contrastive_pair = extractor.extract_contrastive_pair(sample, task)
348
+
349
+ # Skip samples where contrastive pair extraction failed
350
+ if not contrastive_pair:
351
+ self.logger.debug(f"Sample {i + 1}: No contrastive pair extracted from keys: {list(sample.keys())}")
352
+ fail_count += 1
353
+ continue
354
+
355
+ success_count += 1
356
+ self.logger.debug(f"Sample {i + 1}: Successfully extracted contrastive pair")
357
+
358
+ except Exception as e:
359
+ self.logger.error(f"Sample {i + 1}: Exception during contrastive pair extraction: {e}")
360
+ fail_count += 1
361
+ continue
362
+
363
+ question = contrastive_pair["question"]
364
+ correct_answer = contrastive_pair["correct_answer"]
365
+ incorrect_answer = contrastive_pair["incorrect_answer"]
366
+
367
+ # Log contrastive pair details
368
+ self.logger.debug(f"Contrastive pair - Question: ...{question[-50:]}")
369
+ self.logger.debug(f"Contrastive pair - Correct: {correct_answer}, Incorrect: {incorrect_answer}")
370
+
371
+ correct_text = f"{question} {correct_answer}"
372
+ texts.append(correct_text)
373
+ labels.append(1)
374
+
375
+ incorrect_text = f"{question} {incorrect_answer}"
376
+ texts.append(incorrect_text)
377
+ labels.append(0)
378
+
379
+ self.logger.info(
380
+ f"Probe data creation: {success_count} successful, {fail_count} failed. Generated {len(texts)} texts."
381
+ )
382
+
383
+ if len(texts) == 0:
384
+ self.logger.error("No texts generated for activation extraction! All contrastive pair extractions failed.")
385
+ return np.array([]), np.array([])
386
+
387
+ activations = data_utils.extract_activations_with_hook(
388
+ self.model, self.tokenizer, texts, layer_id, self.config.batch_size, self.config.max_length, self.device
389
+ )
390
+
391
+ return activations, np.array(labels)
392
+
393
+ def _create_optuna_study(self) -> optuna.Study:
394
+ """Create Optuna study with SQLite persistence and specified sampler/pruner."""
395
+ self.logger.info("📋 Creating Optuna study with SQLite persistence...")
396
+ self.logger.info(f"Database: {self.config.db_url}")
397
+ self.logger.info(f"Study name: {self.config.study_name}")
398
+ self.logger.info(f"🎲 Warmup: {self.config.n_startup_trials} random trials before TPE sampling")
399
+
400
+ # Setup sampler
401
+ if self.config.sampler == "TPE":
402
+ sampler = TPESampler(seed=self.config.seed, n_startup_trials=self.config.n_startup_trials)
403
+ elif self.config.sampler == "Random":
404
+ sampler = optuna.samplers.RandomSampler(seed=self.config.seed)
405
+ else:
406
+ sampler = TPESampler(seed=self.config.seed, n_startup_trials=self.config.n_startup_trials)
407
+
408
+ # Setup pruner
409
+ if self.config.pruner == "MedianPruner":
410
+ pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=10)
411
+ elif self.config.pruner == "SuccessiveHalvingPruner":
412
+ pruner = SuccessiveHalvingPruner()
413
+ else:
414
+ pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=10)
415
+
416
+ # Create study with SQLite storage
417
+ study = optuna.create_study(
418
+ study_name=self.config.study_name,
419
+ storage=self.config.db_url,
420
+ direction="maximize", # Maximize validation accuracy
421
+ sampler=sampler,
422
+ pruner=pruner,
423
+ load_if_exists=True, # Continue existing study if it exists
424
+ )
425
+
426
+ self.logger.info(f"Study created/loaded with {len(study.trials)} existing trials")
427
+
428
+ return study
429
+
430
+ def _save_steering_vector_dual_format(self, steering_instance, pt_path: Path, safetensors_path: Path) -> bool:
431
+ """Save steering vector in both .pt and safetensors formats."""
432
+ # Save in original .pt format first (preserves all metadata)
433
+ if not steering_instance.save_steering_vector(str(pt_path)):
434
+ self.logger.warning("Failed to save steering vector - method may not be trained")
435
+ return False
436
+
437
+ self.logger.info(f"💾 Saved best steering vector to: {pt_path.name}")
438
+
439
+ # Also save in safetensors format for HuggingFace compatibility
440
+ try:
441
+ # Load the .pt file and extract steering vector
442
+ data = torch.load(str(pt_path), map_location="cpu", weights_only=False)
443
+ if isinstance(data, dict) and "steering_vector" in data:
444
+ # Save just the steering vector in safetensors format
445
+ safetensors_save({"steering_vector": data["steering_vector"]}, str(safetensors_path))
446
+ self.logger.info(f"💾 Also saved as safetensors: {safetensors_path.name}")
447
+ return True
448
+ self.logger.warning("Unexpected .pt file structure, safetensors conversion skipped")
449
+ return True # .pt save was successful
450
+ except Exception as e:
451
+ self.logger.warning(f"Could not create safetensors version: {e}")
452
+ return True # .pt save was successful
453
+
454
+ def _objective_function(self, trial: optuna.Trial) -> float:
455
+ """Optuna objective function for hyperparameter optimization."""
456
+ try:
457
+ # Sample hyperparameters
458
+ layer_id = trial.suggest_int(
459
+ "layer_id", self.config.layer_search_range[0], self.config.layer_search_range[1]
460
+ )
461
+
462
+ # Fixed probe type and regularization
463
+ probe_type = self.config.probe_type # Always logistic_regression
464
+ probe_c = 1.0 # Default regularization strength
465
+
466
+ steering_method = trial.suggest_categorical("steering_method", self.config.steering_methods)
467
+
468
+ if steering_method == "dac":
469
+ steering_alpha = trial.suggest_float("steering_alpha", 0.1, 5.0)
470
+ entropy_threshold = trial.suggest_float("entropy_threshold", 0.5, 2.0)
471
+ ptop = trial.suggest_float("ptop", 0.2, 0.8)
472
+ max_alpha = trial.suggest_float("max_alpha", 1.0, 5.0)
473
+ elif steering_method == "caa":
474
+ steering_alpha = trial.suggest_float("steering_alpha", 0.1, 5.0)
475
+
476
+ probe_score = self._train_and_evaluate_probe(trial, layer_id, probe_type, probe_c)
477
+
478
+ # Don't prune based on probe score - focus optimization on steering parameters
479
+
480
+ # Build clean hyperparameters dictionary
481
+ if steering_method == "dac":
482
+ hyperparams = {
483
+ "steering_alpha": steering_alpha,
484
+ "entropy_threshold": entropy_threshold,
485
+ "ptop": ptop,
486
+ "max_alpha": max_alpha,
487
+ }
488
+ elif steering_method == "caa":
489
+ hyperparams = {
490
+ "steering_alpha": steering_alpha,
491
+ }
492
+ else:
493
+ raise ValueError(f"Unsupported steering method: {steering_method}")
494
+
495
+ steering_method_instance = self._train_steering_method(trial, steering_method, layer_id, hyperparams)
496
+
497
+ validation_accuracy = self._evaluate_steering_on_validation(
498
+ steering_method_instance, steering_method, layer_id, hyperparams, trial.number, trial
499
+ )
500
+
501
+ trial.report(validation_accuracy, step=1)
502
+
503
+ # Log to WandB
504
+ metrics = {"validation_accuracy": validation_accuracy, "probe_score": probe_score}
505
+ self._log_trial_to_wandb(trial, metrics)
506
+
507
+ return validation_accuracy
508
+
509
+ except Exception as e:
510
+ self.logger.error(f"Trial failed: {e}")
511
+ return 0.0
512
+
513
+ def _train_and_evaluate_probe(self, trial: optuna.Trial, layer_id: int, probe_type: str, probe_c: float) -> float:
514
+ """Train probe on training data and evaluate on validation data using cached activations."""
515
+ # Load cached training activations
516
+ X_train, y_train = self.cache.load_activations("train", layer_id, self.tokenization_config)
517
+
518
+ # Train probe
519
+ if probe_type == "logistic_regression":
520
+ from sklearn.linear_model import LogisticRegression
521
+
522
+ probe = LogisticRegression(C=probe_c, random_state=self.config.seed, max_iter=1000)
523
+ probe.fit(X_train, y_train)
524
+ else:
525
+ raise ValueError(f"Unsupported probe type: {probe_type}")
526
+
527
+ # Evaluate on validation data using cached activations
528
+ X_val, y_val = self.cache.load_activations("val", layer_id, self.tokenization_config)
529
+
530
+ from sklearn.metrics import roc_auc_score
531
+
532
+ y_pred_proba = probe.predict_proba(X_val)[:, 1]
533
+ return roc_auc_score(y_val, y_pred_proba) if len(np.unique(y_val)) > 1 else 0.5
534
+
535
+ # Don't store the probe object - it can't be JSON serialized
536
+ # The probe will be retrained in the final evaluation if needed
537
+
538
+ def _train_steering_method(
539
+ self, trial: optuna.Trial, method_name: str, layer_id: int, hyperparams: dict[str, Any]
540
+ ) -> Any:
541
+ """Train steering method on training data."""
542
+ # Use contrastive_pairs_limit with bounds checking
543
+ contrastive_limit = min(self.config.contrastive_pairs_limit, len(self.train_samples))
544
+ contrastive_pairs = self._create_contrastive_pairs(
545
+ self.train_samples, layer_id, self.config.train_dataset, limit=contrastive_limit
546
+ )
547
+
548
+ if method_name == "dac":
549
+ # Create DAC instance
550
+ dac = DAC(
551
+ entropy_threshold=hyperparams["entropy_threshold"],
552
+ ptop=hyperparams["ptop"],
553
+ max_alpha=hyperparams["max_alpha"],
554
+ )
555
+
556
+ # Train DAC
557
+ dac.train(contrastive_pairs, layer_id)
558
+ return dac
559
+
560
+ if method_name == "caa":
561
+ # Create CAA instance
562
+ from wisent.core.steering_methods.caa import CAA
563
+
564
+ caa = CAA(device=self.device)
565
+
566
+ # Train CAA
567
+ caa.train(contrastive_pairs, layer_id)
568
+ return caa
569
+
570
+ raise ValueError(f"Unsupported steering method: {method_name}")
571
+
572
+ def _create_contrastive_pairs(
573
+ self, samples: list[dict], layer_id: int, dataset_name: str, limit: Optional[int] = None
574
+ ) -> ContrastivePairSet:
575
+ """Create contrastive pairs with activations for steering training."""
576
+ contrastive_pairs = []
577
+ task = get_task(dataset_name)
578
+ extractor = task.get_extractor()
579
+
580
+ samples_to_use = samples[:limit] if limit else samples
581
+
582
+ for sample in samples_to_use:
583
+ contrastive_pair = extractor.extract_contrastive_pair(sample, task)
584
+ if contrastive_pair:
585
+ # Log contrastive pair details
586
+ self.logger.debug(f"Creating contrastive pair - Question: ...{contrastive_pair['question'][-50:]}")
587
+ self.logger.debug(
588
+ f"Creating contrastive pair - Correct: {contrastive_pair['correct_answer']}, Incorrect: {contrastive_pair['incorrect_answer']}"
589
+ )
590
+
591
+ positive_response = Response(text=contrastive_pair["correct_answer"], label=1)
592
+ negative_response = Response(text=contrastive_pair["incorrect_answer"], label=0)
593
+
594
+ pair = ContrastivePair(
595
+ prompt=contrastive_pair["question"],
596
+ positive_response=positive_response,
597
+ negative_response=negative_response,
598
+ )
599
+ contrastive_pairs.append(pair)
600
+
601
+ pair_set = ContrastivePairSet(name=f"{dataset_name}_training", pairs=contrastive_pairs)
602
+
603
+ # Extract activations for all pairs in batches
604
+ if pair_set.pairs:
605
+ all_texts = []
606
+ text_to_pair_mapping = []
607
+
608
+ for pair_idx, pair in enumerate(pair_set.pairs):
609
+ pos_text = f"{pair.prompt} {pair.positive_response.text}"
610
+ neg_text = f"{pair.prompt} {pair.negative_response.text}"
611
+
612
+ all_texts.extend([pos_text, neg_text])
613
+ text_to_pair_mapping.extend([(pair_idx, "positive"), (pair_idx, "negative")])
614
+
615
+ all_activations = self._extract_batch_activations(all_texts, layer_id)
616
+
617
+ for text_idx, (pair_idx, response_type) in enumerate(text_to_pair_mapping):
618
+ activation = all_activations[text_idx]
619
+
620
+ if response_type == "positive":
621
+ pair_set.pairs[pair_idx].positive_response.activations = activation
622
+ else:
623
+ pair_set.pairs[pair_idx].negative_response.activations = activation
624
+
625
+ return pair_set
626
+
627
+ def _extract_batch_activations(self, texts: list[str], layer_id: int) -> list[torch.Tensor]:
628
+ """Extract activations for multiple texts in batches."""
629
+ if not texts:
630
+ return []
631
+
632
+ all_activations = []
633
+ batch_size = self.config.batch_size
634
+
635
+ for i in range(0, len(texts), batch_size):
636
+ batch_texts = texts[i : i + batch_size]
637
+
638
+ inputs = self.tokenizer(
639
+ batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=self.config.max_length
640
+ ).to(self.device)
641
+
642
+ batch_activations = []
643
+
644
+ def batch_hook_fn(module, input, output):
645
+ with torch.no_grad():
646
+ hidden_states = output[0] if isinstance(output, tuple) else output
647
+
648
+ last_token_acts = hidden_states[:, -1, :].detach().clone()
649
+ batch_activations.append(last_token_acts)
650
+
651
+ if hasattr(self.model, "transformer"):
652
+ target_layer = self.model.transformer.h[layer_id]
653
+ elif hasattr(self.model, "model"):
654
+ target_layer = self.model.model.layers[layer_id]
655
+ else:
656
+ raise ValueError("Unknown model architecture")
657
+
658
+ handle = target_layer.register_forward_hook(batch_hook_fn)
659
+
660
+ try:
661
+ with torch.no_grad():
662
+ _ = self.model(**inputs)
663
+ finally:
664
+ handle.remove()
665
+ empty_device_cache(self.device.type)
666
+
667
+ if batch_activations:
668
+ batch_tensor = batch_activations[0]
669
+ for j in range(batch_tensor.shape[0]):
670
+ all_activations.append(batch_tensor[j].unsqueeze(0))
671
+
672
+ return all_activations
673
+
674
+ def _extract_single_activation(self, text: str, layer_id: int) -> torch.Tensor:
675
+ """Extract activation for a single text."""
676
+ activations = self._extract_batch_activations([text], layer_id)
677
+ return activations[0] if activations else torch.zeros(1, self.model.config.hidden_size, device=self.device)
678
+
679
+ def _evaluate_steering_on_validation(
680
+ self,
681
+ steering_instance: Any,
682
+ method_name: str,
683
+ layer_id: int,
684
+ hyperparams: dict[str, Any],
685
+ trial_number: int = 0,
686
+ trial=None,
687
+ ) -> float:
688
+ """Evaluate steering method on validation data by re-running forward passes."""
689
+ if steering_instance is None:
690
+ return 0.0
691
+
692
+ # Generate predictions with steering applied
693
+ predictions = []
694
+ ground_truths = []
695
+ task_docs = [] # Preserve original task documents for BigCode evaluation
696
+
697
+ task = get_task(self.config.val_dataset)
698
+ extractor = task.get_extractor()
699
+
700
+ # Collect all questions for batched processing (use ALL validation samples)
701
+ questions = []
702
+ ground_truths = []
703
+ valid_samples = [] # Keep track of samples that produce valid QA pairs
704
+
705
+ for sample in tqdm(
706
+ self.val_samples, desc="Extracting validation QA pairs", leave=False
707
+ ): # Use all validation samples for reliable evaluation
708
+ qa_pair = extractor.extract_qa_pair(sample, task)
709
+ if not qa_pair:
710
+ continue
711
+
712
+ question = qa_pair["formatted_question"]
713
+ ground_truth = qa_pair["correct_answer"]
714
+ questions.append(question)
715
+ ground_truths.append(ground_truth)
716
+ valid_samples.append(sample) # Store the original sample
717
+
718
+ # Generate predictions using batched approach
719
+ if questions:
720
+ if steering_instance is None:
721
+ predictions = self._generate_baseline_batched(questions)
722
+ else:
723
+ # Extract the appropriate strength parameter based on method
724
+ if method_name == "dac":
725
+ # DAC uses steering_alpha as base strength multiplier
726
+ strength = hyperparams.get("steering_alpha", 1.0)
727
+ else:
728
+ # CAA and other methods use steering_alpha directly
729
+ strength = hyperparams["steering_alpha"]
730
+
731
+ predictions = self._generate_with_steering_batched(steering_instance, questions, strength, layer_id)
732
+
733
+ # Log sample predictions for debugging
734
+ for i, (pred, gt) in enumerate(zip(predictions[:3], ground_truths[:3])):
735
+ self.logger.debug(f"{method_name.upper()} Sample {i} - Model: ...{pred[-50:] if pred else 'None'}")
736
+ self.logger.debug(f"{method_name.upper()} Sample {i} - Ground truth: {gt}")
737
+ else:
738
+ predictions = []
739
+
740
+ if not predictions:
741
+ return 0.0
742
+
743
+ # Save detailed validation results to JSON
744
+ self._save_detailed_validation_results(
745
+ questions,
746
+ ground_truths,
747
+ predictions,
748
+ trial_number,
749
+ trial=trial,
750
+ steering_method=method_name,
751
+ layer_id=layer_id,
752
+ hyperparams=hyperparams,
753
+ )
754
+
755
+ # Prepare task docs for BigCode evaluation (if coding task)
756
+ task_docs = valid_samples[: len(predictions)] if valid_samples else []
757
+
758
+ # Evaluate benchmark performance (with task docs for coding tasks)
759
+ benchmark_metrics = metrics.evaluate_benchmark_performance(
760
+ predictions, ground_truths, self.config.val_dataset, task_docs=task_docs
761
+ )
762
+
763
+ return benchmark_metrics.get("accuracy", 0.0)
764
+
765
+ def _generate_with_dac_steering(self, dac: DAC, question: str, alpha: float, layer_id: int) -> str:
766
+ """Generate response with DAC steering applied."""
767
+ # Use the general steering method which calls DAC's apply_steering
768
+ return self._generate_with_steering(dac, question, alpha, layer_id)
769
+
770
+ def _generate_with_caa_steering(self, caa, question: str, alpha: float, layer_id: int) -> str:
771
+ """Generate response with CAA steering applied."""
772
+ if not hasattr(caa, "steering_vector") or caa.steering_vector is None:
773
+ return self._generate_baseline(question)
774
+
775
+ return self._generate_with_steering_hook(question, caa.steering_vector, layer_id, alpha)
776
+
777
+ def _generate_with_steering_hook(
778
+ self, question: str, steering_vector: torch.Tensor, layer_id: int, alpha: float
779
+ ) -> str:
780
+ """Generate response with steering vector applied via hook (re-runs forward pass)."""
781
+ inputs = self.tokenizer(question, return_tensors="pt").to(self.device)
782
+
783
+ def steering_hook(module, input, output):
784
+ """Hook that applies steering vector during forward pass."""
785
+ if isinstance(output, tuple):
786
+ hidden_states = output[0]
787
+ # Apply steering to the last token
788
+ hidden_states[:, -1, :] += alpha * steering_vector.to(hidden_states.device)
789
+ return (hidden_states, *output[1:])
790
+ hidden_states = output
791
+ hidden_states[:, -1, :] += alpha * steering_vector.to(hidden_states.device)
792
+ return hidden_states
793
+
794
+ # Register hook on target layer
795
+ if hasattr(self.model, "transformer"):
796
+ target_layer = self.model.transformer.h[layer_id]
797
+ elif hasattr(self.model, "model"):
798
+ target_layer = self.model.model.layers[layer_id]
799
+ else:
800
+ raise ValueError("Unknown model architecture")
801
+
802
+ handle = target_layer.register_forward_hook(steering_hook)
803
+
804
+ try:
805
+ with torch.no_grad():
806
+ outputs = self.model.generate(
807
+ **inputs,
808
+ max_new_tokens=self.config.max_new_tokens,
809
+ do_sample=self.config.do_sample,
810
+ temperature=self.config.temperature if self.config.do_sample else 1.0,
811
+ pad_token_id=self.tokenizer.eos_token_id,
812
+ eos_token_id=self.tokenizer.eos_token_id,
813
+ )
814
+ finally:
815
+ handle.remove()
816
+
817
+ response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True)
818
+ return response.strip()
819
+
820
+ def _generate_baseline(self, question: str) -> str:
821
+ """Generate baseline response without steering."""
822
+ inputs = self.tokenizer(question, return_tensors="pt").to(self.device)
823
+
824
+ with torch.no_grad():
825
+ outputs = self.model.generate(
826
+ **inputs,
827
+ max_new_tokens=self.config.max_new_tokens,
828
+ do_sample=self.config.do_sample,
829
+ temperature=self.config.temperature if self.config.do_sample else 1.0,
830
+ pad_token_id=self.tokenizer.eos_token_id,
831
+ eos_token_id=self.tokenizer.eos_token_id,
832
+ )
833
+
834
+ response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True)
835
+ return response.strip()
836
+
837
+ def _generate_baseline_batched(self, questions: list[str]) -> list[str]: # TODO
838
+ """Generate baseline responses in batches without steering."""
839
+ if not questions:
840
+ return []
841
+
842
+ batch_size = self.config.batch_size
843
+ all_responses = []
844
+
845
+ # Process questions in batches
846
+ for i in tqdm(range(0, len(questions), batch_size), desc="Generating baseline predictions", leave=False):
847
+ batch_questions = questions[i : i + batch_size]
848
+
849
+ # Batch tokenization with padding
850
+ inputs = self.tokenizer(
851
+ batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=self.config.max_length
852
+ ).to(self.device)
853
+
854
+ with torch.no_grad():
855
+ outputs = self.model.generate(
856
+ **inputs,
857
+ max_new_tokens=self.config.max_new_tokens,
858
+ do_sample=self.config.do_sample,
859
+ temperature=self.config.temperature if self.config.do_sample else 1.0,
860
+ pad_token_id=self.tokenizer.eos_token_id,
861
+ eos_token_id=self.tokenizer.eos_token_id,
862
+ )
863
+
864
+ # Decode responses
865
+ batch_responses = []
866
+ for j, output in enumerate(outputs):
867
+ input_length = inputs.input_ids[j].shape[0]
868
+ response = self.tokenizer.decode(output[input_length:], skip_special_tokens=True)
869
+ batch_responses.append(response.strip())
870
+
871
+ all_responses.extend(batch_responses)
872
+
873
+ return all_responses
874
+
875
+ def _generate_with_steering_batched(
876
+ self, steering_instance: Any, questions: list[str], alpha: float, layer_id: int
877
+ ) -> list[str]:
878
+ """Generate responses with steering applied in batches using apply_steering()."""
879
+ if not questions:
880
+ return []
881
+
882
+ batch_size = self.config.batch_size
883
+ all_responses = []
884
+
885
+ # Process questions in batches
886
+ for i in tqdm(range(0, len(questions), batch_size), desc="Generating steered predictions", leave=False):
887
+ batch_questions = questions[i : i + batch_size]
888
+
889
+ # Batch tokenization with padding
890
+ inputs = self.tokenizer(
891
+ batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=self.config.max_length
892
+ ).to(self.device)
893
+
894
+ def steering_hook(module, input, output):
895
+ """Hook that applies steering using the steering method's apply_steering()."""
896
+ hidden_states = output[0] if isinstance(output, tuple) else output
897
+
898
+ # Apply steering using the method's apply_steering() function
899
+ steered = steering_instance.apply_steering(hidden_states, strength=alpha)
900
+
901
+ if isinstance(output, tuple):
902
+ return (steered, *output[1:])
903
+ return steered
904
+
905
+ # Register hook on target layer
906
+ if hasattr(self.model, "transformer"):
907
+ if layer_id >= len(self.model.transformer.h):
908
+ raise ValueError(f"layer_id {layer_id} exceeds model layers")
909
+ target_layer = self.model.transformer.h[layer_id]
910
+ elif hasattr(self.model, "model"):
911
+ if layer_id >= len(self.model.model.layers):
912
+ raise ValueError(f"layer_id {layer_id} exceeds model layers")
913
+ target_layer = self.model.model.layers[layer_id]
914
+ else:
915
+ raise ValueError("Unknown model architecture")
916
+
917
+ handle = target_layer.register_forward_hook(steering_hook)
918
+
919
+ try:
920
+ with torch.no_grad():
921
+ outputs = self.model.generate(
922
+ **inputs,
923
+ max_new_tokens=self.config.max_new_tokens,
924
+ do_sample=self.config.do_sample,
925
+ temperature=self.config.temperature if self.config.do_sample else 1.0,
926
+ pad_token_id=self.tokenizer.eos_token_id,
927
+ eos_token_id=self.tokenizer.eos_token_id,
928
+ )
929
+
930
+ # Decode responses
931
+ batch_responses = []
932
+ for j, output in enumerate(outputs):
933
+ input_length = inputs.input_ids[j].shape[0]
934
+ response = self.tokenizer.decode(output[input_length:], skip_special_tokens=True)
935
+ batch_responses.append(response.strip())
936
+
937
+ all_responses.extend(batch_responses)
938
+
939
+ finally:
940
+ # Always remove the hook
941
+ handle.remove()
942
+
943
+ return all_responses
944
+
945
+ def _final_evaluation(self, best_trial: optuna.Trial) -> dict[str, Any]:
946
+ """Run final evaluation on test split with best configuration."""
947
+ self.logger.info("🏆 Running final evaluation with best configuration...")
948
+
949
+ # Extract best hyperparameters
950
+ # Handle both real trials and FixedTrial objects
951
+ if hasattr(best_trial, "params") and best_trial.params:
952
+ best_params = best_trial.params
953
+ elif hasattr(best_trial, "_params"):
954
+ best_params = best_trial._params
955
+ else:
956
+ # Fallback - this shouldn't happen
957
+ raise ValueError("Cannot access trial parameters")
958
+ layer_id = best_params["layer_id"]
959
+
960
+ self.logger.info(f"Best configuration: {best_params}")
961
+
962
+ # Re-train best probe and steering method on training data
963
+ from sklearn.linear_model import LogisticRegression
964
+
965
+ # Train best probe with fixed probe_c
966
+ X_train, y_train = self.cache.load_activations("train", layer_id, self.tokenization_config)
967
+ probe = LogisticRegression(C=1.0, random_state=self.config.seed, max_iter=1000) # Fixed probe_c
968
+ probe.fit(X_train, y_train)
969
+
970
+ # Train best steering method
971
+ steering_method = best_params.get("steering_method", "caa") # Default to CAA if missing
972
+ steering_instance = self._train_steering_method(best_trial, steering_method, layer_id, best_params)
973
+
974
+ # Save the best steering vector in both formats
975
+ if steering_instance and hasattr(steering_instance, "save_steering_vector"):
976
+ pt_path = self.run_dir / "best_steering_vector.pt"
977
+ safetensors_path = self.run_dir / "best_steering_vector.safetensors"
978
+ self._save_steering_vector_dual_format(steering_instance, pt_path, safetensors_path)
979
+
980
+ # Generate baseline predictions (no steering)
981
+ self.logger.info("Generating baseline predictions...")
982
+ baseline_predictions, test_ground_truths, test_questions, test_task_docs = self._generate_test_predictions(
983
+ None, None, layer_id, 0.0
984
+ )
985
+
986
+ # Generate steered predictions
987
+ self.logger.info("Generating steered predictions...")
988
+
989
+ # Extract the appropriate strength parameter based on method and available parameters
990
+ method_name = best_params.get("steering_method", "caa") # Default to CAA if missing
991
+ if method_name == "dac":
992
+ # DAC can use base_strength or steering_alpha, with fallback to 1.0
993
+ strength = best_params.get("base_strength", best_params.get("steering_alpha", 1.0))
994
+ else:
995
+ # CAA and other methods use steering_alpha
996
+ strength = best_params["steering_alpha"]
997
+
998
+ steered_predictions, _, _, _ = self._generate_test_predictions(
999
+ steering_instance, method_name, layer_id, strength
1000
+ )
1001
+
1002
+ # Save detailed test results to JSON
1003
+ if test_questions and test_ground_truths and baseline_predictions and steered_predictions:
1004
+ self._save_detailed_test_results(
1005
+ test_questions,
1006
+ test_ground_truths,
1007
+ baseline_predictions,
1008
+ steered_predictions,
1009
+ best_trial=best_trial,
1010
+ best_params=best_params,
1011
+ layer_id=layer_id,
1012
+ steering_method=method_name,
1013
+ )
1014
+
1015
+ # Calculate benchmark metrics (with real task docs for coding tasks)
1016
+ baseline_benchmark_metrics = metrics.evaluate_benchmark_performance(
1017
+ baseline_predictions, test_ground_truths, self.config.test_dataset, task_docs=test_task_docs
1018
+ )
1019
+ steered_benchmark_metrics = metrics.evaluate_benchmark_performance(
1020
+ steered_predictions, test_ground_truths, self.config.test_dataset, task_docs=test_task_docs
1021
+ )
1022
+
1023
+ # Evaluate probe on test data
1024
+ X_test, y_test = self.cache.load_activations("test", layer_id, self.tokenization_config)
1025
+ test_probe_metrics = self._evaluate_probe_metrics(probe, X_test, y_test)
1026
+
1027
+ # Calculate improvement
1028
+ accuracy_improvement = steered_benchmark_metrics.get("accuracy", 0.0) - baseline_benchmark_metrics.get(
1029
+ "accuracy", 0.0
1030
+ )
1031
+
1032
+ final_results = {
1033
+ "best_trial_params": best_params,
1034
+ "best_validation_score": getattr(best_trial, "value", None),
1035
+ "baseline_benchmark_metrics": baseline_benchmark_metrics,
1036
+ "steered_benchmark_metrics": steered_benchmark_metrics,
1037
+ "accuracy_improvement": accuracy_improvement,
1038
+ "test_probe_metrics": test_probe_metrics,
1039
+ "config": self.config.to_dict(),
1040
+ "num_test_samples": len(test_ground_truths),
1041
+ }
1042
+
1043
+ # Log final results
1044
+ self.logger.info("=" * 60)
1045
+ self.logger.info("🏆 FINAL TEST RESULTS")
1046
+ self.logger.info("=" * 60)
1047
+ self.logger.info(f"Baseline accuracy: {baseline_benchmark_metrics.get('accuracy', 0.0):.4f}")
1048
+ self.logger.info(f"Steered accuracy: {steered_benchmark_metrics.get('accuracy', 0.0):.4f}")
1049
+ self.logger.info(f"Improvement: {accuracy_improvement:+.4f}")
1050
+ self.logger.info(f"Probe AUC: {test_probe_metrics.get('auc', 0.5):.4f}")
1051
+ self.logger.info(f"Test samples: {len(test_ground_truths)}")
1052
+ self.logger.info("=" * 60)
1053
+
1054
+ return final_results
1055
+
1056
+ def _generate_test_predictions(
1057
+ self, steering_instance: Any, method_name: str, layer_id: int, alpha: float
1058
+ ) -> tuple[list[str], list[str], list[str], list[dict]]:
1059
+ """Generate predictions on test data using batched generation."""
1060
+ # Collect all questions and ground truths for batching
1061
+ questions = []
1062
+ ground_truths = []
1063
+ valid_samples = [] # Keep track of samples that produce valid QA pairs
1064
+
1065
+ task = get_task(self.config.test_dataset)
1066
+ extractor = task.get_extractor()
1067
+
1068
+ for sample in self.test_samples:
1069
+ qa_pair = extractor.extract_qa_pair(sample, task)
1070
+ if not qa_pair:
1071
+ continue
1072
+
1073
+ question = qa_pair["formatted_question"]
1074
+ ground_truth = qa_pair["correct_answer"]
1075
+ questions.append(question)
1076
+ ground_truths.append(ground_truth)
1077
+ valid_samples.append(sample) # Store the original sample
1078
+
1079
+ # Process all questions with appropriate batched method
1080
+ if questions:
1081
+ try:
1082
+ if steering_instance is None:
1083
+ # Baseline generation - use batched method
1084
+ predictions = self._generate_baseline_batched(questions)
1085
+ else:
1086
+ # Use unified batched generation with apply_steering()
1087
+ predictions = self._generate_with_steering_batched(steering_instance, questions, alpha, layer_id)
1088
+
1089
+ # Log sample predictions for debugging
1090
+ for i, (pred, gt) in enumerate(zip(predictions[:3], ground_truths[:3])):
1091
+ self.logger.debug(f"Test Sample {i} - Model: ...{pred[-50:] if pred else 'None'}")
1092
+ self.logger.debug(f"Test Sample {i} - Ground truth: {gt}")
1093
+
1094
+ except Exception as e:
1095
+ self.logger.warning(f"Batched generation failed for test: {e}")
1096
+ predictions = ["Error"] * len(questions)
1097
+ else:
1098
+ predictions = []
1099
+
1100
+ return predictions, ground_truths, questions, valid_samples
1101
+
1102
+ def _evaluate_probe_metrics(self, probe, X_test: np.ndarray, y_test: np.ndarray) -> dict[str, float]:
1103
+ """Evaluate probe metrics."""
1104
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
1105
+
1106
+ y_pred = probe.predict(X_test)
1107
+ y_pred_proba = probe.predict_proba(X_test)[:, 1]
1108
+
1109
+ return {
1110
+ "accuracy": accuracy_score(y_test, y_pred),
1111
+ "precision": precision_score(y_test, y_pred, zero_division=0),
1112
+ "recall": recall_score(y_test, y_pred, zero_division=0),
1113
+ "f1": f1_score(y_test, y_pred, zero_division=0),
1114
+ "auc": roc_auc_score(y_test, y_pred_proba) if len(np.unique(y_test)) > 1 else 0.5,
1115
+ }
1116
+
1117
+ def _create_experiment_metadata(
1118
+ self, trial=None, steering_method: str = None, layer_id: int = None, hyperparams: dict = None
1119
+ ):
1120
+ """Create comprehensive experiment metadata for detailed results."""
1121
+ import platform
1122
+ from datetime import datetime
1123
+
1124
+ metadata = {
1125
+ "trial_info": {
1126
+ "trial_number": trial.number if trial else None,
1127
+ "trial_params": dict(trial.params) if trial else {},
1128
+ "trial_state": str(getattr(trial, "state", "RUNNING")) if trial else None,
1129
+ },
1130
+ "model_config": {
1131
+ "model_name": self.config.model_name,
1132
+ "device": self.config.device,
1133
+ "is_coding_task": self.is_coding_task,
1134
+ },
1135
+ "dataset_config": {
1136
+ "train_dataset": self.config.train_dataset,
1137
+ "val_dataset": self.config.val_dataset,
1138
+ "test_dataset": self.config.test_dataset,
1139
+ "train_limit": self.config.train_limit,
1140
+ "val_limit": self.config.val_limit,
1141
+ "test_limit": self.config.test_limit,
1142
+ "contrastive_pairs_limit": self.config.contrastive_pairs_limit,
1143
+ },
1144
+ "steering_config": {
1145
+ "steering_method": steering_method,
1146
+ "layer_id": layer_id,
1147
+ "hyperparams": hyperparams or {},
1148
+ "layer_search_range": self.config.layer_search_range,
1149
+ "probe_type": self.config.probe_type,
1150
+ "available_steering_methods": self.config.steering_methods,
1151
+ },
1152
+ "optimization_config": {
1153
+ "study_name": self.config.study_name,
1154
+ "sampler": self.config.sampler,
1155
+ "pruner": self.config.pruner,
1156
+ "n_trials": self.config.n_trials,
1157
+ "n_startup_trials": self.config.n_startup_trials,
1158
+ },
1159
+ "generation_config": {
1160
+ "batch_size": self.config.batch_size,
1161
+ "max_length": self.config.max_length,
1162
+ "max_new_tokens": self.config.max_new_tokens,
1163
+ "temperature": self.config.temperature,
1164
+ "do_sample": self.config.do_sample,
1165
+ },
1166
+ "run_info": {
1167
+ "timestamp": datetime.now().isoformat(),
1168
+ "run_dir": str(self.run_dir),
1169
+ "output_dir": self.config.output_dir,
1170
+ "cache_dir": self.config.cache_dir,
1171
+ "platform": platform.platform(),
1172
+ "python_version": platform.python_version(),
1173
+ },
1174
+ "wandb_config": {
1175
+ "use_wandb": self.config.use_wandb,
1176
+ "wandb_project": self.config.wandb_project,
1177
+ }
1178
+ if hasattr(self.config, "use_wandb")
1179
+ else {},
1180
+ }
1181
+
1182
+ return metadata
1183
+
1184
+ def _save_detailed_validation_results(
1185
+ self,
1186
+ questions: list[str],
1187
+ ground_truths: list[str],
1188
+ predictions: list[str],
1189
+ trial_number: int,
1190
+ trial=None,
1191
+ steering_method: str = None,
1192
+ layer_id: int = None,
1193
+ hyperparams: dict = None,
1194
+ ):
1195
+ """Save detailed validation results to JSON file with experiment metadata."""
1196
+ detailed_results = []
1197
+
1198
+ # For coding tasks, use the same BigCode evaluation as accuracy calculation
1199
+ if self.is_coding_task:
1200
+ # Use evaluate_benchmark_performance to get consistent BigCode evaluation
1201
+ eval_results = metrics.evaluate_benchmark_performance(
1202
+ predictions, ground_truths, task_name=self.config.val_dataset, task_docs=self.val_task_docs
1203
+ )
1204
+
1205
+ # Extract individual correctness from evaluation details
1206
+ eval_details = eval_results.get("evaluation_details", [])
1207
+
1208
+ for i, (question, correct_answer, model_answer) in enumerate(zip(questions, ground_truths, predictions)):
1209
+ # Get correctness from BigCode evaluation if available
1210
+ is_correct = eval_details[i]["is_correct"] if i < len(eval_details) else False
1211
+
1212
+ detailed_results.append(
1213
+ {
1214
+ "row": i,
1215
+ "question": question,
1216
+ "correct_answer": correct_answer,
1217
+ "model_answer": model_answer,
1218
+ "is_correct": is_correct,
1219
+ "evaluation_method": eval_results.get("evaluation_method", "unknown"),
1220
+ "extracted_code": eval_details[i].get("prediction", model_answer)
1221
+ if i < len(eval_details)
1222
+ else model_answer,
1223
+ "execution_error": eval_details[i].get("execution_error") if i < len(eval_details) else None,
1224
+ }
1225
+ )
1226
+ else:
1227
+ # For non-coding tasks, process each result
1228
+ for i, (question, correct_answer, model_answer) in enumerate(zip(questions, ground_truths, predictions)):
1229
+ # Use standard evaluation via metrics module
1230
+ is_correct = metrics.evaluate_response_correctness(
1231
+ model_answer, correct_answer, self.config.val_dataset
1232
+ )
1233
+
1234
+ result_entry = {
1235
+ "row": i,
1236
+ "question": question,
1237
+ "correct_answer": correct_answer,
1238
+ "model_answer": model_answer,
1239
+ "is_correct": is_correct,
1240
+ "evaluation_method": "string_comparison",
1241
+ }
1242
+
1243
+ # Add MC-specific fields if this is a multiple choice task
1244
+ if self._should_use_multiple_choice_evaluation():
1245
+ # Extract MC diagnostics directly without custom evaluation
1246
+ import re
1247
+
1248
+ # Extract available answers from question (A. choice, B. choice, etc.)
1249
+ available_answers = []
1250
+ choice_pattern = r"([A-E])\.\s+(.+?)(?=\n[A-E]\.|$)"
1251
+ matches = re.findall(choice_pattern, question, re.MULTILINE | re.DOTALL)
1252
+ for letter, choice_text in matches:
1253
+ available_answers.append(f"{letter}. {choice_text.strip()}")
1254
+
1255
+ # Extract model's selected letter from model answer
1256
+ model_selected_letter = "?"
1257
+ model_letter_match = re.search(r"\b([A-E])\b", model_answer.upper())
1258
+ if model_letter_match:
1259
+ model_selected_letter = model_letter_match.group(1)
1260
+
1261
+ result_entry["available_answers"] = available_answers
1262
+ result_entry["correct_choice_letter"] = correct_answer
1263
+ result_entry["model_selected_letter"] = model_selected_letter
1264
+
1265
+ detailed_results.append(result_entry)
1266
+
1267
+ # Create experiment metadata
1268
+ experiment_metadata = self._create_experiment_metadata(trial, steering_method, layer_id, hyperparams)
1269
+
1270
+ # Create final results structure with metadata
1271
+ final_results = {"experiment_metadata": experiment_metadata, "results": detailed_results}
1272
+
1273
+ # Save to JSON file
1274
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1275
+ filename = f"validation_detailed_results_trial_{trial_number:03d}_{timestamp}.json"
1276
+ filepath = self.run_dir / filename
1277
+
1278
+ with open(filepath, "w", encoding="utf-8") as f:
1279
+ json.dump(final_results, f, indent=2, ensure_ascii=False)
1280
+
1281
+ self.logger.info(f"💾 Saved detailed validation results to: {filename}")
1282
+
1283
+ def _save_detailed_test_results(
1284
+ self,
1285
+ questions: list[str],
1286
+ ground_truths: list[str],
1287
+ baseline_predictions: list[str],
1288
+ steered_predictions: list[str],
1289
+ best_trial=None,
1290
+ best_params: dict = None,
1291
+ layer_id: int = None,
1292
+ steering_method: str = None,
1293
+ ):
1294
+ """Save detailed test results to JSON file with both baseline and steered answers and experiment metadata."""
1295
+ detailed_results = []
1296
+
1297
+ # For coding tasks, use BigCode evaluation consistently
1298
+ if self.is_coding_task:
1299
+ # Evaluate baseline predictions with BigCode
1300
+ baseline_eval_results = metrics.evaluate_benchmark_performance(
1301
+ baseline_predictions, ground_truths, task_name=self.config.test_dataset, task_docs=self.test_task_docs
1302
+ )
1303
+
1304
+ # Evaluate steered predictions with BigCode
1305
+ steered_eval_results = metrics.evaluate_benchmark_performance(
1306
+ steered_predictions, ground_truths, task_name=self.config.test_dataset, task_docs=self.test_task_docs
1307
+ )
1308
+
1309
+ baseline_details = baseline_eval_results.get("evaluation_details", [])
1310
+ steered_details = steered_eval_results.get("evaluation_details", [])
1311
+
1312
+ for i, (question, correct_answer, baseline_answer, steered_answer) in enumerate(
1313
+ zip(questions, ground_truths, baseline_predictions, steered_predictions)
1314
+ ):
1315
+ # Get correctness from BigCode evaluation
1316
+ is_baseline_correct = baseline_details[i]["is_correct"] if i < len(baseline_details) else False
1317
+ is_correct = steered_details[i]["is_correct"] if i < len(steered_details) else False
1318
+
1319
+ detailed_results.append(
1320
+ {
1321
+ "row": i,
1322
+ "question": question,
1323
+ "correct_answer": correct_answer,
1324
+ "baseline_model_answer": baseline_answer,
1325
+ "model_answer": steered_answer,
1326
+ "is_baseline_correct": is_baseline_correct,
1327
+ "is_correct": is_correct,
1328
+ "evaluation_method": steered_eval_results.get("evaluation_method", "bigcode_execution"),
1329
+ "baseline_extracted_code": baseline_details[i].get("prediction", baseline_answer)
1330
+ if i < len(baseline_details)
1331
+ else baseline_answer,
1332
+ "steered_extracted_code": steered_details[i].get("prediction", steered_answer)
1333
+ if i < len(steered_details)
1334
+ else steered_answer,
1335
+ "baseline_execution_error": baseline_details[i].get("execution_error")
1336
+ if i < len(baseline_details)
1337
+ else None,
1338
+ "steered_execution_error": steered_details[i].get("execution_error")
1339
+ if i < len(steered_details)
1340
+ else None,
1341
+ }
1342
+ )
1343
+ else:
1344
+ # For non-coding tasks, process each result
1345
+ for i, (question, correct_answer, baseline_answer, steered_answer) in enumerate(
1346
+ zip(questions, ground_truths, baseline_predictions, steered_predictions)
1347
+ ):
1348
+ # Use standard evaluation for both baseline and steered answers
1349
+ is_baseline_correct = metrics.evaluate_response_correctness(
1350
+ baseline_answer, correct_answer, self.config.test_dataset
1351
+ )
1352
+ is_correct = metrics.evaluate_response_correctness(
1353
+ steered_answer, correct_answer, self.config.test_dataset
1354
+ )
1355
+
1356
+ result_entry = {
1357
+ "row": i,
1358
+ "question": question,
1359
+ "correct_answer": correct_answer,
1360
+ "baseline_model_answer": baseline_answer,
1361
+ "model_answer": steered_answer,
1362
+ "is_baseline_correct": is_baseline_correct,
1363
+ "is_correct": is_correct,
1364
+ "evaluation_method": "string_comparison",
1365
+ }
1366
+
1367
+ # Add MC-specific fields if this is a multiple choice task
1368
+ if self._should_use_multiple_choice_evaluation():
1369
+ # Extract MC diagnostics directly without custom evaluation
1370
+ import re
1371
+
1372
+ # Extract available answers from question (A. choice, B. choice, etc.)
1373
+ available_answers = []
1374
+ choice_pattern = r"([A-E])\.\s+(.+?)(?=\n[A-E]\.|$)"
1375
+ matches = re.findall(choice_pattern, question, re.MULTILINE | re.DOTALL)
1376
+ for letter, choice_text in matches:
1377
+ available_answers.append(f"{letter}. {choice_text.strip()}")
1378
+
1379
+ # Extract steered model's selected letter
1380
+ steered_selected_letter = "?"
1381
+ steered_letter_match = re.search(r"\b([A-E])\b", steered_answer.upper())
1382
+ if steered_letter_match:
1383
+ steered_selected_letter = steered_letter_match.group(1)
1384
+
1385
+ # Extract baseline model's selected letter
1386
+ baseline_selected_letter = "?"
1387
+ baseline_letter_match = re.search(r"\b([A-E])\b", baseline_answer.upper())
1388
+ if baseline_letter_match:
1389
+ baseline_selected_letter = baseline_letter_match.group(1)
1390
+
1391
+ result_entry["available_answers"] = available_answers
1392
+ result_entry["correct_choice_letter"] = correct_answer
1393
+ result_entry["model_selected_letter"] = steered_selected_letter
1394
+ result_entry["baseline_model_selected_letter"] = baseline_selected_letter
1395
+
1396
+ detailed_results.append(result_entry)
1397
+
1398
+ # Create experiment metadata for test results
1399
+ experiment_metadata = self._create_experiment_metadata(
1400
+ trial=best_trial,
1401
+ steering_method=steering_method or best_params.get("steering_method") if best_params else None,
1402
+ layer_id=layer_id,
1403
+ hyperparams=best_params,
1404
+ )
1405
+
1406
+ # Create final results structure with metadata
1407
+ final_results = {"experiment_metadata": experiment_metadata, "results": detailed_results}
1408
+
1409
+ # Save to JSON file
1410
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1411
+ filename = f"test_detailed_results_{timestamp}.json"
1412
+ filepath = self.run_dir / filename
1413
+
1414
+ with open(filepath, "w", encoding="utf-8") as f:
1415
+ json.dump(final_results, f, indent=2, ensure_ascii=False)
1416
+
1417
+ self.logger.info(f"💾 Saved detailed test results to: {filename}")
1418
+ return filename
1419
+
1420
+ def _save_reproducibility_bundle(self, study: optuna.Study, final_results: dict[str, Any]):
1421
+ """Save complete reproducibility bundle."""
1422
+
1423
+ # Save Optuna study
1424
+ study_path = self.run_dir / f"optuna_study_{self.run_timestamp}.db"
1425
+ study.study_name = str(study_path)
1426
+
1427
+ # Save configuration
1428
+ config_path = self.run_dir / f"config_{self.run_timestamp}.json"
1429
+ with open(config_path, "w") as f:
1430
+ json.dump(self.config.to_dict(), f, indent=2)
1431
+
1432
+ # Save final results
1433
+ results_path = self.run_dir / f"final_results_{self.run_timestamp}.json"
1434
+ with open(results_path, "w") as f:
1435
+ json.dump(final_results, f, indent=2, default=str)
1436
+
1437
+ # Save best configuration
1438
+ best_config = {
1439
+ "best_params": study.best_trial.params,
1440
+ "best_value": study.best_trial.value,
1441
+ "model_name": self.config.model_name,
1442
+ "random_seed": self.config.seed,
1443
+ "commit_hash": self._get_git_commit_hash(),
1444
+ "timestamp": self.run_timestamp,
1445
+ }
1446
+
1447
+ best_config_path = self.run_dir / f"best_configuration_{self.run_timestamp}.json"
1448
+ with open(best_config_path, "w") as f:
1449
+ json.dump(best_config, f, indent=2)
1450
+
1451
+ # Save study trials summary
1452
+ trials_df = study.trials_dataframe()
1453
+ trials_path = self.run_dir / f"study_trials_{self.run_timestamp}.csv"
1454
+ trials_df.to_csv(trials_path, index=False)
1455
+
1456
+ self.logger.info(f"💾 Reproducibility bundle saved to: {self.run_dir}")
1457
+ self.logger.info(f"📊 Study database: {study_path}")
1458
+ self.logger.info(f"⚙️ Configuration: {config_path}")
1459
+ self.logger.info(f"🏆 Results: {results_path}")
1460
+ self.logger.info(f"🎯 Best config: {best_config_path}")
1461
+
1462
+ # Log steering vector if it exists (prefer safetensors format)
1463
+ safetensors_path = self.run_dir / "best_steering_vector.safetensors"
1464
+ pt_path = self.run_dir / "best_steering_vector.pt"
1465
+
1466
+ if safetensors_path.exists():
1467
+ self.logger.info(f"🧭 Steering vector: {safetensors_path.name}")
1468
+ elif pt_path.exists():
1469
+ self.logger.info(f"🧭 Steering vector: {pt_path.name}")
1470
+
1471
+ def _get_git_commit_hash(self) -> Optional[str]:
1472
+ """Get current git commit hash for reproducibility."""
1473
+ try:
1474
+ import subprocess
1475
+
1476
+ result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True)
1477
+ if result.returncode == 0:
1478
+ return result.stdout.strip()
1479
+ except:
1480
+ pass
1481
+ return None
1482
+
1483
+ def evaluate_only(self, best_params: dict[str, Any]) -> dict[str, Any]:
1484
+ """Run evaluation only with provided parameters.
1485
+
1486
+ Args:
1487
+ best_params: Dictionary of hyperparameters to use for evaluation
1488
+
1489
+ Returns:
1490
+ Dictionary containing evaluation results
1491
+ """
1492
+ self.logger.info("🔬 Running evaluation-only mode with provided parameters")
1493
+ self.logger.info(f"Parameters: {best_params}")
1494
+
1495
+ # Setup experiment if not already done
1496
+ if self.model is None:
1497
+ self._setup_experiment()
1498
+
1499
+ # Create timestamped run directory for evaluation-only mode
1500
+ if not hasattr(self, "run_dir"):
1501
+ self.run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1502
+ self.run_dir = self.output_dir / f"evaluate_only_{self.run_timestamp}"
1503
+ self.run_dir.mkdir(parents=True, exist_ok=True)
1504
+ self.logger.info(f"📁 Evaluation directory: {self.run_dir}")
1505
+
1506
+ # Create a complete mock trial with all expected parameters
1507
+ from optuna.trial import FixedTrial
1508
+
1509
+ # Ensure we have all required parameters for _final_evaluation
1510
+ complete_params = {
1511
+ "layer_id": best_params.get("layer_id", 15),
1512
+ "probe_type": best_params.get("probe_type", "logistic_regression"),
1513
+ "probe_c": best_params.get("probe_c", 1.0),
1514
+ "steering_method": best_params.get("steering_method", "caa"),
1515
+ "steering_alpha": best_params.get("steering_alpha", 0.5),
1516
+ }
1517
+
1518
+ # Add method-specific parameters if needed
1519
+ if complete_params["steering_method"] == "dac":
1520
+ complete_params.update(
1521
+ {
1522
+ "entropy_threshold": best_params.get("entropy_threshold", 1.5),
1523
+ "ptop": best_params.get("ptop", 0.5),
1524
+ "max_alpha": best_params.get("max_alpha", 2.0),
1525
+ }
1526
+ )
1527
+
1528
+ fixed_trial = FixedTrial(complete_params)
1529
+
1530
+ # Fix FixedTrial params access issue
1531
+ if not hasattr(fixed_trial, "params"):
1532
+ fixed_trial.params = complete_params
1533
+
1534
+ # Run final evaluation
1535
+ return self._final_evaluation(fixed_trial)
1536
+
1537
+ @classmethod
1538
+ def from_saved_study(
1539
+ cls, study_path: str, config_path: Optional[str] = None, override_config: Optional[dict[str, Any]] = None
1540
+ ):
1541
+ """Create pipeline from saved study and optionally saved config.
1542
+
1543
+ Args:
1544
+ study_path: Path to the SQLite study database
1545
+ config_path: Optional path to saved configuration JSON
1546
+ override_config: Optional dict of config values to override
1547
+
1548
+ Returns:
1549
+ Tuple of (pipeline, study) ready for evaluation
1550
+ """
1551
+ # Load config if provided
1552
+ if config_path:
1553
+ with open(config_path) as f:
1554
+ config_dict = json.load(f)
1555
+ # Apply any overrides
1556
+ if override_config:
1557
+ config_dict.update(override_config)
1558
+ config = OptimizationConfig(**config_dict)
1559
+ else:
1560
+ # Create minimal config with overrides
1561
+ config = OptimizationConfig(**(override_config or {}))
1562
+
1563
+ # Load study
1564
+ from pathlib import Path
1565
+
1566
+ study_name = Path(study_path).stem
1567
+ study = optuna.load_study(study_name=study_name, storage=f"sqlite:///{study_path}")
1568
+
1569
+ pipeline = cls(config)
1570
+ return pipeline, study
1571
+
1572
+ def evaluate_on_dataset(
1573
+ self, best_params: dict[str, Any], dataset_name: str, dataset_limit: Optional[int] = None
1574
+ ) -> dict[str, Any]:
1575
+ """Evaluate best parameters on a different dataset.
1576
+
1577
+ Args:
1578
+ best_params: Dictionary of hyperparameters to use
1579
+ dataset_name: Name of dataset to evaluate on
1580
+ dataset_limit: Optional limit on number of samples
1581
+
1582
+ Returns:
1583
+ Dictionary containing evaluation results on the new dataset
1584
+ """
1585
+ # Temporarily override dataset configuration
1586
+ original_test_dataset = self.config.test_dataset
1587
+ original_test_limit = self.config.test_limit
1588
+
1589
+ self.config.test_dataset = dataset_name
1590
+ self.config.test_limit = dataset_limit or self.config.test_limit
1591
+
1592
+ self.logger.info(f"📊 Evaluating on {dataset_name} with {self.config.test_limit} samples")
1593
+
1594
+ # Reload test samples for new dataset
1595
+ from . import data_utils
1596
+
1597
+ self.test_samples = data_utils.load_dataset_samples(self.config.test_dataset, self.config.test_limit)
1598
+
1599
+ # Run evaluation
1600
+ results = self.evaluate_only(best_params)
1601
+
1602
+ # Restore original config
1603
+ self.config.test_dataset = original_test_dataset
1604
+ self.config.test_limit = original_test_limit
1605
+
1606
+ return results
1607
+
1608
+ def cleanup_memory(self):
1609
+ """Clean up GPU/MPS memory."""
1610
+ if hasattr(self, "model") and self.model is not None:
1611
+ del self.model
1612
+ self.model = None
1613
+ if hasattr(self, "tokenizer") and self.tokenizer is not None:
1614
+ del self.tokenizer
1615
+ self.tokenizer = None
1616
+
1617
+ # Finish WandB run
1618
+ if self.wandb_run is not None:
1619
+ wandb.finish()
1620
+ self.wandb_run = None
1621
+
1622
+ # Clean up device memory
1623
+ empty_device_cache(self.device.type)
1624
+
1625
+ import gc
1626
+
1627
+ gc.collect()
1628
+
1629
+ def _init_wandb(self):
1630
+ """Initialize WandB for experiment tracking."""
1631
+ try:
1632
+ self.wandb_run = wandb.init(
1633
+ project=self.config.wandb_project,
1634
+ name=f"{self.config.study_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
1635
+ config=self.config.to_dict(),
1636
+ tags=["optuna", "steering", "optimization"],
1637
+ reinit=True,
1638
+ )
1639
+ self.logger.info(f"WandB initialized: {wandb.run.url}")
1640
+ except Exception as e:
1641
+ # Don't silently disable - user explicitly requested WandB
1642
+ raise RuntimeError(
1643
+ f"Failed to initialize WandB: {e}\n"
1644
+ f"Possible solutions:\n"
1645
+ f"1. Run 'wandb login' to authenticate\n"
1646
+ f"2. Check your internet connection\n"
1647
+ f"3. Verify project name: {self.config.wandb_project}\n"
1648
+ f"4. Set use_wandb=False to disable WandB"
1649
+ ) from e
1650
+
1651
+ def _log_trial_to_wandb(self, trial: optuna.Trial, metrics: dict[str, float]):
1652
+ """Log trial results to WandB."""
1653
+ if not self.config.use_wandb or self.wandb_run is None:
1654
+ return
1655
+
1656
+ try:
1657
+ # Log trial parameters and metrics
1658
+ log_data = {f"trial/{k}": v for k, v in trial.params.items()}
1659
+ log_data.update({f"metrics/{k}": v for k, v in metrics.items()})
1660
+ log_data["trial/number"] = trial.number
1661
+
1662
+ wandb.log(log_data)
1663
+ except Exception as e:
1664
+ self.logger.warning(f"Failed to log trial to WandB: {e}")
1665
+
1666
+ def _log_final_results_to_wandb(self, study: optuna.Study, final_results: dict[str, Any]):
1667
+ """Log final optimization results to WandB."""
1668
+ if not self.config.use_wandb or self.wandb_run is None:
1669
+ return
1670
+
1671
+ try:
1672
+ # Log best trial results
1673
+ best_params = {f"best/{k}": v for k, v in study.best_params.items()}
1674
+ best_metrics = {
1675
+ "best/validation_accuracy": study.best_value,
1676
+ "best/baseline_accuracy": final_results["baseline_benchmark_metrics"]["accuracy"],
1677
+ "best/steered_accuracy": final_results["steered_benchmark_metrics"]["accuracy"],
1678
+ "best/accuracy_improvement": final_results["accuracy_improvement"],
1679
+ "study/n_trials": len(study.trials),
1680
+ "study/n_complete_trials": len(
1681
+ [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
1682
+ ),
1683
+ }
1684
+
1685
+ wandb.log({**best_params, **best_metrics})
1686
+
1687
+ # Log optimization history
1688
+ trial_values = [t.value for t in study.trials if t.value is not None]
1689
+ if trial_values:
1690
+ wandb.log(
1691
+ {
1692
+ "optimization/best_value_so_far": max(trial_values),
1693
+ "optimization/mean_trial_value": np.mean(trial_values),
1694
+ "optimization/std_trial_value": np.std(trial_values),
1695
+ }
1696
+ )
1697
+
1698
+ except Exception as e:
1699
+ self.logger.warning(f"Failed to log final results to WandB: {e}")
1700
+
1701
+ def _should_use_multiple_choice_evaluation(self) -> bool:
1702
+ """Determine if we should use multiple choice evaluation for this dataset."""
1703
+ # Use multiple choice evaluation for TruthfulQA and other MC tasks
1704
+ return self.config.test_dataset.lower() in ["truthfulqa_mc1", "truthfulqa", "mmlu"]
1705
+
1706
+
1707
+ def main():
1708
+ """Main entry point for optimization pipeline."""
1709
+ # Setup logging
1710
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
1711
+
1712
+ # Create configuration
1713
+ config = OptimizationConfig(
1714
+ train_limit=100,
1715
+ contrastive_pairs_limit=30, # Bounded by train_limit
1716
+ val_limit=50,
1717
+ test_limit=50,
1718
+ n_trials=20,
1719
+ layer_search_range=(10, 15),
1720
+ )
1721
+
1722
+ # Run optimization
1723
+ pipeline = OptimizationPipeline(config)
1724
+ try:
1725
+ results = pipeline.run_optimization()
1726
+
1727
+ print("🎉 Optimization completed!")
1728
+ print(f"Best validation score: {results['best_validation_score']:.4f}")
1729
+ print(f"Test accuracy: {results['steered_benchmark_metrics']['accuracy']:.4f}")
1730
+ print(f"Accuracy improvement: {results['accuracy_improvement']:+.4f}")
1731
+
1732
+ finally:
1733
+ # Clean up memory
1734
+ pipeline.cleanup_memory()
1735
+
1736
+
1737
+ if __name__ == "__main__":
1738
+ main()