wisent 0.7.701__py3-none-any.whl → 0.7.901__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (330) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/core/activations/activation_cache.py +393 -0
  3. wisent/core/activations/activations.py +3 -3
  4. wisent/core/activations/activations_collector.py +9 -5
  5. wisent/core/activations/classifier_inference_strategy.py +12 -11
  6. wisent/core/activations/extraction_strategy.py +256 -84
  7. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  8. wisent/core/cli/__init__.py +2 -1
  9. wisent/core/cli/agent/apply_steering.py +5 -7
  10. wisent/core/cli/agent/train_classifier.py +19 -7
  11. wisent/core/cli/check_linearity.py +35 -3
  12. wisent/core/cli/cluster_benchmarks.py +4 -6
  13. wisent/core/cli/create_steering_vector.py +6 -4
  14. wisent/core/cli/diagnose_vectors.py +7 -4
  15. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  16. wisent/core/cli/generate_pairs_from_task.py +9 -56
  17. wisent/core/cli/geometry_search.py +137 -0
  18. wisent/core/cli/get_activations.py +1 -1
  19. wisent/core/cli/method_optimizer.py +4 -3
  20. wisent/core/cli/modify_weights.py +3 -2
  21. wisent/core/cli/optimize_sample_size.py +1 -1
  22. wisent/core/cli/optimize_steering.py +14 -16
  23. wisent/core/cli/optimize_weights.py +2 -1
  24. wisent/core/cli/preview_pairs.py +203 -0
  25. wisent/core/cli/steering_method_trainer.py +3 -3
  26. wisent/core/cli/tasks.py +19 -76
  27. wisent/core/cli/train_unified_goodness.py +3 -3
  28. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  29. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  30. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  31. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  32. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  33. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  34. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  35. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  36. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  37. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  38. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  53. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  54. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  55. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  56. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  57. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  58. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  59. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  60. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  61. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +2 -2
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +2 -2
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  273. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  274. wisent/core/geometry_runner.py +995 -0
  275. wisent/core/geometry_search_space.py +237 -0
  276. wisent/core/hyperparameter_optimizer.py +1 -1
  277. wisent/core/main.py +3 -0
  278. wisent/core/models/core/atoms.py +5 -3
  279. wisent/core/models/wisent_model.py +1 -1
  280. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  281. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  282. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  283. wisent/core/parser_arguments/generate_vector_from_task_parser.py +2 -2
  284. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  285. wisent/core/parser_arguments/main_parser.py +8 -0
  286. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  287. wisent/core/steering.py +5 -3
  288. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  289. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  290. wisent/core/trainers/steering_trainer.py +2 -2
  291. wisent/core/utils/device.py +27 -27
  292. wisent/core/utils/layer_combinations.py +70 -0
  293. wisent/examples/__init__.py +1 -0
  294. wisent/examples/scripts/__init__.py +1 -0
  295. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  296. wisent/examples/scripts/discover_directions.py +469 -0
  297. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  298. wisent/examples/scripts/generate_paper_data.py +384 -0
  299. wisent/examples/scripts/intervention_validation.py +626 -0
  300. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +324 -0
  301. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +92 -0
  302. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +324 -0
  303. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +92 -0
  304. wisent/examples/scripts/results/test_afrimgsm_pairs.json +92 -0
  305. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +324 -0
  306. wisent/examples/scripts/results/test_afrimmlu_pairs.json +92 -0
  307. wisent/examples/scripts/search_all_short_names.py +31 -0
  308. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  309. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  310. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  311. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  312. wisent/examples/scripts/test_one_benchmark.py +324 -0
  313. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  314. wisent/examples/scripts/threshold_analysis.py +434 -0
  315. wisent/examples/scripts/visualization_gallery.py +582 -0
  316. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  317. wisent/parameters/lm_eval/category_directions.json +137 -0
  318. wisent/parameters/lm_eval/repair_plan.json +282 -0
  319. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  320. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  321. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  322. wisent/tests/test_detector_accuracy.py +1 -1
  323. wisent/tests/visualize_geometry.py +1 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/METADATA +1 -1
  325. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/RECORD +329 -295
  326. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  327. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/WHEEL +0 -0
  328. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/entry_points.txt +0 -0
  329. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/licenses/LICENSE +0 -0
  330. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ STRATEGIES = [
33
33
  "chat_mean",
34
34
  "chat_first",
35
35
  "chat_last",
36
- "chat_gen_point",
37
36
  "chat_max_norm",
38
37
  "chat_weighted",
39
38
  "role_play",
@@ -134,9 +133,9 @@ def get_weighted_mean_answer_act(model, tokenizer, text: str, answer: str, layer
134
133
  hidden = outputs.hidden_states[layer][0]
135
134
  if num_answer_tokens > 0 and num_answer_tokens < hidden.shape[0]:
136
135
  answer_hidden = hidden[-num_answer_tokens-1:-1, :]
137
- weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=torch.float32) * 0.5)
136
+ weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=answer_hidden.dtype, device=answer_hidden.device) * 0.5)
138
137
  weights = weights / weights.sum()
