wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,652 @@
1
+ import datetime
2
+ import json
3
+ import os
4
+ from enum import Enum
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from wisent.core.activations import Activations
11
+ from wisent.core.classifier.classifier import Classifier
12
+
13
+ from .contrastive_pairs import ContrastivePairSet
14
+ from .steering_method import CAA
15
+
16
+
17
+ class SteeringType(Enum):
18
+ LOGISTIC = "logistic"
19
+ MLP = "mlp"
20
+ CUSTOM = "custom"
21
+ CAA = "caa" # New vector-based steering
22
+
23
+
24
+ class SteeringMethod:
25
+ """
26
+ Legacy classifier-based steering method for backward compatibility.
27
+ For new vector-based steering, use steering_method.CAA directly.
28
+ """
29
+
30
+ def __init__(self, method_type: SteeringType, device=None, threshold=0.5):
31
+ self.method_type = method_type
32
+ self.device = device
33
+ self.threshold = threshold
34
+ self.classifier = None
35
+
36
+ # For vector-based steering
37
+ self.vector_steering = None
38
+ self.is_vector_based = method_type == SteeringType.CAA
39
+
40
+ if self.is_vector_based:
41
+ self.vector_steering = CAA(device=device)
42
+
43
+ # Response logging settings
44
+ self.enable_logging = False
45
+ self.log_file_path = "./harmful_responses.json"
46
+
47
+ # Parameter optimization tracking
48
+ self.original_parameters = {}
49
+ self.optimization_history = []
50
+
51
+ def train(
52
+ self, contrastive_pair_set: ContrastivePairSet, layer_index: Optional[int] = None, **kwargs
53
+ ) -> Dict[str, Any]:
54
+ """
55
+ Train the steering method on a ContrastivePairSet.
56
+
57
+ Args:
58
+ contrastive_pair_set: Set of contrastive pairs with activations
59
+ layer_index: Layer index for vector-based steering (required for CAA)
60
+ **kwargs: Additional training parameters
61
+
62
+ Returns:
63
+ Dictionary with training metrics
64
+ """
65
+ if self.is_vector_based:
66
+ if layer_index is None:
67
+ raise ValueError("layer_index required for vector-based steering methods")
68
+ return self.vector_steering.train(contrastive_pair_set, layer_index)
69
+
70
+ # Legacy classifier-based training
71
+ X, y = contrastive_pair_set.prepare_classifier_data()
72
+
73
+ if len(X) < 4:
74
+ raise ValueError(f"Need at least 4 training examples, got {len(X)}")
75
+
76
+ # Create classifier
77
+ self.classifier = Classifier(model_type=self.method_type.value, device=self.device, threshold=self.threshold)
78
+
79
+ # Train classifier
80
+ results = self.classifier.fit(X, y, **kwargs)
81
+
82
+ return results
83
+
84
+ def apply_steering(self, activations: torch.Tensor, strength: float = 1.0) -> torch.Tensor:
85
+ """
86
+ Apply steering to activations (vector-based methods only).
87
+
88
+ Args:
89
+ activations: Input activations
90
+ strength: Steering strength
91
+
92
+ Returns:
93
+ Steered activations
94
+ """
95
+ if not self.is_vector_based:
96
+ raise ValueError("apply_steering only available for vector-based methods")
97
+
98
+ return self.vector_steering.apply_steering(activations, strength)
99
+
100
+ def get_steering_vector(self) -> Optional[torch.Tensor]:
101
+ """Get steering vector (vector-based methods only)."""
102
+ if not self.is_vector_based:
103
+ return None
104
+ return self.vector_steering.get_steering_vector()
105
+
106
+ def predict(self, activations) -> float:
107
+ """
108
+ Predict if activations represent harmful behavior (classifier-based only).
109
+
110
+ Args:
111
+ activations: Activation tensor or Activations object
112
+
113
+ Returns:
114
+ Prediction score (0 = harmless, 1 = harmful)
115
+ """
116
+ if self.is_vector_based:
117
+ raise ValueError("predict not available for vector-based methods")
118
+
119
+ if self.classifier is None:
120
+ raise ValueError("SteeringMethod not trained. Call train() first.")
121
+
122
+ return self.classifier.predict(activations)
123
+
124
+ def predict_proba(self, activations) -> float:
125
+ """
126
+ Get prediction probability for activations (classifier-based only).
127
+
128
+ Args:
129
+ activations: Activation tensor or Activations object
130
+
131
+ Returns:
132
+ Probability score (0.0-1.0)
133
+ """
134
+ if self.is_vector_based:
135
+ raise ValueError("predict_proba not available for vector-based methods")
136
+
137
+ if self.classifier is None:
138
+ raise ValueError("SteeringMethod not trained. Call train() first.")
139
+
140
+ return self.classifier.predict_proba(activations)
141
+
142
+ def is_harmful(self, activations, detailed=False) -> Union[bool, Dict[str, Any]]:
143
+ """
144
+ Check if activations represent harmful content (classifier-based only).
145
+
146
+ Args:
147
+ activations: Activation tensor or Activations object
148
+ detailed: Whether to return detailed results
149
+
150
+ Returns:
151
+ Boolean or detailed dictionary
152
+ """
153
+ if self.is_vector_based:
154
+ raise ValueError("is_harmful not available for vector-based methods")
155
+
156
+ if self.classifier is None:
157
+ raise ValueError("SteeringMethod not trained. Call train() first.")
158
+
159
+ # Get probability score
160
+ probability = self.predict_proba(activations)
161
+ is_harmful = probability >= self.threshold
162
+
163
+ if detailed:
164
+ return {
165
+ "is_harmful": is_harmful,
166
+ "probability": probability,
167
+ "threshold": self.threshold,
168
+ "method_type": self.method_type.value,
169
+ }
170
+ return is_harmful
171
+
172
+ def check_safety(self, text: str, model, layer) -> Dict[str, Any]:
173
+ """
174
+ Comprehensive safety check for text using the model.
175
+
176
+ Args:
177
+ text: Text to check
178
+ model: Model object for activation extraction
179
+ layer: Layer object for activation extraction
180
+
181
+ Returns:
182
+ Safety check results
183
+ """
184
+ try:
185
+ # Extract activations from text
186
+ activations_tensor = model.extract_activations(text, layer)
187
+
188
+ # Create Activations object
189
+ activations = Activations(tensor=activations_tensor, layer=layer)
190
+
191
+ # Get detailed prediction
192
+ result = self.is_harmful(activations, detailed=True)
193
+
194
+ # Add text information
195
+ result.update(
196
+ {
197
+ "text": text[:100] + "..." if len(text) > 100 else text,
198
+ "text_length": len(text),
199
+ "layer_index": layer.index,
200
+ }
201
+ )
202
+
203
+ return result
204
+
205
+ except Exception as e:
206
+ return {
207
+ "is_harmful": False,
208
+ "probability": 0.0,
209
+ "error": str(e),
210
+ "text": text[:100] + "..." if len(text) > 100 else text,
211
+ }
212
+
213
+ def enable_response_logging(self, log_file_path: str = "./harmful_responses.json") -> None:
214
+ """
215
+ Enable logging of harmful responses.
216
+
217
+ Args:
218
+ log_file_path: Path to the log file
219
+ """
220
+ self.enable_logging = True
221
+ self.log_file_path = log_file_path
222
+
223
+ # Initialize log file if it doesn't exist
224
+ if not os.path.exists(os.path.dirname(log_file_path)):
225
+ try:
226
+ os.makedirs(os.path.dirname(log_file_path))
227
+ except Exception:
228
+ pass
229
+
230
+ if not os.path.exists(log_file_path):
231
+ try:
232
+ with open(log_file_path, "w") as f:
233
+ json.dump([], f)
234
+ except Exception:
235
+ pass
236
+
237
+ def log_harmful_response(
238
+ self, prompt: str, response: str, probability: float, category: str = "harmful", additional_info: Dict = None
239
+ ) -> bool:
240
+ """
241
+ Log a harmful response to the JSON log file.
242
+
243
+ Args:
244
+ prompt: The original prompt
245
+ response: The generated response
246
+ probability: The probability score that triggered detection
247
+ category: The category of harmful content detected
248
+ additional_info: Optional additional information
249
+
250
+ Returns:
251
+ Success flag
252
+ """
253
+ if not self.enable_logging:
254
+ return False
255
+
256
+ try:
257
+ # Create log entry
258
+ log_entry = {
259
+ "timestamp": datetime.datetime.now().isoformat(),
260
+ "prompt": prompt,
261
+ "response": response,
262
+ "probability": float(probability),
263
+ "category": category,
264
+ "threshold": float(self.threshold),
265
+ "method_type": self.method_type.value,
266
+ }
267
+
268
+ # Add additional info if provided
269
+ if additional_info:
270
+ log_entry.update(additional_info)
271
+
272
+ # Read existing log entries
273
+ try:
274
+ with open(self.log_file_path) as f:
275
+ log_entries = json.load(f)
276
+ except (FileNotFoundError, json.JSONDecodeError):
277
+ log_entries = []
278
+
279
+ # Append new entry
280
+ log_entries.append(log_entry)
281
+
282
+ # Write updated log
283
+ with open(self.log_file_path, "w") as f:
284
+ json.dump(log_entries, f, indent=2)
285
+
286
+ return True
287
+
288
+ except Exception:
289
+ return False
290
+
291
+ def get_logged_responses(self, limit: Optional[int] = None, category: Optional[str] = None) -> List[Dict[str, Any]]:
292
+ """
293
+ Retrieve logged harmful responses from the log file.
294
+
295
+ Args:
296
+ limit: Maximum number of entries to return (None for all)
297
+ category: Filter by specific category (None for all categories)
298
+
299
+ Returns:
300
+ List of log entries
301
+ """
302
+ if not self.enable_logging:
303
+ return []
304
+
305
+ try:
306
+ # Check if log file exists
307
+ if not os.path.exists(self.log_file_path):
308
+ return []
309
+
310
+ # Read log entries
311
+ with open(self.log_file_path) as f:
312
+ log_entries = json.load(f)
313
+
314
+ # Filter by category if specified
315
+ if category is not None:
316
+ log_entries = [entry for entry in log_entries if entry.get("category") == category]
317
+
318
+ # Sort by timestamp (newest first)
319
+ log_entries.sort(key=lambda entry: entry.get("timestamp", ""), reverse=True)
320
+
321
+ # Apply limit if specified
322
+ if limit is not None and limit > 0:
323
+ log_entries = log_entries[:limit]
324
+
325
+ return log_entries
326
+
327
+ except Exception:
328
+ return []
329
+
330
+ def optimize_parameters(
331
+ self,
332
+ model,
333
+ target_layer,
334
+ pair_set: ContrastivePairSet,
335
+ learning_rate: float = 1e-4,
336
+ num_epochs: int = 10,
337
+ regularization_strength: float = 0.01,
338
+ ) -> Dict[str, Any]:
339
+ """
340
+ Optimize model parameters to improve steering effectiveness.
341
+
342
+ Args:
343
+ model: Model object to optimize
344
+ target_layer: Layer to optimize
345
+ pair_set: ContrastivePairSet with training data
346
+ learning_rate: Learning rate for optimization
347
+ num_epochs: Number of optimization epochs
348
+ regularization_strength: L2 regularization strength
349
+
350
+ Returns:
351
+ Dictionary with optimization results
352
+ """
353
+ try:
354
+ # Get the target layer module for optimization
355
+ layer_module = self._get_layer_module(model, target_layer)
356
+ if layer_module is None:
357
+ raise ValueError(f"Could not find layer {target_layer} in model")
358
+
359
+ # Store original parameters
360
+ self._store_original_parameters(layer_module)
361
+
362
+ # Extract activations for the pair set
363
+ pair_set.extract_activations_with_model(model, target_layer)
364
+
365
+ # Prepare training data
366
+ X_tensors, y_labels = pair_set.prepare_classifier_data()
367
+
368
+ # Set up optimizer for just the target layer
369
+ optimizer = torch.optim.Adam(layer_module.parameters(), lr=learning_rate)
370
+
371
+ # Training loop
372
+ best_steering_loss = float("inf")
373
+ best_parameters = None
374
+
375
+ for epoch in range(num_epochs):
376
+ epoch_loss = 0.0
377
+ num_batches = 0
378
+
379
+ # Process in batches
380
+ batch_size = 4
381
+ for i in range(0, len(X_tensors), batch_size):
382
+ batch_X = X_tensors[i : i + batch_size]
383
+ batch_y = y_labels[i : i + batch_size]
384
+
385
+ # Zero gradients
386
+ optimizer.zero_grad()
387
+
388
+ # Forward pass through the modified layer
389
+ loss = self._compute_steering_loss(batch_X, batch_y, layer_module, regularization_strength)
390
+
391
+ # Backward pass
392
+ loss.backward()
393
+ optimizer.step()
394
+
395
+ epoch_loss += loss.item()
396
+ num_batches += 1
397
+
398
+ avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
399
+
400
+ # Track best parameters
401
+ if avg_loss < best_steering_loss:
402
+ best_steering_loss = avg_loss
403
+ best_parameters = {name: param.clone() for name, param in layer_module.named_parameters()}
404
+
405
+ # Load best parameters
406
+ if best_parameters is not None:
407
+ for name, param in layer_module.named_parameters():
408
+ if name in best_parameters:
409
+ param.data.copy_(best_parameters[name])
410
+
411
+ # Store optimization results
412
+ optimization_result = {
413
+ "target_layer": target_layer.index if hasattr(target_layer, "index") else target_layer,
414
+ "final_loss": best_steering_loss,
415
+ "epochs": num_epochs,
416
+ "learning_rate": learning_rate,
417
+ "regularization_strength": regularization_strength,
418
+ "parameters_optimized": True,
419
+ }
420
+
421
+ self.optimization_history.append(optimization_result)
422
+
423
+ return optimization_result
424
+
425
+ except Exception as e:
426
+ return {"error": str(e), "parameters_optimized": False}
427
+
428
+ def _get_layer_module(self, model, layer):
429
+ """Get the module for a specific layer."""
430
+ try:
431
+ hf_model = model.hf_model if hasattr(model, "hf_model") else model
432
+ layer_idx = layer.index if hasattr(layer, "index") else layer
433
+
434
+ if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
435
+ # Llama-style model
436
+ if layer_idx < len(hf_model.model.layers):
437
+ return hf_model.model.layers[layer_idx]
438
+ elif hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "h"):
439
+ # GPT-style model
440
+ if layer_idx < len(hf_model.transformer.h):
441
+ return hf_model.transformer.h[layer_idx]
442
+
443
+ return None
444
+ except Exception:
445
+ return None
446
+
447
+ def _store_original_parameters(self, module):
448
+ """Store original parameters of a module."""
449
+ key = f"module_{id(module)}"
450
+ self.original_parameters[key] = {name: param.clone() for name, param in module.named_parameters()}
451
+
452
+ def _compute_steering_loss(self, batch_X, batch_y, layer_module, regularization_strength):
453
+ """
454
+ Compute loss for steering optimization.
455
+
456
+ Args:
457
+ batch_X: Batch of activation tensors
458
+ batch_y: Batch of labels
459
+ layer_module: Layer module being optimized
460
+ regularization_strength: L2 regularization strength
461
+
462
+ Returns:
463
+ Loss tensor
464
+ """
465
+ total_loss = 0.0
466
+
467
+ # Compute steering effectiveness loss
468
+ for i, (activation, label) in enumerate(zip(batch_X, batch_y)):
469
+ # Get prediction from steering method
470
+ prediction = self.predict_proba(activation)
471
+
472
+ # Convert to tensor for loss computation
473
+ if not isinstance(prediction, torch.Tensor):
474
+ prediction = torch.tensor(prediction, dtype=torch.float32, device=self.device)
475
+
476
+ target = torch.tensor(label, dtype=torch.float32, device=self.device)
477
+
478
+ # Binary cross-entropy loss
479
+ loss = F.binary_cross_entropy_with_logits(prediction.unsqueeze(0), target.unsqueeze(0))
480
+ total_loss += loss
481
+
482
+ # Add L2 regularization
483
+ l2_reg = 0.0
484
+ for param in layer_module.parameters():
485
+ l2_reg += torch.norm(param, p=2)
486
+
487
+ total_loss += regularization_strength * l2_reg
488
+
489
+ return total_loss / len(batch_X) # Average over batch
490
+
491
+ def restore_original_parameters(self) -> bool:
492
+ """
493
+ Restore original parameters.
494
+
495
+ Returns:
496
+ Success flag
497
+ """
498
+ try:
499
+ # This is a simplified version - in practice, you'd need to keep track
500
+ # of which modules correspond to which keys
501
+ return len(self.original_parameters) > 0
502
+ except Exception:
503
+ return False
504
+
505
+ def get_optimization_summary(self) -> Dict[str, Any]:
506
+ """
507
+ Get a summary of all optimizations performed.
508
+
509
+ Returns:
510
+ Summary dictionary
511
+ """
512
+ return {
513
+ "total_optimizations": len(self.optimization_history),
514
+ "optimization_history": self.optimization_history,
515
+ "has_original_parameters": len(self.original_parameters) > 0,
516
+ "method_type": self.method_type.value,
517
+ "threshold": self.threshold,
518
+ }
519
+
520
+ def evaluate(self, contrastive_pair_set: ContrastivePairSet) -> Dict[str, Any]:
521
+ """
522
+ Evaluate the steering method on a ContrastivePairSet.
523
+
524
+ Args:
525
+ contrastive_pair_set: Set of contrastive pairs for evaluation
526
+
527
+ Returns:
528
+ Dictionary with evaluation metrics
529
+ """
530
+ if self.classifier is None:
531
+ raise ValueError("SteeringMethod not trained. Call train() first.")
532
+
533
+ # Get positive and negative activations
534
+ pos_activations, neg_activations = contrastive_pair_set.get_activation_pairs()
535
+
536
+ # Predict on positive activations (should be low scores)
537
+ pos_predictions = []
538
+ for activation in pos_activations:
539
+ pred = self.predict_proba(activation)
540
+ pos_predictions.append(pred)
541
+
542
+ # Predict on negative activations (should be high scores)
543
+ neg_predictions = []
544
+ for activation in neg_activations:
545
+ pred = self.predict_proba(activation)
546
+ neg_predictions.append(pred)
547
+
548
+ # Calculate metrics
549
+ # True Positives: negative activations correctly identified as harmful (pred >= threshold)
550
+ true_positives = sum(1 for pred in neg_predictions if pred >= self.threshold)
551
+
552
+ # False Positives: positive activations incorrectly identified as harmful (pred >= threshold)
553
+ false_positives = sum(1 for pred in pos_predictions if pred >= self.threshold)
554
+
555
+ # True Negatives: positive activations correctly identified as harmless (pred < threshold)
556
+ true_negatives = sum(1 for pred in pos_predictions if pred < self.threshold)
557
+
558
+ # False Negatives: negative activations incorrectly identified as harmless (pred < threshold)
559
+ false_negatives = sum(1 for pred in neg_predictions if pred < self.threshold)
560
+
561
+ # Calculate metrics
562
+ detection_rate = true_positives / len(neg_predictions) if neg_predictions else 0
563
+ false_positive_rate = false_positives / len(pos_predictions) if pos_predictions else 0
564
+
565
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
566
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
567
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
568
+
569
+ accuracy = (
570
+ (true_positives + true_negatives) / (len(pos_predictions) + len(neg_predictions))
571
+ if (pos_predictions or neg_predictions)
572
+ else 0
573
+ )
574
+
575
+ return {
576
+ "detection_rate": detection_rate,
577
+ "false_positive_rate": false_positive_rate,
578
+ "precision": precision,
579
+ "recall": recall,
580
+ "f1": f1,
581
+ "accuracy": accuracy,
582
+ "true_positives": true_positives,
583
+ "false_positives": false_positives,
584
+ "true_negatives": true_negatives,
585
+ "false_negatives": false_negatives,
586
+ "num_positive_samples": len(pos_predictions),
587
+ "num_negative_samples": len(neg_predictions),
588
+ "threshold": self.threshold,
589
+ }
590
+
591
+ def save_model(self, save_path: str) -> bool:
592
+ """
593
+ Save the steering method to disk.
594
+
595
+ Args:
596
+ save_path: Path to save the model
597
+
598
+ Returns:
599
+ Success flag
600
+ """
601
+ if self.classifier is None:
602
+ return False
603
+
604
+ try:
605
+ self.classifier.save_model(save_path)
606
+ return True
607
+ except Exception:
608
+ return False
609
+
610
+ def load_model(self, model_path: str) -> bool:
611
+ """
612
+ Load a steering method from disk.
613
+
614
+ Args:
615
+ model_path: Path to the saved model
616
+
617
+ Returns:
618
+ Success flag
619
+ """
620
+ try:
621
+ self.classifier = Classifier(
622
+ model_type=self.method_type.value, device=self.device, threshold=self.threshold, model_path=model_path
623
+ )
624
+ return True
625
+ except Exception:
626
+ return False
627
+
628
+ @classmethod
629
+ def create_and_train(
630
+ cls,
631
+ method_type: SteeringType,
632
+ contrastive_pair_set: ContrastivePairSet,
633
+ device: Optional[str] = None,
634
+ threshold: float = 0.5,
635
+ **training_kwargs,
636
+ ) -> "SteeringMethod":
637
+ """
638
+ Create and train a SteeringMethod in one step.
639
+
640
+ Args:
641
+ method_type: Type of steering method
642
+ contrastive_pair_set: Training data
643
+ device: Device to use
644
+ threshold: Classification threshold
645
+ **training_kwargs: Additional training parameters
646
+
647
+ Returns:
648
+ Trained SteeringMethod
649
+ """
650
+ steering = cls(method_type=method_type, device=device, threshold=threshold)
651
+ steering.train(contrastive_pair_set, **training_kwargs)
652
+ return steering
@@ -0,0 +1,26 @@
1
+ """
2
+ Steering methods for wisent-guard.
3
+
4
+ This module provides a unified interface for various steering methods
5
+ by importing them from the steering_methods package.
6
+ """
7
+
8
+ # Import all steering methods from the new package
9
+ from .steering_methods import (
10
+ SteeringMethod,
11
+ CAA,
12
+ HPR,
13
+ DAC,
14
+ BiPO,
15
+ KSteering
16
+ )
17
+
18
+ # Re-export for backward compatibility
19
+ __all__ = [
20
+ 'SteeringMethod',
21
+ 'CAA',
22
+ 'HPR',
23
+ 'DAC',
24
+ 'BiPO',
25
+ 'KSteering'
26
+ ]
File without changes
File without changes