wisent 0.7.701__py3-none-any.whl → 0.7.901__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (330) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/core/activations/activation_cache.py +393 -0
  3. wisent/core/activations/activations.py +3 -3
  4. wisent/core/activations/activations_collector.py +9 -5
  5. wisent/core/activations/classifier_inference_strategy.py +12 -11
  6. wisent/core/activations/extraction_strategy.py +256 -84
  7. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  8. wisent/core/cli/__init__.py +2 -1
  9. wisent/core/cli/agent/apply_steering.py +5 -7
  10. wisent/core/cli/agent/train_classifier.py +19 -7
  11. wisent/core/cli/check_linearity.py +35 -3
  12. wisent/core/cli/cluster_benchmarks.py +4 -6
  13. wisent/core/cli/create_steering_vector.py +6 -4
  14. wisent/core/cli/diagnose_vectors.py +7 -4
  15. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  16. wisent/core/cli/generate_pairs_from_task.py +9 -56
  17. wisent/core/cli/geometry_search.py +137 -0
  18. wisent/core/cli/get_activations.py +1 -1
  19. wisent/core/cli/method_optimizer.py +4 -3
  20. wisent/core/cli/modify_weights.py +3 -2
  21. wisent/core/cli/optimize_sample_size.py +1 -1
  22. wisent/core/cli/optimize_steering.py +14 -16
  23. wisent/core/cli/optimize_weights.py +2 -1
  24. wisent/core/cli/preview_pairs.py +203 -0
  25. wisent/core/cli/steering_method_trainer.py +3 -3
  26. wisent/core/cli/tasks.py +19 -76
  27. wisent/core/cli/train_unified_goodness.py +3 -3
  28. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  29. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  30. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  31. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  32. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  33. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  34. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  35. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  36. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  37. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  38. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  53. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  54. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  55. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  56. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  57. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  58. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  59. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  60. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  61. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +2 -2
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +2 -2
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  273. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  274. wisent/core/geometry_runner.py +995 -0
  275. wisent/core/geometry_search_space.py +237 -0
  276. wisent/core/hyperparameter_optimizer.py +1 -1
  277. wisent/core/main.py +3 -0
  278. wisent/core/models/core/atoms.py +5 -3
  279. wisent/core/models/wisent_model.py +1 -1
  280. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  281. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  282. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  283. wisent/core/parser_arguments/generate_vector_from_task_parser.py +2 -2
  284. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  285. wisent/core/parser_arguments/main_parser.py +8 -0
  286. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  287. wisent/core/steering.py +5 -3
  288. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  289. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  290. wisent/core/trainers/steering_trainer.py +2 -2
  291. wisent/core/utils/device.py +27 -27
  292. wisent/core/utils/layer_combinations.py +70 -0
  293. wisent/examples/__init__.py +1 -0
  294. wisent/examples/scripts/__init__.py +1 -0
  295. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  296. wisent/examples/scripts/discover_directions.py +469 -0
  297. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  298. wisent/examples/scripts/generate_paper_data.py +384 -0
  299. wisent/examples/scripts/intervention_validation.py +626 -0
  300. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +324 -0
  301. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +92 -0
  302. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +324 -0
  303. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +92 -0
  304. wisent/examples/scripts/results/test_afrimgsm_pairs.json +92 -0
  305. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +324 -0
  306. wisent/examples/scripts/results/test_afrimmlu_pairs.json +92 -0
  307. wisent/examples/scripts/search_all_short_names.py +31 -0
  308. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  309. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  310. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  311. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  312. wisent/examples/scripts/test_one_benchmark.py +324 -0
  313. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  314. wisent/examples/scripts/threshold_analysis.py +434 -0
  315. wisent/examples/scripts/visualization_gallery.py +582 -0
  316. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  317. wisent/parameters/lm_eval/category_directions.json +137 -0
  318. wisent/parameters/lm_eval/repair_plan.json +282 -0
  319. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  320. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  321. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  322. wisent/tests/test_detector_accuracy.py +1 -1
  323. wisent/tests/visualize_geometry.py +1 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/METADATA +1 -1
  325. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/RECORD +329 -295
  326. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  327. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/WHEEL +0 -0
  328. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/entry_points.txt +0 -0
  329. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/licenses/LICENSE +0 -0
  330. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
