wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1154 @@
1
+ """
2
+ On-the-Fly Classifier Creation System for Autonomous Agent
3
+
4
+ This module handles:
5
+ - Dynamic training of new classifiers for specific issue types
6
+ - Automatic training data generation for different problem domains
7
+ - Classifier optimization and validation
8
+ - Integration with the autonomous agent system
9
+ """
10
+
11
+ import time
12
+ from dataclasses import dataclass
13
+ from typing import Any, Dict, List, Optional, Tuple
14
+
15
+ from wisent_guard.core.classifier.classifier import ActivationClassifier, Classifier
16
+
17
+ from ...activations import Activations
18
+ from ...layer import Layer
19
+ from ...model import Model
20
+ from ...model_persistence import ModelPersistence, create_classifier_metadata
21
+
22
+
23
+ @dataclass
24
+ class TrainingConfig:
25
+ """Configuration for classifier training."""
26
+
27
+ issue_type: str
28
+ layer: int
29
+ classifier_type: str = "logistic"
30
+ threshold: float = 0.5
31
+ model_name: str = ""
32
+ training_samples: int = 100
33
+ test_split: float = 0.2
34
+ optimization_metric: str = "f1"
35
+ save_path: Optional[str] = None
36
+
37
+
38
+ @dataclass
39
+ class TrainingResult:
40
+ """Result of classifier training."""
41
+
42
+ classifier: Classifier
43
+ config: TrainingConfig
44
+ performance_metrics: Dict[str, float]
45
+ training_time: float
46
+ save_path: Optional[str] = None
47
+
48
+
49
+ class ClassifierCreator:
50
+ """Creates new classifiers on demand for the autonomous agent."""
51
+
52
+ def __init__(self, model: Model):
53
+ """
54
+ Initialize the classifier creator.
55
+
56
+ Args:
57
+ model: The language model to use for training
58
+ """
59
+ self.model = model
60
+
61
+ def create_classifier_for_issue_type(
62
+ self, issue_type: str, layer: int, config: Optional[TrainingConfig] = None
63
+ ) -> TrainingResult:
64
+ """
65
+ Create a new classifier for a specific issue type.
66
+
67
+ Args:
68
+ issue_type: Type of issue to detect (e.g., "hallucination", "quality")
69
+ layer: Model layer to use for activation extraction
70
+ config: Optional training configuration
71
+
72
+ Returns:
73
+ TrainingResult with the trained classifier and metrics
74
+ """
75
+ print(f"🏋️ Creating classifier for {issue_type} at layer {layer}...")
76
+
77
+ # Use provided config or create default
78
+ if config is None:
79
+ config = TrainingConfig(issue_type=issue_type, layer=layer, model_name=self.model.name)
80
+
81
+ start_time = time.time()
82
+
83
+ # Generate training data
84
+ print(" 📊 Generating training data...")
85
+ training_data = self._generate_training_data(issue_type, config.training_samples)
86
+
87
+ # Extract activations
88
+ print(" 🧠 Extracting activations...")
89
+ harmful_activations, harmless_activations = self._extract_activations_from_data(training_data, layer)
90
+
91
+ # Train classifier
92
+ print(" 🎯 Training classifier...")
93
+ classifier = self._train_classifier(harmful_activations, harmless_activations, config)
94
+
95
+ # Evaluate performance
96
+ print(" 📈 Evaluating performance...")
97
+ metrics = self._evaluate_classifier(classifier, harmful_activations, harmless_activations)
98
+
99
+ training_time = time.time() - start_time
100
+
101
+ # Save classifier if path provided
102
+ save_path = None
103
+ if config.save_path:
104
+ print(" 💾 Saving classifier...")
105
+ save_path = self._save_classifier(classifier, config, metrics)
106
+
107
+ result = TrainingResult(
108
+ classifier=classifier.classifier, # Return the base classifier
109
+ config=config,
110
+ performance_metrics=metrics,
111
+ training_time=training_time,
112
+ save_path=save_path,
113
+ )
114
+
115
+ print(
116
+ f" ✅ Classifier created in {training_time:.2f}s "
117
+ f"(F1: {metrics.get('f1', 0):.3f}, Accuracy: {metrics.get('accuracy', 0):.3f})"
118
+ )
119
+
120
+ return result
121
+
122
+ def create_multi_layer_classifiers(
123
+ self, issue_type: str, layers: List[int], save_base_path: Optional[str] = None
124
+ ) -> Dict[int, TrainingResult]:
125
+ """
126
+ Create classifiers for multiple layers for the same issue type.
127
+
128
+ Args:
129
+ issue_type: Type of issue to detect
130
+ layers: List of layers to create classifiers for
131
+ save_base_path: Base path for saving classifiers
132
+
133
+ Returns:
134
+ Dictionary mapping layer indices to training results
135
+ """
136
+ print(f"🔄 Creating multi-layer classifiers for {issue_type}...")
137
+
138
+ results = {}
139
+
140
+ for layer in layers:
141
+ config = TrainingConfig(
142
+ issue_type=issue_type,
143
+ layer=layer,
144
+ model_name=self.model.name,
145
+ save_path=f"{save_base_path}_layer_{layer}.pkl" if save_base_path else None,
146
+ )
147
+
148
+ result = self.create_classifier_for_issue_type(issue_type, layer, config)
149
+ results[layer] = result
150
+
151
+ print(f" ✅ Created {len(results)} classifiers across layers {layers}")
152
+ return results
153
+
154
+ def optimize_classifier_for_performance(
155
+ self,
156
+ issue_type: str,
157
+ layer_range: Tuple[int, int] = None,
158
+ classifier_types: List[str] = None,
159
+ target_metric: str = "f1",
160
+ min_target_score: float = 0.7,
161
+ ) -> TrainingResult:
162
+ """
163
+ Optimize classifier by testing different configurations.
164
+
165
+ Args:
166
+ issue_type: Type of issue to detect
167
+ layer_range: Range of layers to test (start, end). If None, auto-detect all model layers
168
+ classifier_types: Types of classifiers to test
169
+ target_metric: Metric to optimize for
170
+ min_target_score: Minimum acceptable score
171
+
172
+ Returns:
173
+ Best performing classifier configuration
174
+ """
175
+ print(f"🎯 Optimizing classifier for {issue_type}...")
176
+
177
+ if classifier_types is None:
178
+ classifier_types = ["logistic", "mlp"]
179
+
180
+ # Auto-detect layer range if not provided
181
+ if layer_range is None:
182
+ from ..hyperparameter_optimizer import detect_model_layers
183
+
184
+ total_layers = detect_model_layers(self.model)
185
+ layer_range = (0, total_layers - 1)
186
+ print(f" 📊 Auto-detected {total_layers} layers, testing range {layer_range[0]}-{layer_range[1]}")
187
+
188
+ best_result = None
189
+ best_score = 0.0
190
+
191
+ layers_to_test = range(layer_range[0], layer_range[1] + 1, 2) # Test every 2nd layer
192
+
193
+ for layer in layers_to_test:
194
+ for classifier_type in classifier_types:
195
+ config = TrainingConfig(
196
+ issue_type=issue_type, layer=layer, classifier_type=classifier_type, model_name=self.model.name
197
+ )
198
+
199
+ try:
200
+ result = self.create_classifier_for_issue_type(issue_type, layer, config)
201
+ score = result.performance_metrics.get(target_metric, 0.0)
202
+
203
+ print(f" Layer {layer}, {classifier_type}: {target_metric}={score:.3f}")
204
+
205
+ if score > best_score:
206
+ best_score = score
207
+ best_result = result
208
+
209
+ # Early stopping if we hit the target
210
+ if score >= min_target_score:
211
+ print(f" 🎉 Target score reached: {score:.3f}")
212
+ break
213
+
214
+ except Exception as e:
215
+ print(f" ❌ Failed layer {layer}, {classifier_type}: {e}")
216
+ continue
217
+
218
+ # Break outer loop if target reached
219
+ if best_score >= min_target_score:
220
+ break
221
+
222
+ if best_result is None:
223
+ raise RuntimeError(f"Failed to create any working classifier for {issue_type}")
224
+
225
+ print(
226
+ f" ✅ Best configuration: Layer {best_result.config.layer}, "
227
+ f"{best_result.config.classifier_type}, {target_metric}={best_score:.3f}"
228
+ )
229
+
230
+ return best_result
231
+
232
+ async def create_classifier_for_issue_with_benchmarks(
233
+ self,
234
+ issue_type: str,
235
+ relevant_benchmarks: List[str],
236
+ layer: int = 15,
237
+ num_samples: int = 50,
238
+ config: Optional[TrainingConfig] = None,
239
+ ) -> TrainingResult:
240
+ """
241
+ Create a classifier using specific benchmarks for better contrastive pairs.
242
+
243
+ Args:
244
+ issue_type: Type of issue to detect (e.g., "hallucination", "quality")
245
+ relevant_benchmarks: List of benchmark names to use for training data
246
+ layer: Model layer to use for activation extraction (default: 15)
247
+ num_samples: Number of training samples to generate
248
+ config: Optional training configuration
249
+
250
+ Returns:
251
+ TrainingResult with the trained classifier and metrics
252
+ """
253
+ print(f"🎯 Creating {issue_type} classifier using benchmarks: {relevant_benchmarks}")
254
+
255
+ # Use provided config or create default
256
+ if config is None:
257
+ config = TrainingConfig(
258
+ issue_type=issue_type, layer=layer, model_name=self.model.name, training_samples=num_samples
259
+ )
260
+
261
+ start_time = time.time()
262
+
263
+ # Generate training data using the provided benchmarks
264
+ print(" 📊 Loading benchmark-specific training data...")
265
+ training_data = []
266
+
267
+ try:
268
+ # Load data from the relevant benchmarks
269
+ benchmark_data = self._load_benchmark_data(relevant_benchmarks, num_samples)
270
+ training_data.extend(benchmark_data)
271
+ print(f" ✅ Loaded {len(benchmark_data)} examples from benchmarks")
272
+ except Exception as e:
273
+ print(f" ⚠️ Failed to load benchmark data: {e}")
274
+
275
+ # If we don't have enough data from benchmarks, supplement with synthetic data
276
+ if len(training_data) < num_samples // 2:
277
+ print(" 🧪 Supplementing with synthetic training data...")
278
+ try:
279
+ synthetic_data = self._generate_synthetic_training_data(issue_type, num_samples - len(training_data))
280
+ training_data.extend(synthetic_data)
281
+ print(f" ✅ Added {len(synthetic_data)} synthetic examples")
282
+ except Exception as e:
283
+ print(f" ⚠️ Failed to generate synthetic data: {e}")
284
+
285
+ if not training_data:
286
+ raise ValueError(f"No training data available for {issue_type}")
287
+
288
+ print(f" 📈 Total training examples: {len(training_data)}")
289
+
290
+ # Extract activations
291
+ print(" 🧠 Extracting activations...")
292
+ harmful_activations, harmless_activations = self._extract_activations_from_data(training_data, layer)
293
+
294
+ # Train classifier
295
+ print(" 🎯 Training classifier...")
296
+ classifier = self._train_classifier(harmful_activations, harmless_activations, config)
297
+
298
+ # Evaluate performance
299
+ print(" 📈 Evaluating performance...")
300
+ metrics = self._evaluate_classifier(classifier, harmful_activations, harmless_activations)
301
+
302
+ training_time = time.time() - start_time
303
+
304
+ # Save classifier if path provided
305
+ save_path = None
306
+ if config.save_path:
307
+ print(" 💾 Saving classifier...")
308
+ save_path = self._save_classifier(classifier, config, metrics)
309
+
310
+ result = TrainingResult(
311
+ classifier=classifier.classifier, # Return the base classifier
312
+ config=config,
313
+ performance_metrics=metrics,
314
+ training_time=training_time,
315
+ save_path=save_path,
316
+ )
317
+
318
+ print(
319
+ f" ✅ Benchmark-based classifier created in {training_time:.2f}s "
320
+ f"(F1: {metrics.get('f1', 0):.3f}, Accuracy: {metrics.get('accuracy', 0):.3f})"
321
+ )
322
+ print(f" 📊 Used benchmarks: {relevant_benchmarks}")
323
+
324
+ return result
325
+
326
+ async def create_combined_benchmark_classifier(
327
+ self, benchmark_names: List[str], classifier_params: "ClassifierParams", config: Optional[TrainingConfig] = None
328
+ ) -> TrainingResult:
329
+ """
330
+ Create a classifier trained on combined data from multiple benchmarks.
331
+
332
+ Args:
333
+ benchmark_names: List of benchmark names to combine training data from
334
+ classifier_params: Model-determined classifier parameters
335
+ config: Optional training configuration
336
+
337
+ Returns:
338
+ TrainingResult with the trained combined classifier
339
+ """
340
+ print(f"🏗️ Creating combined classifier from {len(benchmark_names)} benchmarks...")
341
+ print(f" 📊 Benchmarks: {benchmark_names}")
342
+ print(f" 🧠 Using layer {classifier_params.optimal_layer}, {classifier_params.training_samples} samples")
343
+
344
+ # Create config from classifier_params
345
+ if config is None:
346
+ config = TrainingConfig(
347
+ issue_type=f"quality_combined_{'_'.join(sorted(benchmark_names))}",
348
+ layer=classifier_params.optimal_layer,
349
+ classifier_type=classifier_params.classifier_type,
350
+ threshold=classifier_params.classification_threshold,
351
+ training_samples=classifier_params.training_samples,
352
+ model_name=self.model.name,
353
+ )
354
+
355
+ start_time = time.time()
356
+
357
+ # Generate combined training data from all benchmarks
358
+ print(" 📊 Loading and combining benchmark training data...")
359
+ combined_training_data = await self._load_combined_benchmark_data(
360
+ benchmark_names, classifier_params.training_samples
361
+ )
362
+
363
+ print(f" 📈 Loaded {len(combined_training_data)} combined training examples")
364
+
365
+ # Extract activations
366
+ print(" 🧠 Extracting activations...")
367
+ harmful_activations, harmless_activations = self._extract_activations_from_data(
368
+ combined_training_data, classifier_params.optimal_layer
369
+ )
370
+
371
+ # Train classifier
372
+ print(" 🎯 Training combined classifier...")
373
+ classifier = self._train_classifier(harmful_activations, harmless_activations, config)
374
+
375
+ # Evaluate performance
376
+ print(" 📈 Evaluating performance...")
377
+ metrics = self._evaluate_classifier(classifier, harmful_activations, harmless_activations)
378
+
379
+ training_time = time.time() - start_time
380
+
381
+ # Save classifier if path provided
382
+ save_path = None
383
+ if config.save_path:
384
+ print(" 💾 Saving combined classifier...")
385
+ save_path = self._save_classifier(classifier, config, metrics)
386
+
387
+ result = TrainingResult(
388
+ classifier=classifier.classifier,
389
+ config=config,
390
+ performance_metrics=metrics,
391
+ training_time=training_time,
392
+ save_path=save_path,
393
+ )
394
+
395
+ print(
396
+ f" ✅ Combined classifier created in {training_time:.2f}s "
397
+ f"(F1: {metrics.get('f1', 0):.3f}, Accuracy: {metrics.get('accuracy', 0):.3f})"
398
+ )
399
+
400
+ return result
401
+
402
+ async def _load_combined_benchmark_data(
403
+ self, benchmark_names: List[str], total_samples: int
404
+ ) -> List[Dict[str, Any]]:
405
+ """
406
+ Load and combine training data from multiple benchmarks.
407
+
408
+ Args:
409
+ benchmark_names: List of benchmark names to load data from
410
+ total_samples: Total number of training samples to create
411
+
412
+ Returns:
413
+ Combined list of training examples with balanced sampling
414
+ """
415
+ combined_data = []
416
+ samples_per_benchmark = max(1, total_samples // len(benchmark_names))
417
+
418
+ print(f" 📊 Loading ~{samples_per_benchmark} samples per benchmark")
419
+
420
+ for benchmark_name in benchmark_names:
421
+ try:
422
+ print(f" 🔄 Loading data from {benchmark_name}...")
423
+ benchmark_data = self._load_benchmark_data([benchmark_name], samples_per_benchmark)
424
+ combined_data.extend(benchmark_data)
425
+ print(f" ✅ Loaded {len(benchmark_data)} samples from {benchmark_name}")
426
+
427
+ except Exception as e:
428
+ print(f" ⚠️ Failed to load {benchmark_name}: {e}")
429
+ # Continue with other benchmarks
430
+ continue
431
+
432
+ # If we don't have enough samples, pad with synthetic data
433
+ if len(combined_data) < total_samples:
434
+ remaining_samples = total_samples - len(combined_data)
435
+ print(f" 🔧 Generating {remaining_samples} synthetic samples to reach target")
436
+ synthetic_data = self._generate_synthetic_training_data("quality", remaining_samples)
437
+ combined_data.extend(synthetic_data)
438
+
439
+ # Shuffle the combined data to ensure good mixing
440
+ import random
441
+
442
+ random.shuffle(combined_data)
443
+
444
+ # Trim to exact target if we have too many
445
+ combined_data = combined_data[:total_samples]
446
+
447
+ print(f" ✅ Final combined dataset: {len(combined_data)} samples")
448
+ return combined_data
449
+
450
+ async def create_classifier_for_issue(self, issue_type: str, layer: int = 15) -> TrainingResult:
451
+ """
452
+ Create a classifier for an issue type (async version for compatibility).
453
+
454
+ Args:
455
+ issue_type: Type of issue to detect
456
+ layer: Model layer to use for activation extraction
457
+
458
+ Returns:
459
+ TrainingResult with the trained classifier
460
+ """
461
+ return self.create_classifier_for_issue_type(issue_type, layer)
462
+
463
+ def _generate_training_data(self, issue_type: str, num_samples: int) -> List[Dict[str, Any]]:
464
+ """
465
+ Generate training data dynamically for a specific issue type using relevant benchmarks.
466
+
467
+ Args:
468
+ issue_type: Type of issue to generate data for
469
+ num_samples: Number of training samples to generate
470
+
471
+ Returns:
472
+ List of training examples with harmful/harmless pairs
473
+ """
474
+ print(f" 📊 Loading dynamic training data for {issue_type}...")
475
+
476
+ # Try to find relevant benchmarks for the issue type (using default 5-minute budget)
477
+ relevant_benchmarks = self._find_relevant_benchmarks(issue_type)
478
+
479
+ if relevant_benchmarks:
480
+ print(f" 🎯 Found {len(relevant_benchmarks)} relevant benchmarks: {relevant_benchmarks[:3]}...")
481
+ return self._load_benchmark_data(relevant_benchmarks, num_samples)
482
+ print(" 🤖 No specific benchmarks found, using synthetic generation...")
483
+ return self._generate_synthetic_training_data(issue_type, num_samples)
484
+
485
+ def _find_relevant_benchmarks(self, issue_type: str, time_budget_minutes: float = 5.0) -> List[str]:
486
+ """Find relevant benchmarks for the given issue type based on time budget with priority-aware selection."""
487
+ from ..budget import calculate_max_tasks_for_time_budget
488
+ from .tasks.task_relevance import find_relevant_tasks
489
+
490
+ try:
491
+ # Calculate max tasks using budget system
492
+ max_tasks = calculate_max_tasks_for_time_budget(
493
+ task_type="benchmark_evaluation", time_budget_minutes=time_budget_minutes
494
+ )
495
+
496
+ print(f" 🕐 Time budget: {time_budget_minutes:.1f}min → max {max_tasks} tasks")
497
+
498
+ # Use priority-aware intelligent benchmark selection
499
+ try:
500
+ # Import priority-aware selection function
501
+ import os
502
+ import sys
503
+
504
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "lm-harness-integration"))
505
+ from only_benchmarks import find_most_relevant_benchmarks
506
+
507
+ # Use priority-aware selection with time budget
508
+ relevant_results = find_most_relevant_benchmarks(
509
+ prompt=issue_type,
510
+ top_k=max_tasks,
511
+ priority="all",
512
+ fast_only=False,
513
+ time_budget_minutes=time_budget_minutes,
514
+ prefer_fast=True, # Prefer fast benchmarks for agent use
515
+ )
516
+
517
+ # Extract benchmark names
518
+ relevant_benchmarks = [result["benchmark"] for result in relevant_results]
519
+
520
+ if relevant_benchmarks:
521
+ print(f" 📊 Found {len(relevant_benchmarks)} priority-aware benchmarks for '{issue_type}':")
522
+ for i, result in enumerate(relevant_results[:3]):
523
+ priority_str = f" (priority: {result.get('priority', 'unknown')})"
524
+ loading_time_str = f" (loading time: {result.get('loading_time', 60.0):.1f}s)"
525
+ print(f" {i + 1}. {result['benchmark']}{priority_str}{loading_time_str}")
526
+ if len(relevant_benchmarks) > 3:
527
+ print(f" ... and {len(relevant_benchmarks) - 3} more")
528
+
529
+ return relevant_benchmarks
530
+
531
+ except Exception as priority_error:
532
+ print(f" ⚠️ Priority-aware selection failed: {priority_error}")
533
+ print(" 🔄 Falling back to legacy task relevance...")
534
+
535
+ # Fallback to legacy system
536
+ relevant_task_results = find_relevant_tasks(
537
+ query=issue_type, max_results=max_tasks, min_relevance_score=0.1
538
+ )
539
+
540
+ # Extract just the task names
541
+ candidate_benchmarks = [task_name for task_name, score in relevant_task_results]
542
+
543
+ # Use priority-aware budget optimization
544
+ from ..budget import optimize_benchmarks_for_budget
545
+
546
+ relevant_benchmarks = optimize_benchmarks_for_budget(
547
+ task_candidates=candidate_benchmarks,
548
+ time_budget_minutes=time_budget_minutes,
549
+ max_tasks=max_tasks,
550
+ prefer_fast=True, # Agent prefers fast benchmarks
551
+ )
552
+
553
+ if relevant_benchmarks:
554
+ print(f" 📊 Found {len(relevant_benchmarks)} relevant benchmarks for '{issue_type}':")
555
+ # Show the scores for the selected benchmarks
556
+ for i, (task_name, score) in enumerate(relevant_task_results[:3]):
557
+ if task_name in relevant_benchmarks:
558
+ print(f" {i + 1}. {task_name} (relevance: {score:.3f})")
559
+ if len(relevant_benchmarks) > 3:
560
+ print(f" ... and {len(relevant_benchmarks) - 3} more")
561
+
562
+ return relevant_benchmarks
563
+
564
+ except Exception as e:
565
+ print(f" ⚠️ Error finding relevant benchmarks: {e}")
566
+ print(" ⚠️ Using fallback tasks")
567
+ # Minimal fallback to high priority fast benchmarks
568
+ return ["mmlu", "truthfulqa_mc1", "hellaswag"]
569
+
570
+ def _extract_benchmark_concepts(self, benchmark_names: List[str]) -> Dict[str, List[str]]:
571
+ """Extract semantic concepts from benchmark names."""
572
+ concepts = {}
573
+
574
+ for name in benchmark_names:
575
+ # Extract concepts from benchmark name
576
+ name_concepts = []
577
+ name_lower = name.lower()
578
+
579
+ # Split on common separators and extract meaningful tokens
580
+ tokens = name_lower.replace("_", " ").replace("-", " ").split()
581
+
582
+ # Filter out common non-semantic tokens
583
+ semantic_tokens = []
584
+ skip_tokens = {
585
+ "the",
586
+ "and",
587
+ "or",
588
+ "of",
589
+ "in",
590
+ "on",
591
+ "at",
592
+ "to",
593
+ "for",
594
+ "with",
595
+ "by",
596
+ "from",
597
+ "as",
598
+ "is",
599
+ "are",
600
+ "was",
601
+ "were",
602
+ "be",
603
+ "been",
604
+ "being",
605
+ "have",
606
+ "has",
607
+ "had",
608
+ "do",
609
+ "does",
610
+ "did",
611
+ "will",
612
+ "would",
613
+ "could",
614
+ "should",
615
+ "may",
616
+ "might",
617
+ "can",
618
+ "light",
619
+ "full",
620
+ "val",
621
+ "test",
622
+ "dev",
623
+ "mc1",
624
+ "mc2",
625
+ "mt",
626
+ "cot",
627
+ "fewshot",
628
+ "zeroshot",
629
+ "generate",
630
+ "until",
631
+ "multiple",
632
+ "choice",
633
+ "group",
634
+ "subset",
635
+ }
636
+
637
+ for token in tokens:
638
+ if len(token) > 2 and token not in skip_tokens and token.isalpha():
639
+ semantic_tokens.append(token)
640
+
641
+ # Extract domain-specific concepts
642
+ domain_concepts = self._extract_domain_concepts(name_lower, semantic_tokens)
643
+ name_concepts.extend(domain_concepts)
644
+
645
+ concepts[name] = list(set(name_concepts)) # Remove duplicates
646
+
647
+ return concepts
648
+
649
+ def _extract_domain_concepts(self, benchmark_name: str, tokens: List[str]) -> List[str]:
650
+ """Extract domain-specific concepts directly from benchmark name components."""
651
+ concepts = []
652
+
653
+ # Add all meaningful tokens as concepts
654
+ for token in tokens:
655
+ if len(token) > 2:
656
+ concepts.append(token)
657
+
658
+ # Extract compound concept meanings from token combinations
659
+ name_parts = benchmark_name.lower().split("_")
660
+
661
+ # Generate concept combinations
662
+ for i, part in enumerate(name_parts):
663
+ if len(part) > 2:
664
+ concepts.append(part)
665
+
666
+ # Look for meaningful compound concepts
667
+ if i < len(name_parts) - 1:
668
+ next_part = name_parts[i + 1]
669
+ if len(next_part) > 2:
670
+ compound = f"{part}_{next_part}"
671
+ concepts.append(compound)
672
+
673
+ # Extract semantic root words
674
+ for token in tokens:
675
+ root_concepts = self._extract_semantic_roots(token)
676
+ concepts.extend(root_concepts)
677
+
678
+ return list(set(concepts)) # Remove duplicates
679
+
680
+ def _extract_semantic_roots(self, word: str) -> List[str]:
681
+ """Extract semantic root concepts from a word."""
682
+ roots = []
683
+
684
+ # Simple morphological analysis
685
+ # Remove common suffixes to find roots
686
+ suffixes = [
687
+ "ing",
688
+ "tion",
689
+ "sion",
690
+ "ness",
691
+ "ment",
692
+ "able",
693
+ "ible",
694
+ "ful",
695
+ "less",
696
+ "ly",
697
+ "al",
698
+ "ic",
699
+ "ous",
700
+ "ive",
701
+ ]
702
+
703
+ root = word
704
+ for suffix in suffixes:
705
+ if word.endswith(suffix) and len(word) > len(suffix) + 2:
706
+ root = word[: -len(suffix)]
707
+ break
708
+
709
+ if root != word and len(root) > 2:
710
+ roots.append(root)
711
+
712
+ # Add the original word
713
+ roots.append(word)
714
+
715
+ return roots
716
+
717
+ def _calculate_benchmark_relevance(self, issue_type: str, benchmark_concepts: Dict[str, List[str]]) -> List[str]:
718
+ """Calculate relevance scores using semantic similarity."""
719
+ # Calculate relevance scores
720
+ benchmark_scores = []
721
+
722
+ for benchmark_name, concepts in benchmark_concepts.items():
723
+ score = self._calculate_semantic_similarity(issue_type, benchmark_name, concepts)
724
+
725
+ if score > 0:
726
+ benchmark_scores.append((benchmark_name, score))
727
+
728
+ # Sort by relevance score
729
+ benchmark_scores.sort(key=lambda x: x[1], reverse=True)
730
+
731
+ return [name for name, score in benchmark_scores]
732
+
733
+ def _calculate_semantic_similarity(self, issue_type: str, benchmark_name: str, concepts: List[str]) -> float:
734
+ """Calculate semantic similarity between issue type and benchmark."""
735
+ issue_lower = issue_type.lower()
736
+ benchmark_lower = benchmark_name.lower()
737
+
738
+ score = 0.0
739
+
740
+ # Direct name matching (highest weight)
741
+ if issue_lower in benchmark_lower or benchmark_lower in issue_lower:
742
+ score += 5.0
743
+
744
+ # Concept matching
745
+ for concept in concepts:
746
+ concept_lower = concept.lower()
747
+
748
+ # Exact concept match
749
+ if issue_lower == concept_lower:
750
+ score += 4.0
751
+ # Partial concept match
752
+ elif issue_lower in concept_lower or concept_lower in issue_lower:
753
+ score += 2.0
754
+ # Semantic similarity check
755
+ elif self._are_semantically_similar(issue_lower, concept_lower):
756
+ score += 1.5
757
+
758
+ # Token-level similarity in benchmark name
759
+ benchmark_tokens = benchmark_lower.replace("_", " ").replace("-", " ").split()
760
+ issue_tokens = issue_lower.replace("_", " ").replace("-", " ").split()
761
+
762
+ for issue_token in issue_tokens:
763
+ for benchmark_token in benchmark_tokens:
764
+ if len(issue_token) > 2 and len(benchmark_token) > 2:
765
+ if issue_token == benchmark_token:
766
+ score += 3.0
767
+ elif issue_token in benchmark_token or benchmark_token in issue_token:
768
+ score += 1.0
769
+ elif self._are_semantically_similar(issue_token, benchmark_token):
770
+ score += 0.5
771
+
772
+ return score
773
+
774
+ def _are_semantically_similar(self, term1: str, term2: str) -> bool:
775
+ """Check if two terms are semantically similar using algorithmic methods."""
776
+ if len(term1) < 3 or len(term2) < 3:
777
+ return False
778
+
779
+ # Character-level similarity
780
+ overlap = len(set(term1) & set(term2))
781
+ min_len = min(len(term1), len(term2))
782
+ char_similarity = overlap / min_len
783
+
784
+ # Substring similarity
785
+ longer, shorter = (term1, term2) if len(term1) > len(term2) else (term2, term1)
786
+ substring_match = shorter in longer
787
+
788
+ # Prefix/suffix similarity
789
+ prefix_len = 0
790
+ suffix_len = 0
791
+
792
+ for i in range(min(len(term1), len(term2))):
793
+ if term1[i] == term2[i]:
794
+ prefix_len += 1
795
+ else:
796
+ break
797
+
798
+ for i in range(1, min(len(term1), len(term2)) + 1):
799
+ if term1[-i] == term2[-i]:
800
+ suffix_len += 1
801
+ else:
802
+ break
803
+
804
+ affix_similarity = (prefix_len + suffix_len) / max(len(term1), len(term2))
805
+
806
+ # Combined similarity score
807
+ return char_similarity > 0.6 or substring_match or affix_similarity > 0.4 or prefix_len >= 3 or suffix_len >= 3
808
+
809
+ def _prioritize_benchmarks(self, relevant_benchmarks: List[str]) -> List[str]:
810
+ """Prioritize benchmarks algorithmically based on naming patterns and characteristics."""
811
+ benchmark_scores = []
812
+
813
+ for benchmark in relevant_benchmarks:
814
+ score = self._calculate_benchmark_quality_score(benchmark)
815
+ benchmark_scores.append((benchmark, score))
816
+
817
+ # Sort by quality score (higher is better)
818
+ benchmark_scores.sort(key=lambda x: x[1], reverse=True)
819
+ return [benchmark for benchmark, score in benchmark_scores]
820
+
821
+ def _calculate_benchmark_quality_score(self, benchmark_name: str) -> float:
822
+ """Calculate quality score for a benchmark based on naming patterns and characteristics."""
823
+ score = 0.0
824
+ benchmark_lower = benchmark_name.lower()
825
+
826
+ # Length heuristic - moderate length names tend to be well-established
827
+ name_length = len(benchmark_name)
828
+ if 8 <= name_length <= 25:
829
+ score += 2.0
830
+ elif name_length < 8:
831
+ score += 0.5 # Very short names might be too simple
832
+ else:
833
+ score += 1.0 # Very long names might be overly specific
834
+
835
+ # Component analysis
836
+ parts = benchmark_lower.split("_")
837
+ num_parts = len(parts)
838
+
839
+ # Well-structured benchmarks often have 2-3 parts
840
+ if 2 <= num_parts <= 3:
841
+ score += 2.0
842
+ elif num_parts == 1:
843
+ score += 1.5 # Simple names can be good too
844
+ else:
845
+ score += 0.5 # Too many parts might indicate over-specification
846
+
847
+ # Indicator of established benchmarks (avoid hardcoding specific names)
848
+ quality_indicators = [
849
+ # Multiple choice indicators (often well-validated)
850
+ ("mc1", 1.5),
851
+ ("mc2", 1.5),
852
+ ("multiple_choice", 1.5),
853
+ # Evaluation methodology indicators
854
+ ("eval", 1.0),
855
+ ("test", 1.0),
856
+ ("benchmark", 1.0),
857
+ # Language understanding indicators
858
+ ("language", 1.0),
859
+ ("understanding", 1.0),
860
+ ("comprehension", 1.0),
861
+ # Logic and reasoning indicators
862
+ ("logic", 1.0),
863
+ ("reasoning", 1.0),
864
+ ("deduction", 1.0),
865
+ # Knowledge assessment indicators
866
+ ("knowledge", 1.0),
867
+ ("question", 1.0),
868
+ ("answer", 1.0),
869
+ ]
870
+
871
+ for indicator, points in quality_indicators:
872
+ if indicator in benchmark_lower:
873
+ score += points
874
+
875
+ # Penalize very specialized or experimental indicators
876
+ experimental_indicators = [
877
+ "experimental",
878
+ "pilot",
879
+ "demo",
880
+ "sample",
881
+ "tiny",
882
+ "mini",
883
+ "subset",
884
+ "light",
885
+ "debug",
886
+ "test_only",
887
+ ]
888
+
889
+ for indicator in experimental_indicators:
890
+ if indicator in benchmark_lower:
891
+ score -= 1.0
892
+
893
+ # Bonus for domain diversity indicators
894
+ domain_indicators = ["multilingual", "global", "cross", "multi", "diverse"]
895
+
896
+ for indicator in domain_indicators:
897
+ if indicator in benchmark_lower:
898
+ score += 0.5
899
+
900
+ return max(0.0, score) # Ensure non-negative score
901
+
902
+ def _load_benchmark_data(self, benchmarks: List[str], num_samples: int) -> List[Dict[str, Any]]:
903
+ """Load training data from multiple relevant benchmarks."""
904
+ from .tasks import TaskManager
905
+
906
+ training_data = []
907
+ samples_per_benchmark = max(1, num_samples // len(benchmarks))
908
+
909
+ # Create task manager instance
910
+ task_manager = TaskManager()
911
+
912
+ for benchmark in benchmarks:
913
+ try:
914
+ print(f" 🔄 Loading from {benchmark}...")
915
+
916
+ # Load benchmark task using TaskManager
917
+ task_data = task_manager.load_task(benchmark, limit=samples_per_benchmark * 3)
918
+ docs = task_manager.split_task_data(task_data, split_ratio=1.0)[0]
919
+
920
+ # Extract QA pairs using existing system
921
+ from ...contrastive_pairs.contrastive_pair_set import ContrastivePairSet
922
+
923
+ qa_pairs = ContrastivePairSet.extract_qa_pairs_from_task_docs(benchmark, task_data, docs)
924
+
925
+ # Convert to training format
926
+ for pair in qa_pairs[:samples_per_benchmark]:
927
+ if self._is_valid_pair(pair):
928
+ training_data.append(
929
+ {
930
+ "prompt": pair.get("question", f"Context from {benchmark}"),
931
+ "harmful_response": pair.get("incorrect_answer", ""),
932
+ "harmless_response": pair.get("correct_answer", ""),
933
+ "source": benchmark,
934
+ }
935
+ )
936
+
937
+ print(f" ✅ Loaded {len(qa_pairs[:samples_per_benchmark])} examples from {benchmark}")
938
+
939
+ # Stop if we have enough data
940
+ if len(training_data) >= num_samples:
941
+ break
942
+
943
+ except Exception as e:
944
+ print(f" ⚠️ Failed to load {benchmark}: {e}")
945
+ continue
946
+
947
+ if not training_data:
948
+ print(" ❌ Failed to load from any benchmarks, falling back to synthetic...")
949
+ return self._generate_synthetic_training_data("unknown", num_samples)
950
+
951
+ print(
952
+ f" ✅ Loaded {len(training_data)} examples from {len([b for b in benchmarks if any(b in item.get('source', '') for item in training_data)])} benchmarks"
953
+ )
954
+ return training_data[:num_samples]
955
+
956
+ def _is_valid_pair(self, pair: Dict[str, Any]) -> bool:
957
+ """Check if a QA pair is valid for training."""
958
+ return (
959
+ pair.get("correct_answer")
960
+ and pair.get("incorrect_answer")
961
+ and len(pair.get("correct_answer", "")) > 5
962
+ and len(pair.get("incorrect_answer", "")) > 5
963
+ and pair.get("correct_answer") != pair.get("incorrect_answer")
964
+ )
965
+
966
+ def _generate_synthetic_training_data(self, issue_type: str, num_samples: int) -> List[Dict[str, Any]]:
967
+ """Generate synthetic training data for issue types."""
968
+ from ...contrastive_pairs.generate_synthetically import SyntheticContrastivePairGenerator
969
+
970
+ try:
971
+ generator = SyntheticContrastivePairGenerator(self.model)
972
+
973
+ # Create trait description based on issue type
974
+ trait_descriptions = {
975
+ "quality": "providing clear, accurate, and well-structured responses",
976
+ "harmful": "giving safe, helpful, and constructive advice",
977
+ "bias": "responding with fairness and avoiding stereotypes",
978
+ "coherence": "maintaining logical flow and clear communication",
979
+ "hallucination": "providing factually accurate information",
980
+ }
981
+
982
+ trait_description = trait_descriptions.get(issue_type, f"avoiding {issue_type} issues in responses")
983
+
984
+ # Generate synthetic pairs
985
+ synthetic_pairs = generator.generate_contrastive_pair_set(
986
+ trait_description=trait_description, num_pairs=num_samples, name=f"synthetic_{issue_type}"
987
+ )
988
+
989
+ # Convert to training format
990
+ training_data = []
991
+ for pair in synthetic_pairs.pairs[:num_samples]:
992
+ training_data.append(
993
+ {
994
+ "prompt": pair.prompt or f"Context for {issue_type} detection",
995
+ "harmful_response": pair.negative_response,
996
+ "harmless_response": pair.positive_response,
997
+ }
998
+ )
999
+
1000
+ print(f" ✅ Generated {len(training_data)} synthetic examples for {issue_type}")
1001
+ return training_data
1002
+
1003
+ except Exception as e:
1004
+ print(f" ❌ Failed to generate synthetic data: {e}")
1005
+ raise ValueError(f"Cannot generate training data for issue type: {issue_type}")
1006
+
1007
+ def _extract_activations_from_data(
1008
+ self, training_data: List[Dict[str, Any]], layer: int
1009
+ ) -> Tuple[List[Activations], List[Activations]]:
1010
+ """
1011
+ Extract activations from training data.
1012
+
1013
+ Args:
1014
+ training_data: List of training examples
1015
+ layer: Layer to extract activations from
1016
+
1017
+ Returns:
1018
+ Tuple of (harmful_activations, harmless_activations)
1019
+ """
1020
+ harmful_activations = []
1021
+ harmless_activations = []
1022
+
1023
+ layer_obj = Layer(index=layer, type="transformer")
1024
+
1025
+ for example in training_data:
1026
+ # Extract harmful activation
1027
+ harmful_tensor = self.model.extract_activations(example["harmful_response"], layer_obj)
1028
+ harmful_activation = Activations(tensor=harmful_tensor, layer=layer_obj)
1029
+ harmful_activations.append(harmful_activation)
1030
+
1031
+ # Extract harmless activation
1032
+ harmless_tensor = self.model.extract_activations(example["harmless_response"], layer_obj)
1033
+ harmless_activation = Activations(tensor=harmless_tensor, layer=layer_obj)
1034
+ harmless_activations.append(harmless_activation)
1035
+
1036
+ return harmful_activations, harmless_activations
1037
+
1038
+ def _train_classifier(
1039
+ self, harmful_activations: List[Activations], harmless_activations: List[Activations], config: TrainingConfig
1040
+ ) -> ActivationClassifier:
1041
+ """
1042
+ Train a classifier on the activation data.
1043
+
1044
+ Args:
1045
+ harmful_activations: List of harmful activations
1046
+ harmless_activations: List of harmless activations
1047
+ config: Training configuration
1048
+
1049
+ Returns:
1050
+ Trained ActivationClassifier
1051
+ """
1052
+ classifier = ActivationClassifier(
1053
+ model_type=config.classifier_type, threshold=config.threshold, device=self.model.device
1054
+ )
1055
+
1056
+ classifier.train_on_activations(harmful_activations, harmless_activations)
1057
+
1058
+ return classifier
1059
+
1060
+ def _evaluate_classifier(
1061
+ self,
1062
+ classifier: ActivationClassifier,
1063
+ harmful_activations: List[Activations],
1064
+ harmless_activations: List[Activations],
1065
+ ) -> Dict[str, float]:
1066
+ """
1067
+ Evaluate classifier performance.
1068
+
1069
+ Args:
1070
+ classifier: Trained classifier
1071
+ harmful_activations: Test harmful activations
1072
+ harmless_activations: Test harmless activations
1073
+
1074
+ Returns:
1075
+ Dictionary of performance metrics
1076
+ """
1077
+ # Use a portion of data for testing
1078
+ test_size = min(10, len(harmful_activations) // 5) # 20% or at least 10
1079
+
1080
+ test_harmful = harmful_activations[-test_size:]
1081
+ test_harmless = harmless_activations[-test_size:]
1082
+
1083
+ return classifier.evaluate_on_activations(test_harmful, test_harmless)
1084
+
1085
+ def _save_classifier(
1086
+ self, classifier: ActivationClassifier, config: TrainingConfig, metrics: Dict[str, float]
1087
+ ) -> str:
1088
+ """
1089
+ Save classifier with metadata.
1090
+
1091
+ Args:
1092
+ classifier: Trained classifier
1093
+ config: Training configuration
1094
+ metrics: Performance metrics
1095
+
1096
+ Returns:
1097
+ Path where classifier was saved
1098
+ """
1099
+ # Create metadata
1100
+ metadata = create_classifier_metadata(
1101
+ model_name=config.model_name,
1102
+ task_name=config.issue_type,
1103
+ layer=config.layer,
1104
+ classifier_type=config.classifier_type,
1105
+ training_accuracy=metrics.get("accuracy", 0.0),
1106
+ training_samples=config.training_samples,
1107
+ token_aggregation="final", # Default for our system
1108
+ detection_threshold=config.threshold,
1109
+ f1=metrics.get("f1", 0.0),
1110
+ precision=metrics.get("precision", 0.0),
1111
+ recall=metrics.get("recall", 0.0),
1112
+ auc=metrics.get("auc", 0.0),
1113
+ )
1114
+
1115
+ # Save using ModelPersistence
1116
+ save_path = ModelPersistence.save_classifier(classifier.classifier, config.layer, config.save_path, metadata)
1117
+
1118
+ return save_path
1119
+
1120
+
1121
+ def create_classifier_on_demand(
1122
+ model: Model, issue_type: str, layer: int = None, save_path: str = None, optimize: bool = False
1123
+ ) -> TrainingResult:
1124
+ """
1125
+ Convenience function to create a classifier on demand.
1126
+
1127
+ Args:
1128
+ model: Language model to use
1129
+ issue_type: Type of issue to detect
1130
+ layer: Specific layer to use (auto-optimized if None)
1131
+ save_path: Path to save the classifier
1132
+ optimize: Whether to optimize for best performance
1133
+
1134
+ Returns:
1135
+ TrainingResult with the created classifier
1136
+ """
1137
+ creator = ClassifierCreator(model)
1138
+
1139
+ if optimize or layer is None:
1140
+ # Optimize to find best configuration
1141
+ result = creator.optimize_classifier_for_performance(issue_type)
1142
+
1143
+ # Save if path provided
1144
+ if save_path:
1145
+ result.config.save_path = save_path
1146
+ result.save_path = creator._save_classifier(
1147
+ ActivationClassifier(device=model.device), result.config, result.performance_metrics
1148
+ )
1149
+
1150
+ return result
1151
+ # Use specified layer
1152
+ config = TrainingConfig(issue_type=issue_type, layer=layer, save_path=save_path, model_name=model.name)
1153
+
1154
+ return creator.create_classifier_for_issue_type(issue_type, layer, config)