wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
wisent/core/parser.py ADDED
@@ -0,0 +1,1668 @@
1
+ """
2
+ Command-line argument parser for wisent-guard.
3
+ """
4
+
5
+ import argparse
6
+ from typing import List, Optional
7
+
8
+
9
+ def setup_parser() -> argparse.ArgumentParser:
10
+ """Set up the main CLI parser with subcommands."""
11
+ parser = argparse.ArgumentParser(description="Wisent-Guard: Advanced AI Safety and Alignment Toolkit")
12
+
13
+ # Global arguments
14
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
15
+
16
+ # Create subparsers for different commands
17
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
18
+
19
+ # Tasks command (main evaluation pipeline)
20
+ tasks_parser = subparsers.add_parser("tasks", help="Run evaluation tasks")
21
+ setup_tasks_parser(tasks_parser)
22
+
23
+ # Generate pairs command
24
+ generate_parser = subparsers.add_parser("generate-pairs", help="Generate synthetic contrastive pairs")
25
+ setup_generate_pairs_parser(generate_parser)
26
+
27
+ # Synthetic command (generate + train + test)
28
+ synthetic_parser = subparsers.add_parser("synthetic", help="Run synthetic contrastive pair pipeline")
29
+ setup_synthetic_parser(synthetic_parser)
30
+
31
+ # Test nonsense detection command
32
+ test_nonsense_parser = subparsers.add_parser("test-nonsense", help="Test nonsense detection system")
33
+ setup_test_nonsense_parser(test_nonsense_parser)
34
+
35
+ # Monitor command for performance monitoring
36
+ monitor_parser = subparsers.add_parser("monitor", help="Performance monitoring and system information")
37
+ setup_monitor_parser(monitor_parser)
38
+
39
+ # Agent command for autonomous agent interaction
40
+ agent_parser = subparsers.add_parser("agent", help="Interact with autonomous agent")
41
+ setup_agent_parser(agent_parser)
42
+
43
+ # Model configuration command for managing optimal parameters
44
+ model_config_parser = subparsers.add_parser("model-config", help="Manage model-specific optimal parameters")
45
+ setup_model_config_parser(model_config_parser)
46
+
47
+ # Configure model command for setting up new/unsupported models
48
+ configure_model_parser = subparsers.add_parser(
49
+ "configure-model", help="Configure tokens and layer access for unsupported models"
50
+ )
51
+ setup_configure_model_parser(configure_model_parser)
52
+
53
+ # Classification optimization command for finding optimal classification parameters
54
+ classification_optimizer_parser = subparsers.add_parser(
55
+ "optimize-classification", help="Optimize classification parameters across all tasks"
56
+ )
57
+ setup_classification_optimizer_parser(classification_optimizer_parser)
58
+
59
+ # Steering optimization command for finding optimal steering parameters
60
+ steering_optimizer_parser = subparsers.add_parser(
61
+ "optimize-steering", help="Optimize steering parameters for different methods"
62
+ )
63
+ setup_steering_optimizer_parser(steering_optimizer_parser)
64
+
65
+ # Sample size optimization command for finding optimal training sample sizes
66
+ sample_size_optimizer_parser = subparsers.add_parser(
67
+ "optimize-sample-size", help="Find optimal training sample size for classifiers"
68
+ )
69
+ setup_sample_size_optimizer_parser(sample_size_optimizer_parser)
70
+
71
+ # Full optimization command that runs both classification and sample size optimization
72
+ full_optimizer_parser = subparsers.add_parser(
73
+ "full-optimize", help="Run full optimization: classification parameters then sample size"
74
+ )
75
+ setup_full_optimizer_parser(full_optimizer_parser)
76
+
77
+ # Generate vector command for creating steering vectors without tasks
78
+ generate_vector_parser = subparsers.add_parser(
79
+ "generate-vector", help="Generate steering vectors from contrastive pairs (file or description)"
80
+ )
81
+ setup_generate_vector_parser(generate_vector_parser)
82
+
83
+ # Multi-vector steering command for combining multiple vectors at inference time
84
+ multi_steer_parser = subparsers.add_parser(
85
+ "multi-steer", help="Combine multiple steering vectors dynamically at inference time"
86
+ )
87
+ setup_multi_steer_parser(multi_steer_parser)
88
+
89
+ # Single-prompt evaluation command for real-time steering assessment
90
+ evaluate_parser = subparsers.add_parser(
91
+ "evaluate", help="Evaluate single prompt with steering vector and return quality scores"
92
+ )
93
+ setup_evaluate_parser(evaluate_parser)
94
+
95
+ return parser
96
+
97
+
98
+ def setup_tasks_parser(parser):
99
+ """Set up the tasks subcommand parser."""
100
+
101
+ # Task listing options (mutually exclusive with task execution)
102
+ list_group = parser.add_mutually_exclusive_group()
103
+ list_group.add_argument(
104
+ "--list-tasks",
105
+ action="store_true",
106
+ help="List all 37 available benchmark tasks organized by priority (excludes 28 known problematic benchmarks)",
107
+ )
108
+ list_group.add_argument(
109
+ "--task-info", type=str, metavar="TASK_NAME", help="Show detailed information about a specific task"
110
+ )
111
+ list_group.add_argument("--all", action="store_true", help="Run all 37 available benchmarks automatically")
112
+
113
+ # Task execution argument (optional when using listing commands or --all)
114
+ parser.add_argument(
115
+ "task_names",
116
+ nargs="?",
117
+ help="Comma-separated list of available task names (37 working benchmarks), or path to CSV/JSON file with --from-csv/--from-json (not needed with --all)",
118
+ )
119
+
120
+ # Skills/risks based task selection
121
+ parser.add_argument(
122
+ "--skills", type=str, nargs="+", help="Select tasks by skill categories (e.g., coding, mathematics, reasoning)"
123
+ )
124
+ parser.add_argument(
125
+ "--risks",
126
+ type=str,
127
+ nargs="+",
128
+ help="Select tasks by risk categories (e.g., harmfulness, toxicity, hallucination)",
129
+ )
130
+ parser.add_argument(
131
+ "--num-tasks",
132
+ type=int,
133
+ default=None,
134
+ help="Number of tasks to randomly select from matched tasks (default: all)",
135
+ )
136
+ parser.add_argument(
137
+ "--min-quality-score",
138
+ type=int,
139
+ default=2,
140
+ choices=[1, 2, 3, 4, 5],
141
+ help="Minimum quality score for tasks when using --skills/--risks (default: 2)",
142
+ )
143
+ parser.add_argument(
144
+ "--task-seed", type=int, default=None, help="Random seed for task selection (for reproducibility)"
145
+ )
146
+
147
+ # Mixed sampling from multiple benchmarks
148
+ parser.add_argument(
149
+ "--tag",
150
+ type=str,
151
+ nargs="+",
152
+ help="Sample randomly from all benchmarks with these tags (e.g., --tag coding). Creates a mixed dataset from multiple benchmarks.",
153
+ )
154
+ parser.add_argument(
155
+ "--mixed-samples",
156
+ type=int,
157
+ default=1000,
158
+ help="Total number of samples to collect when using --tag (default: 1000)",
159
+ )
160
+ parser.add_argument(
161
+ "--tag-mode",
162
+ type=str,
163
+ choices=["any", "all"],
164
+ default="any",
165
+ help="Whether benchmarks must have ANY or ALL specified tags (default: any)",
166
+ )
167
+
168
+ # Cross-benchmark evaluation
169
+ parser.add_argument(
170
+ "--train-task", type=str, help="Task/benchmark to train on (can be a task name or --tag for mixed)"
171
+ )
172
+ parser.add_argument(
173
+ "--eval-task", type=str, help="Task/benchmark to evaluate on (can be a task name or --tag for mixed)"
174
+ )
175
+ parser.add_argument(
176
+ "--train-tag", type=str, nargs="+", help="Tags for training data when using cross-benchmark evaluation"
177
+ )
178
+ parser.add_argument(
179
+ "--eval-tag", type=str, nargs="+", help="Tags for evaluation data when using cross-benchmark evaluation"
180
+ )
181
+ parser.add_argument(
182
+ "--cross-benchmark",
183
+ action="store_true",
184
+ help="Enable cross-benchmark evaluation mode (train on one, eval on another)",
185
+ )
186
+
187
+ # Synthetic pair generation
188
+ parser.add_argument(
189
+ "--synthetic", action="store_true", help="Generate synthetic contrastive pairs from a trait description"
190
+ )
191
+ parser.add_argument(
192
+ "--trait",
193
+ type=str,
194
+ help="Natural language description of desired model behavior (e.g., 'hallucinates less', 'more factual', 'less verbose')",
195
+ )
196
+ parser.add_argument(
197
+ "--num-synthetic-pairs", type=int, default=30, help="Number of synthetic pairs to generate (default: 30)"
198
+ )
199
+ parser.add_argument("--save-synthetic", type=str, help="Path to save generated synthetic pairs as JSON")
200
+ parser.add_argument(
201
+ "--load-synthetic", type=str, help="Path to load previously generated synthetic pairs from JSON"
202
+ )
203
+
204
+ parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="Model name or path")
205
+ parser.add_argument(
206
+ "--layer",
207
+ type=str,
208
+ default="15",
209
+ help="Layer(s) to extract activations from. Can be a single layer (15), range (14-16), or comma-separated list (14,15,16)",
210
+ )
211
+ parser.add_argument("--shots", type=int, default=0, help="Number of few-shot examples")
212
+ parser.add_argument("--split-ratio", type=float, default=0.8, help="Train/test split ratio")
213
+ parser.add_argument("--limit", type=int, default=None, help="Limit number of documents per task")
214
+ parser.add_argument(
215
+ "--training-limit",
216
+ type=int,
217
+ default=None,
218
+ help="Limit number of training documents (overrides limit for training)",
219
+ )
220
+ parser.add_argument(
221
+ "--testing-limit",
222
+ type=int,
223
+ default=None,
224
+ help="Limit number of testing documents (overrides limit for testing)",
225
+ )
226
+ parser.add_argument("--output", type=str, default="./results", help="Output directory for results")
227
+ parser.add_argument(
228
+ "--classifier-type", type=str, choices=["logistic", "mlp"], default="logistic", help="Type of classifier"
229
+ )
230
+ parser.add_argument("--max-new-tokens", type=int, default=300, help="Maximum new tokens for generation")
231
+ parser.add_argument("--device", type=str, default=None, help="Device to run on")
232
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
233
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
234
+ parser.add_argument(
235
+ "--token-aggregation",
236
+ type=str,
237
+ choices=["average", "final", "first", "max", "min"],
238
+ default="average",
239
+ help="How to aggregate token scores for classification",
240
+ )
241
+ parser.add_argument(
242
+ "--ground-truth-method",
243
+ type=str,
244
+ choices=[
245
+ "none",
246
+ "exact_match",
247
+ "substring_match",
248
+ "user_specified",
249
+ "interactive",
250
+ "manual_review",
251
+ "good",
252
+ "lm-eval-harness",
253
+ ],
254
+ default="lm-eval-harness",
255
+ help="Method for ground truth evaluation. 'lm-eval-harness' uses lm-eval-harness tasks for evaluation (default for most tasks), 'none' skips evaluation, 'exact_match' and 'substring_match' are problematic for free-form generation, 'user_specified' allows manual labeling, 'interactive' prompts for y/n labeling, 'manual_review' marks for review, 'good' marks everything as truthful (for debugging)",
256
+ )
257
+ parser.add_argument(
258
+ "--user-labels",
259
+ type=str,
260
+ nargs="*",
261
+ default=None,
262
+ help="User-specified ground truth labels for responses ('truthful' or 'hallucination'). Used with --ground-truth-method user_specified",
263
+ )
264
+
265
+ # File input arguments
266
+ parser.add_argument(
267
+ "--from-csv",
268
+ action="store_true",
269
+ help="Load task data from CSV file. Requires columns: question, correct_answer, incorrect_answer",
270
+ )
271
+ parser.add_argument(
272
+ "--from-json",
273
+ action="store_true",
274
+ help="Load task data from JSON file. Expected format: list of objects with question, correct_answer, incorrect_answer",
275
+ )
276
+ parser.add_argument(
277
+ "--question-col", type=str, default="question", help="Column name for questions in CSV file (default: question)"
278
+ )
279
+ parser.add_argument(
280
+ "--correct-col",
281
+ type=str,
282
+ default="correct_answer",
283
+ help="Column name for correct answers in CSV file (default: correct_answer)",
284
+ )
285
+ parser.add_argument(
286
+ "--incorrect-col",
287
+ type=str,
288
+ default="incorrect_answer",
289
+ help="Column name for incorrect answers in CSV file (default: incorrect_answer)",
290
+ )
291
+
292
+ # Optimization arguments
293
+ parser.add_argument(
294
+ "--optimize",
295
+ action="store_true",
296
+ help="Enable hyperparameter optimization. When enabled, will find optimal layer, threshold, and aggregation method",
297
+ )
298
+ parser.add_argument(
299
+ "--optimize-layers",
300
+ type=str,
301
+ default="all",
302
+ help="Layer range for optimization (e.g., '8-24' or '10,15,20' or 'all'). Default: all (uses all model layers)",
303
+ )
304
+ parser.add_argument(
305
+ "--optimize-metric",
306
+ type=str,
307
+ choices=["accuracy", "f1", "precision", "recall", "auc"],
308
+ default="f1",
309
+ help="Metric to optimize for. Default: f1",
310
+ )
311
+ parser.add_argument(
312
+ "--optimize-max-combinations",
313
+ type=int,
314
+ default=100,
315
+ help="Maximum number of hyperparameter combinations to test. Default: 100",
316
+ )
317
+ parser.add_argument(
318
+ "--auto-optimize",
319
+ action="store_true",
320
+ help="Automatically enable optimization when layer is not specified or is -1",
321
+ )
322
+
323
+ # Dataset validation arguments
324
+ parser.add_argument(
325
+ "--allow-small-dataset",
326
+ action="store_true",
327
+ help="Allow training with datasets smaller than 4 samples (may cause training issues)",
328
+ )
329
+
330
+ # Detection handling arguments
331
+ parser.add_argument(
332
+ "--detection-action",
333
+ type=str,
334
+ choices=["pass_through", "replace_with_placeholder", "regenerate_until_safe"],
335
+ default="pass_through",
336
+ help="Action to take when problematic content is detected (default: pass_through)",
337
+ )
338
+ parser.add_argument(
339
+ "--placeholder-message",
340
+ type=str,
341
+ default=None,
342
+ help="Custom placeholder message for detected content (if not specified, uses default)",
343
+ )
344
+ parser.add_argument(
345
+ "--max-regeneration-attempts",
346
+ type=int,
347
+ default=3,
348
+ help="Maximum attempts to regenerate safe content (default: 3)",
349
+ )
350
+ parser.add_argument(
351
+ "--detection-threshold",
352
+ type=float,
353
+ default=0.6,
354
+ help="Threshold for classification (higher = more strict detection) (default: 0.6)",
355
+ )
356
+ parser.add_argument("--log-detections", action="store_true", help="Enable logging of detection events")
357
+
358
+ # Code execution security arguments
359
+ parser.add_argument(
360
+ "--trust-code-execution",
361
+ action="store_true",
362
+ help="⚠️ UNSAFE: Allow code execution without Docker in trusted sandbox environments (e.g., RunPod containers). Use only in secure, isolated environments!",
363
+ )
364
+
365
+ # Steering mode arguments
366
+ parser.add_argument(
367
+ "--steering-mode", action="store_true", help="Enable steering mode (uses CAA vectors instead of classification)"
368
+ )
369
+ parser.add_argument(
370
+ "--steering-strength", type=float, default=1.0, help="Strength of steering vector application (default: 1.0)"
371
+ )
372
+
373
+ # Steering method selection
374
+ parser.add_argument(
375
+ "--steering-method",
376
+ type=str,
377
+ default="CAA",
378
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
379
+ help="Steering method to use",
380
+ )
381
+
382
+ # Steering output mode selection
383
+ parser.add_argument(
384
+ "--output-mode",
385
+ type=str,
386
+ default="both",
387
+ choices=["likelihoods", "responses", "both"],
388
+ help="Type of comparison to show: 'likelihoods' for log-likelihood comparison only, 'responses' for response generation only, 'both' for both (default: both)",
389
+ )
390
+
391
+ # HPR-specific parameters
392
+ parser.add_argument("--hpr-beta", type=float, default=1.0, help="Beta parameter for HPR method")
393
+
394
+ # DAC-specific parameters
395
+ parser.add_argument("--dac-dynamic-control", action="store_true", help="Enable dynamic control for DAC method")
396
+ parser.add_argument(
397
+ "--dac-entropy-threshold", type=float, default=1.0, help="Entropy threshold for DAC dynamic control"
398
+ )
399
+
400
+ # BiPO-specific parameters
401
+ parser.add_argument("--bipo-beta", type=float, default=0.1, help="Beta parameter for BiPO method")
402
+ parser.add_argument("--bipo-learning-rate", type=float, default=5e-4, help="Learning rate for BiPO method")
403
+ parser.add_argument("--bipo-epochs", type=int, default=100, help="Number of epochs for BiPO training")
404
+
405
+ # K-Steering-specific parameters
406
+ parser.add_argument(
407
+ "--ksteering-num-labels", type=int, default=6, help="Number of labels for K-steering classifier"
408
+ )
409
+ parser.add_argument(
410
+ "--ksteering-hidden-dim", type=int, default=512, help="Hidden dimension for K-steering classifier"
411
+ )
412
+ parser.add_argument(
413
+ "--ksteering-learning-rate", type=float, default=1e-3, help="Learning rate for K-steering classifier training"
414
+ )
415
+ parser.add_argument(
416
+ "--ksteering-classifier-epochs",
417
+ type=int,
418
+ default=100,
419
+ help="Number of epochs for K-steering classifier training",
420
+ )
421
+ parser.add_argument(
422
+ "--ksteering-target-labels",
423
+ type=str,
424
+ default="0",
425
+ help="Comma-separated target label indices for K-steering (e.g., '0,1,2')",
426
+ )
427
+ parser.add_argument(
428
+ "--ksteering-avoid-labels",
429
+ type=str,
430
+ default="",
431
+ help="Comma-separated avoid label indices for K-steering (e.g., '3,4,5')",
432
+ )
433
+ parser.add_argument(
434
+ "--ksteering-alpha", type=float, default=50.0, help="Alpha parameter (step size) for K-steering"
435
+ )
436
+
437
+ # Token steering arguments
438
+ parser.add_argument("--enable-token-steering", action="store_true", help="Enable token-level steering control")
439
+ parser.add_argument(
440
+ "--token-steering-strategy",
441
+ type=str,
442
+ default="last_only",
443
+ choices=[
444
+ "last_only",
445
+ "first_only",
446
+ "all_equal",
447
+ "exponential_decay",
448
+ "exponential_growth",
449
+ "linear_decay",
450
+ "linear_growth",
451
+ "custom",
452
+ ],
453
+ help="Token steering strategy (default: last_only)",
454
+ )
455
+ parser.add_argument(
456
+ "--token-decay-rate",
457
+ type=float,
458
+ default=0.5,
459
+ help="Decay rate for exponential token steering strategies (0-1, default: 0.5)",
460
+ )
461
+ parser.add_argument(
462
+ "--token-min-strength",
463
+ type=float,
464
+ default=0.1,
465
+ help="Minimum steering strength for token strategies (default: 0.1)",
466
+ )
467
+ parser.add_argument(
468
+ "--token-max-strength",
469
+ type=float,
470
+ default=1.0,
471
+ help="Maximum steering strength for token strategies (default: 1.0)",
472
+ )
473
+ parser.add_argument(
474
+ "--token-apply-to-prompt",
475
+ action="store_true",
476
+ help="Apply steering to prompt tokens as well as generated tokens",
477
+ )
478
+ parser.add_argument(
479
+ "--token-prompt-strength-multiplier",
480
+ type=float,
481
+ default=0.1,
482
+ help="Strength multiplier for prompt tokens (default: 0.1)",
483
+ )
484
+
485
+ # Training/Inference mode arguments
486
+ parser.add_argument(
487
+ "--train-only",
488
+ action="store_true",
489
+ help="Training-only mode: train classifiers/vectors and save them, skip inference",
490
+ )
491
+ parser.add_argument(
492
+ "--inference-only",
493
+ action="store_true",
494
+ help="Inference-only mode: load pre-trained classifiers/vectors and use for monitoring/steering",
495
+ )
496
+ parser.add_argument(
497
+ "--save-classifier",
498
+ type=str,
499
+ default=None,
500
+ help="Path to save trained classifier(s). In multi-layer mode, saves one file per layer with layer suffix",
501
+ )
502
+ parser.add_argument(
503
+ "--load-classifier",
504
+ type=str,
505
+ default=None,
506
+ help="Path to load pre-trained classifier(s). In multi-layer mode, expects files with layer suffix",
507
+ )
508
+ parser.add_argument(
509
+ "--classifier-dir",
510
+ type=str,
511
+ default="./models",
512
+ help="Directory for saving/loading classifiers and vectors (default: ./models)",
513
+ )
514
+
515
+ # Prompt construction and token targeting strategy arguments
516
+ parser.add_argument(
517
+ "--prompt-construction-strategy",
518
+ type=str,
519
+ choices=["multiple_choice", "role_playing", "direct_completion", "instruction_following"],
520
+ default="multiple_choice",
521
+ help="Strategy for constructing prompts from question-answer pairs (default: multiple_choice)",
522
+ )
523
+ parser.add_argument(
524
+ "--token-targeting-strategy",
525
+ type=str,
526
+ choices=["choice_token", "continuation_token", "last_token", "first_token", "mean_pooling", "max_pooling"],
527
+ default="choice_token",
528
+ help="Strategy for targeting tokens during activation extraction (default: choice_token)",
529
+ )
530
+
531
+ # Normalization options
532
+ parser.add_argument("--normalize-mode", action="store_true", help="Enable normalization mode (legacy flag)")
533
+ parser.add_argument(
534
+ "--normalization-method",
535
+ type=str,
536
+ default="none",
537
+ choices=["none", "l2_unit", "cross_behavior", "layer_wise_mean"],
538
+ help="Vector normalization method to apply",
539
+ )
540
+ parser.add_argument("--target-norm", type=float, default=None, help="Target norm for certain normalization methods")
541
+
542
+ # Nonsense detection options
543
+ parser.add_argument(
544
+ "--enable-nonsense-detection",
545
+ action="store_true",
546
+ help="Enable nonsense detection to stop lobotomized responses",
547
+ )
548
+ parser.add_argument(
549
+ "--max-word-length",
550
+ type=int,
551
+ default=20,
552
+ help="Maximum reasonable word length for nonsense detection (default: 20)",
553
+ )
554
+ parser.add_argument(
555
+ "--repetition-threshold",
556
+ type=float,
557
+ default=0.7,
558
+ help="Threshold for repetitive content detection (0-1, default: 0.7)",
559
+ )
560
+ parser.add_argument(
561
+ "--gibberish-threshold",
562
+ type=float,
563
+ default=0.3,
564
+ help="Threshold for gibberish word detection (0-1, default: 0.3)",
565
+ )
566
+ parser.add_argument(
567
+ "--disable-dictionary-check",
568
+ action="store_true",
569
+ help="Disable dictionary-based word validation (faster but less accurate)",
570
+ )
571
+ parser.add_argument(
572
+ "--nonsense-action",
573
+ type=str,
574
+ default="regenerate",
575
+ choices=["regenerate", "stop", "flag"],
576
+ help="Action when nonsense is detected: regenerate, stop generation, or flag for review",
577
+ )
578
+
579
+ # Performance monitoring options
580
+ parser.add_argument(
581
+ "--enable-memory-tracking", action="store_true", help="Enable memory usage tracking and reporting"
582
+ )
583
+ parser.add_argument(
584
+ "--enable-latency-tracking", action="store_true", help="Enable latency/timing tracking and reporting"
585
+ )
586
+ parser.add_argument(
587
+ "--memory-sampling-interval", type=float, default=0.1, help="Memory sampling interval in seconds (default: 0.1)"
588
+ )
589
+ parser.add_argument("--track-gpu-memory", action="store_true", help="Track GPU memory usage (requires CUDA)")
590
+ parser.add_argument(
591
+ "--detailed-performance-report",
592
+ action="store_true",
593
+ help="Generate detailed performance report with all metrics",
594
+ )
595
+ parser.add_argument("--export-performance-csv", type=str, default=None, help="Export performance data to CSV file")
596
+ parser.add_argument(
597
+ "--show-memory-usage", action="store_true", help="Show current memory usage without full tracking"
598
+ )
599
+ parser.add_argument("--show-timing-summary", action="store_true", help="Show timing summary after evaluation")
600
+
601
+ # Test-time activation saving/loading options
602
+ parser.add_argument(
603
+ "--save-test-activations", type=str, default=None, help="Save test activations to file for future use"
604
+ )
605
+ parser.add_argument(
606
+ "--load-test-activations", type=str, default=None, help="Load test activations from file instead of computing"
607
+ )
608
+
609
+ # Priority-aware benchmark selection options
610
+ parser.add_argument(
611
+ "--priority",
612
+ type=str,
613
+ default="all",
614
+ choices=["all", "high", "medium", "low"],
615
+ help="Priority level for benchmark selection (default: all)",
616
+ )
617
+ parser.add_argument(
618
+ "--fast-only", action="store_true", help="Only use fast benchmarks (high priority, < 13.5s loading time)"
619
+ )
620
+ parser.add_argument(
621
+ "--time-budget",
622
+ type=float,
623
+ default=None,
624
+ help="Time budget in minutes for benchmark selection (auto-selects fast benchmarks)",
625
+ )
626
+ parser.add_argument(
627
+ "--max-benchmarks",
628
+ type=int,
629
+ default=None,
630
+ help="Maximum number of benchmarks to select (combines with priority filtering)",
631
+ )
632
+ parser.add_argument(
633
+ "--smart-selection", action="store_true", help="Use smart benchmark selection based on relevance and priority"
634
+ )
635
+ parser.add_argument(
636
+ "--prefer-fast",
637
+ action="store_true",
638
+ help="Prefer fast benchmarks in selection when multiple options are available",
639
+ )
640
+
641
+ parser.add_argument(
642
+ "--save-steering-vector", type=str, default=None, help="Path to save the computed steering vector"
643
+ )
644
+ parser.add_argument(
645
+ "--load-steering-vector", type=str, default=None, help="Path to load a pre-computed steering vector"
646
+ )
647
+
648
+ # Additional output options
649
+ parser.add_argument("--csv-output", type=str, default=None, help="Path to save results in CSV format")
650
+ parser.add_argument("--evaluation-report", type=str, default=None, help="Path to save evaluation report")
651
+ parser.add_argument("--continue-on-error", action="store_true", help="Continue processing other tasks if one fails")
652
+
653
+ # Benchmark caching arguments
654
+ parser.add_argument(
655
+ "--cache-benchmark",
656
+ action="store_true",
657
+ default=True,
658
+ help="Cache the benchmark data locally for faster future access (default: True)",
659
+ )
660
+ parser.add_argument("--no-cache", dest="cache_benchmark", action="store_false", help="Disable benchmark caching")
661
+ parser.add_argument(
662
+ "--use-cached", action="store_true", default=True, help="Use cached benchmark data if available (default: True)"
663
+ )
664
+ parser.add_argument(
665
+ "--force-download", action="store_true", help="Force fresh download even if cached version exists"
666
+ )
667
+ parser.add_argument(
668
+ "--cache-dir",
669
+ type=str,
670
+ default="./benchmark_cache",
671
+ help="Directory to store cached benchmark data (default: ./benchmark_cache)",
672
+ )
673
+ parser.add_argument("--cache-status", action="store_true", help="Show cache status and exit")
674
+ parser.add_argument("--cleanup-cache", type=int, metavar="DAYS", help="Clean up cache entries older than DAYS days")
675
+
676
+
677
+ def parse_layers_from_arg(layer_arg: str, model=None) -> List[int]:
678
+ """
679
+ Parse layer argument into list of integers.
680
+
681
+ Args:
682
+ layer_arg: String like "15", "14-16", "14,15,16", or "-1" (for auto-optimization)
683
+ model: Model object (needed for determining available layers)
684
+
685
+ Returns:
686
+ List of layer indices
687
+ """
688
+ # Handle special cases
689
+ if layer_arg == "-1":
690
+ # Signal for auto-optimization - return single layer list
691
+ return [-1]
692
+
693
+ # Use existing parse_layer_range logic
694
+ layers = parse_layer_range(layer_arg, model)
695
+ if layers is None:
696
+ # "all" case - auto-detect model layers
697
+ if model is not None:
698
+ from .hyperparameter_optimizer import detect_model_layers
699
+
700
+ total_layers = detect_model_layers(model)
701
+ return list(range(total_layers))
702
+ # If no model provided, we cannot determine layers - this should not happen
703
+ raise ValueError("Cannot determine layer range without model instance")
704
+
705
+ return layers
706
+
707
+
708
+ def parse_layer_range(layer_range_str: str, model=None) -> Optional[List[int]]:
709
+ """
710
+ Parse layer range string into list of integers.
711
+
712
+ Args:
713
+ layer_range_str: String like "8-24", "10,15,20", or "all"
714
+ model: Model object (needed for "all" option)
715
+
716
+ Returns:
717
+ List of layer indices, or None if "all" (will be auto-detected later)
718
+ """
719
+ if layer_range_str.lower() == "all":
720
+ # Return None to signal auto-detection
721
+ return None
722
+ if "-" in layer_range_str:
723
+ # Range format: "8-24"
724
+ start, end = map(int, layer_range_str.split("-"))
725
+ return list(range(start, end + 1))
726
+ if "," in layer_range_str:
727
+ # Comma-separated format: "10,15,20"
728
+ return [int(x.strip()) for x in layer_range_str.split(",")]
729
+ # Single layer
730
+ return [int(layer_range_str)]
731
+
732
+
733
+ def aggregate_token_scores(token_scores: List[float], method: str) -> float:
734
+ """
735
+ Aggregate token scores using the specified method.
736
+
737
+ Args:
738
+ token_scores: List of token scores (probabilities)
739
+ method: Aggregation method ("average", "final", "first", "max", "min")
740
+
741
+ Returns:
742
+ Aggregated score
743
+ """
744
+ if not token_scores:
745
+ return 0.5
746
+
747
+ # Convert any tensor values to floats and filter out None values
748
+ clean_scores = []
749
+ for i, score in enumerate(token_scores):
750
+ if score is None:
751
+ raise ValueError(
752
+ f"Token score at index {i} is None! This indicates a bug in the classifier output handling."
753
+ )
754
+ if hasattr(score, "item"): # Handle tensors
755
+ raise ValueError(
756
+ f"Token score at index {i} is a tensor ({type(score)})! Expected float but got tensor: {score}"
757
+ )
758
+ if not isinstance(score, (int, float)):
759
+ raise ValueError(
760
+ f"Token score at index {i} has invalid type: {type(score)}. Expected float but got {type(score).__name__}: {score}"
761
+ )
762
+ clean_scores.append(float(score))
763
+
764
+ if not clean_scores:
765
+ return 0.5
766
+
767
+ if method == "average":
768
+ return sum(clean_scores) / len(clean_scores)
769
+ if method == "final":
770
+ return clean_scores[-1]
771
+ if method == "first":
772
+ return clean_scores[0]
773
+ if method == "max":
774
+ return max(clean_scores)
775
+ if method == "min":
776
+ return min(clean_scores)
777
+ # Default to average if unknown method
778
+ return sum(clean_scores) / len(clean_scores)
779
+
780
+
781
+ def setup_generate_pairs_parser(parser):
782
+ """Set up the generate-pairs subcommand parser."""
783
+ parser.add_argument(
784
+ "--trait", type=str, required=True, help="Natural language description of the desired trait or behavior"
785
+ )
786
+ parser.add_argument(
787
+ "--num-pairs", type=int, default=30, help="Number of contrastive pairs to generate (default: 30)"
788
+ )
789
+ parser.add_argument(
790
+ "--output", type=str, required=True, help="Output file path for the generated pairs (JSON format)"
791
+ )
792
+ parser.add_argument(
793
+ "--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="Model name or path to use for generation"
794
+ )
795
+ parser.add_argument("--device", type=str, default=None, help="Device to run on")
796
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
797
+ parser.add_argument(
798
+ "--similarity-threshold",
799
+ type=float,
800
+ default=0.8,
801
+ help="Similarity threshold for deduplication (0-1, higher = more strict)",
802
+ )
803
+ parser.add_argument("--timing", action="store_true", help="Show detailed timing for each generation step")
804
+ parser.add_argument(
805
+ "--max-workers", type=int, default=4, help="Number of parallel workers for generation (default: 4)"
806
+ )
807
+
808
+
809
+ def setup_synthetic_parser(parser):
810
+ """Set up the synthetic subcommand parser."""
811
+ # Either generate new pairs or load existing ones
812
+ group = parser.add_mutually_exclusive_group(required=True)
813
+ group.add_argument(
814
+ "--trait", type=str, help="Natural language description of the desired trait or behavior (generates new pairs)"
815
+ )
816
+ group.add_argument("--pairs-file", type=str, help="Path to existing JSON file with contrastive pairs")
817
+
818
+ # Generation parameters (only used if --trait is specified)
819
+ parser.add_argument(
820
+ "--num-pairs",
821
+ type=int,
822
+ default=30,
823
+ help="Number of contrastive pairs to generate (default: 30, only used with --trait)",
824
+ )
825
+ parser.add_argument(
826
+ "--save-pairs",
827
+ type=str,
828
+ default=None,
829
+ help="Save generated pairs to this file (optional, only used with --trait)",
830
+ )
831
+
832
+ # Model and device
833
+ parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="Model name or path")
834
+ parser.add_argument("--device", type=str, default=None, help="Device to run on")
835
+
836
+ # Training/evaluation parameters
837
+ parser.add_argument("--layer", type=str, default="15", help="Layer(s) to extract activations from")
838
+ parser.add_argument(
839
+ "--steering-method",
840
+ type=str,
841
+ default="CAA",
842
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
843
+ help="Steering method to use",
844
+ )
845
+ parser.add_argument("--steering-strength", type=float, default=1.0, help="Strength of steering vector application")
846
+ parser.add_argument(
847
+ "--test-questions", type=int, default=5, help="Number of test questions to generate for evaluation"
848
+ )
849
+
850
+ # Output
851
+ parser.add_argument("--output", type=str, default="./results", help="Output directory for results")
852
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
853
+
854
+ # K-Steering specific parameters
855
+ parser.add_argument(
856
+ "--ksteering-target-labels", type=str, default="0", help="Comma-separated target label indices for K-steering"
857
+ )
858
+ parser.add_argument(
859
+ "--ksteering-avoid-labels", type=str, default="", help="Comma-separated avoid label indices for K-steering"
860
+ )
861
+ parser.add_argument("--ksteering-alpha", type=float, default=50.0, help="Alpha parameter for K-steering")
862
+
863
+ # Nonsense detection options
864
+ parser.add_argument(
865
+ "--enable-nonsense-detection",
866
+ action="store_true",
867
+ help="Enable nonsense detection to stop lobotomized responses",
868
+ )
869
+ parser.add_argument(
870
+ "--max-word-length",
871
+ type=int,
872
+ default=20,
873
+ help="Maximum reasonable word length for nonsense detection (default: 20)",
874
+ )
875
+ parser.add_argument(
876
+ "--repetition-threshold",
877
+ type=float,
878
+ default=0.7,
879
+ help="Threshold for repetitive content detection (0-1, default: 0.7)",
880
+ )
881
+ parser.add_argument(
882
+ "--gibberish-threshold",
883
+ type=float,
884
+ default=0.3,
885
+ help="Threshold for gibberish word detection (0-1, default: 0.3)",
886
+ )
887
+ parser.add_argument(
888
+ "--disable-dictionary-check",
889
+ action="store_true",
890
+ help="Disable dictionary-based word validation (faster but less accurate)",
891
+ )
892
+ parser.add_argument(
893
+ "--nonsense-action",
894
+ type=str,
895
+ default="regenerate",
896
+ choices=["regenerate", "stop", "flag"],
897
+ help="Action when nonsense is detected: regenerate, stop generation, or flag for review",
898
+ )
899
+
900
+
901
+ def setup_test_nonsense_parser(parser):
902
+ """Set up the test-nonsense subcommand parser."""
903
+ parser.add_argument(
904
+ "text", type=str, nargs="?", help="Text to analyze (if not provided, will use interactive mode)"
905
+ )
906
+ parser.add_argument("--max-word-length", type=int, default=20, help="Maximum reasonable word length (default: 20)")
907
+ parser.add_argument(
908
+ "--repetition-threshold",
909
+ type=float,
910
+ default=0.7,
911
+ help="Threshold for repetitive content detection (0-1, default: 0.7)",
912
+ )
913
+ parser.add_argument(
914
+ "--gibberish-threshold",
915
+ type=float,
916
+ default=0.3,
917
+ help="Threshold for gibberish word detection (0-1, default: 0.3)",
918
+ )
919
+ parser.add_argument(
920
+ "--disable-dictionary-check", action="store_true", help="Disable dictionary-based word validation"
921
+ )
922
+ parser.add_argument("--verbose", action="store_true", help="Show detailed analysis")
923
+ parser.add_argument("--examples", action="store_true", help="Test with built-in example texts")
924
+
925
+
926
+ def setup_monitor_parser(parser):
927
+ """Set up the monitor subcommand parser."""
928
+ parser.add_argument("--memory-info", action="store_true", help="Show current memory usage information")
929
+ parser.add_argument("--system-info", action="store_true", help="Show system information and capabilities")
930
+ parser.add_argument("--benchmark", action="store_true", help="Run performance benchmarks")
931
+ parser.add_argument("--test-gpu", action="store_true", help="Test GPU availability and memory")
932
+ parser.add_argument("--continuous", action="store_true", help="Continuous monitoring mode (Ctrl+C to stop)")
933
+ parser.add_argument("--interval", type=float, default=1.0, help="Monitoring interval in seconds (default: 1.0)")
934
+ parser.add_argument("--export-csv", type=str, default=None, help="Export monitoring data to CSV file")
935
+ parser.add_argument(
936
+ "--duration", type=int, default=60, help="Duration for continuous monitoring in seconds (default: 60)"
937
+ )
938
+ parser.add_argument("--track-gpu", action="store_true", help="Include GPU monitoring (requires CUDA)")
939
+ parser.add_argument("--detailed", action="store_true", help="Show detailed monitoring information")
940
+
941
+
942
+ def setup_agent_parser(parser):
943
+ """Set up the agent subcommand parser."""
944
+ parser.add_argument("prompt", type=str, help="Prompt to send to the autonomous agent")
945
+ parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="Model to use")
946
+ parser.add_argument("--layer", type=int, help="Layer to use (overrides parameter file)")
947
+ parser.add_argument(
948
+ "--quality-threshold", type=float, default=0.3, help="Quality threshold for classifiers (default: 0.3)"
949
+ )
950
+ parser.add_argument(
951
+ "--time-budget",
952
+ type=float,
953
+ default=10.0,
954
+ help="Time budget in minutes for creating classifiers (default: 10.0)",
955
+ )
956
+ parser.add_argument("--max-attempts", type=int, default=3, help="Maximum improvement attempts (default: 3)")
957
+ parser.add_argument(
958
+ "--max-classifiers", type=int, default=None, help="Maximum classifiers to use (default: no limit)"
959
+ )
960
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
961
+
962
+ # Steering method arguments
963
+ parser.add_argument(
964
+ "--steering-method",
965
+ type=str,
966
+ default="CAA",
967
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
968
+ help="Steering method to use (default: CAA)",
969
+ )
970
+ parser.add_argument(
971
+ "--steering-strength", type=float, default=1.0, help="Strength of steering vector application (default: 1.0)"
972
+ )
973
+ parser.add_argument("--steering-mode", action="store_true", help="Enable steering mode")
974
+
975
+ # Normalization parameters
976
+ parser.add_argument("--normalize-mode", action="store_true", help="Enable normalization of steering vectors")
977
+ parser.add_argument(
978
+ "--normalization-method",
979
+ type=str,
980
+ default="none",
981
+ choices=["none", "l2_unit", "l2_norm", "max_norm"],
982
+ help="Normalization method for steering vectors (default: none)",
983
+ )
984
+ parser.add_argument("--target-norm", type=float, default=None, help="Target norm for steering vectors")
985
+
986
+ # HPR (Householder Pseudo-Rotation) parameters
987
+ parser.add_argument("--hpr-beta", type=float, default=1.0, help="Beta parameter for HPR steering (default: 1.0)")
988
+
989
+ # DAC (Dynamic Activation Composition) parameters
990
+ parser.add_argument("--dac-dynamic-control", action="store_true", help="Enable dynamic control for DAC steering")
991
+ parser.add_argument(
992
+ "--dac-entropy-threshold", type=float, default=1.0, help="Entropy threshold for DAC steering (default: 1.0)"
993
+ )
994
+
995
+ # BiPO (Bi-directional Preference Optimization) parameters
996
+ parser.add_argument("--bipo-beta", type=float, default=0.1, help="Beta parameter for BiPO steering (default: 0.1)")
997
+ parser.add_argument(
998
+ "--bipo-learning-rate", type=float, default=5e-4, help="Learning rate for BiPO steering (default: 5e-4)"
999
+ )
1000
+ parser.add_argument(
1001
+ "--bipo-epochs", type=int, default=100, help="Number of epochs for BiPO steering (default: 100)"
1002
+ )
1003
+
1004
+ # KSteering parameters
1005
+ parser.add_argument(
1006
+ "--ksteering-num-labels", type=int, default=6, help="Number of labels for K-steering (default: 6)"
1007
+ )
1008
+ parser.add_argument(
1009
+ "--ksteering-hidden-dim", type=int, default=512, help="Hidden dimension for K-steering (default: 512)"
1010
+ )
1011
+ parser.add_argument(
1012
+ "--ksteering-learning-rate", type=float, default=1e-3, help="Learning rate for K-steering (default: 1e-3)"
1013
+ )
1014
+ parser.add_argument(
1015
+ "--ksteering-classifier-epochs", type=int, default=100, help="Classifier epochs for K-steering (default: 100)"
1016
+ )
1017
+ parser.add_argument(
1018
+ "--ksteering-target-labels",
1019
+ type=str,
1020
+ default="0",
1021
+ help="Target labels for K-steering (comma-separated, default: '0')",
1022
+ )
1023
+ parser.add_argument(
1024
+ "--ksteering-avoid-labels",
1025
+ type=str,
1026
+ default="",
1027
+ help="Avoid labels for K-steering (comma-separated, default: '')",
1028
+ )
1029
+ parser.add_argument(
1030
+ "--ksteering-alpha", type=float, default=50.0, help="Alpha parameter for K-steering (default: 50.0)"
1031
+ )
1032
+
1033
+ # Quality Control System parameters
1034
+ parser.add_argument(
1035
+ "--enable-quality-control",
1036
+ action="store_true",
1037
+ default=True,
1038
+ help="Enable new quality control system (default: True)",
1039
+ )
1040
+ parser.add_argument(
1041
+ "--max-quality-attempts",
1042
+ type=int,
1043
+ default=5,
1044
+ help="Maximum attempts to achieve acceptable quality (default: 5)",
1045
+ )
1046
+ parser.add_argument(
1047
+ "--show-parameter-reasoning", action="store_true", help="Display model's reasoning for parameter choices"
1048
+ )
1049
+
1050
+
1051
+ def setup_classification_optimizer_parser(parser):
1052
+ """Set up the classification-optimizer subcommand parser."""
1053
+ parser.add_argument("model", type=str, help="Model name or path to optimize")
1054
+ parser.add_argument("--limit", type=int, default=1000, help="Maximum samples per task (default: 1000)")
1055
+ parser.add_argument(
1056
+ "--optimization-metric",
1057
+ type=str,
1058
+ default="f1",
1059
+ choices=["f1", "accuracy", "precision", "recall"],
1060
+ help="Metric to optimize (default: f1)",
1061
+ )
1062
+ parser.add_argument(
1063
+ "--max-time-per-task", type=float, default=15.0, help="Maximum time per task in minutes (default: 15.0)"
1064
+ )
1065
+ parser.add_argument(
1066
+ "--layer-range", type=str, default=None, help="Layer range to test (e.g., '10-20', if None uses all layers)"
1067
+ )
1068
+ parser.add_argument(
1069
+ "--aggregation-methods",
1070
+ type=str,
1071
+ nargs="+",
1072
+ default=["average", "final", "first", "max", "min"],
1073
+ help="Token aggregation methods to test",
1074
+ )
1075
+ parser.add_argument(
1076
+ "--threshold-range",
1077
+ type=float,
1078
+ nargs="+",
1079
+ default=[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
1080
+ help="Detection thresholds to test",
1081
+ )
1082
+ parser.add_argument("--device", type=str, default=None, help="Device to run on")
1083
+ parser.add_argument("--results-file", type=str, default=None, help="Custom file path for saving results")
1084
+ parser.add_argument("--no-save", action="store_true", help="Don't save results to model config")
1085
+ parser.add_argument("--save-logs-json", type=str, default=None, help="Save detailed optimization logs to JSON file")
1086
+ parser.add_argument(
1087
+ "--save-classifiers",
1088
+ action="store_true",
1089
+ default=True,
1090
+ help="Save best classifiers for each task (default: True)",
1091
+ )
1092
+ parser.add_argument(
1093
+ "--no-save-classifiers",
1094
+ dest="save_classifiers",
1095
+ action="store_false",
1096
+ help="Don't save classifiers (overrides --save-classifiers)",
1097
+ )
1098
+ parser.add_argument(
1099
+ "--classifiers-dir",
1100
+ type=str,
1101
+ default=None,
1102
+ help="Directory to save classifiers (default: ./optimized_classifiers/model_name/)",
1103
+ )
1104
+
1105
+ # Timing calibration options
1106
+ parser.add_argument(
1107
+ "--skip-timing-estimation", action="store_true", help="Skip timing estimation and proceed without time warnings"
1108
+ )
1109
+ parser.add_argument("--calibration-file", type=str, default=None, help="File to save/load calibration data")
1110
+ parser.add_argument(
1111
+ "--calibrate-only",
1112
+ action="store_true",
1113
+ help="Only run calibration and exit (saves to --calibration-file if provided)",
1114
+ )
1115
+
1116
+
1117
+ def setup_configure_model_parser(parser):
1118
+ """Set up the configure-model subcommand parser."""
1119
+ parser.add_argument("model", type=str, help="Model name to configure")
1120
+ parser.add_argument("--force", action="store_true", help="Force reconfiguration even if model already has a config")
1121
+
1122
+
1123
+ def setup_steering_optimizer_parser(parser):
1124
+ """Set up the steering-optimizer subcommand parser."""
1125
+ # Create subparsers for different steering optimization types
1126
+ steering_subparsers = parser.add_subparsers(dest="steering_action", help="Steering optimization actions")
1127
+
1128
+ # Auto optimization subcommand (NEW - runs after classification optimization)
1129
+ auto_parser = steering_subparsers.add_parser(
1130
+ "auto", help="Automatically optimize steering based on classification config"
1131
+ )
1132
+ auto_parser.add_argument("model", type=str, help="Model name or path")
1133
+ auto_parser.add_argument(
1134
+ "--task",
1135
+ type=str,
1136
+ default=None,
1137
+ help="Specific task to optimize (defaults to all classification-optimized tasks)",
1138
+ )
1139
+ auto_parser.add_argument(
1140
+ "--methods",
1141
+ type=str,
1142
+ nargs="+",
1143
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
1144
+ default=["CAA", "HPR"],
1145
+ help="Steering methods to test (default: CAA, HPR)",
1146
+ )
1147
+ auto_parser.add_argument("--limit", type=int, default=100, help="Maximum samples for testing (default: 100)")
1148
+ auto_parser.add_argument("--max-time", type=float, default=60.0, help="Maximum time in minutes (default: 60)")
1149
+ auto_parser.add_argument(
1150
+ "--strength-range",
1151
+ type=float,
1152
+ nargs="+",
1153
+ default=[0.5, 1.0, 1.5, 2.0],
1154
+ help="Steering strengths to test (default: 0.5 1.0 1.5 2.0)",
1155
+ )
1156
+ auto_parser.add_argument(
1157
+ "--layer-range",
1158
+ type=str,
1159
+ default=None,
1160
+ help="Explicit layer range to search (e.g., '0-5' or '0,2,4'). If not specified, uses classification layer or defaults to 0-5",
1161
+ )
1162
+
1163
+ # Method comparison subcommand
1164
+ method_parser = steering_subparsers.add_parser(
1165
+ "compare-methods", help="Compare different steering methods for a task"
1166
+ )
1167
+ method_parser.add_argument("model", type=str, help="Model name or path")
1168
+ method_parser.add_argument(
1169
+ "--task", type=str, default="truthfulqa_mc1", help="Task to optimize steering for (default: truthfulqa_mc1)"
1170
+ )
1171
+ method_parser.add_argument(
1172
+ "--methods",
1173
+ type=str,
1174
+ nargs="+",
1175
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
1176
+ default=["CAA", "HPR"],
1177
+ help="Steering methods to compare",
1178
+ )
1179
+ method_parser.add_argument("--limit", type=int, default=100, help="Maximum samples for testing (default: 100)")
1180
+ method_parser.add_argument(
1181
+ "--max-time", type=float, default=30.0, help="Maximum optimization time in minutes (default: 30.0)"
1182
+ )
1183
+
1184
+ # Layer optimization subcommand
1185
+ layer_parser = steering_subparsers.add_parser("optimize-layer", help="Find optimal steering layer for a method")
1186
+ layer_parser.add_argument("model", type=str, help="Model name or path")
1187
+ layer_parser.add_argument(
1188
+ "--task", type=str, default="truthfulqa_mc1", help="Task to optimize for (default: truthfulqa_mc1)"
1189
+ )
1190
+ layer_parser.add_argument(
1191
+ "--method",
1192
+ type=str,
1193
+ default="CAA",
1194
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
1195
+ help="Steering method to use (default: CAA)",
1196
+ )
1197
+ layer_parser.add_argument("--layer-range", type=str, default=None, help="Layer range to search (e.g., '10-20')")
1198
+ layer_parser.add_argument(
1199
+ "--strength", type=float, default=1.0, help="Fixed steering strength during layer search (default: 1.0)"
1200
+ )
1201
+ layer_parser.add_argument("--limit", type=int, default=100, help="Maximum samples for testing (default: 100)")
1202
+
1203
+ # Strength optimization subcommand
1204
+ strength_parser = steering_subparsers.add_parser("optimize-strength", help="Find optimal steering strength")
1205
+ strength_parser.add_argument("model", type=str, help="Model name or path")
1206
+ strength_parser.add_argument(
1207
+ "--task", type=str, default="truthfulqa_mc1", help="Task to optimize for (default: truthfulqa_mc1)"
1208
+ )
1209
+ strength_parser.add_argument(
1210
+ "--method",
1211
+ type=str,
1212
+ default="CAA",
1213
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
1214
+ help="Steering method to use (default: CAA)",
1215
+ )
1216
+ strength_parser.add_argument(
1217
+ "--layer", type=int, default=None, help="Steering layer to use (defaults to classification layer)"
1218
+ )
1219
+ strength_parser.add_argument(
1220
+ "--strength-range",
1221
+ type=float,
1222
+ nargs=2,
1223
+ default=[0.1, 2.0],
1224
+ help="Min and max strength to test (default: 0.1 2.0)",
1225
+ )
1226
+ strength_parser.add_argument(
1227
+ "--strength-steps", type=int, default=10, help="Number of strength values to test (default: 10)"
1228
+ )
1229
+ strength_parser.add_argument("--limit", type=int, default=100, help="Maximum samples for testing (default: 100)")
1230
+
1231
+ # Comprehensive optimization subcommand
1232
+ comprehensive_parser = steering_subparsers.add_parser(
1233
+ "comprehensive", help="Run comprehensive steering optimization"
1234
+ )
1235
+ comprehensive_parser.add_argument("model", type=str, help="Model name or path")
1236
+ comprehensive_parser.add_argument(
1237
+ "--tasks",
1238
+ type=str,
1239
+ nargs="+",
1240
+ default=None,
1241
+ help="Tasks to optimize (defaults to classification-optimized tasks)",
1242
+ )
1243
+ comprehensive_parser.add_argument(
1244
+ "--methods",
1245
+ type=str,
1246
+ nargs="+",
1247
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
1248
+ default=["CAA", "HPR"],
1249
+ help="Steering methods to test",
1250
+ )
1251
+ comprehensive_parser.add_argument("--limit", type=int, default=100, help="Sample limit per task (default: 100)")
1252
+ comprehensive_parser.add_argument(
1253
+ "--max-time-per-task", type=float, default=20.0, help="Time limit per task in minutes (default: 20.0)"
1254
+ )
1255
+ comprehensive_parser.add_argument("--no-save", action="store_true", help="Don't save results to model config")
1256
+
1257
+ # Common arguments for all steering optimization subcommands
1258
+ parser.add_argument("--device", type=str, default=None, help="Device to run on")
1259
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
1260
+
1261
+
1262
+ def setup_model_config_parser(parser):
1263
+ """Set up the model-config subcommand parser."""
1264
+ # Create subparsers for different model config actions
1265
+ config_subparsers = parser.add_subparsers(dest="config_action", help="Model configuration actions")
1266
+
1267
+ # Save configuration subcommand
1268
+ save_parser = config_subparsers.add_parser("save", help="Save optimal parameters for a model")
1269
+ save_parser.add_argument("model", type=str, help="Model name or path")
1270
+ save_parser.add_argument("--classification-layer", type=int, required=True, help="Optimal layer for classification")
1271
+ save_parser.add_argument(
1272
+ "--steering-layer", type=int, default=None, help="Optimal layer for steering (defaults to classification layer)"
1273
+ )
1274
+ save_parser.add_argument(
1275
+ "--token-aggregation",
1276
+ type=str,
1277
+ default="average",
1278
+ choices=["average", "final", "first", "max", "min"],
1279
+ help="Token aggregation method",
1280
+ )
1281
+ save_parser.add_argument("--detection-threshold", type=float, default=0.6, help="Detection threshold")
1282
+ save_parser.add_argument(
1283
+ "--optimization-method", type=str, default="manual", help="How these parameters were determined"
1284
+ )
1285
+ save_parser.add_argument("--metrics", type=str, default=None, help="JSON string with optimization metrics")
1286
+
1287
+ # List configurations subcommand
1288
+ list_parser = config_subparsers.add_parser("list", help="List all saved model configurations")
1289
+ list_parser.add_argument("--detailed", action="store_true", help="Show detailed configuration information")
1290
+
1291
+ # Show configuration subcommand
1292
+ show_parser = config_subparsers.add_parser("show", help="Show configuration for a specific model")
1293
+ show_parser.add_argument("model", type=str, help="Model name or path")
1294
+ show_parser.add_argument("--task", type=str, default=None, help="Show task-specific overrides if available")
1295
+
1296
+ # Remove configuration subcommand
1297
+ remove_parser = config_subparsers.add_parser("remove", help="Remove configuration for a model")
1298
+ remove_parser.add_argument("model", type=str, help="Model name or path")
1299
+ remove_parser.add_argument("--confirm", action="store_true", help="Confirm removal without prompting")
1300
+
1301
+ # Test configuration subcommand
1302
+ test_parser = config_subparsers.add_parser("test", help="Test if saved configuration works")
1303
+ test_parser.add_argument("model", type=str, help="Model name or path")
1304
+ test_parser.add_argument(
1305
+ "--task", type=str, default="truthfulqa_mc1", help="Task to test with (default: truthfulqa_mc1)"
1306
+ )
1307
+ test_parser.add_argument("--limit", type=int, default=5, help="Number of samples to test with (default: 5)")
1308
+ test_parser.add_argument("--device", type=str, default=None, help="Device to run on")
1309
+
1310
+ # Common arguments for all subcommands
1311
+ parser.add_argument(
1312
+ "--config-dir",
1313
+ type=str,
1314
+ default=None,
1315
+ help="Custom directory for configuration files (default: ~/.wisent-guard/model_configs/)",
1316
+ )
1317
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
1318
+
1319
+
1320
+ def setup_sample_size_optimizer_parser(parser):
1321
+ """Set up the sample-size-optimizer subcommand parser."""
1322
+ parser.add_argument("model", type=str, help="Model name or path to optimize")
1323
+ parser.add_argument("--task", type=str, required=True, help="Task to optimize for (REQUIRED)")
1324
+ parser.add_argument("--layer", type=int, required=True, help="Layer index to use (REQUIRED)")
1325
+ parser.add_argument(
1326
+ "--token-aggregation",
1327
+ type=str,
1328
+ required=True,
1329
+ choices=["average", "final", "first", "max", "min"],
1330
+ help="Token aggregation method (REQUIRED)",
1331
+ )
1332
+
1333
+ # Classification-specific arguments
1334
+ parser.add_argument(
1335
+ "--threshold", type=float, default=0.5, help="Detection threshold for classification (default: 0.5)"
1336
+ )
1337
+
1338
+ # Steering mode
1339
+ parser.add_argument("--steering-mode", action="store_true", help="Optimize for steering instead of classification")
1340
+ parser.add_argument(
1341
+ "--steering-method",
1342
+ type=str,
1343
+ default="CAA",
1344
+ choices=["CAA", "CAA_L2", "HPR", "DAC", "BiPO", "KSteering"],
1345
+ help="Steering method to use (default: CAA)",
1346
+ )
1347
+ parser.add_argument("--steering-strength", type=float, default=1.0, help="Steering strength to use (default: 1.0)")
1348
+ parser.add_argument(
1349
+ "--token-targeting-strategy",
1350
+ type=str,
1351
+ default="LAST_TOKEN",
1352
+ choices=["CHOICE_TOKEN", "LAST_TOKEN", "FIRST_TOKEN", "ALL_TOKENS"],
1353
+ help="Token targeting strategy for steering (default: LAST_TOKEN)",
1354
+ )
1355
+
1356
+ # Common optimization parameters
1357
+ parser.add_argument(
1358
+ "--sample-sizes",
1359
+ type=int,
1360
+ nargs="+",
1361
+ default=[5, 10, 20, 50, 100, 200, 500],
1362
+ help="Sample sizes to test (default: 5 10 20 50 100 200 500)",
1363
+ )
1364
+ parser.add_argument("--test-size", type=int, default=200, help="Fixed test set size (default: 200)")
1365
+ parser.add_argument("--test-split", type=float, default=0.2, help="DEPRECATED: Use --test-size instead")
1366
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility (default: 42)")
1367
+ parser.add_argument("--limit", type=int, default=None, help="Maximum number of samples to load from dataset")
1368
+ parser.add_argument("--save-plot", action="store_true", help="Save performance plot")
1369
+ parser.add_argument("--no-save-config", action="store_true", help="Don't save optimal sample size to model config")
1370
+ parser.add_argument("--device", type=str, default=None, help="Device to run on")
1371
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
1372
+ parser.add_argument(
1373
+ "--force", action="store_true", help="Force optimization even without matching classifier parameters"
1374
+ )
1375
+
1376
+
1377
+ def setup_full_optimizer_parser(parser):
1378
+ """Set up the full-optimize subcommand parser."""
1379
+ parser.add_argument("model", type=str, help="Model name or path to optimize")
1380
+
1381
+ # Task selection - mutually exclusive options
1382
+ task_group = parser.add_mutually_exclusive_group()
1383
+ task_group.add_argument("--tasks", type=str, nargs="+", help="Specific tasks to optimize")
1384
+ task_group.add_argument(
1385
+ "--skills", type=str, nargs="+", help="Select tasks by skill categories (e.g., coding, mathematics, reasoning)"
1386
+ )
1387
+ task_group.add_argument(
1388
+ "--risks",
1389
+ type=str,
1390
+ nargs="+",
1391
+ help="Select tasks by risk categories (e.g., harmfulness, toxicity, hallucination)",
1392
+ )
1393
+
1394
+ # General limit that applies to all optimizations unless overridden
1395
+ parser.add_argument(
1396
+ "--limit",
1397
+ type=int,
1398
+ default=100,
1399
+ help="Sample limit for all optimizations (default: 100). Can be overridden by specific limits below",
1400
+ )
1401
+
1402
+ # Specific limits (override general limit if provided)
1403
+ parser.add_argument(
1404
+ "--classification-limit",
1405
+ type=int,
1406
+ default=None,
1407
+ help="Sample limit for classification optimization (overrides --limit)",
1408
+ )
1409
+ parser.add_argument(
1410
+ "--sample-size-limit",
1411
+ type=int,
1412
+ default=None,
1413
+ help="Sample limit for sample size optimization (overrides --limit)",
1414
+ )
1415
+ parser.add_argument(
1416
+ "--steering-limit", type=int, default=None, help="Sample limit for steering optimization (overrides --limit)"
1417
+ )
1418
+
1419
+ parser.add_argument(
1420
+ "--sample-sizes",
1421
+ type=int,
1422
+ nargs="+",
1423
+ default=[5, 10, 20, 50, 100, 200, 500],
1424
+ help="Sample sizes to test (default: 5 10 20 50 100 200 500)",
1425
+ )
1426
+ parser.add_argument(
1427
+ "--skip-classification", action="store_true", help="Skip classification optimization and use existing config"
1428
+ )
1429
+ parser.add_argument("--skip-sample-size", action="store_true", help="Skip sample size optimization")
1430
+ parser.add_argument("--skip-classifier-training", action="store_true", help="Skip final classifier training step")
1431
+ parser.add_argument("--skip-control-vectors", action="store_true", help="Skip control vector training step")
1432
+
1433
+ # Steering optimization options
1434
+ parser.add_argument("--skip-steering", action="store_true", help="Skip steering optimization")
1435
+ parser.add_argument(
1436
+ "--steering-methods",
1437
+ type=str,
1438
+ nargs="+",
1439
+ choices=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
1440
+ default=["CAA", "HPR", "DAC", "BiPO", "KSteering"],
1441
+ help="Steering methods to test (default: all methods with parameter variations)",
1442
+ )
1443
+ parser.add_argument(
1444
+ "--steering-layer-range", type=str, default=None, help="Layer range for steering optimization (e.g., '0-5')"
1445
+ )
1446
+ parser.add_argument(
1447
+ "--steering-strength-range",
1448
+ type=float,
1449
+ nargs="+",
1450
+ default=[0.5, 1.0, 1.5, 2.0],
1451
+ help="Steering strengths to test (default: 0.5 1.0 1.5 2.0)",
1452
+ )
1453
+ # Task selection options
1454
+ parser.add_argument(
1455
+ "--num-tasks",
1456
+ type=int,
1457
+ default=None,
1458
+ help="Number of tasks to randomly select from matched tasks (default: all)",
1459
+ )
1460
+ parser.add_argument(
1461
+ "--min-quality-score",
1462
+ type=int,
1463
+ default=2,
1464
+ choices=[1, 2, 3, 4, 5],
1465
+ help="Minimum quality score for tasks (default: 2)",
1466
+ )
1467
+ parser.add_argument(
1468
+ "--task-seed", type=int, default=None, help="Random seed for task selection (for reproducibility)"
1469
+ )
1470
+
1471
+ parser.add_argument(
1472
+ "--max-time-per-task", type=float, default=20.0, help="Maximum time per task in minutes (default: 20.0)"
1473
+ )
1474
+
1475
+ parser.add_argument("--device", type=str, default=None, help="Device to run on")
1476
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
1477
+ parser.add_argument("--save-plots", action="store_true", help="Save plots for both optimizations")
1478
+
1479
+ # Timing calibration options
1480
+ parser.add_argument(
1481
+ "--skip-timing-estimation", action="store_true", help="Skip timing estimation and proceed without time warnings"
1482
+ )
1483
+ parser.add_argument("--calibration-file", type=str, default=None, help="File to save/load calibration data")
1484
+ parser.add_argument(
1485
+ "--calibrate-only",
1486
+ action="store_true",
1487
+ help="Only run calibration and exit (saves to --calibration-file if provided)",
1488
+ )
1489
+
1490
+
1491
+ def setup_configure_model_parser(parser):
1492
+ """Set up the configure-model subcommand parser."""
1493
+ parser.add_argument("model", type=str, help="Model name to configure")
1494
+ parser.add_argument("--force", action="store_true", help="Force reconfiguration even if model already has a config")
1495
+
1496
+
1497
+ def setup_generate_vector_parser(parser):
1498
+ """Set up the generate-vector subcommand parser."""
1499
+ # Source of contrastive pairs - mutually exclusive for single property
1500
+ source_group = parser.add_mutually_exclusive_group(required=False)
1501
+ source_group.add_argument(
1502
+ "--from-pairs",
1503
+ type=str,
1504
+ metavar="FILE",
1505
+ help="Path to JSON file containing contrastive pairs (single property)",
1506
+ )
1507
+ source_group.add_argument(
1508
+ "--from-description",
1509
+ type=str,
1510
+ metavar="TRAIT",
1511
+ help="Natural language description of the trait (single property)",
1512
+ )
1513
+
1514
+ # Multi-property support
1515
+ parser.add_argument("--multi-property", action="store_true", help="Enable multi-property steering (DAC only)")
1516
+ parser.add_argument(
1517
+ "--property-files",
1518
+ type=str,
1519
+ nargs="+",
1520
+ metavar="NAME:FILE:LAYER",
1521
+ help="Property definitions from files (format: property_name:pairs_file:layer)",
1522
+ )
1523
+ parser.add_argument(
1524
+ "--property-descriptions",
1525
+ type=str,
1526
+ nargs="+",
1527
+ metavar="NAME:DESC:LAYER",
1528
+ help="Property definitions from descriptions (format: property_name:description:layer)",
1529
+ )
1530
+
1531
+ # Model configuration
1532
+ parser.add_argument("--model", type=str, default="distilgpt2", help="Model name or path (default: distilgpt2)")
1533
+ parser.add_argument("--device", type=str, default=None, help="Device to run on (default: auto-detect)")
1534
+
1535
+ # Steering method configuration
1536
+ parser.add_argument(
1537
+ "--method",
1538
+ type=str,
1539
+ default="DAC",
1540
+ choices=["DAC", "CAA", "HPR", "BiPO", "ControlVectorSteering"],
1541
+ help="Steering method to use (default: DAC)",
1542
+ )
1543
+ parser.add_argument("--layer", type=int, default=0, help="Layer index to apply steering (default: 0)")
1544
+
1545
+ # Output configuration
1546
+ parser.add_argument("--output", type=str, required=True, help="Output path for the generated steering vector")
1547
+
1548
+ # Pair generation options (only used with --from-description)
1549
+ parser.add_argument(
1550
+ "--num-pairs",
1551
+ type=int,
1552
+ default=30,
1553
+ help="Number of pairs to generate when using --from-description (default: 30)",
1554
+ )
1555
+ parser.add_argument(
1556
+ "--save-pairs", type=str, default=None, help="Save generated pairs to this file when using --from-description"
1557
+ )
1558
+
1559
+ # Method-specific parameters
1560
+ parser.add_argument("--dynamic-control", action="store_true", help="Enable dynamic control for DAC method")
1561
+ parser.add_argument(
1562
+ "--entropy-threshold", type=float, default=1.0, help="Entropy threshold for DAC method (default: 1.0)"
1563
+ )
1564
+ parser.add_argument("--beta", type=float, default=1.0, help="Beta parameter for HPR method (default: 1.0)")
1565
+
1566
+ # Activation extraction configuration
1567
+ parser.add_argument(
1568
+ "--prompt-construction",
1569
+ type=str,
1570
+ default="multiple_choice",
1571
+ choices=["multiple_choice", "role_playing", "direct_completion", "instruction_following"],
1572
+ help="Strategy for constructing prompts from question-answer pairs (default: multiple_choice)",
1573
+ )
1574
+ parser.add_argument(
1575
+ "--token-targeting",
1576
+ type=str,
1577
+ default="choice_token",
1578
+ choices=["choice_token", "continuation_token", "last_token", "first_token", "mean_pooling", "max_pooling"],
1579
+ help="Strategy for targeting tokens in activation extraction (default: choice_token)",
1580
+ )
1581
+
1582
+ # General options
1583
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
1584
+
1585
+
1586
+ def setup_multi_steer_parser(parser):
1587
+ """Set up the multi-steer subcommand parser for dynamic vector combination."""
1588
+ # Vector inputs - can specify multiple vector-weight pairs
1589
+ parser.add_argument(
1590
+ "--vector",
1591
+ type=str,
1592
+ action="append",
1593
+ required=True,
1594
+ metavar="PATH:WEIGHT",
1595
+ help="Path to steering vector and its weight (format: path/to/vector.pt:0.5). Can be specified multiple times.",
1596
+ )
1597
+
1598
+ # Model configuration
1599
+ parser.add_argument("--model", type=str, required=True, help="Model name or path")
1600
+ parser.add_argument("--layer", type=int, required=True, help="Layer index to apply combined steering")
1601
+ parser.add_argument("--device", type=str, default=None, help="Device to run on (default: auto-detect)")
1602
+
1603
+ # Steering method configuration
1604
+ parser.add_argument(
1605
+ "--method",
1606
+ type=str,
1607
+ default="CAA",
1608
+ choices=["CAA", "DAC"],
1609
+ help="Steering method to use for combination (default: CAA)",
1610
+ )
1611
+
1612
+ # Generation configuration
1613
+ parser.add_argument("--prompt", type=str, required=True, help="Prompt to generate with combined steering")
1614
+ parser.add_argument("--max-new-tokens", type=int, default=100, help="Maximum new tokens to generate (default: 100)")
1615
+
1616
+ # Weight normalization
1617
+ parser.add_argument("--normalize-weights", action="store_true", help="Normalize weights to sum to 1.0")
1618
+ parser.add_argument(
1619
+ "--allow-unnormalized", action="store_true", help="Allow weights that don't sum to 1.0 (for stronger effects)"
1620
+ )
1621
+ parser.add_argument(
1622
+ "--target-norm", type=float, default=None, help="Scale the combined vector to have this norm (e.g., 10.0)"
1623
+ )
1624
+
1625
+ # Output options
1626
+ parser.add_argument(
1627
+ "--save-combined", type=str, default=None, help="Save the combined steering vector to this path"
1628
+ )
1629
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output showing weight calculations")
1630
+
1631
+
1632
+ def setup_evaluate_parser(parser):
1633
+ """Set up the evaluate subcommand parser for single-prompt evaluation."""
1634
+
1635
+ # Required arguments
1636
+ parser.add_argument("--vector", type=str, required=True, help="Path to steering vector file (.pt)")
1637
+ parser.add_argument("--prompt", type=str, required=True, help="Prompt to evaluate")
1638
+ parser.add_argument(
1639
+ "--model", type=str, required=True, help="Model name or path (used for both generation and evaluation)"
1640
+ )
1641
+ parser.add_argument("--trait", type=str, required=True, help="Trait name (e.g., 'catholic', 'cynical')")
1642
+
1643
+ # Optional model configuration
1644
+ parser.add_argument("--device", type=str, default=None, help="Device to run on (default: auto-detect)")
1645
+
1646
+ # Optional steering parameters
1647
+ parser.add_argument(
1648
+ "--steering-strength", type=float, default=2.0, help="Steering strength to apply (default: 2.0)"
1649
+ )
1650
+ parser.add_argument("--max-new-tokens", type=int, default=100, help="Maximum new tokens to generate (default: 100)")
1651
+ parser.add_argument(
1652
+ "--trait-description",
1653
+ type=str,
1654
+ default=None,
1655
+ help="Optional description of the trait (default: use trait name)",
1656
+ )
1657
+
1658
+ # Optional threshold parameters
1659
+ parser.add_argument(
1660
+ "--trait-threshold", type=float, default=None, help="Minimum trait quality threshold (-1 to 1 scale)"
1661
+ )
1662
+ parser.add_argument(
1663
+ "--answer-threshold", type=float, default=None, help="Minimum answer quality threshold (0 to 1 scale)"
1664
+ )
1665
+
1666
+ # Output options
1667
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
1668
+ parser.add_argument("--json", action="store_true", help="Output results as JSON")