wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ """Multi-steering functionality for combining multiple steering vectors."""
2
+
3
+ import sys
4
+ import torch
5
+ from typing import List, Tuple, Optional, Dict, Any
6
+ from pathlib import Path
7
+
8
+ from .layer import Layer
9
+ from .model import Model
10
+ from .steering_methods.caa import CAA
11
+ from .steering_methods.dac import DAC
12
+ from .utils.device import resolve_default_device
13
+
14
+
15
+ class MultiSteeringError(Exception):
16
+ """Exception raised for multi-steering errors."""
17
+ pass
18
+
19
+
20
+ class MultiSteering:
21
+ """Handles multi-steering vector combination and application."""
22
+
23
+ def __init__(self, device: str | None = None, method: str = "CAA"):
24
+ """Initialize multi-steering handler.
25
+
26
+ Args:
27
+ device: Device to use for computations (cpu/cuda/mps)
28
+ method: Steering method to use for combination ("CAA" or "DAC")
29
+ """
30
+ self.device = device or resolve_default_device()
31
+ self.method = method
32
+ self.loaded_vectors = []
33
+ self.weights = []
34
+ self.combined_vector = None
35
+ self.layer = None
36
+
37
+ def load_vectors(self, vector_specs: List[str]) -> None:
38
+ """Load and validate steering vectors from file paths.
39
+
40
+ Args:
41
+ vector_specs: List of "path:weight" specifications
42
+
43
+ Raises:
44
+ MultiSteeringError: If vectors cannot be loaded or are incompatible
45
+ """
46
+ if not vector_specs:
47
+ raise MultiSteeringError("No vectors specified")
48
+
49
+ self.loaded_vectors = []
50
+ self.weights = []
51
+ layers_found = set()
52
+
53
+ for spec in vector_specs:
54
+ parts = spec.split(":")
55
+ if len(parts) != 2:
56
+ raise MultiSteeringError(f"Invalid vector specification: {spec}. Expected format: path:weight")
57
+
58
+ vector_path = parts[0]
59
+ try:
60
+ weight = float(parts[1])
61
+ except ValueError:
62
+ raise MultiSteeringError(f"Invalid weight in {spec}. Must be a number.")
63
+
64
+ if not Path(vector_path).exists():
65
+ raise MultiSteeringError(f"Vector file not found: {vector_path}")
66
+
67
+ print(f"Loading vector from {vector_path} with weight {weight}")
68
+
69
+ try:
70
+ vector_data = torch.load(vector_path, map_location=self.device)
71
+ except Exception as e:
72
+ raise MultiSteeringError(f"Failed to load vector from {vector_path}: {e}")
73
+
74
+ # Extract metadata from loaded vector
75
+ if isinstance(vector_data, dict):
76
+ layer = vector_data.get("layer_index", None)
77
+ steering_vector = vector_data.get("steering_vector", None)
78
+
79
+ if steering_vector is None:
80
+ raise MultiSteeringError(f"No steering vector found in {vector_path}")
81
+
82
+ if layer is not None:
83
+ layers_found.add(layer)
84
+
85
+ self.loaded_vectors.append(vector_data)
86
+ self.weights.append(weight)
87
+
88
+ print(f" āœ“ Loaded vector from layer {layer}")
89
+ else:
90
+ raise MultiSteeringError(f"Invalid vector format in {vector_path}")
91
+
92
+ # Validate compatibility
93
+ if len(layers_found) > 1:
94
+ raise MultiSteeringError(f"Vectors from different layers cannot be combined: {layers_found}")
95
+
96
+ if not layers_found:
97
+ raise MultiSteeringError("No layer information found in vectors")
98
+
99
+ self.layer = Layer(list(layers_found)[0])
100
+
101
+ print(f"\nUsing {self.method} method for vector combination")
102
+ print(f"Target layer: {self.layer.index}")
103
+
104
+ def combine_vectors(self, normalize: bool = True) -> torch.Tensor:
105
+ """Combine loaded vectors using appropriate method.
106
+
107
+ Args:
108
+ normalize: Whether to normalize the combined vector
109
+
110
+ Returns:
111
+ Combined steering vector as tensor
112
+
113
+ Raises:
114
+ MultiSteeringError: If combination fails
115
+ """
116
+ if not self.loaded_vectors:
117
+ raise MultiSteeringError("No vectors loaded")
118
+
119
+ print(f"\nšŸ”„ Combining {len(self.loaded_vectors)} vectors using {self.method}")
120
+
121
+ if self.method == "CAA":
122
+ # Create a CAA instance and use its proper combination method
123
+ caa = CAA(device=self.device)
124
+
125
+ # Set up behavior vectors dictionary
126
+ caa.behavior_vectors = {}
127
+ for i, (vector_data, weight) in enumerate(zip(self.loaded_vectors, self.weights)):
128
+ steering_vector = vector_data["steering_vector"]
129
+
130
+ if not isinstance(steering_vector, torch.Tensor):
131
+ steering_vector = torch.tensor(steering_vector, device=self.device)
132
+ else:
133
+ steering_vector = steering_vector.to(self.device)
134
+
135
+ # Store with unique names
136
+ behavior_name = f"vector_{i}"
137
+ caa.behavior_vectors[behavior_name] = steering_vector
138
+
139
+ # Create weights dictionary
140
+ behavior_weights = {f"vector_{i}": weight for i, weight in enumerate(self.weights)}
141
+
142
+ # Use CAA's combine_behaviors method with normalization
143
+ self.combined_vector = caa.combine_behaviors(behavior_weights, normalize_result=normalize)
144
+
145
+ else: # DAC or mixed methods
146
+ # For DAC, use its combine_steering_vectors method
147
+ vectors = []
148
+ for vector_data in self.loaded_vectors:
149
+ steering_vector = vector_data["steering_vector"]
150
+
151
+ if not isinstance(steering_vector, torch.Tensor):
152
+ steering_vector = torch.tensor(steering_vector, device=self.device)
153
+ else:
154
+ steering_vector = steering_vector.to(self.device)
155
+
156
+ vectors.append(steering_vector)
157
+
158
+ # Use DAC's static method for combination
159
+ self.combined_vector = DAC.combine_steering_vectors(
160
+ vectors, self.weights, normalize_weights=normalize
161
+ )
162
+
163
+ print(f" āœ“ Combined vector shape: {self.combined_vector.shape}")
164
+ print(f" āœ“ Combined vector norm: {torch.norm(self.combined_vector).item():.4f}")
165
+
166
+ return self.combined_vector
167
+
168
+ def apply_steering(self, model: Model, prompt: str, max_new_tokens: int = 100,
169
+ temperature: float = 0.7, top_p: float = 0.9) -> str:
170
+ """Apply the combined steering vector to generate text.
171
+
172
+ Args:
173
+ model: Model to use for generation
174
+ prompt: Input prompt
175
+ max_new_tokens: Maximum tokens to generate
176
+ temperature: Sampling temperature
177
+ top_p: Top-p sampling parameter
178
+
179
+ Returns:
180
+ Generated text
181
+
182
+ Raises:
183
+ MultiSteeringError: If steering fails
184
+ """
185
+ if self.combined_vector is None:
186
+ raise MultiSteeringError("No combined vector available. Call combine_vectors() first.")
187
+
188
+ if self.layer is None:
189
+ raise MultiSteeringError("No layer information available")
190
+
191
+ print(f"\nšŸŽÆ Applying combined steering vector at layer {self.layer.index}")
192
+ print(f"Prompt: {prompt}")
193
+ print("=" * 50)
194
+
195
+ # Create appropriate steering method instance
196
+ if self.method == "CAA":
197
+ steering_method = CAA(device=self.device)
198
+ steering_method.steering_vector = self.combined_vector
199
+ steering_method.layer_index = self.layer.index
200
+ steering_method.is_trained = True
201
+ else:
202
+ # Use DAC for other methods
203
+ steering_method = DAC(device=self.device)
204
+ steering_method.steering_vector = self.combined_vector
205
+ steering_method.layer_index = self.layer.index
206
+ steering_method.is_trained = True
207
+
208
+ # Set up steering hook
209
+ hooks = []
210
+
211
+ def steering_hook(module, input, output):
212
+ if isinstance(output, tuple):
213
+ hidden_states = output[0]
214
+ else:
215
+ hidden_states = output
216
+
217
+ # Apply steering using the method's apply_steering
218
+ steered = steering_method.apply_steering(hidden_states, strength=1.0)
219
+
220
+ if isinstance(output, tuple):
221
+ return (steered,) + output[1:]
222
+ return steered
223
+
224
+ # Find the target layer module
225
+ if hasattr(model.hf_model, "model") and hasattr(model.hf_model.model, "layers"):
226
+ layer_module = model.hf_model.model.layers[self.layer.index]
227
+ elif hasattr(model.hf_model, "transformer") and hasattr(model.hf_model.transformer, "h"):
228
+ layer_module = model.hf_model.transformer.h[self.layer.index]
229
+ else:
230
+ raise MultiSteeringError("Could not find model layers")
231
+
232
+ # Register hook
233
+ handle = layer_module.register_forward_hook(steering_hook)
234
+ hooks.append(handle)
235
+
236
+ try:
237
+ # Generate with steering
238
+ output, _ = model.generate(
239
+ prompt=prompt,
240
+ layer_index=self.layer.index,
241
+ max_new_tokens=max_new_tokens,
242
+ temperature=temperature,
243
+ top_p=top_p,
244
+ )
245
+
246
+ return output
247
+
248
+ except Exception as e:
249
+ raise MultiSteeringError(f"Failed to apply steering: {e}")
250
+ finally:
251
+ # Clean up hooks
252
+ for hook in hooks:
253
+ hook.remove()
254
+
255
+
256
+ def run_multi_steer(
257
+ vector_specs: List[str],
258
+ model_name: str,
259
+ prompt: str,
260
+ method: str = "CAA",
261
+ layer: Optional[int] = None,
262
+ max_new_tokens: int = 100,
263
+ temperature: float = 0.7,
264
+ top_p: float = 0.9,
265
+ device: str | None = None,
266
+ verbose: bool = True,
267
+ ) -> str:
268
+ """Convenience function to run multi-steering.
269
+
270
+ Args:
271
+ vector_specs: List of "path:weight" specifications
272
+ model_name: Name of model to load
273
+ prompt: Input prompt
274
+ method: Steering method to use ("CAA" or "DAC")
275
+ layer: Target layer (will be inferred from vectors if not specified)
276
+ max_new_tokens: Maximum tokens to generate
277
+ temperature: Sampling temperature
278
+ top_p: Top-p sampling parameter
279
+ device: Device to use
280
+ verbose: Whether to print progress
281
+
282
+ Returns:
283
+ Generated text
284
+ """
285
+ # Initialize model
286
+ if verbose:
287
+ print(f"\nšŸš€ Loading model: {model_name}")
288
+
289
+ chosen_device = device or resolve_default_device()
290
+ model = Model(model_name, device=chosen_device)
291
+
292
+ # Initialize multi-steering with specified method
293
+ multi_steer = MultiSteering(device=chosen_device, method=method)
294
+
295
+ # Load vectors
296
+ multi_steer.load_vectors(vector_specs)
297
+
298
+ # Override layer if specified
299
+ if layer is not None:
300
+ multi_steer.layer = Layer(layer)
301
+ if verbose:
302
+ print(f"Overriding layer to: {layer}")
303
+
304
+ # Combine vectors with normalization
305
+ multi_steer.combine_vectors(normalize=True)
306
+
307
+ # Apply steering
308
+ output = multi_steer.apply_steering(
309
+ model=model,
310
+ prompt=prompt,
311
+ max_new_tokens=max_new_tokens,
312
+ temperature=temperature,
313
+ top_p=top_p
314
+ )
315
+
316
+ return output
@@ -0,0 +1,57 @@
1
+ """
2
+ Optuna-based Optimization Framework for Wisent Guard
3
+
4
+ This module provides Optuna-based hyperparameter optimization for both steering and classifier systems:
5
+
6
+ STEERING OPTIMIZATION:
7
+ 1. Hyperparameter Optimization: Optuna-driven search for best steering parameters
8
+ 2. Evaluation Pipeline: Comprehensive evaluation on multiple datasets
9
+ 3. Reproducibility: Complete experiment tracking and reproduction
10
+
11
+ CLASSIFIER OPTIMIZATION:
12
+ 1. Activation Pre-generation: Efficient caching of model activations
13
+ 2. Model Training: Optimized logistic regression and MLP classifiers
14
+ 3. Intelligent Caching: Avoid retraining identical configurations
15
+ 4. Cross-validation: Robust performance evaluation
16
+
17
+ Key components:
18
+ - Steering: OptimizationPipeline, OptimizationConfig, metrics
19
+ - Classifier: OptunaClassifierOptimizer, GenerationConfig, CacheConfig
20
+ """
21
+
22
+ # Steering optimization components
23
+ # Classifier optimization components
24
+ from wisent_guard.core.optuna.classifier import (
25
+ ActivationGenerator,
26
+ CacheConfig,
27
+ ClassifierCache,
28
+ ClassifierOptimizationConfig as ClassifierOptimizationConfig,
29
+ GenerationConfig,
30
+ OptimizationResult,
31
+ OptunaClassifierOptimizer,
32
+ )
33
+ from wisent_guard.core.optuna.steering.metrics import (
34
+ calculate_comprehensive_metrics,
35
+ evaluate_benchmark_performance,
36
+ evaluate_probe_performance,
37
+ generate_performance_summary,
38
+ )
39
+ from wisent_guard.core.optuna.steering.optuna_pipeline import OptimizationConfig, OptimizationPipeline
40
+
41
+ __all__ = [
42
+ # Steering optimization
43
+ "OptimizationConfig",
44
+ "OptimizationPipeline",
45
+ "calculate_comprehensive_metrics",
46
+ "evaluate_benchmark_performance",
47
+ "evaluate_probe_performance",
48
+ "generate_performance_summary",
49
+ # Classifier optimization
50
+ "OptunaClassifierOptimizer",
51
+ "ClassifierOptimizationConfig",
52
+ "GenerationConfig",
53
+ "CacheConfig",
54
+ "ActivationGenerator",
55
+ "ClassifierCache",
56
+ "OptimizationResult",
57
+ ]
@@ -0,0 +1,25 @@
1
+ """
2
+ Optuna-based classifier optimization module.
3
+
4
+ This module provides modern, efficient classifier optimization using Optuna with
5
+ intelligent caching and pre-generation of activations for maximum performance.
6
+ """
7
+
8
+ from .activation_generator import ActivationData, ActivationGenerator, GenerationConfig
9
+ from .classifier_cache import CacheConfig, CacheMetadata, ClassifierCache
10
+ from .optuna_classifier_optimizer import ClassifierOptimizationConfig, OptimizationResult, OptunaClassifierOptimizer
11
+
12
+ __all__ = [
13
+ # Activation generation
14
+ "ActivationGenerator",
15
+ "GenerationConfig",
16
+ "ActivationData",
17
+ # Classifier caching
18
+ "ClassifierCache",
19
+ "CacheConfig",
20
+ "CacheMetadata",
21
+ # Optuna optimization
22
+ "OptunaClassifierOptimizer",
23
+ "ClassifierOptimizationConfig",
24
+ "OptimizationResult",
25
+ ]