wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,429 @@
1
+ import logging
2
+ import itertools
3
+ from typing import Dict, List, Tuple, Any, Optional
4
+ from dataclasses import dataclass, field
5
+ import numpy as np
6
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
7
+
8
+ from .contrastive_pairs import ContrastivePairSet
9
+ from .steering import SteeringMethod, SteeringType
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def detect_model_layers(model) -> int:
15
+ """
16
+ Detect the number of layers in a model.
17
+
18
+ Args:
19
+ model: The model object to inspect
20
+
21
+ Returns:
22
+ Number of layers in the model
23
+ """
24
+ try:
25
+ # Try different ways to get layer count based on model architecture
26
+ if hasattr(model, 'hf_model'):
27
+ hf_model = model.hf_model
28
+ else:
29
+ hf_model = model
30
+
31
+ # Method 1: Check config for common layer count attributes
32
+ if hasattr(hf_model, 'config'):
33
+ config = hf_model.config
34
+
35
+ # Different models use different names for layer count
36
+ layer_attrs = ['num_hidden_layers', 'n_layer', 'num_layers', 'n_layers']
37
+ for attr in layer_attrs:
38
+ if hasattr(config, attr):
39
+ layer_count = getattr(config, attr)
40
+ if isinstance(layer_count, int) and layer_count > 0:
41
+ logger.info(f"Detected {layer_count} layers from config.{attr}")
42
+ return layer_count
43
+
44
+ # Method 2: Count actual layer modules
45
+ if hasattr(hf_model, 'model') and hasattr(hf_model.model, 'layers'):
46
+ # Llama/Mistral style: model.layers
47
+ layer_count = len(hf_model.model.layers)
48
+ logger.info(f"Detected {layer_count} layers from model.layers")
49
+ return layer_count
50
+ elif hasattr(hf_model, 'transformer') and hasattr(hf_model.transformer, 'h'):
51
+ # GPT style: transformer.h
52
+ layer_count = len(hf_model.transformer.h)
53
+ logger.info(f"Detected {layer_count} layers from transformer.h")
54
+ return layer_count
55
+ elif hasattr(hf_model, 'encoder') and hasattr(hf_model.encoder, 'layer'):
56
+ # BERT style: encoder.layer
57
+ layer_count = len(hf_model.encoder.layer)
58
+ logger.info(f"Detected {layer_count} layers from encoder.layer")
59
+ return layer_count
60
+
61
+ # Method 3: Try to count by iterating through named modules
62
+ layer_count = 0
63
+ for name, _ in hf_model.named_modules():
64
+ # Look for patterns like "layers.0", "h.0", "layer.0", etc.
65
+ if any(pattern in name for pattern in ['.layers.', '.h.', '.layer.']):
66
+ # Extract layer number
67
+ for part in name.split('.'):
68
+ if part.isdigit():
69
+ layer_num = int(part)
70
+ layer_count = max(layer_count, layer_num + 1)
71
+
72
+ if layer_count > 0:
73
+ logger.info(f"Detected {layer_count} layers from module names")
74
+ return layer_count
75
+
76
+ # Fallback: Conservative default
77
+ logger.warning("Could not detect layer count, using default of 32")
78
+ return 32
79
+
80
+ except Exception as e:
81
+ logger.warning(f"Error detecting layer count: {e}, using default of 32")
82
+ return 32
83
+
84
+
85
+ def get_default_layer_range(total_layers: int, use_all: bool = True) -> List[int]:
86
+ """
87
+ Get a reasonable default layer range for optimization.
88
+
89
+ Args:
90
+ total_layers: Total number of layers in the model
91
+ use_all: If True, use all layers; if False, use middle layers only
92
+
93
+ Returns:
94
+ List of layer indices to optimize over
95
+ """
96
+ if use_all:
97
+ # Use all layers (0-indexed)
98
+ return list(range(total_layers))
99
+ else:
100
+ # Use middle layers (skip first and last quarter)
101
+ start_layer = max(0, total_layers // 4)
102
+ end_layer = min(total_layers, (3 * total_layers) // 4)
103
+ return list(range(start_layer, end_layer))
104
+
105
+
106
+ @dataclass
107
+ class OptimizationConfig:
108
+ """Configuration for hyperparameter optimization."""
109
+
110
+ # Layer range to search (will be auto-detected if None)
111
+ layer_range: List[int] = None
112
+
113
+ # Token aggregation methods to try
114
+ aggregation_methods: List[str] = field(default_factory=lambda: ["average", "final", "first", "max", "min"])
115
+
116
+ # Threshold range to search (for classification)
117
+ threshold_range: List[float] = field(default_factory=lambda: [0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
118
+
119
+ # Classifier types to try
120
+ classifier_types: List[str] = field(default_factory=lambda: ["logistic"])
121
+
122
+ # Performance metric to optimize
123
+ metric: str = "f1" # Options: "accuracy", "f1", "precision", "recall", "auc"
124
+
125
+ # Cross-validation folds (if 0, uses simple train/val split)
126
+ cv_folds: int = 0
127
+
128
+ # Validation split ratio (used when cv_folds=0)
129
+ val_split: float = 0.2
130
+
131
+ # Maximum number of combinations to try (for performance)
132
+ max_combinations: int = 100
133
+
134
+ # Random seed for reproducibility
135
+ seed: int = 42
136
+
137
+
138
+ @dataclass
139
+ class OptimizationResult:
140
+ """Result of hyperparameter optimization."""
141
+
142
+ best_layer: int
143
+ best_aggregation: str
144
+ best_threshold: float
145
+ best_classifier_type: str
146
+ best_score: float
147
+ best_metrics: Dict[str, float]
148
+
149
+ # All tested combinations and their scores
150
+ all_results: List[Dict[str, Any]] = field(default_factory=list)
151
+
152
+ # Configuration used for optimization
153
+ config: OptimizationConfig = None
154
+
155
+
156
+ class HyperparameterOptimizer:
157
+ """Optimizes hyperparameters for the guard system."""
158
+
159
+ def __init__(self, config: OptimizationConfig = None):
160
+ self.config = config or OptimizationConfig()
161
+ np.random.seed(self.config.seed)
162
+
163
+ def optimize(
164
+ self,
165
+ model,
166
+ train_pair_set: ContrastivePairSet,
167
+ test_pair_set: ContrastivePairSet,
168
+ device: str = None,
169
+ verbose: bool = False
170
+ ) -> OptimizationResult:
171
+ """
172
+ Optimize hyperparameters for the guard system.
173
+
174
+ Args:
175
+ model: The model to use for training
176
+ train_pair_set: Training contrastive pairs
177
+ test_pair_set: Test contrastive pairs for evaluation
178
+ device: Device to run on
179
+ verbose: Whether to print progress
180
+
181
+ Returns:
182
+ OptimizationResult with best hyperparameters and performance
183
+ """
184
+
185
+ # Auto-detect layer range if not provided
186
+ layer_range = self.config.layer_range
187
+ if layer_range is None:
188
+ total_layers = detect_model_layers(model)
189
+ layer_range = get_default_layer_range(total_layers, use_all=True)
190
+ if verbose:
191
+ print(f" • Auto-detected {total_layers} model layers")
192
+ print(f" • Using all layers for optimization: {layer_range[0]}-{layer_range[-1]}")
193
+
194
+ if verbose:
195
+ print(f"\n🔍 Starting hyperparameter optimization...")
196
+ print(f" • Layers to test: {len(layer_range)} (range: {layer_range[0]}-{layer_range[-1]})")
197
+ print(f" • Aggregation methods: {len(self.config.aggregation_methods)}")
198
+ print(f" • Thresholds: {len(self.config.threshold_range)}")
199
+ print(f" • Classifier types: {len(self.config.classifier_types)}")
200
+ print(f" • Optimization metric: {self.config.metric}")
201
+
202
+ # Generate all combinations of hyperparameters
203
+ combinations = list(itertools.product(
204
+ layer_range,
205
+ self.config.aggregation_methods,
206
+ self.config.threshold_range,
207
+ self.config.classifier_types
208
+ ))
209
+
210
+ # Limit combinations if too many
211
+ if len(combinations) > self.config.max_combinations:
212
+ if verbose:
213
+ print(f" • Too many combinations ({len(combinations)}), sampling {self.config.max_combinations}")
214
+ combinations = np.random.choice(
215
+ combinations,
216
+ size=self.config.max_combinations,
217
+ replace=False
218
+ ).tolist()
219
+
220
+ if verbose:
221
+ print(f" • Testing {len(combinations)} combinations...")
222
+
223
+ best_score = -np.inf
224
+ best_result = None
225
+ all_results = []
226
+
227
+ for i, (layer, aggregation, threshold, classifier_type) in enumerate(combinations):
228
+ try:
229
+ if verbose and (i + 1) % 10 == 0:
230
+ print(f" • Progress: {i + 1}/{len(combinations)} combinations tested")
231
+
232
+ # Train and evaluate this combination
233
+ result = self._evaluate_combination(
234
+ model=model,
235
+ train_pair_set=train_pair_set,
236
+ test_pair_set=test_pair_set,
237
+ layer=layer,
238
+ aggregation=aggregation,
239
+ threshold=threshold,
240
+ classifier_type=classifier_type,
241
+ device=device
242
+ )
243
+
244
+ all_results.append(result)
245
+
246
+ # Check if this is the best so far
247
+ score = result[self.config.metric]
248
+ if score > best_score:
249
+ best_score = score
250
+ best_result = result
251
+
252
+ if verbose:
253
+ print(f" • New best: layer={layer}, agg={aggregation}, thresh={threshold:.2f}, {self.config.metric}={score:.3f}")
254
+
255
+ except Exception as e:
256
+ logger.warning(f"Failed to evaluate combination (layer={layer}, agg={aggregation}, thresh={threshold}, type={classifier_type}): {e}")
257
+ continue
258
+
259
+ if best_result is None:
260
+ raise ValueError("No valid combinations found during optimization")
261
+
262
+ # Create optimization result
263
+ optimization_result = OptimizationResult(
264
+ best_layer=best_result['layer'],
265
+ best_aggregation=best_result['aggregation'],
266
+ best_threshold=best_result['threshold'],
267
+ best_classifier_type=best_result['classifier_type'],
268
+ best_score=best_result[self.config.metric],
269
+ best_metrics={
270
+ 'accuracy': best_result['accuracy'],
271
+ 'f1': best_result['f1'],
272
+ 'precision': best_result['precision'],
273
+ 'recall': best_result['recall'],
274
+ 'auc': best_result.get('auc', 0.0)
275
+ },
276
+ all_results=all_results,
277
+ config=self.config
278
+ )
279
+
280
+ if verbose:
281
+ print(f"\n✅ Optimization complete!")
282
+ print(f" • Best layer: {optimization_result.best_layer}")
283
+ print(f" • Best aggregation: {optimization_result.best_aggregation}")
284
+ print(f" • Best threshold: {optimization_result.best_threshold:.2f}")
285
+ print(f" • Best classifier: {optimization_result.best_classifier_type}")
286
+ print(f" • Best {self.config.metric}: {optimization_result.best_score:.3f}")
287
+ print(f" • Tested {len(all_results)} valid combinations")
288
+
289
+ return optimization_result
290
+
291
+ def _evaluate_combination(
292
+ self,
293
+ model,
294
+ train_pair_set: ContrastivePairSet,
295
+ test_pair_set: ContrastivePairSet,
296
+ layer: int,
297
+ aggregation: str,
298
+ threshold: float,
299
+ classifier_type: str,
300
+ device: str = None
301
+ ) -> Dict[str, Any]:
302
+ """
303
+ Evaluate a single hyperparameter combination.
304
+
305
+ Args:
306
+ model: The model to use
307
+ train_pair_set: Training data
308
+ test_pair_set: Test data
309
+ layer: Layer index to use
310
+ aggregation: Token aggregation method
311
+ threshold: Classification threshold
312
+ classifier_type: Type of classifier
313
+ device: Device to run on
314
+
315
+ Returns:
316
+ Dictionary with evaluation metrics
317
+ """
318
+
319
+ # Train classifier with this combination
320
+ steering_type = SteeringType.LOGISTIC if classifier_type == "logistic" else SteeringType.MLP
321
+ steering_method = SteeringMethod(method_type=steering_type, device=device)
322
+
323
+ # Extract activations for training (this should be done by the activation collector)
324
+ # For now, assume the pair set already has activations
325
+
326
+ training_results = steering_method.train(train_pair_set)
327
+
328
+ # Evaluate on test set
329
+ test_results = steering_method.evaluate(test_pair_set)
330
+
331
+ # Get predictions with the specified threshold
332
+ predictions = []
333
+ true_labels = []
334
+
335
+ for pair in test_pair_set.pairs:
336
+ if hasattr(pair.positive_response, 'activations') and hasattr(pair.negative_response, 'activations'):
337
+ # Get classifier predictions for both responses
338
+ pos_features = pair.positive_response.activations.extract_features_for_classifier()
339
+ neg_features = pair.negative_response.activations.extract_features_for_classifier()
340
+
341
+ # Predict probabilities
342
+ pos_prob = steering_method.classifier.predict_proba([pos_features.numpy()])[0]
343
+ neg_prob = steering_method.classifier.predict_proba([neg_features.numpy()])[0]
344
+
345
+ # Apply threshold and aggregation
346
+ pos_pred = 1 if pos_prob > threshold else 0
347
+ neg_pred = 1 if neg_prob > threshold else 0
348
+
349
+ # Positive response should be classified as 0 (harmless)
350
+ # Negative response should be classified as 1 (harmful)
351
+ predictions.extend([pos_pred, neg_pred])
352
+ true_labels.extend([0, 1])
353
+
354
+ if len(predictions) == 0:
355
+ raise ValueError("No valid predictions generated")
356
+
357
+ # Calculate metrics
358
+ accuracy = accuracy_score(true_labels, predictions)
359
+ f1 = f1_score(true_labels, predictions, zero_division=0)
360
+ precision = precision_score(true_labels, predictions, zero_division=0)
361
+ recall = recall_score(true_labels, predictions, zero_division=0)
362
+
363
+ # Calculate AUC if possible
364
+ try:
365
+ # Get probability scores for positive class
366
+ prob_scores = []
367
+ for pair in test_pair_set.pairs:
368
+ if hasattr(pair.positive_response, 'activations') and hasattr(pair.negative_response, 'activations'):
369
+ pos_features = pair.positive_response.activations.extract_features_for_classifier()
370
+ neg_features = pair.negative_response.activations.extract_features_for_classifier()
371
+
372
+ pos_prob = steering_method.classifier.predict_proba([pos_features.numpy()])[0]
373
+ neg_prob = steering_method.classifier.predict_proba([neg_features.numpy()])[0]
374
+
375
+ prob_scores.extend([pos_prob, neg_prob])
376
+
377
+ auc = roc_auc_score(true_labels, prob_scores) if len(set(true_labels)) > 1 else 0.0
378
+ except:
379
+ auc = 0.0
380
+
381
+ return {
382
+ 'layer': layer,
383
+ 'aggregation': aggregation,
384
+ 'threshold': threshold,
385
+ 'classifier_type': classifier_type,
386
+ 'accuracy': accuracy,
387
+ 'f1': f1,
388
+ 'precision': precision,
389
+ 'recall': recall,
390
+ 'auc': auc,
391
+ 'training_results': training_results,
392
+ 'test_results': test_results
393
+ }
394
+
395
+ @staticmethod
396
+ def from_config_dict(config_dict: Dict[str, Any]) -> 'HyperparameterOptimizer':
397
+ """Create optimizer from configuration dictionary."""
398
+ config = OptimizationConfig(**config_dict)
399
+ return HyperparameterOptimizer(config)
400
+
401
+ def save_results(self, result: OptimizationResult, filepath: str):
402
+ """Save optimization results to file."""
403
+ import json
404
+
405
+ # Convert result to serializable format
406
+ result_dict = {
407
+ 'best_hyperparameters': {
408
+ 'layer': result.best_layer,
409
+ 'aggregation': result.best_aggregation,
410
+ 'threshold': result.best_threshold,
411
+ 'classifier_type': result.best_classifier_type
412
+ },
413
+ 'best_score': result.best_score,
414
+ 'best_metrics': result.best_metrics,
415
+ 'optimization_config': {
416
+ 'layer_range': self.config.layer_range,
417
+ 'aggregation_methods': self.config.aggregation_methods,
418
+ 'threshold_range': self.config.threshold_range,
419
+ 'classifier_types': self.config.classifier_types,
420
+ 'metric': self.config.metric,
421
+ 'max_combinations': self.config.max_combinations
422
+ },
423
+ 'all_results': result.all_results
424
+ }
425
+
426
+ with open(filepath, 'w') as f:
427
+ json.dump(result_dict, f, indent=2)
428
+
429
+ logger.info(f"Optimization results saved to {filepath}")