1
+ """
2
+ Configuration for geometry search space.
3
+
4
+ Defines all parameters to search over when testing if a unified "goodness"
5
+ direction exists across benchmarks.
6
+
7
+ Strategy:
8
+ - Extract activations for ALL layers once per (benchmark, strategy) pair
9
+ - Cache activations to disk/memory
10
+ - Test all layer combinations from cached activations (fast, just tensor math)
11
+ - This reduces extraction time from O(layer_combos) to O(1) per benchmark
12
+ """
13
+
14
+ from dataclasses import dataclass, field
15
+ from typing import List, Optional, Dict, Any
16
+ from enum import Enum
17
+ from pathlib import Path
18
+ import json
19
+
20
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
21
+ from wisent.core.utils.layer_combinations import get_layer_combinations
22
+ from wisent.core.benchmark_registry import get_all_benchmarks
23
+ from wisent.core.activations.activation_cache import ActivationCache, CachedActivations
24
+
25
+
26
+ @dataclass
27
+ class GeometrySearchConfig:
28
+ """Configuration for a single geometry search run."""
29
+
30
+ # Pairs settings
31
+ pairs_per_benchmark: int = 50
32
+ random_seed: int = 42
33
+
34
+ # Layer settings
35
+ max_layer_combo_size: int = 3
36
+
37
+ # Caching
38
+ cache_activations: bool = True
39
+ cache_dir: Optional[str] = None
40
+
41
+ # Estimation
42
+ estimated_time_per_extraction_seconds: float = 120.0 # ~2 min per (benchmark, strategy)
43
+
44
+ def to_dict(self) -> Dict[str, Any]:
45
+ return {
46
+ "pairs_per_benchmark": self.pairs_per_benchmark,
47
+ "random_seed": self.random_seed,
48
+ "max_layer_combo_size": self.max_layer_combo_size,
49
+ "cache_activations": self.cache_activations,
50
+ "cache_dir": self.cache_dir,
51
+ "estimated_time_per_extraction_seconds": self.estimated_time_per_extraction_seconds,
52
+ }
53
+
54
+ @classmethod
55
+ def from_dict(cls, data: Dict[str, Any]) -> "GeometrySearchConfig":
56
+ return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
57
+
58
+
59
+ class GeometrySearchSpace:
60
+ """
61
+ Search space configuration for geometry testing.
62
+
63
+ Combines:
64
+ - Models to test
65
+ - Extraction strategies
66
+ - Layer combinations
67
+ - Benchmarks
68
+
69
+ With activation caching:
70
+ - Extract ALL layers once per (benchmark, strategy)
71
+ - Test layer combinations from cache (no re-extraction needed)
72
+ """
73
+
74
+ # Default models to test
75
+ DEFAULT_MODELS = [
76
+ "meta-llama/Llama-3.2-1B-Instruct",
77
+ "meta-llama/Llama-2-7b-chat-hf",
78
+ "Qwen/Qwen3-8B",
79
+ "openai/gpt-oss-20b",
80
+ ]
81
+
82
+ # Extraction strategies for instruct models
83
+ INSTRUCT_STRATEGIES = [
84
+ ExtractionStrategy.CHAT_MEAN,
85
+ ExtractionStrategy.CHAT_FIRST,
86
+ ExtractionStrategy.CHAT_LAST,
87
+ ExtractionStrategy.CHAT_MAX_NORM,
88
+ ExtractionStrategy.CHAT_WEIGHTED,
89
+ ExtractionStrategy.ROLE_PLAY,
90
+ ExtractionStrategy.MC_BALANCED,
91
+ ]
92
+
93
+ # Extraction strategies for base models
94
+ BASE_STRATEGIES = [
95
+ ExtractionStrategy.COMPLETION_LAST,
96
+ ExtractionStrategy.COMPLETION_MEAN,
97
+ ExtractionStrategy.MC_COMPLETION,
98
+ ]
99
+
100
+ def __init__(
101
+ self,
102
+ models: Optional[List[str]] = None,
103
+ strategies: Optional[List[ExtractionStrategy]] = None,
104
+ benchmarks: Optional[List[str]] = None,
105
+ config: Optional[GeometrySearchConfig] = None,
106
+ ):
107
+ """
108
+ Initialize the search space.
109
+
110
+ Args:
111
+ models: List of model names to test. Defaults to DEFAULT_MODELS.
112
+ strategies: List of extraction strategies. Defaults to INSTRUCT_STRATEGIES.
113
+ benchmarks: List of benchmarks. Defaults to all available benchmarks.
114
+ config: Search configuration (pairs, caching, etc.)
115
+ """
116
+ self.models = models or self.DEFAULT_MODELS
117
+ self.strategies = strategies or self.INSTRUCT_STRATEGIES
118
+ self.benchmarks = benchmarks or get_all_benchmarks()
119
+ self.config = config or GeometrySearchConfig()
120
+
121
+ def get_layer_combinations_for_model(self, model_name: str, num_layers: int) -> List[List[int]]:
122
+ """
123
+ Get all layer combinations to test for a given model.
124
+
125
+ Args:
126
+ model_name: Name of the model
127
+ num_layers: Number of layers in the model
128
+
129
+ Returns:
130
+ List of layer combinations
131
+ """
132
+ return get_layer_combinations(num_layers, self.config.max_layer_combo_size)
133
+
134
+ def get_extraction_count(self) -> int:
135
+ """
136
+ Calculate number of activation extractions needed (with caching).
137
+
138
+ With caching, we extract ALL layers once per (benchmark, strategy).
139
+ Layer combinations are tested from cache without re-extraction.
140
+
141
+ Returns:
142
+ Number of (benchmark, strategy) pairs = extraction operations
143
+ """
144
+ return len(self.benchmarks) * len(self.strategies)
145
+
146
+ def get_total_configurations(self, num_layers: int) -> int:
147
+ """
148
+ Calculate total number of configurations to test.
149
+
150
+ Total = strategies * layer_combos * benchmarks
151
+ (Layer combos are tested from cached activations)
152
+ """
153
+ from wisent.core.utils.layer_combinations import get_layer_combinations_count
154
+
155
+ layer_combos = get_layer_combinations_count(num_layers, self.config.max_layer_combo_size)
156
+ return len(self.strategies) * layer_combos * len(self.benchmarks)
157
+
158
+ def estimate_time_hours(self) -> float:
159
+ """
160
+ Estimate total time for geometry search (per model).
161
+
162
+ With caching:
163
+ - Extract once per (benchmark, strategy)
164
+ - Layer combo testing is fast (from cache)
165
+
166
+ Returns:
167
+ Estimated hours per model
168
+ """
169
+ extractions = self.get_extraction_count()
170
+ seconds = extractions * self.config.estimated_time_per_extraction_seconds
171
+ return seconds / 3600
172
+
173
+ def to_dict(self) -> Dict[str, Any]:
174
+ """Serialize to dictionary."""
175
+ return {
176
+ "models": self.models,
177
+ "strategies": [s.value for s in self.strategies],
178
+ "benchmarks": self.benchmarks,
179
+ "config": self.config.to_dict(),
180
+ }
181
+
182
+ @classmethod
183
+ def from_dict(cls, data: Dict[str, Any]) -> "GeometrySearchSpace":
184
+ """Deserialize from dictionary."""
185
+ strategies = [ExtractionStrategy(s) for s in data.get("strategies", [])]
186
+ config = GeometrySearchConfig.from_dict(data.get("config", {}))
187
+ return cls(
188
+ models=data.get("models"),
189
+ strategies=strategies if strategies else None,
190
+ benchmarks=data.get("benchmarks"),
191
+ config=config,
192
+ )
193
+
194
+ def summary(self) -> str:
195
+ """Return a human-readable summary of the search space."""
196
+ lines = [
197
+ "Geometry Search Space:",
198
+ f" Models: {len(self.models)}",
199
+ f" Strategies: {len(self.strategies)}",
200
+ f" Benchmarks: {len(self.benchmarks)}",
201
+ f" Pairs per benchmark: {self.config.pairs_per_benchmark}",
202
+ f" Max layer combo size: {self.config.max_layer_combo_size}",
203
+ f" Cache activations: {self.config.cache_activations}",
204
+ f"",
205
+ f" Extractions needed (per model): {self.get_extraction_count()}",
206
+ f" Estimated time (per model): {self.estimate_time_hours():.1f} hours",
207
+ ]
208
+ return "\n".join(lines)
209
+
210
+ def save(self, path: str) -> None:
211
+ """Save search space to JSON file."""
212
+ with open(path, "w") as f:
213
+ json.dump(self.to_dict(), f, indent=2)
214
+
215
+ @classmethod
216
+ def load(cls, path: str) -> "GeometrySearchSpace":
217
+ """Load search space from JSON file."""
218
+ with open(path) as f:
219
+ return cls.from_dict(json.load(f))
220
+
221
+
222
+ # Default search space instance
223
+ DEFAULT_SEARCH_SPACE = GeometrySearchSpace()
224
+
225
+
226
+ if __name__ == "__main__":
227
+ # Print summary of default search space
228
+ space = GeometrySearchSpace()
229
+ print(space.summary())
230
+ print()
231
+
232
+ # Example with 16 layers (Llama-3.2-1B)
233
+ num_layers = 16
234
+ layer_combos = space.get_layer_combinations_for_model("test", num_layers)
235
+ print(f"For a {num_layers}-layer model:")
236
+ print(f" Layer combinations: {len(layer_combos)}")
237
+ print(f" Total configs to test: {space.get_total_configurations(num_layers)}")
@@ -370,7 +370,7 @@ class HyperparameterOptimizer:
370
370
  prompt_strategy = prompt_strategy_map.get(prompt_construction_strategy, ExtractionStrategy.CHAT_LAST)
371
371
 
372
372
  # Create activation collector
373
- collector = ActivationCollector(model=model, store_device="cpu")
373
+ collector = ActivationCollector(model=model)
374
374
  layer_str = str(layer)
375
375
 
376
376
  # Collect activations for training pairs
wisent/core/main.py CHANGED
@@ -13,6 +13,7 @@ from wisent.core.cli import execute_tasks, execute_generate_pairs_from_task, exe
13
13
  from wisent.core.cli.train_unified_goodness import execute_train_unified_goodness
14
14
  from wisent.core.cli.check_linearity import execute_check_linearity
15
15
  from wisent.core.cli.cluster_benchmarks import execute_cluster_benchmarks
16
+ from wisent.core.cli.geometry_search import execute_geometry_search
16
17
 
17
18
 
18
19
  def _should_show_banner() -> bool:
@@ -95,6 +96,8 @@ def main():
95
96
  execute_check_linearity(args)
96
97
  elif args.command == 'cluster-benchmarks':
97
98
  execute_cluster_benchmarks(args)
99
+ elif args.command == 'geometry-search':
100
+ execute_geometry_search(args)
98
101
  else:
99
102
  print(f"\n✗ Command '{args.command}' is not yet implemented")
100
103
  sys.exit(1)
@@ -7,6 +7,7 @@ import torch
7
7
  from typing import Mapping
8
8
 
9
9
  from wisent.core.errors import InvalidValueError, InvalidRangeError
10
+ from wisent.core.utils.device import preferred_dtype
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from wisent.core.activations.core.atoms import RawActivationMap
@@ -213,12 +214,13 @@ class SteeringPlan:
213
214
  """
214
215
  if n < 0:
215
216
  raise InvalidRangeError(param_name="n", actual=n, min_val=0)
217
+ dtype = preferred_dtype()
216
218
  if n == 0:
217
- return torch.empty(0, dtype=torch.float32)
219
+ return torch.empty(0, dtype=dtype)
218
220
  if weights is None:
219
- return torch.full((n,), 1.0 / n, dtype=torch.float32)
221
+ return torch.full((n,), 1.0 / n, dtype=dtype)
220
222
 
221
- w = torch.as_tensor(weights, dtype=torch.float32)
223
+ w = torch.as_tensor(weights, dtype=dtype)
222
224
  if w.numel() != n:
223
225
  raise InvalidValueError(param_name="weights length", actual=w.numel(), expected=f"{n} (number of activation maps)")
224
226
  s = float(w.sum())
@@ -89,7 +89,7 @@ class WisentModel:
89
89
  optional preloaded model (skips from_pretrained if provided).
90
90
  """
91
91
  self.model_name = model_name
92
- self.device = device or resolve_default_device()
92
+ self.device = resolve_default_device() if device is None or device == "auto" else device
93
93
 
94
94
  # Determine appropriate dtype and settings for the device
95
95
  load_kwargs = {
@@ -17,7 +17,7 @@ from optuna.pruners import MedianPruner
17
17
  from optuna.samplers import TPESampler
18
18
 
19
19
  from wisent.core.classifier.classifier import Classifier
20
- from wisent.core.utils.device import resolve_default_device
20
+ from wisent.core.utils.device import resolve_default_device, preferred_dtype
21
21
  from wisent.core.errors import NoActivationDataError, ClassifierCreationError
22
22
 
23
23
  from .activation_generator import ActivationData, ActivationGenerator, GenerationConfig
@@ -44,7 +44,7 @@ def get_model_dtype(model) -> torch.dtype:
44
44
  return next(model_params).dtype
45
45
  except StopIteration:
46
46
  # Fallback if no parameters found
47
- return torch.float32
47
+ return preferred_dtype()
48
48
 
49
49
 
50
50
  logger = logging.getLogger(__name__)
@@ -1,5 +1,7 @@
1
1
  """Parser for check-linearity command."""
2
2
 
3
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
4
+
3
5
 
4
6
  def setup_check_linearity_parser(parser):
5
7
  """Set up the check-linearity command parser."""
@@ -9,6 +11,14 @@ def setup_check_linearity_parser(parser):
9
11
  help='Path to JSON file containing contrastive pairs'
10
12
  )
11
13
 
14
+ parser.add_argument(
15
+ '--extraction-strategy',
16
+ type=str,
17
+ default=None,
18
+ choices=ExtractionStrategy.list_all(),
19
+ help=f'Extraction strategy to use. If not specified, tests multiple strategies. Options: {", ".join(ExtractionStrategy.list_all())}'
20
+ )
21
+
12
22
  parser.add_argument(
13
23
  '--model',
14
24
  type=str,
@@ -19,8 +29,8 @@ def setup_check_linearity_parser(parser):
19
29
  parser.add_argument(
20
30
  '--device',
21
31
  type=str,
22
- default='cuda',
23
- help='Device to run model on (cuda, mps, cpu)'
32
+ default='auto',
33
+ help='Device to run model on (auto, cuda, mps, cpu)'
24
34
  )
25
35
 
26
36
  parser.add_argument(
@@ -40,8 +40,8 @@ def setup_generate_vector_from_synthetic_parser(parser: argparse.ArgumentParser)
40
40
  parser.add_argument(
41
41
  "--device",
42
42
  type=str,
43
- default="cpu",
44
- help="Device to use (e.g., 'cpu', 'cuda', 'cuda:0')"
43
+ default="auto",
44
+ help="Device to use (e.g., 'auto', 'cpu', 'cuda', 'cuda:0', 'mps')"
45
45
  )
46
46
 
47
47
  # Pair generation
@@ -46,8 +46,8 @@ def setup_generate_vector_from_task_parser(parser: argparse.ArgumentParser) -> N
46
46
  parser.add_argument(
47
47
  "--device",
48
48
  type=str,
49
- default="cpu",
50
- help="Device to use (e.g., 'cpu', 'cuda', 'cuda:0')"
49
+ default="auto",
50
+ help="Device to use (e.g., 'auto', 'cpu', 'cuda', 'cuda:0', 'mps')"
51
51
  )
52
52
 
53
53
  # Pair generation
@@ -0,0 +1,61 @@
1
+ """Parser for geometry-search command."""
2
+
3
+ import argparse
4
+
5
+
6
+ def setup_geometry_search_parser(parser: argparse.ArgumentParser) -> None:
7
+ """Set up the geometry-search command parser."""
8
+ parser.add_argument(
9
+ "--model",
10
+ type=str,
11
+ required=True,
12
+ help="Model name or path (e.g., meta-llama/Llama-3.2-1B-Instruct)",
13
+ )
14
+ parser.add_argument(
15
+ "--output",
16
+ type=str,
17
+ default="/home/ubuntu/output/geometry_results.json",
18
+ help="Output path for results JSON",
19
+ )
20
+ parser.add_argument(
21
+ "--pairs-per-benchmark",
22
+ type=int,
23
+ default=50,
24
+ help="Number of pairs to sample per benchmark (default: 50)",
25
+ )
26
+ parser.add_argument(
27
+ "--max-layer-combo-size",
28
+ type=int,
29
+ default=3,
30
+ help="Maximum layers in combination (default: 3 = individual + pairs + triplets)",
31
+ )
32
+ parser.add_argument(
33
+ "--strategies",
34
+ type=str,
35
+ default=None,
36
+ help="Comma-separated list of strategies (default: all 7)",
37
+ )
38
+ parser.add_argument(
39
+ "--benchmarks",
40
+ type=str,
41
+ default=None,
42
+ help="Comma-separated list of benchmarks, or path to .txt file (default: all)",
43
+ )
44
+ parser.add_argument(
45
+ "--cache-dir",
46
+ type=str,
47
+ default=None,
48
+ help="Directory for activation cache (default: /tmp/wisent_geometry_cache_<model>)",
49
+ )
50
+ parser.add_argument(
51
+ "--seed",
52
+ type=int,
53
+ default=42,
54
+ help="Random seed for reproducibility (default: 42)",
55
+ )
56
+ parser.add_argument(
57
+ "--device",
58
+ type=str,
59
+ default="auto",
60
+ help="Device for model (auto/cuda/mps/cpu, default: auto)",
61
+ )
@@ -40,6 +40,7 @@ from wisent.core.parser_arguments.train_unified_goodness_parser import setup_tra
40
40
  from wisent.core.parser_arguments.optimize_parser import setup_optimize_parser
41
41
  from wisent.core.parser_arguments.check_linearity_parser import setup_check_linearity_parser
42
42
  from wisent.core.parser_arguments.cluster_benchmarks_parser import setup_cluster_benchmarks_parser
43
+ from wisent.core.parser_arguments.geometry_search_parser import setup_geometry_search_parser
43
44
 
44
45
 
45
46
  def setup_parser() -> argparse.ArgumentParser:
@@ -225,4 +226,11 @@ def setup_parser() -> argparse.ArgumentParser:
225
226
  )
226
227
  setup_cluster_benchmarks_parser(cluster_benchmarks_parser)
227
228
 
229
+ # Geometry search command - search for unified goodness direction across all benchmarks
230
+ geometry_search_parser = subparsers.add_parser(
231
+ "geometry-search",
232
+ help="Search for unified goodness direction across benchmarks (analyzes structure: linear/cone/orthogonal)"
233
+ )
234
+ setup_geometry_search_parser(geometry_search_parser)
235
+
228
236
  return parser
@@ -32,8 +32,8 @@ def setup_train_unified_goodness_parser(parser: argparse.ArgumentParser) -> None
32
32
  parser.add_argument(
33
33
  "--device",
34
34
  type=str,
35
- default="cuda",
36
- help="Device to use (e.g., 'cpu', 'cuda', 'cuda:0')"
35
+ default="auto",
36
+ help="Device to use (e.g., 'auto', 'cpu', 'cuda', 'cuda:0', 'mps')"
37
37
  )
38
38
 
39
39
  # Benchmark selection
wisent/core/steering.py CHANGED
@@ -477,11 +477,13 @@ class SteeringMethod:
477
477
  # Get prediction from steering method
478
478
  prediction = self.predict_proba(activation)
479
479
 
480
- # Convert to tensor for loss computation
480
+ # Convert to tensor for loss computation (use activation's dtype)
481
481
  if not isinstance(prediction, torch.Tensor):
482
- prediction = torch.tensor(prediction, dtype=torch.float32, device=self.device)
482
+ from wisent.core.utils.device import preferred_dtype
483
+ pred_dtype = activation.dtype if isinstance(activation, torch.Tensor) else preferred_dtype()
484
+ prediction = torch.tensor(prediction, dtype=pred_dtype, device=self.device)
483
485
 
484
- target = torch.tensor(label, dtype=torch.float32, device=self.device)
486
+ target = torch.tensor(label, dtype=prediction.dtype, device=self.device)
485
487
 
486
488
  # Binary cross-entropy loss
487
489
  loss = F.binary_cross_entropy_with_logits(prediction.unsqueeze(0), target.unsqueeze(0))
@@ -6,6 +6,7 @@ import numpy as np
6
6
 
7
7
  from wisent.core.steering_methods.core.atoms import PerLayerBaseSteeringMethod
8
8
  from wisent.core.errors import InsufficientDataError
9
+ from wisent.core.utils.device import preferred_dtype
9
10
 
10
11
  __all__ = [
11
12
  "HyperplaneMethod",
@@ -61,7 +62,7 @@ class HyperplaneMethod(PerLayerBaseSteeringMethod):
61
62
  clf.fit(X, y)
62
63
 
63
64
  # Use classifier weights as steering vector
64
- v = torch.tensor(clf.coef_[0], dtype=torch.float32)
65
+ v = torch.tensor(clf.coef_[0], dtype=preferred_dtype())
65
66
 
66
67
  if bool(self.kwargs.get("normalize", True)):
67
68
  v = self._safe_l2_normalize(v)
@@ -16,16 +16,6 @@ __all__ = [
16
16
  class ProgrammaticNonsenseGenerator:
17
17
  """Generate nonsense contrastive pairs programmatically without using LLM."""
18
18
 
19
- # Word list for word salad mode
20
- WORD_LIST = [
21
- "purple", "elephant", "calculator", "yesterday", "moon", "basket", "thinking",
22
- "telephone", "mountain", "running", "quickly", "tomorrow", "happiness", "keyboard",
23
- "window", "dancing", "coffee", "planet", "singing", "computer", "orange", "flying",
24
- "bicycle", "dream", "ocean", "pencil", "laughing", "cloud", "table", "walking",
25
- "music", "river", "chair", "jumping", "sun", "book", "swimming", "star", "door",
26
- "cooking", "tree", "writing", "sky", "flower", "playing", "rain", "paper", "sleeping"
27
- ]
28
-
29
19
  def __init__(
30
20
  self,
31
21
  nonsense_mode: str,
@@ -46,6 +36,18 @@ class ProgrammaticNonsenseGenerator:
46
36
  self.contrastive_set_name = contrastive_set_name
47
37
  self.trait_label = trait_label
48
38
  self.trait_description = trait_description
39
+ self._valid_words = None
40
+
41
+ def set_tokenizer(self, tokenizer) -> None:
42
+ """Extract valid words from tokenizer vocabulary."""
43
+ vocab = tokenizer.get_vocab()
44
+ valid_words = []
45
+ for token, token_id in vocab.items():
46
+ decoded = tokenizer.decode([token_id])
47
+ clean = decoded.strip()
48
+ if clean.isalpha() and len(clean) > 1 and len(clean) < 15:
49
+ valid_words.append(clean)
50
+ self._valid_words = list(set(valid_words))
49
51
 
50
52
  def generate(self, num_pairs: int = 10) -> ContrastivePairSet:
51
53
  """
@@ -108,11 +110,14 @@ class ProgrammaticNonsenseGenerator:
108
110
 
109
111
  def _generate_repetitive(self) -> str:
110
112
  """Generate pathologically repetitive text."""
113
+ if self._valid_words is None:
114
+ raise ValueError("Tokenizer must be set. Call set_tokenizer() first.")
115
+
111
116
  # Pick a random word or phrase
112
117
  choices = [
113
118
  random.choice(string.ascii_lowercase), # Single letter
114
- random.choice(self.WORD_LIST), # Single word
115
- ' '.join(random.sample(self.WORD_LIST, 2)), # Two-word phrase
119
+ random.choice(self._valid_words), # Single word
120
+ ' '.join(random.sample(self._valid_words, 2)), # Two-word phrase
116
121
  ]
117
122
  unit = random.choice(choices)
118
123
 
@@ -121,13 +126,20 @@ class ProgrammaticNonsenseGenerator:
121
126
  return ' '.join([unit] * repetitions)
122
127
 
123
128
  def _generate_word_salad(self) -> str:
124
- """Generate word salad (real words, no meaning)."""
125
- num_words = random.randint(8, 15)
126
- words = random.choices(self.WORD_LIST, k=num_words)
127
- return ' '.join(words)
129
+ """Generate word salad (random tokens from tokenizer vocabulary)."""
130
+ num_words = random.randint(3, 10)
131
+
132
+ if self._valid_words is not None:
133
+ words = random.choices(self._valid_words, k=num_words)
134
+ return ' '.join(words)
135
+
136
+ raise ValueError("Tokenizer must be set to generate word salad. Call set_tokenizer() first.")
128
137
 
129
138
  def _generate_mixed(self) -> str:
130
139
  """Generate mixed nonsense (combination of all types)."""
140
+ if self._valid_words is None:
141
+ raise ValueError("Tokenizer must be set. Call set_tokenizer() first.")
142
+
131
143
  components = []
132
144
 
133
145
  # Add 2-4 different types of nonsense
@@ -140,11 +152,11 @@ class ProgrammaticNonsenseGenerator:
140
152
  length = random.randint(5, 15)
141
153
  components.append(''.join(random.choices(string.ascii_lowercase, k=length)))
142
154
  elif mode == 'repetitive':
143
- word = random.choice(self.WORD_LIST)
155
+ word = random.choice(self._valid_words)
144
156
  reps = random.randint(3, 6)
145
157
  components.append(' '.join([word] * reps))
146
158
  else: # word_salad
147
159
  num_words = random.randint(3, 6)
148
- components.append(' '.join(random.choices(self.WORD_LIST, k=num_words)))
160
+ components.append(' '.join(random.choices(self._valid_words, k=num_words)))
149
161
 
150
162
  return ' '.join(components)
@@ -48,8 +48,8 @@ class WisentSteeringTrainer(BaseSteeringTrainer):
48
48
  model: WisentModel to use for activation collection.
49
49
  pair_set: ContrastivePairSet with pairs to use for collection and training.
50
50
  steering_method: BaseSteeringMethod instance to use for training.
51
- store_device: Device to store collected activations on (default "cpu").
52
- dtype: Optional torch.dtype to cast collected activations to (default None, meaning no cast).
51
+ store_device: Device to store collected activations on (default: "cpu" to avoid GPU OOM).
52
+ dtype: Optional torch.dtype to cast collected activations to.
53
53
  """
54
54
 
55
55
  model: WisentModel