139
- weighted_mean = (answer_hidden * weights.unsqueeze(1).to(answer_hidden.device)).sum(dim=0)
138
+ weighted_mean = (answer_hidden * weights.unsqueeze(1)).sum(dim=0)
140
139
  return weighted_mean.cpu().float()
141
140
  return hidden[-1].cpu().float()
142
141
 
@@ -156,8 +155,6 @@ def get_activation(model, tokenizer, prompt: str, response: str, layer: int, dev
156
155
  return get_first_answer_token_act(model, tokenizer, text, response, layer, device)
157
156
  elif strategy == "chat_last":
158
157
  return get_last_token_act(model, tokenizer, text, layer, device)
159
- elif strategy == "chat_gen_point":
160
- return get_generation_point_act(model, tokenizer, text, response, layer, device)
161
158
  elif strategy == "chat_max_norm":
162
159
  return get_max_norm_answer_act(model, tokenizer, text, response, layer, device)
163
160
  elif strategy == "chat_weighted":
@@ -348,7 +345,8 @@ def execute_cluster_benchmarks(args):
348
345
 
349
346
  logger.info(f"Loading {model}...")
350
347
  tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
351
- dtype = torch.bfloat16 if device == 'cuda' else torch.float16
348
+ from wisent.core.utils.device import device_optimized_dtype
349
+ dtype = device_optimized_dtype(device)
352
350
  llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=dtype, device_map=device, trust_remote_code=True)
353
351
 
354
352
  layers = get_layers_to_test(llm)
@@ -8,6 +8,7 @@ import torch
8
8
  from collections import defaultdict
9
9
 
10
10
  from wisent.core.errors import SteeringMethodUnknownError, VectorQualityTooLowError
11
+ from wisent.core.utils.device import preferred_dtype
11
12
 
12
13
 
13
14
  def execute_create_steering_vector(args):
@@ -46,20 +47,21 @@ def execute_create_steering_vector(args):
46
47
 
47
48
  # Structure: {layer_str: {"positive": [tensors], "negative": [tensors]}}
48
49
  layer_activations = defaultdict(lambda: {"positive": [], "negative": []})
50
+ dtype = preferred_dtype()
49
51
 
50
52
  for pair in pairs_list:
51
53
  # Extract positive activations
52
54
  pos_layers = pair['positive_response'].get('layers_activations', {})
53
55
  for layer_str, activation_list in pos_layers.items():
54
56
  if activation_list is not None:
55
- tensor = torch.tensor(activation_list, dtype=torch.float32)
57
+ tensor = torch.tensor(activation_list, dtype=dtype)
56
58
  layer_activations[layer_str]["positive"].append(tensor)
57
59
 
58
60
  # Extract negative activations
59
61
  neg_layers = pair['negative_response'].get('layers_activations', {})
60
62
  for layer_str, activation_list in neg_layers.items():
61
63
  if activation_list is not None:
62
- tensor = torch.tensor(activation_list, dtype=torch.float32)
64
+ tensor = torch.tensor(activation_list, dtype=dtype)
63
65
  layer_activations[layer_str]["negative"].append(tensor)
64
66
 
65
67
  available_layers = sorted(layer_activations.keys(), key=int)
@@ -232,7 +234,7 @@ def execute_create_steering_vector(args):
232
234
  # If multiple layers, save the first one (or could save all and let user specify)
233
235
  if len(steering_vectors) == 1:
234
236
  layer_str = list(steering_vectors.keys())[0]
235
- vector_tensor = torch.tensor(steering_vectors[layer_str], dtype=torch.float32)
237
+ vector_tensor = torch.tensor(steering_vectors[layer_str], dtype=dtype)
236
238
  torch.save({
237
239
  'steering_vector': vector_tensor,
238
240
  'layer_index': int(layer_str),
@@ -251,7 +253,7 @@ def execute_create_steering_vector(args):
251
253
  # Save multiple layers - save each to separate file
252
254
  for layer_str in steering_vectors.keys():
253
255
  layer_output = args.output.replace('.pt', f'_layer_{layer_str}.pt')
254
- vector_tensor = torch.tensor(steering_vectors[layer_str], dtype=torch.float32)
256
+ vector_tensor = torch.tensor(steering_vectors[layer_str], dtype=dtype)
255
257
  torch.save({
256
258
  'steering_vector': vector_tensor,
257
259
  'layer_index': int(layer_str),
@@ -6,6 +6,7 @@ import os
6
6
  import math
7
7
 
8
8
  import torch
9
+ from wisent.core.utils.device import preferred_dtype
9
10
 
10
11
 
11
12
  def execute_diagnose_vectors(args):
@@ -227,10 +228,11 @@ def _run_cone_analysis(
227
228
  return
228
229
 
229
230
  # Convert to tensors if needed
231
+ dtype = preferred_dtype()
230
232
  if not isinstance(pos_acts, torch.Tensor):
231
- pos_acts = torch.tensor(pos_acts, dtype=torch.float32)
233
+ pos_acts = torch.tensor(pos_acts, dtype=dtype)
232
234
  if not isinstance(neg_acts, torch.Tensor):
233
- neg_acts = torch.tensor(neg_acts, dtype=torch.float32)
235
+ neg_acts = torch.tensor(neg_acts, dtype=dtype)
234
236
 
235
237
  print(f" Positive samples: {pos_acts.shape[0]}")
236
238
  print(f" Negative samples: {neg_acts.shape[0]}")
@@ -342,10 +344,11 @@ def _run_geometry_analysis(
342
344
  return
343
345
 
344
346
  # Convert to tensors
347
+ dtype = preferred_dtype()
345
348
  if not isinstance(pos_acts, torch.Tensor):
346
- pos_acts = torch.tensor(pos_acts, dtype=torch.float32)
349
+ pos_acts = torch.tensor(pos_acts, dtype=dtype)
347
350
  if not isinstance(neg_acts, torch.Tensor):
348
- neg_acts = torch.tensor(neg_acts, dtype=torch.float32)
351
+ neg_acts = torch.tensor(neg_acts, dtype=dtype)
349
352
 
350
353
  print(f" Positive samples: {pos_acts.shape[0]}")
351
354
  print(f" Negative samples: {neg_acts.shape[0]}")
@@ -141,8 +141,10 @@ def estimate_runtime(
141
141
  results = {}
142
142
 
143
143
  # 1. Model loading (one-time)
144
- if device == 'cpu':
145
- model_time = TIME_ESTIMATES['model_load_cpu']
144
+ if device == 'cpu' or device == 'auto':
145
+ from wisent.core.utils.device import resolve_default_device
146
+ actual_device = resolve_default_device() if device == 'auto' else device
147
+ model_time = TIME_ESTIMATES['model_load_cpu'] if actual_device == 'cpu' else TIME_ESTIMATES['model_load_gpu']
146
148
  else:
147
149
  model_time = TIME_ESTIMATES['model_load_gpu']
148
150
  results['model_loading'] = model_time
@@ -269,8 +271,8 @@ def main():
269
271
  help="Skip evaluation phase"
270
272
  )
271
273
  parser.add_argument(
272
- "--device", choices=["cuda", "cpu"], default="cuda",
273
- help="Device for computation"
274
+ "--device", choices=["cuda", "cpu", "mps", "auto"], default="auto",
275
+ help="Device for computation (auto = detect best available)"
274
276
  )
275
277
  parser.add_argument(
276
278
  "--show-breakdown", action="store_true",
@@ -4,8 +4,6 @@ import sys
4
4
  import json
5
5
  import os
6
6
 
7
- from wisent.core.errors import InvalidDataFormatError
8
-
9
7
 
10
8
  def execute_generate_pairs_from_task(args):
11
9
  """Execute the generate-pairs-from-task command - load and save contrastive pairs from a task."""
@@ -14,9 +12,8 @@ def execute_generate_pairs_from_task(args):
14
12
  if hasattr(args, 'task_name') and args.task_name:
15
13
  args.task_name = expand_task_if_skill_or_risk(args.task_name)
16
14
 
17
- from wisent.core.contrastive_pairs.huggingface_pairs.hf_extractor_manifest import HF_EXTRACTORS
18
15
  from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import (
19
- lm_build_contrastive_pairs,
16
+ build_contrastive_pairs,
20
17
  )
21
18
 
22
19
  print(f"\nšŸ“Š Generating contrastive pairs from task: {args.task_name}")
@@ -26,58 +23,14 @@ def execute_generate_pairs_from_task(args):
26
23
 
27
24
  try:
28
25
  print(f"\nšŸ”„ Loading task '{args.task_name}'...")
29
-
30
- # Check if task is in HuggingFace manifest (doesn't need lm-eval loading)
31
- task_name_lower = args.task_name.lower()
32
- is_hf_task = task_name_lower in {k.lower() for k in HF_EXTRACTORS.keys()}
33
-
34
- if is_hf_task:
35
- # HuggingFace task - skip lm-eval loading, go directly to extractor
36
- print(f" Found in HuggingFace manifest, using HF extractor...")
37
- print(f" šŸ”Ø Building contrastive pairs...")
38
- pairs = lm_build_contrastive_pairs(
39
- task_name=args.task_name,
40
- lm_eval_task=None, # HF extractors don't need lm_eval_task
41
- limit=args.limit,
42
- )
43
- pairs_task_name = args.task_name
44
- else:
45
- # lm-eval task - load via LMEvalDataLoader
46
- from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
47
- loader = LMEvalDataLoader()
48
- task_obj = loader.load_lm_eval_task(args.task_name)
49
-
50
- # Handle both lm-eval tasks (dict or ConfigurableTask)
51
- if isinstance(task_obj, dict):
52
- # lm-eval task group with subtasks
53
- if len(task_obj) != 1:
54
- keys = ", ".join(sorted(task_obj.keys()))
55
- raise InvalidDataFormatError(
56
- reason=f"Task '{args.task_name}' returned {len(task_obj)} subtasks ({keys}). "
57
- "Specify an explicit subtask, e.g. 'benchmark/subtask'."
58
- )
59
- (subname, task), = task_obj.items()
60
- pairs_task_name = subname
61
-
62
- # Generate contrastive pairs using lm-eval interface
63
- print(f" šŸ”Ø Building contrastive pairs...")
64
- pairs = lm_build_contrastive_pairs(
65
- task_name=pairs_task_name,
66
- lm_eval_task=task,
67
- limit=args.limit,
68
- )
69
- else:
70
- # Single lm-eval task (ConfigurableTask), not wrapped in dict
71
- task = task_obj
72
- pairs_task_name = args.task_name
73
-
74
- # Generate contrastive pairs using lm-eval interface
75
- print(f" šŸ”Ø Building contrastive pairs...")
76
- pairs = lm_build_contrastive_pairs(
77
- task_name=pairs_task_name,
78
- lm_eval_task=task,
79
- limit=args.limit,
80
- )
26
+ print(f" šŸ”Ø Building contrastive pairs...")
27
+
28
+ # Use unified loader - handles HF, lm-eval, and group tasks automatically
29
+ pairs = build_contrastive_pairs(
30
+ task_name=args.task_name,
31
+ limit=args.limit,
32
+ )
33
+ pairs_task_name = args.task_name
81
34
 
82
35
  print(f" āœ“ Generated {len(pairs)} contrastive pairs")
83
36
 
@@ -0,0 +1,137 @@
1
+ """Run geometry search across benchmarks to find unified goodness direction."""
2
+
3
+ import json
4
+ import sys
5
+ import os
6
+ from pathlib import Path
7
+
8
+
9
+ def execute_geometry_search(args):
10
+ """Execute the geometry-search command."""
11
+ print(f"\n{'='*60}")
12
+ print("GEOMETRY SEARCH")
13
+ print(f"{'='*60}")
14
+ print(f"Model: {args.model}")
15
+ print(f"Output: {args.output}")
16
+ print(f"Pairs per benchmark: {args.pairs_per_benchmark}")
17
+ print(f"Max layer combo size: {args.max_layer_combo_size}")
18
+
19
+ # Import dependencies
20
+ from wisent.core.models.wisent_model import WisentModel
21
+ from wisent.core.geometry_search_space import GeometrySearchSpace, GeometrySearchConfig
22
+ from wisent.core.geometry_runner import GeometryRunner
23
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
24
+
25
+ # Parse strategies
26
+ if args.strategies:
27
+ strategy_names = [s.strip() for s in args.strategies.split(',')]
28
+ strategies = [ExtractionStrategy(s) for s in strategy_names]
29
+ print(f"Strategies: {strategy_names}")
30
+ else:
31
+ strategies = None # Use default (all 7)
32
+ print("Strategies: all 7 default strategies")
33
+
34
+ # Parse benchmarks
35
+ if args.benchmarks:
36
+ if args.benchmarks.endswith('.txt'):
37
+ with open(args.benchmarks) as f:
38
+ benchmarks = [line.strip() for line in f if line.strip()]
39
+ else:
40
+ benchmarks = [b.strip() for b in args.benchmarks.split(',')]
41
+ print(f"Benchmarks: {len(benchmarks)} specified")
42
+ else:
43
+ benchmarks = None # Use default (all)
44
+ print("Benchmarks: all available")
45
+
46
+ # Create config
47
+ config = GeometrySearchConfig(
48
+ pairs_per_benchmark=args.pairs_per_benchmark,
49
+ max_layer_combo_size=args.max_layer_combo_size,
50
+ random_seed=args.seed,
51
+ cache_activations=True,
52
+ cache_dir=args.cache_dir,
53
+ )
54
+
55
+ # Create search space
56
+ search_space = GeometrySearchSpace(
57
+ models=[args.model],
58
+ strategies=strategies,
59
+ benchmarks=benchmarks,
60
+ config=config,
61
+ )
62
+
63
+ print(f"\n{search_space.summary()}")
64
+
65
+ # Load model
66
+ print(f"\nLoading model {args.model}...")
67
+ model = WisentModel(args.model, device=args.device)
68
+ print(f"Model loaded: {model.num_layers} layers, hidden_size={model.hidden_size}")
69
+
70
+ # Create runner
71
+ cache_dir = args.cache_dir or f"/tmp/wisent_geometry_cache_{args.model.replace('/', '_')}"
72
+ runner = GeometryRunner(search_space, model, cache_dir=cache_dir)
73
+
74
+ # Run search
75
+ print(f"\nStarting geometry search...")
76
+ results = runner.run(show_progress=True)
77
+
78
+ # Save results
79
+ output_path = Path(args.output)
80
+ output_path.parent.mkdir(parents=True, exist_ok=True)
81
+ results.save(str(output_path))
82
+ print(f"\nResults saved to: {output_path}")
83
+
84
+ # Print summary
85
+ print(f"\n{'='*60}")
86
+ print("SUMMARY")
87
+ print(f"{'='*60}")
88
+ print(f"Total time: {results.total_time_seconds / 3600:.2f} hours")
89
+ print(f" Extraction: {results.extraction_time_seconds / 3600:.2f} hours")
90
+ print(f" Testing: {results.test_time_seconds / 60:.1f} minutes")
91
+ print(f"Benchmarks tested: {results.benchmarks_tested}")
92
+ print(f"Strategies tested: {results.strategies_tested}")
93
+ print(f"Layer combos tested: {results.layer_combos_tested}")
94
+
95
+ print(f"\nStructure distribution:")
96
+ for struct, count in sorted(results.get_structure_distribution().items(), key=lambda x: -x[1]):
97
+ pct = 100 * count / results.layer_combos_tested
98
+ print(f" {struct}: {count} ({pct:.1f}%)")
99
+
100
+ print(f"\nTop 10 by linear score:")
101
+ for r in results.get_best_by_linear_score(10):
102
+ print(f" {r.benchmark}/{r.strategy} layers={r.layers}: linear={r.linear_score:.3f} best={r.best_structure}")
103
+
104
+ print(f"\nTop 10 by cone score:")
105
+ for r in results.get_best_by_structure('cone', 10):
106
+ print(f" {r.benchmark}/{r.strategy} layers={r.layers}: cone={r.cone_score:.3f} best={r.best_structure}")
107
+
108
+ # Summary by benchmark
109
+ print(f"\nSummary by benchmark (avg linear score):")
110
+ by_bench = results.get_summary_by_benchmark()
111
+ sorted_benches = sorted(by_bench.items(), key=lambda x: -x[1]['mean'])[:20]
112
+ for bench, stats in sorted_benches:
113
+ print(f" {bench}: mean={stats['mean']:.3f} max={stats['max']:.3f}")
114
+
115
+ print(f"\n{'='*60}")
116
+ print("CONCLUSION")
117
+ print(f"{'='*60}")
118
+
119
+ # Determine if unified direction exists
120
+ dist = results.get_structure_distribution()
121
+ total = sum(dist.values())
122
+ linear_pct = 100 * dist.get('linear', 0) / total if total > 0 else 0
123
+ cone_pct = 100 * dist.get('cone', 0) / total if total > 0 else 0
124
+ orthogonal_pct = 100 * dist.get('orthogonal', 0) / total if total > 0 else 0
125
+
126
+ if linear_pct > 50:
127
+ print(f"UNIFIED LINEAR DIRECTION EXISTS ({linear_pct:.1f}% linear)")
128
+ print("Recommendation: Use CAA with the best layer/strategy combination")
129
+ elif cone_pct > 30:
130
+ print(f"CONE STRUCTURE DETECTED ({cone_pct:.1f}% cone)")
131
+ print("Recommendation: Use PRISM with multi-directional steering")
132
+ elif orthogonal_pct > 50:
133
+ print(f"ORTHOGONAL STRUCTURE ({orthogonal_pct:.1f}% orthogonal)")
134
+ print("Recommendation: No unified direction - use per-benchmark directions or TITAN")
135
+ else:
136
+ print("MIXED STRUCTURE - no clear unified direction")
137
+ print("Recommendation: Use TITAN for adaptive multi-component steering")
@@ -90,7 +90,7 @@ def execute_get_activations(args):
90
90
 
91
91
  # 6. Collect activations
92
92
  print(f"\n⚔ Collecting activations...")
93
- collector = ActivationCollector(model=model, store_device="cpu")
93
+ collector = ActivationCollector(model=model)
94
94
 
95
95
  enriched_pairs = []
96
96
  for i, pair in enumerate(pair_set.pairs):
@@ -24,6 +24,7 @@ import torch
24
24
  from wisent.core.activations.activations_collector import ActivationCollector
25
25
  from wisent.core.activations.extraction_strategy import ExtractionStrategy
26
26
  from wisent.core.activations.core.atoms import LayerActivations
27
+ from wisent.core.utils.device import resolve_default_device
27
28
 
28
29
  from wisent.core.contrastive_pairs.core.pair import ContrastivePair
29
30
  from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
@@ -175,7 +176,7 @@ class MethodOptimizer:
175
176
  self,
176
177
  model,
177
178
  method_name: str,
178
- device: str = "cpu",
179
+ device: str | None = None,
179
180
  verbose: bool = True,
180
181
  ):
181
182
  """
@@ -189,7 +190,7 @@ class MethodOptimizer:
189
190
  """
190
191
  self.model = model
191
192
  self.method_name = method_name.lower()
192
- self.device = device
193
+ self.device = device or resolve_default_device()
193
194
  self.verbose = verbose
194
195
 
195
196
  # Validate method exists
@@ -250,7 +251,7 @@ class MethodOptimizer:
250
251
  "mean_pooling": ExtractionStrategy.CHAT_MEAN,
251
252
  "first_token": ExtractionStrategy.CHAT_FIRST,
252
253
  "max_pooling": ExtractionStrategy.CHAT_MAX_NORM,
253
- "continuation_token": ExtractionStrategy.CHAT_GEN_POINT,
254
+ "continuation_token": ExtractionStrategy.CHAT_FIRST, # First answer token
254
255
  }
255
256
 
256
257
  prompt_strat_map = {
@@ -14,6 +14,7 @@ import time
14
14
  from pathlib import Path
15
15
  import torch
16
16
 
17
+ from wisent.core.utils.device import resolve_default_device
17
18
  from wisent.core.cli_logger import setup_logger, bind
18
19
  from wisent.core.models.wisent_model import WisentModel
19
20
  from wisent.core.weight_modification import (
@@ -72,7 +73,7 @@ def execute_modify_weights(args):
72
73
 
73
74
  if vector_path.suffix == '.pt':
74
75
  # Load PyTorch format (from train-unified-goodness or similar)
75
- checkpoint = torch.load(args.steering_vectors, map_location='cpu', weights_only=False)
76
+ checkpoint = torch.load(args.steering_vectors, map_location=resolve_default_device(), weights_only=False)
76
77
 
77
78
  # Handle different .pt file formats
78
79
  if 'steering_vectors' in checkpoint:
@@ -354,7 +355,7 @@ def execute_modify_weights(args):
354
355
 
355
356
  execute_train_unified_goodness(unified_args)
356
357
 
357
- checkpoint = torch.load(unified_args.output, map_location='cpu', weights_only=False)
358
+ checkpoint = torch.load(unified_args.output, map_location=resolve_default_device(), weights_only=False)
358
359
 
359
360
  if 'steering_vectors' in checkpoint:
360
361
  raw_vectors = checkpoint['steering_vectors']
@@ -87,7 +87,7 @@ def execute_optimize_sample_size(args):
87
87
  # Get extraction strategy from args
88
88
  extraction_strategy = ExtractionStrategy(getattr(args, 'extraction_strategy', 'chat_last'))
89
89
 
90
- collector = ActivationCollector(model=model, store_device="cpu")
90
+ collector = ActivationCollector(model=model)
91
91
 
92
92
  # Collect test activations for all test pairs (ONCE)
93
93
  X_test_list = []
@@ -77,7 +77,7 @@ def _run_optuna_search_for_task(
77
77
 
78
78
  try:
79
79
  # Collect activations
80
- collector = ActivationCollector(model=model, store_device="cpu")
80
+ collector = ActivationCollector(model=model)
81
81
  pos_acts = []
82
82
  neg_acts = []
83
83
 
@@ -389,7 +389,7 @@ def execute_comprehensive(args, model, loader):
389
389
  "first_token": ExtractionStrategy.CHAT_FIRST,
390
390
  "max_pooling": ExtractionStrategy.CHAT_MAX_NORM,
391
391
  "choice_token": ExtractionStrategy.MC_BALANCED,
392
- "continuation_token": ExtractionStrategy.CHAT_GEN_POINT,
392
+ "continuation_token": ExtractionStrategy.CHAT_FIRST, # First answer token
393
393
  }
394
394
  if hasattr(args, 'search_token_aggregations') and args.search_token_aggregations:
395
395
  token_agg_names = [x.strip() for x in args.search_token_aggregations.split(',')]
@@ -610,7 +610,7 @@ def execute_comprehensive(args, model, loader):
610
610
  layer_str = str(layer)
611
611
 
612
612
  # Step 1: Generate steering vector using CAA with current token aggregation
613
- collector = ActivationCollector(model=model, store_device="cpu")
613
+ collector = ActivationCollector(model=model)
614
614
 
615
615
  pos_acts = []
616
616
  neg_acts = []
@@ -1456,7 +1456,7 @@ def execute_compare_methods(args, model, loader):
1456
1456
 
1457
1457
  # Collect activations once for all methods
1458
1458
  layer_str = str(args.layer)
1459
- collector = ActivationCollector(model=model, store_device="cpu")
1459
+ collector = ActivationCollector(model=model)
1460
1460
 
1461
1461
  print("šŸŽÆ Collecting training activations (ONCE)...")
1462
1462
  pos_acts = []
@@ -1719,7 +1719,7 @@ def execute_optimize_layer(args, model, loader):
1719
1719
  print("Aborted by user.")
1720
1720
  return {"action": "optimize-layer", "status": "aborted", "reason": "user declined reduced search"}
1721
1721
 
1722
- collector = ActivationCollector(model=model, store_device="cpu")
1722
+ collector = ActivationCollector(model=model)
1723
1723
  layer_results = {}
1724
1724
  best_layer = None
1725
1725
  best_accuracy = 0.0
@@ -1986,7 +1986,7 @@ def execute_optimize_strength(args, model, loader):
1986
1986
 
1987
1987
  # Collect activations ONCE
1988
1988
  layer_str = str(args.layer)
1989
- collector = ActivationCollector(model=model, store_device="cpu")
1989
+ collector = ActivationCollector(model=model)
1990
1990
 
1991
1991
  print("šŸŽÆ Collecting training activations (ONCE)...")
1992
1992
  pos_acts = []
@@ -2277,7 +2277,7 @@ def execute_auto(args, model, loader):
2277
2277
  print(f" Testing {len(strengths_to_test)} strengths: {strengths_to_test[0]:.2f} to {strengths_to_test[-1]:.2f}")
2278
2278
  print(f" Total configurations: {len(layers_to_test) * len(strengths_to_test)}\n")
2279
2279
 
2280
- collector = ActivationCollector(model=model, store_device="cpu")
2280
+ collector = ActivationCollector(model=model)
2281
2281
  all_results = {}
2282
2282
  best_config = None
2283
2283
  best_accuracy = 0.0
@@ -2575,16 +2575,15 @@ def execute_personalization(args, model):
2575
2575
  min_strength, max_strength = args.strength_range
2576
2576
  strengths_to_test = np.linspace(min_strength, max_strength, 7)
2577
2577
 
2578
- # Token aggregation strategies to test - ALL 5 strategies
2578
+ # Token aggregation strategies to test
2579
2579
  token_aggregations_to_test = [
2580
2580
  ExtractionStrategy.CHAT_LAST,
2581
2581
  ExtractionStrategy.CHAT_MEAN,
2582
2582
  ExtractionStrategy.CHAT_FIRST,
2583
2583
  ExtractionStrategy.CHAT_MAX_NORM,
2584
- ExtractionStrategy.CHAT_GEN_POINT,
2585
2584
  ]
2586
2585
 
2587
- # Prompt construction strategies to test - ALL 5 strategies
2586
+ # Prompt construction strategies to test
2588
2587
  prompt_constructions_to_test = [
2589
2588
  ExtractionStrategy.CHAT_LAST,
2590
2589
  ExtractionStrategy.CHAT_LAST,
@@ -2655,7 +2654,7 @@ def execute_personalization(args, model):
2655
2654
  print(flush=True)
2656
2655
 
2657
2656
  # Initialize activation collector
2658
- collector = ActivationCollector(model=model, store_device="cpu")
2657
+ collector = ActivationCollector(model=model)
2659
2658
 
2660
2659
  # Track results for all configurations
2661
2660
  all_results = {}
@@ -3108,16 +3107,15 @@ def execute_multi_personalization(args, model):
3108
3107
  min_strength, max_strength = args.strength_range
3109
3108
  strengths_to_test = np.linspace(min_strength, max_strength, 7)
3110
3109
 
3111
- # Token aggregation strategies to test - ALL 5 strategies
3110
+ # Token aggregation strategies to test
3112
3111
  token_aggregations_to_test = [
3113
3112
  ExtractionStrategy.CHAT_LAST,
3114
3113
  ExtractionStrategy.CHAT_MEAN,
3115
3114
  ExtractionStrategy.CHAT_FIRST,
3116
3115
  ExtractionStrategy.CHAT_MAX_NORM,
3117
- ExtractionStrategy.CHAT_GEN_POINT,
3118
3116
  ]
3119
3117
 
3120
- # Prompt construction strategies to test - ALL 5 strategies
3118
+ # Prompt construction strategies to test
3121
3119
  prompt_constructions_to_test = [
3122
3120
  ExtractionStrategy.CHAT_LAST,
3123
3121
  ExtractionStrategy.CHAT_LAST,
@@ -3176,7 +3174,7 @@ def execute_multi_personalization(args, model):
3176
3174
  print(f"\nšŸ“ Test prompts: {test_prompts}", flush=True)
3177
3175
 
3178
3176
  # Initialize collector
3179
- collector = ActivationCollector(model=model, store_device="cpu")
3177
+ collector = ActivationCollector(model=model)
3180
3178
 
3181
3179
  # Track results
3182
3180
  all_results = {}
@@ -3565,7 +3563,7 @@ def execute_universal(args, model, loader):
3565
3563
  optimizer = MethodOptimizer(
3566
3564
  model=model,
3567
3565
  method_name=method_name,
3568
- device=args.device or "cpu",
3566
+ device=args.device if hasattr(args, "device") and args.device else None,
3569
3567
  verbose=args.verbose if hasattr(args, "verbose") else True,
3570
3568
  )
3571
3569
 
@@ -28,6 +28,7 @@ from dataclasses import dataclass
28
28
  from typing import Any, Callable
29
29
 
30
30
  import torch
31
+ from wisent.core.utils.device import resolve_default_device
31
32
 
32
33
 
33
34
  def upload_to_s3(local_path: str, s3_bucket: str, s3_key: str) -> bool:
@@ -661,7 +662,7 @@ def _generate_steering_vectors(args, num_pairs: int, num_layers: int = None) ->
661
662
  execute_train_unified_goodness(vector_args)
662
663
 
663
664
  # Load the .pt file
664
- checkpoint = torch.load(temp_output_pt, map_location='cpu', weights_only=False)
665
+ checkpoint = torch.load(temp_output_pt, map_location=resolve_default_device(), weights_only=False)
665
666
 
666
667
  # Handle different checkpoint formats
667
668
  if 'all_layer_vectors' in checkpoint: