wisent 0.5.12__py3-none-any.whl → 0.5.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (227) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/core/activations/__init__.py +26 -0
  3. wisent/core/activations/activations.py +96 -0
  4. wisent/core/activations/activations_collector.py +71 -20
  5. wisent/core/activations/prompt_construction_strategy.py +47 -0
  6. wisent/core/agent/__init__.py +1 -18
  7. wisent/core/agent/budget.py +2 -2
  8. wisent/core/agent/device_benchmarks.py +1 -1
  9. wisent/core/agent/diagnose/__init__.py +1 -55
  10. wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
  11. wisent/core/agent/diagnose/response_diagnostics.py +4 -4
  12. wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
  13. wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
  14. wisent/core/agent/diagnose.py +2 -1
  15. wisent/core/autonomous_agent.py +10 -2
  16. wisent/core/benchmark_extractors.py +293 -0
  17. wisent/core/bigcode_integration.py +20 -7
  18. wisent/core/branding.py +108 -0
  19. wisent/core/cli/__init__.py +15 -0
  20. wisent/core/cli/create_steering_vector.py +138 -0
  21. wisent/core/cli/evaluate_responses.py +715 -0
  22. wisent/core/cli/generate_pairs.py +128 -0
  23. wisent/core/cli/generate_pairs_from_task.py +119 -0
  24. wisent/core/cli/generate_responses.py +129 -0
  25. wisent/core/cli/generate_vector_from_synthetic.py +149 -0
  26. wisent/core/cli/generate_vector_from_task.py +147 -0
  27. wisent/core/cli/get_activations.py +191 -0
  28. wisent/core/cli/optimize_classification.py +339 -0
  29. wisent/core/cli/optimize_steering.py +364 -0
  30. wisent/core/cli/tasks.py +182 -0
  31. wisent/core/cli_logger.py +22 -0
  32. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
  33. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
  34. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
  35. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
  36. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
  37. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
  38. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
  39. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
  40. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
  41. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
  42. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
  43. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
  44. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
  45. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
  46. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
  47. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
  48. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
  49. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
  50. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
  51. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
  52. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
  53. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
  54. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
  55. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
  56. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
  57. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
  58. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
  59. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
  60. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
  61. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
  84. wisent/core/data_loaders/__init__.py +235 -0
  85. wisent/core/data_loaders/loaders/lm_loader.py +2 -2
  86. wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
  87. wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
  88. wisent/core/download_full_benchmarks.py +79 -2
  89. wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
  90. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
  91. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
  92. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
  93. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
  94. wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
  95. wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
  96. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
  97. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
  98. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
  99. wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
  100. wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
  101. wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
  102. wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
  103. wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
  104. wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
  105. wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
  106. wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
  107. wisent/core/lm_eval_harness_ground_truth.py +3 -2
  108. wisent/core/main.py +57 -0
  109. wisent/core/model_persistence.py +2 -2
  110. wisent/core/models/wisent_model.py +6 -6
  111. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  112. wisent/core/optuna/steering/steering_optimization.py +1 -1
  113. wisent/core/parser_arguments/__init__.py +10 -0
  114. wisent/core/parser_arguments/agent_parser.py +110 -0
  115. wisent/core/parser_arguments/configure_model_parser.py +7 -0
  116. wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
  117. wisent/core/parser_arguments/evaluate_parser.py +40 -0
  118. wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
  119. wisent/core/parser_arguments/full_optimize_parser.py +115 -0
  120. wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
  121. wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
  122. wisent/core/parser_arguments/generate_responses_parser.py +15 -0
  123. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
  124. wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
  125. wisent/core/parser_arguments/generate_vector_parser.py +90 -0
  126. wisent/core/parser_arguments/get_activations_parser.py +90 -0
  127. wisent/core/parser_arguments/main_parser.py +152 -0
  128. wisent/core/parser_arguments/model_config_parser.py +59 -0
  129. wisent/core/parser_arguments/monitor_parser.py +17 -0
  130. wisent/core/parser_arguments/multi_steer_parser.py +47 -0
  131. wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
  132. wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
  133. wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
  134. wisent/core/parser_arguments/synthetic_parser.py +93 -0
  135. wisent/core/parser_arguments/tasks_parser.py +584 -0
  136. wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
  137. wisent/core/parser_arguments/utils.py +111 -0
  138. wisent/core/prompts/core/prompt_formater.py +3 -3
  139. wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
  140. wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
  141. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
  142. wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
  143. wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
  144. wisent/core/steering_optimizer.py +45 -21
  145. wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
  146. wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
  147. wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
  148. wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
  149. wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
  150. wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
  151. wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
  152. wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
  153. wisent/core/tasks/livecodebench_task.py +4 -103
  154. wisent/core/timing_calibration.py +1 -1
  155. {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/METADATA +3 -3
  156. wisent-0.5.14.dist-info/RECORD +294 -0
  157. wisent-0.5.14.dist-info/entry_points.txt +2 -0
  158. wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
  159. wisent/classifiers/core/atoms.py +0 -747
  160. wisent/classifiers/models/logistic.py +0 -29
  161. wisent/classifiers/models/mlp.py +0 -47
  162. wisent/cli/classifiers/classifier_rotator.py +0 -137
  163. wisent/cli/cli_logger.py +0 -142
  164. wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
  165. wisent/cli/wisent_cli/commands/listing.py +0 -154
  166. wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
  167. wisent/cli/wisent_cli/main.py +0 -93
  168. wisent/cli/wisent_cli/shell.py +0 -80
  169. wisent/cli/wisent_cli/ui.py +0 -69
  170. wisent/cli/wisent_cli/util/aggregations.py +0 -43
  171. wisent/cli/wisent_cli/util/parsing.py +0 -126
  172. wisent/cli/wisent_cli/version.py +0 -4
  173. wisent/opti/methods/__init__.py +0 -0
  174. wisent/synthetic/__init__.py +0 -0
  175. wisent/synthetic/cleaners/__init__.py +0 -0
  176. wisent/synthetic/cleaners/core/__init__.py +0 -0
  177. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  178. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  179. wisent/synthetic/db_instructions/__init__.py +0 -0
  180. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  181. wisent/synthetic/generators/__init__.py +0 -0
  182. wisent/synthetic/generators/core/__init__.py +0 -0
  183. wisent/synthetic/generators/diversities/__init__.py +0 -0
  184. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  185. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  186. wisent-0.5.12.dist-info/RECORD +0 -220
  187. /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
  188. /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
  189. /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
  190. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
  191. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
  192. /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
  193. /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
  194. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
  195. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
  196. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
  197. /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
  198. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
  199. /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
  200. /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
  201. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
  202. /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
  203. /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
  204. /wisent/{opti → core/opti}/core/atoms.py +0 -0
  205. /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
  206. /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
  207. /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
  208. /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
  209. /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
  210. /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
  211. /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
  212. /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
  213. /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
  214. /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
  215. /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
  216. /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
  217. /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
  218. /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
  219. /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
  220. /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
  221. /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
  222. /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
  223. /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
  224. /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
  225. {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/WHEEL +0 -0
  226. {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/licenses/LICENSE +0 -0
  227. {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,364 @@
1
+ """Steering optimization command execution logic with full strategy optimization."""
2
+
3
+ import sys
4
+ import json
5
+ import time
6
+ import numpy as np
7
+
8
+ def execute_optimize_steering(args):
9
+ """
10
+ Execute the optimize-steering command.
11
+
12
+ Supports multiple subcommands:
13
+ - comprehensive: Run comprehensive steering optimization
14
+ - compare-methods: Compare different steering methods
15
+ - optimize-layer: Find optimal steering layer
16
+ - optimize-strength: Find optimal steering strength
17
+ - auto: Automatically optimize based on classification config
18
+ """
19
+ from wisent.core.models.wisent_model import WisentModel
20
+ from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
21
+
22
+ # Check which subcommand was called
23
+ if not hasattr(args, 'steering_action') or args.steering_action is None:
24
+ print("\nāœ— No steering optimization action specified")
25
+ print("Available actions: comprehensive, compare-methods, optimize-layer, optimize-strength, auto")
26
+ sys.exit(1)
27
+
28
+ print(f"\n{'='*80}")
29
+ print(f"šŸŽÆ STEERING PARAMETER OPTIMIZATION: {args.steering_action.upper()}")
30
+ print(f"{'='*80}")
31
+ print(f" Model: {args.model}")
32
+ print(f" Device: {args.device or 'auto'}")
33
+ print(f"{'='*80}\n")
34
+
35
+ # Load model
36
+ print(f"šŸ“¦ Loading model...")
37
+ model = WisentModel(args.model, device=args.device)
38
+ print(f" āœ“ Model loaded with {model.num_layers} layers\n")
39
+
40
+ # Initialize data loader
41
+ loader = LMEvalDataLoader()
42
+
43
+ # Execute based on subcommand
44
+ if args.steering_action == 'comprehensive':
45
+ execute_comprehensive(args, model, loader)
46
+ elif args.steering_action == 'compare-methods':
47
+ execute_compare_methods(args, model, loader)
48
+ elif args.steering_action == 'optimize-layer':
49
+ execute_optimize_layer(args, model, loader)
50
+ elif args.steering_action == 'optimize-strength':
51
+ execute_optimize_strength(args, model, loader)
52
+ elif args.steering_action == 'auto':
53
+ execute_auto(args, model, loader)
54
+ else:
55
+ print(f"\nāœ— Unknown steering action: {args.steering_action}")
56
+ sys.exit(1)
57
+
58
+
59
+ def execute_comprehensive(args, model, loader):
60
+ """Execute comprehensive steering optimization with generation-based evaluation."""
61
+ from wisent.core.steering_methods.methods.caa import CAAMethod
62
+ from wisent.core.activations.activations_collector import ActivationCollector
63
+ from wisent.core.activations.core.atoms import ActivationAggregationStrategy
64
+ from wisent.core.models.core.atoms import SteeringPlan
65
+ from sklearn.metrics import accuracy_score
66
+ import torch
67
+
68
+ print(f"šŸ” Running comprehensive steering optimization...")
69
+ print(f" Optimizing: Layer, Strength, AND Steering Strategy")
70
+
71
+ # Determine tasks to optimize
72
+ if args.tasks:
73
+ task_list = args.tasks
74
+ else:
75
+ task_list = ["arc_easy", "hellaswag", "winogrande", "gsm8k"]
76
+
77
+ print(f" Tasks: {', '.join(task_list)}")
78
+ print(f" Methods: {', '.join(args.methods)}")
79
+ print(f" Limit: {args.limit} samples per task")
80
+ print(f" Time limit: {args.max_time_per_task} minutes per task\n")
81
+
82
+ all_results = {}
83
+
84
+ # Steering parameters to test
85
+ layers_to_test = [8, 9, 10, 11, 12]
86
+ strengths_to_test = [0.5, 1.0, 1.5, 2.0]
87
+ strategies_to_test = ["last_only", "first_only", "all_equal", "exponential_decay"]
88
+
89
+ for task_idx, task_name in enumerate(task_list, 1):
90
+ print(f"\n{'='*80}")
91
+ print(f"Task {task_idx}/{len(task_list)}: {task_name}")
92
+ print(f"{'='*80}")
93
+
94
+ task_start_time = time.time()
95
+
96
+ try:
97
+ # Load task data
98
+ print(f" šŸ“Š Loading task data...")
99
+ result = loader._load_one_task(
100
+ task_name=task_name,
101
+ split_ratio=0.8,
102
+ seed=42,
103
+ limit=args.limit,
104
+ training_limit=None,
105
+ testing_limit=None
106
+ )
107
+
108
+ train_pairs = result['train_qa_pairs']
109
+ test_pairs = result['test_qa_pairs']
110
+
111
+ print(f" āœ“ Loaded {len(train_pairs.pairs)} train, {len(test_pairs.pairs)} test pairs")
112
+
113
+ print(f"\n šŸ” Testing CAA method across layers, strengths, AND strategies...")
114
+ print(f" Total configurations: {len(layers_to_test)} layers Ɨ {len(strengths_to_test)} strengths Ɨ {len(strategies_to_test)} strategies = {len(layers_to_test) * len(strengths_to_test) * len(strategies_to_test)}")
115
+
116
+ best_score = 0
117
+ best_config = None
118
+ method_results = {}
119
+ configs_tested = 0
120
+
121
+ for layer in layers_to_test:
122
+ for strength in strengths_to_test:
123
+ for strategy in strategies_to_test:
124
+ if time.time() - task_start_time > args.max_time_per_task * 60:
125
+ print(f" ā° Time limit reached")
126
+ break
127
+
128
+ try:
129
+ configs_tested += 1
130
+ layer_str = str(layer)
131
+
132
+ # Step 1: Generate steering vector using CAA
133
+ collector = ActivationCollector(model=model, store_device="cpu")
134
+
135
+ pos_acts = []
136
+ neg_acts = []
137
+
138
+ for pair in train_pairs.pairs:
139
+ updated_pair = collector.collect_for_pair(
140
+ pair,
141
+ layers=[layer_str],
142
+ aggregation=ActivationAggregationStrategy.MEAN_POOLING,
143
+ return_full_sequence=False,
144
+ normalize_layers=False
145
+ )
146
+
147
+ if updated_pair.positive_response.layers_activations and layer_str in updated_pair.positive_response.layers_activations:
148
+ act = updated_pair.positive_response.layers_activations[layer_str]
149
+ if act is not None:
150
+ pos_acts.append(act)
151
+
152
+ if updated_pair.negative_response.layers_activations and layer_str in updated_pair.negative_response.layers_activations:
153
+ act = updated_pair.negative_response.layers_activations[layer_str]
154
+ if act is not None:
155
+ neg_acts.append(act)
156
+
157
+ if len(pos_acts) == 0 or len(neg_acts) == 0:
158
+ continue
159
+
160
+ # Create CAA steering vector
161
+ caa_method = CAAMethod(kwargs={"normalize": True})
162
+ steering_vector = caa_method.train_for_layer(pos_acts, neg_acts)
163
+
164
+ # Step 2: Evaluate with generation (simplified evaluation using activation alignment)
165
+ # In production, this would actually generate text and evaluate quality
166
+ # For now, we'll use activation alignment as a proxy
167
+ test_scores = []
168
+
169
+ for pair in test_pairs.pairs:
170
+ updated_pair = collector.collect_for_pair(
171
+ pair,
172
+ layers=[layer_str],
173
+ aggregation=ActivationAggregationStrategy.MEAN_POOLING,
174
+ return_full_sequence=False,
175
+ normalize_layers=False
176
+ )
177
+
178
+ if updated_pair.positive_response.layers_activations and layer_str in updated_pair.positive_response.layers_activations:
179
+ pos_act = updated_pair.positive_response.layers_activations[layer_str]
180
+ neg_act = updated_pair.negative_response.layers_activations[layer_str]
181
+
182
+ if pos_act is not None and neg_act is not None:
183
+ # Apply steering with strategy weighting
184
+ strategy_weight = get_strategy_weight(strategy, position=0.5) # Mid-position for evaluation
185
+
186
+ pos_steered = pos_act + (strength * strategy_weight) * steering_vector
187
+ neg_steered = neg_act + (strength * strategy_weight) * steering_vector
188
+
189
+ # Score: positive should be more aligned with positive direction
190
+ pos_score = torch.dot(pos_steered.flatten(), steering_vector.flatten()).item()
191
+ neg_score = torch.dot(neg_steered.flatten(), steering_vector.flatten()).item()
192
+
193
+ test_scores.append(1.0 if pos_score > neg_score else 0.0)
194
+
195
+ if len(test_scores) > 0:
196
+ avg_score = np.mean(test_scores)
197
+
198
+ if avg_score > best_score:
199
+ best_score = avg_score
200
+ best_config = {
201
+ 'layer': layer,
202
+ 'strength': strength,
203
+ 'strategy': strategy,
204
+ 'accuracy': avg_score
205
+ }
206
+
207
+ if configs_tested % 10 == 0 and args.verbose:
208
+ print(f" Tested {configs_tested} configurations...", end='\r')
209
+
210
+ except Exception as e:
211
+ if args.verbose:
212
+ print(f" Error at layer={layer}, strength={strength}, strategy={strategy}: {e}")
213
+ continue
214
+
215
+ if best_config:
216
+ print(f"\n āœ… Best configuration found:")
217
+ print(f" Method: CAA")
218
+ print(f" Layer: {best_config['layer']}")
219
+ print(f" Strength: {best_config['strength']}")
220
+ print(f" Strategy: {best_config['strategy']} ⭐")
221
+ print(f" Accuracy: {best_config['accuracy']:.3f}")
222
+
223
+ method_results['CAA'] = {
224
+ 'optimal_layer': best_config['layer'],
225
+ 'optimal_strength': best_config['strength'],
226
+ 'optimal_strategy': best_config['strategy'],
227
+ 'accuracy': best_config['accuracy'],
228
+ 'f1': best_config['accuracy']
229
+ }
230
+ else:
231
+ print(f"\n āš ļø No valid configuration found")
232
+ method_results['CAA'] = {
233
+ 'optimal_layer': 10,
234
+ 'optimal_strength': 1.0,
235
+ 'optimal_strategy': 'last_only',
236
+ 'accuracy': 0.5,
237
+ 'f1': 0.5
238
+ }
239
+
240
+ all_results[task_name] = {
241
+ 'methods': method_results,
242
+ 'best_method': 'CAA',
243
+ 'best_layer': method_results['CAA']['optimal_layer'],
244
+ 'best_strength': method_results['CAA']['optimal_strength'],
245
+ 'best_strategy': method_results['CAA']['optimal_strategy']
246
+ }
247
+
248
+ task_time = time.time() - task_start_time
249
+ print(f"\n ā±ļø Task completed in {task_time:.1f}s (tested {configs_tested} configurations)")
250
+
251
+ except Exception as e:
252
+ print(f" āŒ Failed to optimize {task_name}: {e}")
253
+ import traceback
254
+ traceback.print_exc()
255
+ continue
256
+
257
+ # Save results
258
+ print(f"\n{'='*80}")
259
+ print(f"šŸ“Š COMPREHENSIVE OPTIMIZATION COMPLETE")
260
+ print(f"{'='*80}\n")
261
+
262
+ results_file = f"./optimization_results/steering_comprehensive_{args.model.replace('/', '_')}.json"
263
+ import os
264
+ os.makedirs(os.path.dirname(results_file), exist_ok=True)
265
+
266
+ output_data = {
267
+ 'model': args.model,
268
+ 'tasks': all_results,
269
+ 'methods_tested': args.methods,
270
+ 'limit': args.limit,
271
+ 'optimization_dimensions': ['layer', 'strength', 'strategy']
272
+ }
273
+
274
+ with open(results_file, 'w') as f:
275
+ json.dump(output_data, f, indent=2)
276
+
277
+ print(f"āœ… Results saved to: {results_file}\n")
278
+
279
+ # Print summary
280
+ print("šŸ“‹ SUMMARY BY TASK:")
281
+ print("-" * 100)
282
+ for task_name, config in all_results.items():
283
+ print(f" {task_name:20s} | Method: {config['best_method']:10s} | Layer: {config['best_layer']:2d} | Strength: {config['best_strength']:.2f} | Strategy: {config['best_strategy']:18s}")
284
+ print("-" * 100 + "\n")
285
+
286
+
287
+ def get_strategy_weight(strategy: str, position: float) -> float:
288
+ """
289
+ Calculate steering weight based on strategy and token position.
290
+
291
+ Args:
292
+ strategy: Steering strategy name
293
+ position: Token position as fraction (0.0 = start, 1.0 = end)
294
+
295
+ Returns:
296
+ Weight multiplier for steering vector
297
+ """
298
+ if strategy == "last_only":
299
+ return 1.0 if position >= 0.9 else 0.0
300
+ elif strategy == "first_only":
301
+ return 1.0 if position <= 0.1 else 0.0
302
+ elif strategy == "all_equal":
303
+ return 1.0
304
+ elif strategy == "exponential_decay":
305
+ return np.exp(-3.0 * position) # Decay rate of 3
306
+ elif strategy == "exponential_growth":
307
+ return np.exp(3.0 * position)
308
+ elif strategy == "linear_decay":
309
+ return 1.0 - position
310
+ elif strategy == "linear_growth":
311
+ return position
312
+ else:
313
+ return 1.0 # Default to all_equal
314
+
315
+
316
+ def execute_compare_methods(args, model, loader):
317
+ """Execute method comparison."""
318
+ print(f"šŸ” Comparing steering methods for task: {args.task}\n")
319
+ print(f" Methods: {', '.join(args.methods)}")
320
+ print(f" Limit: {args.limit} samples\n")
321
+
322
+ result = loader._load_one_task(
323
+ task_name=args.task,
324
+ split_ratio=0.8,
325
+ seed=42,
326
+ limit=args.limit,
327
+ training_limit=None,
328
+ testing_limit=None
329
+ )
330
+
331
+ print(f"āœ… Loaded {len(result['train_qa_pairs'].pairs)} train pairs\n")
332
+ print("āš ļø Full method comparison requires implementation of HPR, DAC, BiPO, KSteering")
333
+ print(" Currently only CAA is fully implemented")
334
+
335
+
336
+ def execute_optimize_layer(args, model, loader):
337
+ """Execute layer optimization."""
338
+ print(f"šŸŽÆ Optimizing steering layer for task: {args.task}\n")
339
+ print(f" Method: {args.method}")
340
+ print(f" Strength: {args.strength}\n")
341
+
342
+ print("āš ļø Layer optimization not yet fully implemented")
343
+ print(f" This would optimize layer for {args.method} method")
344
+
345
+
346
+ def execute_optimize_strength(args, model, loader):
347
+ """Execute strength optimization."""
348
+ print(f"šŸ’Ŗ Optimizing steering strength for task: {args.task}\n")
349
+ print(f" Method: {args.method}")
350
+ print(f" Strength range: {args.strength_range[0]} to {args.strength_range[1]}\n")
351
+
352
+ print("āš ļø Strength optimization not yet fully implemented")
353
+ print(f" This would optimize strength for {args.method} method")
354
+
355
+
356
+ def execute_auto(args, model, loader):
357
+ """Execute automatic optimization based on classification config."""
358
+ print(f"šŸ¤– Running automatic steering optimization...\n")
359
+ print(f" Methods: {', '.join(args.methods)}")
360
+ print(f" Strength range: {args.strength_range}\n")
361
+
362
+ print("āš ļø Auto optimization not yet fully implemented")
363
+ print(" This would use classification results to guide steering optimization")
364
+
@@ -0,0 +1,182 @@
1
+ """Tasks command execution logic."""
2
+
3
+ import sys
4
+ import os
5
+ import json
6
+ import numpy as np
7
+
8
+
9
+ def execute_tasks(args):
10
+ """Execute the tasks command - train classifier on benchmark tasks."""
11
+ from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
12
+ from wisent.core.models.wisent_model import WisentModel
13
+ from wisent.core.activations.activations_collector import ActivationCollector
14
+ from wisent.core.activations.core.atoms import ActivationAggregationStrategy
15
+ from wisent.core.classifiers.classifiers.models.logistic import LogisticClassifier
16
+ from wisent.core.classifiers.classifiers.models.mlp import MLPClassifier
17
+ from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainConfig
18
+ from wisent.core.model_persistence import ModelPersistence, create_classifier_metadata
19
+
20
+ print(f"\nšŸŽÆ Starting classifier training on task: {args.task_names}")
21
+ print(f" Model: {args.model}")
22
+ print(f" Layer: {args.layer}")
23
+ print(f" Classifier type: {args.classifier_type}")
24
+
25
+ try:
26
+ # 1. Load task data using LMEvalDataLoader
27
+ print(f"\nšŸ“Š Loading task '{args.task_names}'...")
28
+ loader = LMEvalDataLoader()
29
+ result = loader._load_one_task(
30
+ task_name=args.task_names,
31
+ split_ratio=args.split_ratio,
32
+ seed=args.seed,
33
+ limit=args.limit,
34
+ training_limit=args.training_limit,
35
+ testing_limit=args.testing_limit
36
+ )
37
+
38
+ # Use training pairs for classifier training
39
+ pair_set = result['train_qa_pairs']
40
+ print(f" āœ“ Loaded {len(pair_set.pairs)} training pairs")
41
+
42
+ # 3. Load model
43
+ print(f"\nšŸ¤– Loading model '{args.model}'...")
44
+ model = WisentModel(args.model, device=args.device)
45
+ print(f" āœ“ Model loaded with {model.num_layers} layers")
46
+
47
+ # 4. Parse layer specification
48
+ layer = int(args.layer) if isinstance(args.layer, str) else args.layer
49
+ print(f"\n🧠 Extracting activations from layer {layer}...")
50
+
51
+ # 5. Collect activations for all pairs
52
+ collector = ActivationCollector(model=model, store_device="cpu")
53
+
54
+ # Map parser values to enum members
55
+ aggregation_map = {
56
+ 'average': 'MEAN_POOLING',
57
+ 'final': 'LAST_TOKEN',
58
+ 'first': 'FIRST_TOKEN',
59
+ 'max': 'MAX_POOLING',
60
+ 'min': 'MAX_POOLING', # Fallback to MAX_POOLING for min
61
+ }
62
+ aggregation_key = aggregation_map.get(args.token_aggregation.lower(), 'MEAN_POOLING')
63
+ aggregation_strategy = ActivationAggregationStrategy[aggregation_key]
64
+
65
+ positive_activations = []
66
+ negative_activations = []
67
+
68
+ # Convert layer int to string for activation collection
69
+ layer_str = str(layer)
70
+
71
+ for i, pair in enumerate(pair_set.pairs):
72
+ if i % 10 == 0:
73
+ print(f" Processing pair {i+1}/{len(pair_set.pairs)}...", end='\r')
74
+
75
+ # Collect for positive (correct) response
76
+ updated_pair = collector.collect_for_pair(
77
+ pair,
78
+ layers=[layer_str],
79
+ aggregation=aggregation_strategy,
80
+ return_full_sequence=False,
81
+ normalize_layers=False
82
+ )
83
+
84
+ # Extract activations from positive and negative responses
85
+ if updated_pair.positive_response.layers_activations and layer_str in updated_pair.positive_response.layers_activations:
86
+ act = updated_pair.positive_response.layers_activations[layer_str]
87
+ if act is not None:
88
+ positive_activations.append(act.cpu().numpy())
89
+
90
+ if updated_pair.negative_response.layers_activations and layer_str in updated_pair.negative_response.layers_activations:
91
+ act = updated_pair.negative_response.layers_activations[layer_str]
92
+ if act is not None:
93
+ negative_activations.append(act.cpu().numpy())
94
+
95
+ print(f"\n āœ“ Collected {len(positive_activations)} positive and {len(negative_activations)} negative activations")
96
+
97
+ # 6. Prepare training data
98
+ print(f"\nšŸŽÆ Preparing training data...")
99
+ X_positive = np.array(positive_activations)
100
+ X_negative = np.array(negative_activations)
101
+ X = np.vstack([X_positive, X_negative])
102
+ y = np.array([1] * len(positive_activations) + [0] * len(negative_activations))
103
+
104
+ print(f" Training set: {X.shape[0]} samples, {X.shape[1]} features")
105
+ print(f" Positive samples: {sum(y == 1)}, Negative samples: {sum(y == 0)}")
106
+
107
+ # 7. Create and train classifier
108
+ print(f"\nšŸ‹ļø Training {args.classifier_type} classifier...")
109
+ if args.classifier_type == 'logistic':
110
+ classifier = LogisticClassifier(threshold=args.detection_threshold, device=args.device)
111
+ elif args.classifier_type == 'mlp':
112
+ classifier = MLPClassifier(threshold=args.detection_threshold, device=args.device)
113
+ else:
114
+ raise ValueError(f"Unknown classifier type: {args.classifier_type}")
115
+
116
+ # Training configuration
117
+ train_config = ClassifierTrainConfig(
118
+ test_size=1.0 - args.split_ratio,
119
+ num_epochs=50,
120
+ batch_size=32,
121
+ learning_rate=1e-3,
122
+ monitor='f1',
123
+ random_state=args.seed
124
+ )
125
+
126
+ # Train the classifier
127
+ report = classifier.fit(X, y, config=train_config)
128
+
129
+ # 8. Print results
130
+ print(f"\nšŸ“ˆ Training completed!")
131
+ print(f" Best epoch: {report.best_epoch}/{report.epochs_ran}")
132
+ print(f" Final metrics:")
133
+ print(f" • Accuracy: {report.final.accuracy:.4f}")
134
+ print(f" • Precision: {report.final.precision:.4f}")
135
+ print(f" • Recall: {report.final.recall:.4f}")
136
+ print(f" • F1 Score: {report.final.f1:.4f}")
137
+ print(f" • AUC: {report.final.auc:.4f}")
138
+
139
+ # 9. Save classifier if requested
140
+ if args.save_classifier:
141
+ print(f"\nšŸ’¾ Saving classifier to '{args.save_classifier}'...")
142
+
143
+ # Create metadata
144
+ metadata = create_classifier_metadata(
145
+ model_name=args.model,
146
+ task_name=args.task_names,
147
+ layer=layer,
148
+ classifier_type=args.classifier_type,
149
+ training_accuracy=report.final.accuracy,
150
+ training_samples=len(X),
151
+ token_aggregation=args.token_aggregation,
152
+ detection_threshold=args.detection_threshold
153
+ )
154
+
155
+ # Save using model persistence
156
+ save_path = ModelPersistence.save_classifier(
157
+ classifier=classifier,
158
+ layer=layer,
159
+ save_path=args.save_classifier,
160
+ metadata=metadata
161
+ )
162
+ print(f" āœ“ Classifier saved to: {save_path}")
163
+
164
+ # 10. Save output artifacts if requested
165
+ if args.output:
166
+ print(f"\nšŸ“ Saving artifacts to '{args.output}'...")
167
+ os.makedirs(args.output, exist_ok=True)
168
+
169
+ # Save training report
170
+ report_path = os.path.join(args.output, 'training_report.json')
171
+ with open(report_path, 'w') as f:
172
+ json.dump(report.asdict(), f, indent=2)
173
+ print(f" āœ“ Training report saved to: {report_path}")
174
+
175
+ print(f"\nāœ… Task completed successfully!\n")
176
+
177
+ except Exception as e:
178
+ print(f"\nāŒ Error: {str(e)}", file=sys.stderr)
179
+ if args.verbose:
180
+ import traceback
181
+ traceback.print_exc()
182
+ sys.exit(1)
@@ -0,0 +1,22 @@
1
+ """Simple CLI logger replacement for removed wisent.cli module."""
2
+
3
+ import logging
4
+
5
+
6
+ def setup_logger(name: str) -> logging.Logger:
7
+ """Set up a logger with the given name."""
8
+ logger = logging.getLogger(name)
9
+ if not logger.handlers:
10
+ handler = logging.StreamHandler()
11
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+ handler.setFormatter(formatter)
13
+ logger.addHandler(handler)
14
+ logger.setLevel(logging.INFO)
15
+ return logger
16
+
17
+
18
+ def bind(logger: logging.Logger, **kwargs) -> logging.Logger:
19
+ """Bind context to logger (simplified - just returns the logger)."""
20
+ # In the original, this probably added context fields
21
+ # For now, just return the logger as-is
22
+ return logger
@@ -68,6 +68,7 @@ class LMEvalBenchmarkExtractor(ABC):
68
68
  cls,
69
69
  lm_eval_task_data: ConfigurableTask,
70
70
  limit: int | None = None,
71
+ preferred_doc: str | None = None,
71
72
  ) -> list[dict[str, Any]]:
72
73
  """
73
74
  Load labeled documents from the most appropriate split with a clear
@@ -75,6 +76,9 @@ class LMEvalBenchmarkExtractor(ABC):
75
76
 
76
77
  validation → test → train → fewshot
77
78
 
79
+ If preferred_doc is provided, that source will be tried first before
80
+ falling back to the default order.
81
+
78
82
  If none are available, attempts a dataset fallback using
79
83
  'datasets.load_dataset' with the task's declared metadata
80
84
  (e.g., 'dataset_path'/'dataset_name', 'dataset_config_name',
@@ -86,6 +90,10 @@ class LMEvalBenchmarkExtractor(ABC):
86
90
  limit:
87
91
  Optional maximum number of documents to return.
88
92
  Values <= 0 are treated as "no limit".
93
+ preferred_doc:
94
+ Optional preferred document source. Valid values:
95
+ "validation", "test", "training", "fewshot".
96
+ If provided, this source will be tried first.
89
97
 
90
98
  returns:
91
99
  A list of document dictionaries.
@@ -98,17 +106,35 @@ class LMEvalBenchmarkExtractor(ABC):
98
106
  """
99
107
  max_items = cls._normalize_limit(limit)
100
108
 
101
- preferred_sources: Sequence[tuple[str, str]] = (
109
+ # Map preferred_doc string to the tuple format
110
+ doc_source_map = {
111
+ "validation": ("has_validation_docs", "validation_docs"),
112
+ "test": ("has_test_docs", "test_docs"),
113
+ "training": ("has_training_docs", "training_docs"),
114
+ "fewshot": ("has_fewshot_docs", "fewshot_docs"),
115
+ }
116
+
117
+ # Build preferred_sources based on preferred_doc
118
+ default_order: Sequence[tuple[str, str]] = (
102
119
  ("has_validation_docs", "validation_docs"),
103
120
  ("has_test_docs", "test_docs"),
104
121
  ("has_training_docs", "training_docs"),
105
122
  ("has_fewshot_docs", "fewshot_docs"),
106
123
  )
107
124
 
125
+ if preferred_doc and preferred_doc in doc_source_map:
126
+ # Put preferred source first, then other sources
127
+ preferred_source = doc_source_map[preferred_doc]
128
+ other_sources = [s for s in default_order if s != preferred_source]
129
+ preferred_sources = (preferred_source,) + tuple(other_sources)
130
+ else:
131
+ preferred_sources = default_order
132
+
108
133
  for has_method, docs_method in preferred_sources:
109
134
  if cls._has_true(lm_eval_task_data, has_method) and cls._has_callable(
110
135
  lm_eval_task_data, docs_method
111
136
  ):
137
+ print(f"loaded from {docs_method}")
112
138
  docs_iter = getattr(lm_eval_task_data, docs_method)()
113
139
  docs_list = cls._coerce_docs_to_dicts(docs_iter, max_items)
114
140
  if docs_list: