wisent 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (225) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/core/activations/__init__.py +26 -0
  3. wisent/core/activations/activations.py +96 -0
  4. wisent/core/activations/activations_collector.py +71 -20
  5. wisent/core/activations/prompt_construction_strategy.py +47 -0
  6. wisent/core/agent/budget.py +2 -2
  7. wisent/core/agent/device_benchmarks.py +1 -1
  8. wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
  9. wisent/core/agent/diagnose/response_diagnostics.py +4 -4
  10. wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
  11. wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
  12. wisent/core/agent/diagnose.py +2 -1
  13. wisent/core/autonomous_agent.py +10 -2
  14. wisent/core/benchmark_extractors.py +293 -0
  15. wisent/core/bigcode_integration.py +20 -7
  16. wisent/core/branding.py +108 -0
  17. wisent/core/cli/__init__.py +15 -0
  18. wisent/core/cli/create_steering_vector.py +138 -0
  19. wisent/core/cli/evaluate_responses.py +715 -0
  20. wisent/core/cli/generate_pairs.py +128 -0
  21. wisent/core/cli/generate_pairs_from_task.py +119 -0
  22. wisent/core/cli/generate_responses.py +129 -0
  23. wisent/core/cli/generate_vector_from_synthetic.py +149 -0
  24. wisent/core/cli/generate_vector_from_task.py +147 -0
  25. wisent/core/cli/get_activations.py +191 -0
  26. wisent/core/cli/optimize_classification.py +339 -0
  27. wisent/core/cli/optimize_steering.py +364 -0
  28. wisent/core/cli/tasks.py +182 -0
  29. wisent/core/cli_logger.py +22 -0
  30. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
  31. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
  32. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
  33. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
  34. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
  35. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
  36. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
  37. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
  38. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
  39. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
  40. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
  41. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
  42. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
  43. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
  44. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
  45. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
  46. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
  47. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
  48. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
  49. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
  50. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
  51. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
  52. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
  53. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
  54. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
  55. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
  56. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
  57. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
  58. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
  59. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
  60. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
  61. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
  82. wisent/core/data_loaders/__init__.py +235 -0
  83. wisent/core/data_loaders/loaders/lm_loader.py +2 -2
  84. wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
  85. wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
  86. wisent/core/download_full_benchmarks.py +79 -2
  87. wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
  88. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
  89. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
  90. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
  91. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
  92. wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
  93. wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
  94. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
  95. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
  96. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
  97. wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
  98. wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
  99. wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
  100. wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
  101. wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
  102. wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
  103. wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
  104. wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
  105. wisent/core/lm_eval_harness_ground_truth.py +3 -2
  106. wisent/core/main.py +57 -0
  107. wisent/core/model_persistence.py +2 -2
  108. wisent/core/models/wisent_model.py +8 -6
  109. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  110. wisent/core/optuna/steering/steering_optimization.py +1 -1
  111. wisent/core/parser_arguments/__init__.py +10 -0
  112. wisent/core/parser_arguments/agent_parser.py +110 -0
  113. wisent/core/parser_arguments/configure_model_parser.py +7 -0
  114. wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
  115. wisent/core/parser_arguments/evaluate_parser.py +40 -0
  116. wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
  117. wisent/core/parser_arguments/full_optimize_parser.py +115 -0
  118. wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
  119. wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
  120. wisent/core/parser_arguments/generate_responses_parser.py +15 -0
  121. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
  122. wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
  123. wisent/core/parser_arguments/generate_vector_parser.py +90 -0
  124. wisent/core/parser_arguments/get_activations_parser.py +90 -0
  125. wisent/core/parser_arguments/main_parser.py +152 -0
  126. wisent/core/parser_arguments/model_config_parser.py +59 -0
  127. wisent/core/parser_arguments/monitor_parser.py +17 -0
  128. wisent/core/parser_arguments/multi_steer_parser.py +47 -0
  129. wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
  130. wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
  131. wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
  132. wisent/core/parser_arguments/synthetic_parser.py +93 -0
  133. wisent/core/parser_arguments/tasks_parser.py +584 -0
  134. wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
  135. wisent/core/parser_arguments/utils.py +111 -0
  136. wisent/core/prompts/core/prompt_formater.py +3 -3
  137. wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
  138. wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
  139. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
  140. wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
  141. wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
  142. wisent/core/steering_optimizer.py +45 -21
  143. wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
  144. wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
  145. wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
  146. wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
  147. wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
  148. wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
  149. wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
  150. wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
  151. wisent/core/tasks/livecodebench_task.py +4 -103
  152. wisent/core/timing_calibration.py +1 -1
  153. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
  154. wisent-0.5.13.dist-info/RECORD +294 -0
  155. wisent-0.5.13.dist-info/entry_points.txt +2 -0
  156. wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
  157. wisent/classifiers/core/atoms.py +0 -747
  158. wisent/classifiers/models/logistic.py +0 -29
  159. wisent/classifiers/models/mlp.py +0 -47
  160. wisent/cli/classifiers/classifier_rotator.py +0 -137
  161. wisent/cli/cli_logger.py +0 -142
  162. wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
  163. wisent/cli/wisent_cli/commands/listing.py +0 -154
  164. wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
  165. wisent/cli/wisent_cli/main.py +0 -93
  166. wisent/cli/wisent_cli/shell.py +0 -80
  167. wisent/cli/wisent_cli/ui.py +0 -69
  168. wisent/cli/wisent_cli/util/aggregations.py +0 -43
  169. wisent/cli/wisent_cli/util/parsing.py +0 -126
  170. wisent/cli/wisent_cli/version.py +0 -4
  171. wisent/opti/methods/__init__.py +0 -0
  172. wisent/synthetic/__init__.py +0 -0
  173. wisent/synthetic/cleaners/__init__.py +0 -0
  174. wisent/synthetic/cleaners/core/__init__.py +0 -0
  175. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  176. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  177. wisent/synthetic/db_instructions/__init__.py +0 -0
  178. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  179. wisent/synthetic/generators/__init__.py +0 -0
  180. wisent/synthetic/generators/core/__init__.py +0 -0
  181. wisent/synthetic/generators/diversities/__init__.py +0 -0
  182. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  183. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  184. wisent-0.5.11.dist-info/RECORD +0 -220
  185. /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
  186. /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
  187. /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
  188. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
  189. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
  190. /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
  191. /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
  192. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
  193. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
  194. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
  195. /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
  196. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
  197. /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
  198. /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
  199. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
  200. /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
  201. /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
  202. /wisent/{opti → core/opti}/core/atoms.py +0 -0
  203. /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
  204. /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
  205. /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
  206. /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
  207. /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
  208. /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
  209. /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
  210. /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
  211. /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
  212. /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
  213. /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
  214. /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
  215. /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
  216. /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
  217. /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
  218. /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
  219. /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
  220. /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
  221. /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
  222. /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
  223. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
  224. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
  225. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,191 @@
1
+ """Get activations command execution logic."""
2
+
3
+ import sys
4
+ import json
5
+ import os
6
+ import time
7
+
8
+
9
+ def execute_get_activations(args):
10
+ """Execute the get-activations command - load pairs and collect activations."""
11
+ from wisent.core.models.wisent_model import WisentModel
12
+ from wisent.core.activations.activations_collector import ActivationCollector
13
+ from wisent.core.activations.core.atoms import ActivationAggregationStrategy
14
+ from wisent.core.activations.prompt_construction_strategy import PromptConstructionStrategy
15
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
16
+ from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
17
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
18
+
19
+ print(f"\n🎨 Collecting activations from contrastive pairs")
20
+ print(f" Input file: {args.pairs_file}")
21
+ print(f" Model: {args.model}")
22
+
23
+ start_time = time.time() if args.timing else None
24
+
25
+ try:
26
+ # 1. Load pairs from JSON
27
+ print(f"\n📂 Loading contrastive pairs...")
28
+ if not os.path.exists(args.pairs_file):
29
+ raise FileNotFoundError(f"Pairs file not found: {args.pairs_file}")
30
+
31
+ with open(args.pairs_file, 'r') as f:
32
+ data = json.load(f)
33
+
34
+ # Handle both formats: dict with 'pairs' key or direct list
35
+ if isinstance(data, dict):
36
+ pairs_list = data.get('pairs', [])
37
+ task_name = data.get('task_name', 'unknown')
38
+ trait_label = data.get('trait_label', task_name)
39
+ else:
40
+ pairs_list = data
41
+ task_name = 'unknown'
42
+ trait_label = 'unknown'
43
+
44
+ # Apply limit if specified
45
+ if args.limit:
46
+ pairs_list = pairs_list[:args.limit]
47
+
48
+ print(f" ✓ Loaded {len(pairs_list)} pairs")
49
+
50
+ # 2. Load model
51
+ print(f"\n🤖 Loading model '{args.model}'...")
52
+ model = WisentModel(args.model, device=args.device)
53
+ print(f" ✓ Model loaded with {model.num_layers} layers")
54
+
55
+ # 3. Determine layers to collect
56
+ if args.layers is None:
57
+ # Default: use middle layer
58
+ layers = [model.num_layers // 2]
59
+ elif args.layers.lower() == 'all':
60
+ layers = list(range(1, model.num_layers + 1))
61
+ else:
62
+ layers = [int(l.strip()) for l in args.layers.split(',')]
63
+
64
+ # Convert to strings for API
65
+ layer_strs = [str(l) for l in layers]
66
+
67
+ print(f"\n🎯 Collecting activations from {len(layers)} layer(s): {layers}")
68
+
69
+ # 4. Set up aggregation strategy
70
+ aggregation_map = {
71
+ 'average': 'MEAN_POOLING',
72
+ 'final': 'LAST_TOKEN',
73
+ 'first': 'FIRST_TOKEN',
74
+ 'max': 'MAX_POOLING',
75
+ 'min': 'MAX_POOLING',
76
+ }
77
+ aggregation_key = aggregation_map.get(args.token_aggregation.lower(), 'MEAN_POOLING')
78
+ aggregation_strategy = ActivationAggregationStrategy[aggregation_key]
79
+
80
+ # 5. Map prompt strategy string to enum
81
+ prompt_strategy_map = {
82
+ 'chat_template': PromptConstructionStrategy.CHAT_TEMPLATE,
83
+ 'direct_completion': PromptConstructionStrategy.DIRECT_COMPLETION,
84
+ 'instruction_following': PromptConstructionStrategy.INSTRUCTION_FOLLOWING,
85
+ 'multiple_choice': PromptConstructionStrategy.MULTIPLE_CHOICE,
86
+ 'role_playing': PromptConstructionStrategy.ROLE_PLAYING,
87
+ }
88
+ prompt_strategy = prompt_strategy_map.get(args.prompt_strategy.lower(), PromptConstructionStrategy.CHAT_TEMPLATE)
89
+
90
+ print(f" Token aggregation: {args.token_aggregation} ({aggregation_key})")
91
+ print(f" Prompt strategy: {args.prompt_strategy}")
92
+
93
+ # 5. Create pair set and reconstruct pairs
94
+ pair_set = ContrastivePairSet(name=task_name, task_type=trait_label)
95
+
96
+ for pair_data in pairs_list:
97
+ pair = ContrastivePair(
98
+ prompt=pair_data['prompt'],
99
+ positive_response=PositiveResponse(
100
+ model_response=pair_data['positive_response']['model_response']
101
+ ),
102
+ negative_response=NegativeResponse(
103
+ model_response=pair_data['negative_response']['model_response']
104
+ ),
105
+ label=pair_data.get('label', trait_label),
106
+ trait_description=pair_data.get('trait_description', ''),
107
+ )
108
+ pair_set.add(pair)
109
+
110
+ # 6. Collect activations
111
+ print(f"\n⚡ Collecting activations...")
112
+ collector = ActivationCollector(model=model, store_device="cpu")
113
+
114
+ enriched_pairs = []
115
+ for i, pair in enumerate(pair_set.pairs):
116
+ if args.verbose:
117
+ print(f" Processing pair {i+1}/{len(pair_set.pairs)}...")
118
+
119
+ # Collect activations for all requested layers at once
120
+ updated_pair = collector.collect_for_pair(
121
+ pair,
122
+ layers=layer_strs,
123
+ aggregation=aggregation_strategy,
124
+ return_full_sequence=False,
125
+ normalize_layers=False,
126
+ prompt_strategy=prompt_strategy
127
+ )
128
+
129
+ enriched_pairs.append(updated_pair)
130
+
131
+ print(f" ✓ Collected activations for {len(enriched_pairs)} pairs")
132
+
133
+ # 7. Convert to JSON format
134
+ print(f"\n💾 Saving enriched pairs to '{args.output}'...")
135
+ output_data = {
136
+ 'task_name': task_name,
137
+ 'trait_label': trait_label,
138
+ 'model': args.model,
139
+ 'layers': layers,
140
+ 'token_aggregation': args.token_aggregation,
141
+ 'num_pairs': len(enriched_pairs),
142
+ 'pairs': []
143
+ }
144
+
145
+ for pair in enriched_pairs:
146
+ pair_dict = {
147
+ 'prompt': pair.prompt,
148
+ 'positive_response': {
149
+ 'model_response': pair.positive_response.model_response,
150
+ 'layers_activations': {}
151
+ },
152
+ 'negative_response': {
153
+ 'model_response': pair.negative_response.model_response,
154
+ 'layers_activations': {}
155
+ },
156
+ 'label': pair.label,
157
+ 'trait_description': pair.trait_description,
158
+ }
159
+
160
+ # Convert activations to lists for JSON serialization
161
+ if pair.positive_response.layers_activations:
162
+ for layer_str, act in pair.positive_response.layers_activations.items():
163
+ if act is not None:
164
+ pair_dict['positive_response']['layers_activations'][layer_str] = act.cpu().tolist()
165
+
166
+ if pair.negative_response.layers_activations:
167
+ for layer_str, act in pair.negative_response.layers_activations.items():
168
+ if act is not None:
169
+ pair_dict['negative_response']['layers_activations'][layer_str] = act.cpu().tolist()
170
+
171
+ output_data['pairs'].append(pair_dict)
172
+
173
+ # 8. Save to file
174
+ os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
175
+ with open(args.output, 'w') as f:
176
+ json.dump(output_data, f, indent=2)
177
+
178
+ print(f" ✓ Saved enriched pairs to: {args.output}")
179
+
180
+ if args.timing:
181
+ elapsed = time.time() - start_time
182
+ print(f" ⏱️ Total time: {elapsed:.2f}s")
183
+
184
+ print(f"\n✅ Activation collection completed successfully!\n")
185
+
186
+ except Exception as e:
187
+ print(f"\n❌ Error: {str(e)}", file=sys.stderr)
188
+ if args.verbose:
189
+ import traceback
190
+ traceback.print_exc()
191
+ sys.exit(1)
@@ -0,0 +1,339 @@
1
+ """Classification optimization command execution logic."""
2
+
3
+ import sys
4
+ import json
5
+ import time
6
+ from typing import List, Dict, Any
7
+
8
+ def execute_optimize_classification(args):
9
+ """
10
+ Execute the optimize-classification command.
11
+
12
+ Optimizes classification parameters across all available tasks:
13
+ - Finds best layer for each task
14
+ - Finds best token aggregation method
15
+ - Finds best detection threshold
16
+ - Saves trained classifiers
17
+
18
+ EFFICIENCY: Collects raw activations ONCE, then applies different aggregation strategies
19
+ to the cached activations without re-running the model.
20
+ """
21
+ from wisent.core.models.wisent_model import WisentModel
22
+ from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
23
+ from wisent.core.activations.activations_collector import ActivationCollector
24
+ from wisent.core.activations.core.atoms import ActivationAggregationStrategy
25
+ from wisent.core.classifiers.classifiers.models.logistic import LogisticClassifier
26
+ from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainConfig
27
+ import numpy as np
28
+ import torch
29
+
30
+ print(f"\n{'='*80}")
31
+ print(f"🔍 CLASSIFICATION PARAMETER OPTIMIZATION")
32
+ print(f"{'='*80}")
33
+ print(f" Model: {args.model}")
34
+ print(f" Limit per task: {args.limit}")
35
+ print(f" Optimization metric: {args.optimization_metric}")
36
+ print(f" Device: {args.device or 'auto'}")
37
+ print(f"{'='*80}\n")
38
+
39
+ # 1. Load model
40
+ print(f"📦 Loading model...")
41
+ model = WisentModel(args.model, device=args.device)
42
+ total_layers = model.num_layers
43
+ print(f" ✓ Model loaded with {total_layers} layers\n")
44
+
45
+ # 2. Determine layer range
46
+ if args.layer_range:
47
+ start, end = map(int, args.layer_range.split('-'))
48
+ layers_to_test = list(range(start, end + 1))
49
+ else:
50
+ # Test middle layers by default (more informative)
51
+ start_layer = total_layers // 3
52
+ end_layer = (2 * total_layers) // 3
53
+ layers_to_test = list(range(start_layer, end_layer + 1))
54
+
55
+ print(f"🎯 Testing layers: {layers_to_test[0]} to {layers_to_test[-1]} ({len(layers_to_test)} layers)")
56
+ print(f"🔄 Aggregation methods: {', '.join(args.aggregation_methods)}")
57
+ print(f"📊 Thresholds: {args.threshold_range}\n")
58
+
59
+ # 3. Get list of tasks to optimize
60
+ task_list = [
61
+ "arc_easy", "arc_challenge", "hellaswag",
62
+ "winogrande", "gsm8k"
63
+ ]
64
+
65
+ print(f"📋 Optimizing {len(task_list)} tasks\n")
66
+
67
+ # 4. Initialize data loader
68
+ loader = LMEvalDataLoader()
69
+
70
+ # 5. Results storage
71
+ all_results = {}
72
+ classifiers_saved = {}
73
+
74
+ # 6. Process each task
75
+ for task_idx, task_name in enumerate(task_list, 1):
76
+ print(f"\n{'='*80}")
77
+ print(f"Task {task_idx}/{len(task_list)}: {task_name}")
78
+ print(f"{'='*80}")
79
+
80
+ task_start_time = time.time()
81
+
82
+ try:
83
+ # Load task data
84
+ print(f" 📊 Loading data...")
85
+ result = loader._load_one_task(
86
+ task_name=task_name,
87
+ split_ratio=0.8,
88
+ seed=42,
89
+ limit=args.limit,
90
+ training_limit=None,
91
+ testing_limit=None
92
+ )
93
+
94
+ train_pairs = result['train_qa_pairs']
95
+ test_pairs = result['test_qa_pairs']
96
+
97
+ print(f" ✓ Loaded {len(train_pairs.pairs)} train, {len(test_pairs.pairs)} test pairs")
98
+
99
+ # STEP 1: Collect raw activations ONCE for all layers (full sequence)
100
+ print(f" 🧠 Collecting raw activations (once per pair)...")
101
+ collector = ActivationCollector(model=model, store_device="cpu")
102
+
103
+ # Cache structure: train_cache[pair_idx][layer_str] = {pos: tensor, neg: tensor, pos_tokens: int, neg_tokens: int}
104
+ train_cache = {}
105
+ test_cache = {}
106
+
107
+ layer_strs = [str(l) for l in layers_to_test]
108
+
109
+ # Collect training activations with full sequence
110
+ for pair_idx, pair in enumerate(train_pairs.pairs):
111
+ updated_pair = collector.collect_for_pair(
112
+ pair,
113
+ layers=layer_strs,
114
+ aggregation=None, # Get raw activations without aggregation
115
+ return_full_sequence=True, # Get all token positions
116
+ normalize_layers=False
117
+ )
118
+
119
+ train_cache[pair_idx] = {}
120
+ for layer_str in layer_strs:
121
+ train_cache[pair_idx][layer_str] = {
122
+ 'pos': updated_pair.positive_response.layers_activations.get(layer_str),
123
+ 'neg': updated_pair.negative_response.layers_activations.get(layer_str),
124
+ }
125
+
126
+ # Collect test activations
127
+ for pair_idx, pair in enumerate(test_pairs.pairs):
128
+ updated_pair = collector.collect_for_pair(
129
+ pair,
130
+ layers=layer_strs,
131
+ aggregation=None,
132
+ return_full_sequence=True,
133
+ normalize_layers=False
134
+ )
135
+
136
+ test_cache[pair_idx] = {}
137
+ for layer_str in layer_strs:
138
+ test_cache[pair_idx][layer_str] = {
139
+ 'pos': updated_pair.positive_response.layers_activations.get(layer_str),
140
+ 'neg': updated_pair.negative_response.layers_activations.get(layer_str),
141
+ }
142
+
143
+ print(f" ✓ Cached activations for {len(train_cache)} train and {len(test_cache)} test pairs")
144
+
145
+ # STEP 2: Apply different aggregation strategies to cached activations
146
+ print(f" 🔍 Testing {len(layers_to_test) * len(args.aggregation_methods)} layer/aggregation combinations...")
147
+
148
+ # Aggregation functions
149
+ def aggregate_activations(raw_acts, method):
150
+ """Apply aggregation to raw activation tensor."""
151
+ if raw_acts is None or raw_acts.numel() == 0:
152
+ return None
153
+
154
+ # Handle both 1D (already aggregated) and 2D (sequence, hidden_dim) tensors
155
+ if raw_acts.ndim == 1:
156
+ return raw_acts
157
+ elif raw_acts.ndim == 2:
158
+ if method == 'average':
159
+ return raw_acts.mean(dim=0)
160
+ elif method == 'final':
161
+ return raw_acts[-1]
162
+ elif method == 'first':
163
+ return raw_acts[0]
164
+ elif method == 'max':
165
+ return raw_acts.max(dim=0)[0]
166
+ elif method == 'min':
167
+ return raw_acts.min(dim=0)[0]
168
+ else:
169
+ # Flatten to 2D if needed
170
+ raw_acts = raw_acts.view(-1, raw_acts.shape[-1])
171
+ return aggregate_activations(raw_acts, method)
172
+
173
+ best_score = -1
174
+ best_config = None
175
+ best_classifier = None
176
+
177
+ combinations_tested = 0
178
+ total_combinations = len(layers_to_test) * len(args.aggregation_methods)
179
+
180
+ for layer in layers_to_test:
181
+ layer_str = str(layer)
182
+
183
+ for agg_method in args.aggregation_methods:
184
+ # Apply aggregation to cached activations
185
+ train_pos_acts = []
186
+ train_neg_acts = []
187
+
188
+ for pair_idx in train_cache:
189
+ pos_raw = train_cache[pair_idx][layer_str]['pos']
190
+ neg_raw = train_cache[pair_idx][layer_str]['neg']
191
+
192
+ pos_agg = aggregate_activations(pos_raw, agg_method)
193
+ neg_agg = aggregate_activations(neg_raw, agg_method)
194
+
195
+ if pos_agg is not None:
196
+ train_pos_acts.append(pos_agg.cpu().numpy())
197
+ if neg_agg is not None:
198
+ train_neg_acts.append(neg_agg.cpu().numpy())
199
+
200
+ if len(train_pos_acts) == 0 or len(train_neg_acts) == 0:
201
+ combinations_tested += 1
202
+ continue
203
+
204
+ # Prepare training data
205
+ X_train_pos = np.array(train_pos_acts)
206
+ X_train_neg = np.array(train_neg_acts)
207
+ X_train = np.vstack([X_train_pos, X_train_neg])
208
+ y_train = np.array([1] * len(train_pos_acts) + [0] * len(train_neg_acts))
209
+
210
+ # Train classifier
211
+ classifier = LogisticClassifier(threshold=0.5, device="cpu")
212
+
213
+ config = ClassifierTrainConfig(
214
+ test_size=0.2,
215
+ batch_size=32,
216
+ num_epochs=30,
217
+ learning_rate=0.001,
218
+ monitor="f1",
219
+ random_state=42
220
+ )
221
+
222
+ report = classifier.fit(
223
+ torch.tensor(X_train, dtype=torch.float32),
224
+ torch.tensor(y_train, dtype=torch.float32),
225
+ config=config
226
+ )
227
+
228
+ # Apply aggregation to test set
229
+ test_pos_acts = []
230
+ test_neg_acts = []
231
+
232
+ for pair_idx in test_cache:
233
+ pos_raw = test_cache[pair_idx][layer_str]['pos']
234
+ neg_raw = test_cache[pair_idx][layer_str]['neg']
235
+
236
+ pos_agg = aggregate_activations(pos_raw, agg_method)
237
+ neg_agg = aggregate_activations(neg_raw, agg_method)
238
+
239
+ if pos_agg is not None:
240
+ test_pos_acts.append(pos_agg.cpu().numpy())
241
+ if neg_agg is not None:
242
+ test_neg_acts.append(neg_agg.cpu().numpy())
243
+
244
+ if len(test_pos_acts) == 0 or len(test_neg_acts) == 0:
245
+ combinations_tested += 1
246
+ continue
247
+
248
+ X_test_pos = np.array(test_pos_acts)
249
+ X_test_neg = np.array(test_neg_acts)
250
+ X_test = np.vstack([X_test_pos, X_test_neg])
251
+ y_test = np.array([1] * len(test_pos_acts) + [0] * len(test_neg_acts))
252
+
253
+ # Get predictions
254
+ y_pred_proba = np.array(classifier.predict_proba(X_test))
255
+
256
+ # Test different thresholds
257
+ for threshold in args.threshold_range:
258
+ y_pred = (y_pred_proba > threshold).astype(int)
259
+
260
+ # Calculate metrics
261
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
262
+
263
+ accuracy = accuracy_score(y_test, y_pred)
264
+ f1 = f1_score(y_test, y_pred, zero_division=0)
265
+ precision = precision_score(y_test, y_pred, zero_division=0)
266
+ recall = recall_score(y_test, y_pred, zero_division=0)
267
+
268
+ # Choose metric based on args
269
+ metric_value = {
270
+ 'f1': f1,
271
+ 'accuracy': accuracy,
272
+ 'precision': precision,
273
+ 'recall': recall
274
+ }[args.optimization_metric]
275
+
276
+ if metric_value > best_score:
277
+ best_score = metric_value
278
+ best_config = {
279
+ 'layer': layer,
280
+ 'aggregation': agg_method,
281
+ 'threshold': threshold,
282
+ 'accuracy': float(accuracy),
283
+ 'f1': float(f1),
284
+ 'precision': float(precision),
285
+ 'recall': float(recall)
286
+ }
287
+ best_classifier = classifier
288
+
289
+ combinations_tested += 1
290
+ print(f" Progress: {combinations_tested}/{total_combinations} combinations tested", end='\r')
291
+
292
+ print(f"\n ✅ Best config: layer={best_config['layer']}, agg={best_config['aggregation']}, thresh={best_config['threshold']:.2f}")
293
+ print(f" Metrics: acc={best_config['accuracy']:.3f}, f1={best_config['f1']:.3f}, prec={best_config['precision']:.3f}, rec={best_config['recall']:.3f}")
294
+
295
+ all_results[task_name] = best_config
296
+
297
+ # Note: Classifier saving disabled due to missing .save() method
298
+ # Can be enabled once proper serialization is implemented
299
+
300
+ task_time = time.time() - task_start_time
301
+ print(f" ⏱️ Task completed in {task_time:.1f}s")
302
+
303
+ except Exception as e:
304
+ print(f" ❌ Failed to optimize {task_name}: {e}")
305
+ import traceback
306
+ traceback.print_exc()
307
+ continue
308
+
309
+ # 7. Save results
310
+ print(f"\n{'='*80}")
311
+ print(f"📊 OPTIMIZATION COMPLETE")
312
+ print(f"{'='*80}\n")
313
+
314
+ results_file = args.results_file or f"./optimization_results/classification_results.json"
315
+ import os
316
+ os.makedirs(os.path.dirname(results_file) if os.path.dirname(results_file) else ".", exist_ok=True)
317
+
318
+ output_data = {
319
+ 'model': args.model,
320
+ 'optimization_metric': args.optimization_metric,
321
+ 'layer_range': f"{layers_to_test[0]}-{layers_to_test[-1]}",
322
+ 'aggregation_methods': args.aggregation_methods,
323
+ 'threshold_range': args.threshold_range,
324
+ 'tasks': all_results,
325
+ 'classifiers_saved': classifiers_saved
326
+ }
327
+
328
+ with open(results_file, 'w') as f:
329
+ json.dump(output_data, f, indent=2)
330
+
331
+ print(f"✅ Results saved to: {results_file}\n")
332
+
333
+ # Print summary
334
+ print("📋 SUMMARY BY TASK:")
335
+ print("-" * 80)
336
+ for task_name, config in all_results.items():
337
+ print(f" {task_name:20s} | Layer: {config['layer']:2d} | Agg: {config['aggregation']:8s} | Thresh: {config['threshold']:.2f} | F1: {config['f1']:.3f}")
338
+ print("-" * 80 + "\n")
339
+