wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1111 @@
1
+ """
2
+ Steering optimization module for improving benchmark performance.
3
+
4
+ This module handles training and optimizing different steering methods that can
5
+ improve model performance on benchmarks by steering internal activations.
6
+ """
7
+
8
+ import logging
9
+ import traceback
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass
12
+ from typing import Any, Dict, List, Optional, Tuple
13
+
14
+ import torch
15
+ from tqdm import tqdm
16
+
17
+ from wisent_guard.core.activations.core import ActivationAggregationStrategy
18
+ from wisent_guard.core.classifier.classifier import Classifier
19
+ from wisent_guard.core.contrastive_pairs.contrastive_pair import ContrastivePair
20
+ from wisent_guard.core.contrastive_pairs.contrastive_pair_set import ContrastivePairSet
21
+ from wisent_guard.core.optuna.classifier import (
22
+ CacheConfig,
23
+ ClassifierCache,
24
+ ClassifierOptimizationConfig,
25
+ GenerationConfig,
26
+ OptunaClassifierOptimizer,
27
+ )
28
+ from wisent_guard.core.optuna.steering import data_utils, metrics
29
+ from wisent_guard.core.response import Response
30
+ from wisent_guard.core.steering_methods.dac import DAC
31
+ from wisent_guard.core.task_interface import get_task
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class SteeringMethodConfig(ABC):
38
+ """Base configuration for steering methods."""
39
+
40
+ method_name: str = "base"
41
+ layers: List[int] = None
42
+ strengths: List[float] = None
43
+
44
+ def __post_init__(self):
45
+ if self.layers is None:
46
+ self.layers = []
47
+ if self.strengths is None:
48
+ self.strengths = [1.0]
49
+
50
+
51
+ @dataclass
52
+ class DACConfig(SteeringMethodConfig):
53
+ """Configuration for DAC (Dynamic Activation Composition) steering method."""
54
+
55
+ method_name: str = "dac"
56
+ entropy_thresholds: List[float] = None
57
+ ptop_values: List[float] = None
58
+ max_alpha_values: List[float] = None
59
+
60
+ def __post_init__(self):
61
+ super().__post_init__()
62
+ if self.entropy_thresholds is None:
63
+ self.entropy_thresholds = [1.0]
64
+ if self.ptop_values is None:
65
+ self.ptop_values = [0.4]
66
+ if self.max_alpha_values is None:
67
+ self.max_alpha_values = [2.0]
68
+
69
+
70
+ @dataclass
71
+ class SteeringResult:
72
+ """Results from training and evaluating a steering method configuration."""
73
+
74
+ method_name: str
75
+ layer: int
76
+ hyperparameters: Dict[str, Any]
77
+ benchmark_metrics: Dict[str, float]
78
+ training_success: bool
79
+ training_stats: Dict[str, Any] = None
80
+ baseline_metrics: Dict[str, float] = None
81
+ comparative_metrics: Dict[str, Any] = None
82
+
83
+
84
+ class SteeringMethodTrainer(ABC):
85
+ """Abstract base class for training different steering methods."""
86
+
87
+ @abstractmethod
88
+ def create_method_instance(self, hyperparams: Dict[str, Any], device: str) -> Any:
89
+ """Create an instance of the steering method with given hyperparameters."""
90
+
91
+ @abstractmethod
92
+ def train_method(
93
+ self,
94
+ method_instance: Any,
95
+ train_samples: List[Dict],
96
+ layer: int,
97
+ model,
98
+ tokenizer,
99
+ device: str,
100
+ task_name: str = "gsm8k",
101
+ max_new_tokens: int = 200,
102
+ ) -> Tuple[bool, Dict[str, Any]]:
103
+ """Train the steering method on training data."""
104
+
105
+ @abstractmethod
106
+ def apply_steering_and_evaluate(
107
+ self,
108
+ method_instance: Any,
109
+ evaluation_samples: List[Dict],
110
+ layer: int,
111
+ strength: float,
112
+ model,
113
+ tokenizer,
114
+ device: str,
115
+ batch_size: int,
116
+ max_length: int,
117
+ task_name: str = "gsm8k",
118
+ max_new_tokens: int = 200,
119
+ ) -> Tuple[List[str], List[str]]:
120
+ """Apply steering and generate predictions for evaluation."""
121
+
122
+
123
+ class DACTrainer(SteeringMethodTrainer):
124
+ """Trainer for DAC (Dynamic Activation Composition) steering method."""
125
+
126
+ def create_method_instance(self, hyperparams: Dict[str, Any], device: str) -> DAC:
127
+ """Create DAC instance with specified hyperparameters."""
128
+ return DAC(
129
+ device=device,
130
+ dynamic_control=True,
131
+ entropy_threshold=hyperparams.get("entropy_threshold", 1.0),
132
+ ptop=hyperparams.get("ptop", 0.4),
133
+ max_alpha=hyperparams.get("max_alpha", 2.0),
134
+ )
135
+
136
+ def train_method(
137
+ self,
138
+ dac_instance: DAC,
139
+ train_samples: List[Dict],
140
+ layer: int,
141
+ model,
142
+ tokenizer,
143
+ device: str,
144
+ task_name: str = "gsm8k",
145
+ max_new_tokens: int = 200,
146
+ ) -> Tuple[bool, Dict[str, Any]]:
147
+ """Train DAC on training data to create steering vectors."""
148
+ try:
149
+ # Set model reference for KL computation
150
+ dac_instance.set_model_reference(model)
151
+
152
+ # Extract contrastive pairs from training data using task's extractor
153
+ contrastive_pairs = data_utils.get_task_contrastive_pairs(train_samples, task_name)
154
+
155
+ if not contrastive_pairs:
156
+ logger.warning(f"No contrastive pairs extracted from {task_name} training data")
157
+ return False, {"error": "No contrastive pairs"}
158
+
159
+ # Convert to ContrastivePairSet format
160
+ pair_set = self._create_pair_set_from_extracted_pairs(contrastive_pairs, layer, model, tokenizer, device)
161
+
162
+ # Train DAC
163
+ training_result = dac_instance.train(pair_set, layer)
164
+
165
+ success = training_result.get("success", False)
166
+ logger.debug(f"DAC training on layer {layer}: {'Success' if success else 'Failed'}")
167
+
168
+ return success, training_result
169
+
170
+ except Exception as e:
171
+ logger.error(f"DAC training failed on layer {layer}: {e}")
172
+ return False, {"error": str(e)}
173
+
174
+ def apply_steering_and_evaluate(
175
+ self,
176
+ dac_instance: DAC,
177
+ evaluation_samples: List[Dict],
178
+ layer: int,
179
+ strength: float,
180
+ model,
181
+ tokenizer,
182
+ device: str,
183
+ batch_size: int,
184
+ max_length: int,
185
+ task_name: str = "gsm8k",
186
+ max_new_tokens: int = 200,
187
+ ) -> Tuple[List[str], List[str]]:
188
+ """Apply DAC steering and generate predictions using task extractor."""
189
+
190
+ predictions = []
191
+ ground_truths = []
192
+
193
+ # Get the task and its extractor
194
+ task = get_task(task_name)
195
+ extractor = task.get_extractor()
196
+
197
+ # Pre-extract all questions and answers (optimization)
198
+ questions = []
199
+ answers = []
200
+
201
+ for sample in evaluation_samples:
202
+ qa_pair = extractor.extract_qa_pair(sample, task)
203
+ if not qa_pair:
204
+ logger.warning(f"Skipping sample - extractor couldn't extract QA pair: {sample.keys()}")
205
+ continue
206
+ questions.append(qa_pair["formatted_question"])
207
+ answers.append(qa_pair["correct_answer"])
208
+
209
+ # Process questions with steering in batches (optimized approach)
210
+ ground_truths.extend(answers)
211
+
212
+ # Handle different model architectures
213
+ if hasattr(model, "model") and hasattr(model.model, "layers"):
214
+ # LLaMA-style models
215
+ layer_module = model.model.layers[layer]
216
+ elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
217
+ # GPT2-style models
218
+ layer_module = model.transformer.h[layer]
219
+ else:
220
+ raise ValueError("Unsupported model architecture for DAC steering")
221
+
222
+ # Process in batches with steering
223
+ for i in tqdm(range(0, len(questions), batch_size), desc="Generating predictions with steering"):
224
+ batch_questions = questions[i : i + batch_size]
225
+
226
+ # First, get actual lengths (before padding) for proper steering
227
+ actual_lengths = []
228
+ for question in batch_questions:
229
+ tokens = tokenizer(question, return_tensors="pt")
230
+ actual_lengths.append(tokens["input_ids"].shape[1])
231
+
232
+ # Create batched steering hook that handles variable lengths
233
+ def create_batched_steering_hook(actual_lengths):
234
+ def steering_hook(module, input, output):
235
+ hidden_states = output[0] # [batch_size, seq_len, hidden_dim]
236
+
237
+ # Apply steering to each sample's actual last token
238
+ for j, actual_length in enumerate(actual_lengths):
239
+ if j < hidden_states.shape[0]: # Safety check for batch size
240
+ # Get the actual last token (before padding)
241
+ last_token = hidden_states[j : j + 1, actual_length - 1 : actual_length, :]
242
+ steered = dac_instance.apply_steering(last_token, strength=strength)
243
+ hidden_states[j : j + 1, actual_length - 1 : actual_length, :] = steered
244
+
245
+ return (hidden_states,) + output[1:]
246
+
247
+ return steering_hook
248
+
249
+ # Register the batched hook
250
+ batched_hook = create_batched_steering_hook(actual_lengths)
251
+ handle = layer_module.register_forward_hook(batched_hook)
252
+
253
+ try:
254
+ # Tokenize batch with padding for generation
255
+ inputs = tokenizer(
256
+ batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=max_length
257
+ ).to(device)
258
+
259
+ with torch.no_grad():
260
+ outputs = model.generate(
261
+ **inputs,
262
+ max_new_tokens=max_new_tokens,
263
+ do_sample=True,
264
+ temperature=0.7,
265
+ pad_token_id=tokenizer.eos_token_id,
266
+ use_cache=False, # Disable cache to avoid cache_position errors
267
+ )
268
+
269
+ # Decode responses for each item in batch
270
+ for j, (output, question) in enumerate(zip(outputs, batch_questions)):
271
+ response = tokenizer.decode(output, skip_special_tokens=True)
272
+ prediction = response[len(question) :].strip()
273
+ predictions.append(prediction)
274
+
275
+ finally:
276
+ handle.remove()
277
+
278
+ return predictions, ground_truths
279
+
280
+ def _create_pair_set_from_extracted_pairs(
281
+ self, extracted_pairs: List[Dict], layer_index: int, model, tokenizer, device: str
282
+ ) -> ContrastivePairSet:
283
+ """Convert extracted pairs to ContrastivePairSet format with proper activation extraction."""
284
+ pair_set = ContrastivePairSet(name="dac_training", task_type="mathematical_reasoning")
285
+
286
+ logger.info(f"Creating {len(extracted_pairs)} contrastive pairs for layer {layer_index}")
287
+
288
+ for pair_data in tqdm(extracted_pairs, desc="Creating contrastive pairs"):
289
+ # Extract data from GSM8K format
290
+ try:
291
+ question = pair_data["question"]
292
+ correct_answer = pair_data["correct_answer"]
293
+ incorrect_answer = pair_data["incorrect_answer"]
294
+
295
+ # Extract activations for correct and incorrect responses
296
+ correct_activations = self._extract_activations_for_text(
297
+ f"{question} {correct_answer}", layer_index, model, tokenizer, device
298
+ )
299
+ incorrect_activations = self._extract_activations_for_text(
300
+ f"{question} {incorrect_answer}", layer_index, model, tokenizer, device
301
+ )
302
+
303
+ # Create Response objects
304
+ positive_response = Response(text=correct_answer, activations=correct_activations)
305
+ negative_response = Response(text=incorrect_answer, activations=incorrect_activations)
306
+
307
+ # Create ContrastivePair
308
+ contrastive_pair = ContrastivePair(
309
+ prompt=question, positive_response=positive_response, negative_response=negative_response
310
+ )
311
+
312
+ pair_set.pairs.append(contrastive_pair)
313
+
314
+ except Exception as e:
315
+ logger.warning(f"Failed to create contrastive pair: {e}")
316
+ continue
317
+
318
+ logger.info(f"Successfully created ContrastivePairSet with {len(pair_set.pairs)} pairs")
319
+ return pair_set
320
+
321
+ def _extract_activations_for_text(self, text: str, layer_index: int, model, tokenizer, device: str) -> torch.Tensor:
322
+ """Extract activations from a specific layer for given text."""
323
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128).to(device)
324
+
325
+ activations = []
326
+
327
+ def hook(module, input, output):
328
+ # Extract the last token's activations
329
+ hidden_states = output[0]
330
+ last_token_activations = hidden_states[:, -1, :]
331
+ activations.append(last_token_activations.detach().cpu())
332
+
333
+ # Handle different model architectures
334
+ if hasattr(model, "model") and hasattr(model.model, "layers"):
335
+ # LLaMA-style models
336
+ layer_module = model.model.layers[layer_index]
337
+ elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
338
+ # GPT2-style models
339
+ layer_module = model.transformer.h[layer_index]
340
+ else:
341
+ raise ValueError("Unsupported model architecture for activation extraction")
342
+
343
+ handle = layer_module.register_forward_hook(hook)
344
+
345
+ with torch.no_grad():
346
+ model(**inputs)
347
+
348
+ handle.remove()
349
+ return activations[0].squeeze(0)
350
+
351
+
352
+ class SteeringOptimizer:
353
+ """
354
+ Optimizes steering methods for improving benchmark performance.
355
+
356
+ The steering optimization process:
357
+ 1. Train steering methods on training data
358
+ 2. Evaluate steering performance on validation data using benchmark metrics
359
+ 3. Select best configuration based on benchmark performance
360
+ 4. Test final steering method on test data
361
+ """
362
+
363
+ def __init__(self, cache_config: Optional[CacheConfig] = None):
364
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
365
+ self.trainers = {"dac": DACTrainer()}
366
+
367
+ # Initialize classifier cache for reusing trained classifiers
368
+ if cache_config is None:
369
+ cache_config = CacheConfig(cache_dir="./steering_classifier_cache")
370
+ self.classifier_cache = ClassifierCache(cache_config)
371
+
372
+ # Session-level classifier caching for current optimization run
373
+ self._session_classifier = None # Best classifier for current session
374
+ self._session_classifier_metadata = {} # Layer, model_type, performance, etc.
375
+ self._session_cache_key = None # Track current session
376
+
377
+ def register_trainer(self, method_name: str, trainer: SteeringMethodTrainer):
378
+ """Register a new steering method trainer."""
379
+ self.trainers[method_name] = trainer
380
+ self.logger.info(f"Registered trainer for steering method: {method_name}")
381
+
382
+ def optimize_steering_hyperparameters(
383
+ self,
384
+ config: SteeringMethodConfig,
385
+ classifier_optimization_config: ClassifierOptimizationConfig,
386
+ train_samples: List[Dict],
387
+ validation_samples: List[Dict],
388
+ model,
389
+ tokenizer,
390
+ device: str,
391
+ batch_size: int = 32,
392
+ max_length: int = 512,
393
+ task_name: str = "gsm8k",
394
+ max_new_tokens: int = 200,
395
+ ) -> Tuple[Dict[str, Any], List[SteeringResult]]:
396
+ """
397
+ Optimize hyperparameters for a steering method using grid search.
398
+
399
+ Args:
400
+ config: Steering method configuration with hyperparameter ranges
401
+ classifier_optimization_config: Configuration for classifier optimization
402
+ train_samples: Training samples for method training
403
+ validation_samples: Validation samples for evaluation
404
+ model: Language model
405
+ tokenizer: Model tokenizer
406
+ device: Device to run on
407
+ batch_size: Batch size for processing
408
+ max_length: Maximum sequence length
409
+ task_name: Task name for evaluation
410
+ max_new_tokens: Maximum tokens to generate
411
+
412
+ Returns:
413
+ Tuple of (best_config, all_results)
414
+ """
415
+ method_name = config.method_name
416
+
417
+ if method_name not in self.trainers:
418
+ raise ValueError(f"No trainer registered for method: {method_name}")
419
+
420
+ trainer = self.trainers[method_name]
421
+
422
+ # Load best classifier once at the start of optimization
423
+ self.logger.info("Loading/training classifier for evaluation...")
424
+ contrastive_pairs = data_utils.get_task_contrastive_pairs(train_samples, task_name)
425
+
426
+ classifier = self.load_or_find_best_classifier(
427
+ model=model, optimization_config=classifier_optimization_config, contrastive_pairs=contrastive_pairs
428
+ )
429
+
430
+ if classifier is None:
431
+ raise ValueError(
432
+ f"Could not load or train classifier for {classifier_optimization_config.model_name}/{task_name}"
433
+ )
434
+
435
+ self.logger.info(f"Using classifier: {self._session_classifier_metadata}")
436
+
437
+ # Collect baseline predictions once for all trials
438
+ self.logger.info("Collecting baseline predictions for comparison...")
439
+ baseline_predictions, ground_truths = self.collect_baseline_predictions(
440
+ validation_samples, model, tokenizer, classifier, device, batch_size, max_length, task_name, max_new_tokens
441
+ )
442
+
443
+ # Calculate baseline metrics with integrated classifier scoring
444
+ classifier_scorer = lambda predictions, description: self.score_predictions_with_classifier(
445
+ predictions, model, tokenizer, device, max_length, description
446
+ )
447
+ baseline_benchmark_metrics = metrics.evaluate_benchmark_performance(
448
+ baseline_predictions, ground_truths, task_name, classifier_scorer=classifier_scorer
449
+ )
450
+ self.logger.info(f"Baseline performance: {baseline_benchmark_metrics}")
451
+
452
+ # Generate all hyperparameter combinations
453
+ hyperparameter_combinations = self._generate_hyperparameter_combinations(config)
454
+
455
+ self.logger.info(f"Starting {method_name} optimization with {len(hyperparameter_combinations)} configurations")
456
+
457
+ best_config = None
458
+ best_score = -1
459
+ all_results = []
460
+
461
+ for i, (layer, strength, hyperparams) in enumerate(
462
+ tqdm(hyperparameter_combinations, desc="Optimizing steering hyperparameters")
463
+ ):
464
+ self.logger.debug(
465
+ f"Testing {method_name} config {i + 1}/{len(hyperparameter_combinations)}: "
466
+ f"layer={layer}, strength={strength}, hyperparams={hyperparams}"
467
+ )
468
+
469
+ try:
470
+ # Create method instance
471
+ method_instance = trainer.create_method_instance(hyperparams, device)
472
+
473
+ # Train the method
474
+ training_success, training_stats = trainer.train_method(
475
+ method_instance, train_samples, layer, model, tokenizer, device, task_name, max_new_tokens
476
+ )
477
+
478
+ if not training_success:
479
+ self.logger.warning(f"Training failed for config {i + 1}")
480
+ result = SteeringResult(
481
+ method_name=method_name,
482
+ layer=layer,
483
+ hyperparameters={**hyperparams, "strength": strength},
484
+ benchmark_metrics={"accuracy": 0.0},
485
+ training_success=False,
486
+ training_stats=training_stats,
487
+ )
488
+ all_results.append(result)
489
+ continue
490
+
491
+ # Evaluate on validation data with steering
492
+ steered_predictions, steered_ground_truths = trainer.apply_steering_and_evaluate(
493
+ method_instance,
494
+ validation_samples,
495
+ layer,
496
+ strength,
497
+ model,
498
+ tokenizer,
499
+ device,
500
+ batch_size,
501
+ max_length,
502
+ task_name,
503
+ max_new_tokens,
504
+ )
505
+
506
+ # Compare baseline vs steered predictions using enhanced metrics
507
+ enhanced_metrics = self.compare_predictions(
508
+ baseline_predictions,
509
+ steered_predictions,
510
+ ground_truths,
511
+ model,
512
+ tokenizer,
513
+ device,
514
+ max_length,
515
+ task_name,
516
+ )
517
+
518
+ # Extract steered metrics for compatibility
519
+ benchmark_metrics = enhanced_metrics["steered"]
520
+ baseline_metrics_for_result = enhanced_metrics["baseline"]
521
+ comparative_metrics = enhanced_metrics["improvement"]
522
+
523
+ result = SteeringResult(
524
+ method_name=method_name,
525
+ layer=layer,
526
+ hyperparameters={**hyperparams, "strength": strength},
527
+ benchmark_metrics=benchmark_metrics,
528
+ baseline_metrics=baseline_metrics_for_result,
529
+ comparative_metrics=comparative_metrics,
530
+ training_success=True,
531
+ training_stats=training_stats,
532
+ )
533
+ all_results.append(result)
534
+
535
+ # Standard Optuna practice: optimize steered accuracy directly
536
+ steered_accuracy = benchmark_metrics.get("accuracy", 0.0)
537
+ baseline_accuracy = baseline_metrics_for_result.get("accuracy", 0.0)
538
+ improvement_delta = steered_accuracy - baseline_accuracy
539
+
540
+ if steered_accuracy > best_score:
541
+ best_score = steered_accuracy
542
+ best_config = {
543
+ "method": method_name,
544
+ "layer": layer,
545
+ "strength": strength,
546
+ **hyperparams,
547
+ "benchmark_metrics": benchmark_metrics,
548
+ "baseline_metrics": baseline_metrics_for_result,
549
+ "method_instance": method_instance,
550
+ }
551
+
552
+ self.logger.debug(
553
+ f"Config {i + 1} - Baseline: {baseline_accuracy:.3f}, "
554
+ f"Steered: {steered_accuracy:.3f}, Delta: {improvement_delta:+.3f}"
555
+ )
556
+
557
+ except Exception as e:
558
+ self.logger.error(f"Failed to evaluate config {i + 1}: {e}")
559
+ result = SteeringResult(
560
+ method_name=method_name,
561
+ layer=layer,
562
+ hyperparameters={**hyperparams, "strength": strength},
563
+ benchmark_metrics={"accuracy": 0.0},
564
+ baseline_metrics=baseline_benchmark_metrics,
565
+ comparative_metrics={"accuracy_delta": 0.0, "improvement_rate": 0.0},
566
+ training_success=False,
567
+ training_stats={"error": str(e)},
568
+ )
569
+ all_results.append(result)
570
+ continue
571
+
572
+ if best_config is None:
573
+ self.logger.warning("No successful steering configuration found")
574
+ # Return a default configuration
575
+ best_config = {
576
+ "method": method_name,
577
+ "layer": config.layers[0] if config.layers else 0,
578
+ "strength": config.strengths[0] if config.strengths else 1.0,
579
+ "benchmark_metrics": {"accuracy": 0.0},
580
+ "method_instance": None,
581
+ }
582
+ else:
583
+ steered_acc = best_config["benchmark_metrics"]["accuracy"]
584
+ baseline_acc = best_config.get("baseline_metrics", {}).get("accuracy", 0.0)
585
+ improvement = steered_acc - baseline_acc
586
+
587
+ self.logger.info(
588
+ f"Best {method_name} config (optimized for steered accuracy): "
589
+ f"layer={best_config['layer']}, steered={steered_acc:.3f} "
590
+ f"(baseline={baseline_acc:.3f}, Δ={improvement:+.3f})"
591
+ )
592
+
593
+ return best_config, all_results
594
+
595
+ def _generate_hyperparameter_combinations(
596
+ self, config: SteeringMethodConfig
597
+ ) -> List[Tuple[int, float, Dict[str, Any]]]:
598
+ """Generate all combinations of hyperparameters for grid search."""
599
+ combinations = []
600
+
601
+ if isinstance(config, DACConfig):
602
+ # Generate DAC hyperparameter combinations
603
+ for layer in config.layers:
604
+ for strength in config.strengths:
605
+ for entropy_threshold in config.entropy_thresholds:
606
+ for ptop in config.ptop_values:
607
+ for max_alpha in config.max_alpha_values:
608
+ hyperparams = {
609
+ "entropy_threshold": entropy_threshold,
610
+ "ptop": ptop,
611
+ "max_alpha": max_alpha,
612
+ }
613
+ combinations.append((layer, strength, hyperparams))
614
+ else:
615
+ # Generic handling for other steering methods
616
+ for layer in config.layers:
617
+ for strength in config.strengths:
618
+ combinations.append((layer, strength, {}))
619
+
620
+ return combinations
621
+
622
+ def collect_baseline_predictions(
623
+ self,
624
+ evaluation_samples: List[Dict],
625
+ model,
626
+ tokenizer,
627
+ classifier: Classifier,
628
+ device: str,
629
+ batch_size: int,
630
+ max_length: int,
631
+ task_name: str,
632
+ max_new_tokens: int = 200,
633
+ ) -> Tuple[List[str], List[str]]:
634
+ """
635
+ Collect unsteered model predictions for baseline comparison.
636
+ Uses the same evaluation logic as steered evaluation but without steering hooks.
637
+
638
+ Args:
639
+ evaluation_samples: Samples to evaluate
640
+ model: Language model
641
+ tokenizer: Model tokenizer
642
+ classifier: Trained classifier for evaluation
643
+ device: Device to run on
644
+ batch_size: Batch size for processing
645
+ max_length: Maximum sequence length
646
+ task_name: Task name for evaluation
647
+ max_new_tokens: Maximum tokens to generate
648
+
649
+ Returns:
650
+ Tuple of (predictions, ground_truths)
651
+ """
652
+ predictions = []
653
+ ground_truths = []
654
+
655
+ # Get the task and its extractor
656
+ task = get_task(task_name)
657
+ extractor = task.get_extractor()
658
+
659
+ # Pre-extract all questions and answers (optimization)
660
+ questions = []
661
+ answers = []
662
+
663
+ for sample in evaluation_samples:
664
+ qa_pair = extractor.extract_qa_pair(sample, task)
665
+ if not qa_pair:
666
+ self.logger.warning(f"Skipping sample - extractor couldn't extract QA pair: {sample.keys()}")
667
+ continue
668
+ questions.append(qa_pair["formatted_question"])
669
+ answers.append(qa_pair["correct_answer"])
670
+
671
+ # Process questions WITHOUT steering in batches
672
+ ground_truths.extend(answers)
673
+
674
+ # Process in batches without steering
675
+ for i in tqdm(range(0, len(questions), batch_size), desc="Generating baseline predictions"):
676
+ batch_questions = questions[i : i + batch_size]
677
+
678
+ # Tokenize batch with padding for generation
679
+ inputs = tokenizer(
680
+ batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=max_length
681
+ ).to(device)
682
+
683
+ with torch.no_grad():
684
+ outputs = model.generate(
685
+ **inputs,
686
+ max_new_tokens=max_new_tokens,
687
+ do_sample=True,
688
+ temperature=0.7,
689
+ pad_token_id=tokenizer.eos_token_id,
690
+ use_cache=False, # Disable cache to avoid cache_position errors
691
+ )
692
+
693
+ # Decode responses for each item in batch
694
+ for j, (output, question) in enumerate(zip(outputs, batch_questions)):
695
+ response = tokenizer.decode(output, skip_special_tokens=True)
696
+ prediction = response[len(question) :].strip()
697
+ predictions.append(prediction)
698
+
699
+ return predictions, ground_truths
700
+
701
+ def _extract_activation_for_text(
702
+ self,
703
+ text: str,
704
+ layer_index: int,
705
+ aggregation_strategy: str,
706
+ model,
707
+ tokenizer,
708
+ device: str,
709
+ max_length: int = 512,
710
+ ) -> torch.Tensor:
711
+ """
712
+ Extract activation from text at specified layer with aggregation.
713
+
714
+ Args:
715
+ text: Input text to extract activation from
716
+ layer_index: Layer index to extract from
717
+ aggregation_strategy: Aggregation strategy string (e.g., "mean_pooling")
718
+ model: Language model
719
+ tokenizer: Model tokenizer
720
+ device: Device to run on
721
+ max_length: Maximum sequence length
722
+
723
+ Returns:
724
+ Aggregated activation tensor
725
+ """
726
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
727
+ activations = []
728
+
729
+ def hook(module, input, output):
730
+ # Extract hidden states from the layer
731
+ hidden_states = output[0] if isinstance(output, tuple) else output
732
+ activations.append(hidden_states.detach().cpu())
733
+
734
+ # Handle different model architectures
735
+ if hasattr(model, "model") and hasattr(model.model, "layers"):
736
+ # LLaMA-style models
737
+ layer_module = model.model.layers[layer_index]
738
+ elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
739
+ # GPT2-style models
740
+ layer_module = model.transformer.h[layer_index]
741
+ else:
742
+ raise ValueError("Unsupported model architecture for activation extraction")
743
+
744
+ # Register hook and run forward pass
745
+ handle = layer_module.register_forward_hook(hook)
746
+ try:
747
+ with torch.no_grad():
748
+ _ = model(**inputs)
749
+ finally:
750
+ handle.remove()
751
+
752
+ if not activations:
753
+ raise ValueError("No activations extracted")
754
+
755
+ # Get the activation tensor [1, seq_len, hidden_dim]
756
+ activation_tensor = activations[0]
757
+
758
+ # Apply aggregation strategy
759
+ if (
760
+ aggregation_strategy == "mean_pooling"
761
+ or aggregation_strategy == ActivationAggregationStrategy.MEAN_POOLING.value
762
+ ):
763
+ aggregated = torch.mean(activation_tensor, dim=1) # [1, hidden_dim]
764
+ elif (
765
+ aggregation_strategy == "last_token"
766
+ or aggregation_strategy == ActivationAggregationStrategy.LAST_TOKEN.value
767
+ ):
768
+ aggregated = activation_tensor[:, -1, :] # [1, hidden_dim]
769
+ elif (
770
+ aggregation_strategy == "first_token"
771
+ or aggregation_strategy == ActivationAggregationStrategy.FIRST_TOKEN.value
772
+ ):
773
+ aggregated = activation_tensor[:, 0, :] # [1, hidden_dim]
774
+ elif (
775
+ aggregation_strategy == "max_pooling"
776
+ or aggregation_strategy == ActivationAggregationStrategy.MAX_POOLING.value
777
+ ):
778
+ aggregated = torch.max(activation_tensor, dim=1)[0] # [1, hidden_dim]
779
+ else:
780
+ # Default to mean pooling if unknown
781
+ self.logger.warning(f"Unknown aggregation strategy {aggregation_strategy}, using mean pooling")
782
+ aggregated = torch.mean(activation_tensor, dim=1)
783
+
784
+ return aggregated.squeeze(0) # Return [hidden_dim] tensor
785
+
786
+ def score_predictions_with_classifier(
787
+ self,
788
+ predictions: List[str],
789
+ model,
790
+ tokenizer,
791
+ device: str,
792
+ max_length: int = 512,
793
+ description: str = "predictions",
794
+ ) -> List[float]:
795
+ """
796
+ Score predictions using the cached classifier.
797
+
798
+ This is the core feature that was requested - using the optimized classifier
799
+ to score unsteered vs steered generations.
800
+
801
+ Args:
802
+ predictions: Text predictions to score
803
+ model: Language model for activation extraction
804
+ tokenizer: Model tokenizer
805
+ device: Device to run on
806
+ max_length: Maximum sequence length
807
+ description: Description for logging
808
+
809
+ Returns:
810
+ List of classifier scores/probabilities for each prediction
811
+ """
812
+ if self._session_classifier is None:
813
+ self.logger.warning("No cached classifier available for scoring")
814
+ return [0.5] * len(predictions) # Return neutral scores
815
+
816
+ if not predictions:
817
+ self.logger.debug("No predictions to score")
818
+ return []
819
+
820
+ # Get classifier metadata
821
+ layer = self._session_classifier_metadata.get("layer", 12)
822
+ aggregation = self._session_classifier_metadata.get("aggregation", "mean_pooling")
823
+
824
+ self.logger.info(
825
+ f"Scoring {len(predictions)} {description} with cached classifier (layer={layer}, aggregation={aggregation})"
826
+ )
827
+
828
+ confidence_scores = []
829
+
830
+ # Process predictions in batches for efficiency
831
+ batch_size = 8 # Smaller batch size to avoid OOM
832
+ for i in range(0, len(predictions), batch_size):
833
+ batch_predictions = predictions[i : i + batch_size]
834
+ batch_activations = []
835
+
836
+ # Extract activations for each prediction in the batch
837
+ for pred_text in batch_predictions:
838
+ try:
839
+ # Extract activation for this prediction text
840
+ activation = self._extract_activation_for_text(
841
+ text=pred_text,
842
+ layer_index=layer,
843
+ aggregation_strategy=aggregation,
844
+ model=model,
845
+ tokenizer=tokenizer,
846
+ device=device,
847
+ max_length=max_length,
848
+ )
849
+ batch_activations.append(activation)
850
+
851
+ except Exception as e:
852
+ self.logger.debug(f"Failed to extract activation for prediction: {e}")
853
+ # Use neutral score for failed extractions
854
+ confidence_scores.append(0.5)
855
+ continue
856
+
857
+ if batch_activations:
858
+ try:
859
+ # Stack activations into batch tensor
860
+ batch_tensor = torch.stack(batch_activations)
861
+
862
+ # Convert to numpy for sklearn classifier
863
+ batch_numpy = batch_tensor.detach().cpu().numpy()
864
+
865
+ # Get prediction probabilities from classifier
866
+ probabilities = self._session_classifier.predict_proba(batch_numpy)
867
+
868
+ # Extract confidence scores (probability for positive class)
869
+ # Assuming binary classification with class 1 as positive
870
+ if probabilities.shape[1] > 1:
871
+ batch_scores = probabilities[:, 1].tolist() # Probability of positive class
872
+ else:
873
+ batch_scores = probabilities[:, 0].tolist() # Single class probability
874
+
875
+ confidence_scores.extend(batch_scores)
876
+
877
+ except Exception as e:
878
+ self.logger.warning(f"Failed to score batch of activations: {e}")
879
+ # Add neutral scores for failed batch
880
+ confidence_scores.extend([0.5] * len(batch_activations))
881
+
882
+ # Ensure we have scores for all predictions
883
+ while len(confidence_scores) < len(predictions):
884
+ confidence_scores.append(0.5) # Pad with neutral scores if needed
885
+
886
+ # Truncate if we have too many scores (shouldn't happen)
887
+ confidence_scores = confidence_scores[: len(predictions)]
888
+
889
+ # Log statistics
890
+ avg_score = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0.5
891
+ self.logger.debug(
892
+ f"Generated {len(confidence_scores)} classifier confidence scores for {description} (avg={avg_score:.3f})"
893
+ )
894
+
895
+ return confidence_scores
896
+
897
+ def compare_predictions(
898
+ self,
899
+ baseline_predictions: List[str],
900
+ steered_predictions: List[str],
901
+ ground_truths: List[str],
902
+ model,
903
+ tokenizer,
904
+ device: str,
905
+ max_length: int = 512,
906
+ task_name: str = "gsm8k",
907
+ ) -> Dict[str, Any]:
908
+ """
909
+ Compare baseline vs steered predictions using benchmark metrics and classifier scores.
910
+
911
+ Args:
912
+ baseline_predictions: Unsteered model predictions
913
+ steered_predictions: Steered model predictions
914
+ ground_truths: Ground truth answers
915
+ model: Language model for classifier scoring
916
+ tokenizer: Model tokenizer
917
+ device: Device to run on
918
+ max_length: Maximum sequence length
919
+ task_name: Task name for evaluation metrics
920
+
921
+ Returns:
922
+ Enhanced metrics with baseline vs steered comparison including classifier scores
923
+ """
924
+ # Create classifier scorer function for metrics integration
925
+ classifier_scorer = lambda predictions, description: self.score_predictions_with_classifier(
926
+ predictions, model, tokenizer, device, max_length, description
927
+ )
928
+
929
+ # Calculate standard benchmark metrics with integrated classifier confidence scores
930
+ baseline_metrics = metrics.evaluate_benchmark_performance(
931
+ baseline_predictions, ground_truths, task_name, classifier_scorer=classifier_scorer
932
+ )
933
+ steered_metrics = metrics.evaluate_benchmark_performance(
934
+ steered_predictions, ground_truths, task_name, classifier_scorer=classifier_scorer
935
+ )
936
+
937
+ # Extract classifier scores from integrated metrics
938
+ baseline_scores = [
939
+ detail.get("classifier_confidence", 0.5) for detail in baseline_metrics.get("evaluation_details", [])
940
+ ]
941
+ steered_scores = [
942
+ detail.get("classifier_confidence", 0.5) for detail in steered_metrics.get("evaluation_details", [])
943
+ ]
944
+
945
+ # Calculate improvement metrics
946
+ accuracy_delta = steered_metrics.get("accuracy", 0) - baseline_metrics.get("accuracy", 0)
947
+ f1_delta = steered_metrics.get("f1", 0) - baseline_metrics.get("f1", 0)
948
+
949
+ # Calculate classifier score improvements
950
+ avg_baseline_score = sum(baseline_scores) / len(baseline_scores) if baseline_scores else 0.0
951
+ avg_steered_score = sum(steered_scores) / len(steered_scores) if steered_scores else 0.0
952
+ classifier_score_delta = avg_steered_score - avg_baseline_score
953
+
954
+ return {
955
+ "baseline": {
956
+ "accuracy": baseline_metrics.get("accuracy", 0.0),
957
+ "f1": baseline_metrics.get("f1", 0.0),
958
+ "classifier_scores": baseline_scores,
959
+ "avg_classifier_score": avg_baseline_score,
960
+ "predictions": baseline_predictions,
961
+ },
962
+ "steered": {
963
+ "accuracy": steered_metrics.get("accuracy", 0.0),
964
+ "f1": steered_metrics.get("f1", 0.0),
965
+ "classifier_scores": steered_scores,
966
+ "avg_classifier_score": avg_steered_score,
967
+ "predictions": steered_predictions,
968
+ },
969
+ "improvement": {
970
+ "accuracy_delta": accuracy_delta,
971
+ "f1_delta": f1_delta,
972
+ "classifier_score_delta": classifier_score_delta,
973
+ },
974
+ }
975
+
976
+ def load_or_find_best_classifier(
977
+ self,
978
+ model,
979
+ optimization_config: Optional[ClassifierOptimizationConfig] = None,
980
+ model_name: Optional[str] = None,
981
+ task_name: Optional[str] = None,
982
+ contrastive_pairs: Optional[List] = None,
983
+ force_reoptimize: bool = False,
984
+ ) -> Optional[Classifier]:
985
+ """
986
+ Load or train the best classifier for current steering session.
987
+
988
+ On first call: Run full classifier optimization and cache result for session
989
+ On subsequent calls: Return cached classifier from current session
990
+
991
+ Args:
992
+ model: Language model (wisent_guard Model wrapper)
993
+ optimization_config: Primary configuration source
994
+ model_name: Fallback model name if optimization_config not provided
995
+ task_name: Fallback task name if optimization_config not provided
996
+ contrastive_pairs: Training data for classifier optimization
997
+ force_reoptimize: Force reoptimization even if session classifier exists
998
+
999
+ Returns:
1000
+ Best trained classifier or None if optimization failed
1001
+ """
1002
+ # Extract configuration
1003
+ if optimization_config is not None:
1004
+ model_name = optimization_config.model_name
1005
+ task_name = getattr(optimization_config, "task_name", task_name)
1006
+ limit = getattr(optimization_config, "data_limit", 100)
1007
+ else:
1008
+ limit = 100 # Default data limit
1009
+
1010
+ if not model_name or not task_name:
1011
+ raise ValueError("model_name and task_name must be provided either via optimization_config or directly")
1012
+
1013
+ # Create session cache key
1014
+ session_cache_key = f"{model_name}_{task_name}"
1015
+
1016
+ # Check if we already have a classifier for this session
1017
+ if (
1018
+ not force_reoptimize
1019
+ and self._session_classifier is not None
1020
+ and self._session_cache_key == session_cache_key
1021
+ ):
1022
+ self.logger.info("Using cached classifier from current session")
1023
+ return self._session_classifier
1024
+
1025
+ # First call or forced reoptimization - run classifier optimization
1026
+ self.logger.info("Running classifier optimization (first trial in session)")
1027
+
1028
+ if not contrastive_pairs:
1029
+ self.logger.error("contrastive_pairs required for classifier optimization")
1030
+ return None
1031
+
1032
+ try:
1033
+ # Create configuration for classifier optimization if not provided
1034
+ if optimization_config is None:
1035
+ optimization_config = ClassifierOptimizationConfig(
1036
+ model_name=model_name,
1037
+ device="auto",
1038
+ n_trials=20, # Reasonable number for steering optimization
1039
+ model_types=["logistic", "mlp"],
1040
+ primary_metric="f1",
1041
+ )
1042
+
1043
+ # Create generation config for activation pre-generation
1044
+ generation_config = GenerationConfig(
1045
+ layer_search_range=(0, 23), # Will be auto-detected from model
1046
+ aggregation_methods=[
1047
+ ActivationAggregationStrategy.MEAN_POOLING,
1048
+ ActivationAggregationStrategy.LAST_TOKEN,
1049
+ ActivationAggregationStrategy.FIRST_TOKEN,
1050
+ ActivationAggregationStrategy.MAX_POOLING,
1051
+ ],
1052
+ cache_dir="./cache/steering_activations",
1053
+ device=optimization_config.device,
1054
+ batch_size=32,
1055
+ )
1056
+
1057
+ # Create classifier optimizer
1058
+ classifier_optimizer = OptunaClassifierOptimizer(
1059
+ optimization_config=optimization_config,
1060
+ generation_config=generation_config,
1061
+ cache_config=self.classifier_cache.config,
1062
+ )
1063
+
1064
+ # Run classifier optimization
1065
+ self.logger.info(f"Optimizing classifier for {model_name}/{task_name} with {len(contrastive_pairs)} pairs")
1066
+ result = classifier_optimizer.optimize(
1067
+ model=model,
1068
+ contrastive_pairs=contrastive_pairs,
1069
+ task_name=task_name,
1070
+ model_name=model_name,
1071
+ limit=limit,
1072
+ )
1073
+
1074
+ if result.best_value > 0:
1075
+ # Get the best configuration and classifier
1076
+ best_config = result.get_best_config()
1077
+ best_classifier = result.best_classifier
1078
+
1079
+ # Cache for current session
1080
+ self._session_classifier = best_classifier
1081
+ self._session_classifier_metadata = {
1082
+ "layer": best_config["layer"],
1083
+ "aggregation": best_config["aggregation"],
1084
+ "model_type": best_config["model_type"],
1085
+ "threshold": best_config["threshold"],
1086
+ "f1_score": result.best_value,
1087
+ "hyperparameters": best_config.get("hyperparameters", {}),
1088
+ }
1089
+ self._session_cache_key = session_cache_key
1090
+
1091
+ self.logger.info(
1092
+ f"Cached best classifier for session: layer_{best_config['layer']} "
1093
+ f"{best_config['model_type']} (F1: {result.best_value:.3f})"
1094
+ )
1095
+
1096
+ return best_classifier
1097
+ self.logger.warning("Classifier optimization failed - no successful trials")
1098
+ return None
1099
+
1100
+ except Exception as e:
1101
+ self.logger.error(f"Failed to run classifier optimization: {e}")
1102
+ traceback.print_exc()
1103
+ return None
1104
+
1105
+ def get_cache_info(self) -> Dict[str, Any]:
1106
+ """Get information about cached classifiers."""
1107
+ return self.classifier_cache.get_cache_info()
1108
+
1109
+ def clear_classifier_cache(self, keep_recent_hours: float = 24.0) -> int:
1110
+ """Clear old cached classifiers."""
1111
+ return self.classifier_cache.clear_cache(keep_recent_hours=keep_recent_hours)