wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (391) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activation_cache.py +393 -0
  12. wisent/core/activations/activations.py +3 -3
  13. wisent/core/activations/activations_collector.py +12 -7
  14. wisent/core/activations/classifier_inference_strategy.py +12 -11
  15. wisent/core/activations/extraction_strategy.py +260 -84
  16. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  17. wisent/core/cli/__init__.py +2 -1
  18. wisent/core/cli/agent/train_classifier.py +16 -3
  19. wisent/core/cli/check_linearity.py +35 -3
  20. wisent/core/cli/cluster_benchmarks.py +4 -6
  21. wisent/core/cli/create_steering_vector.py +6 -4
  22. wisent/core/cli/diagnose_vectors.py +7 -4
  23. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  24. wisent/core/cli/generate_pairs_from_task.py +9 -56
  25. wisent/core/cli/generate_vector_from_task.py +11 -20
  26. wisent/core/cli/geometry_search.py +137 -0
  27. wisent/core/cli/get_activations.py +2 -2
  28. wisent/core/cli/method_optimizer.py +4 -3
  29. wisent/core/cli/modify_weights.py +3 -2
  30. wisent/core/cli/optimize_sample_size.py +1 -1
  31. wisent/core/cli/optimize_steering.py +14 -16
  32. wisent/core/cli/optimize_weights.py +2 -1
  33. wisent/core/cli/preview_pairs.py +203 -0
  34. wisent/core/cli/steering_method_trainer.py +3 -3
  35. wisent/core/cli/tasks.py +19 -76
  36. wisent/core/cli/train_unified_goodness.py +3 -3
  37. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  38. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  53. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  54. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  55. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  56. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  57. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  58. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  59. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  60. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  61. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  273. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  274. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  275. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  276. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  277. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  278. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  279. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  280. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  281. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  282. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  283. wisent/core/geometry_runner.py +995 -0
  284. wisent/core/geometry_search_space.py +237 -0
  285. wisent/core/hyperparameter_optimizer.py +1 -1
  286. wisent/core/main.py +3 -0
  287. wisent/core/models/core/atoms.py +5 -3
  288. wisent/core/models/wisent_model.py +1 -1
  289. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  290. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  291. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  292. wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
  293. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  294. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  295. wisent/core/parser_arguments/main_parser.py +8 -0
  296. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  297. wisent/core/steering.py +5 -3
  298. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  299. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  300. wisent/core/trainers/steering_trainer.py +2 -2
  301. wisent/core/utils/device.py +27 -27
  302. wisent/core/utils/layer_combinations.py +70 -0
  303. wisent/examples/__init__.py +1 -0
  304. wisent/examples/scripts/__init__.py +1 -0
  305. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  306. wisent/examples/scripts/discover_directions.py +469 -0
  307. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  308. wisent/examples/scripts/search_all_short_names.py +31 -0
  309. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  310. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  311. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  312. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  313. wisent/examples/scripts/test_one_benchmark.py +324 -0
  314. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  315. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  316. wisent/parameters/lm_eval/category_directions.json +137 -0
  317. wisent/parameters/lm_eval/repair_plan.json +282 -0
  318. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  319. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  320. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  321. wisent/tests/test_detector_accuracy.py +1 -1
  322. wisent/tests/visualize_geometry.py +1 -1
  323. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
  325. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  326. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  327. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  328. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  329. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  330. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  331. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  332. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  333. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  334. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  335. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  336. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  337. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  338. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  339. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  340. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  341. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  342. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  343. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  344. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  345. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  346. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  347. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  348. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  349. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  350. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  351. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  352. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  353. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  354. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  355. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  356. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  357. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  358. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  359. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  360. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  361. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  362. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  363. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  364. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  365. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  366. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  367. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  368. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  369. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  370. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  371. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  372. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  373. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  374. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  375. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  376. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  377. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  378. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  379. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  380. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  381. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  382. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  383. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  384. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  385. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  386. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  387. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  388. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  389. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  390. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  391. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,465 @@
1
+ """
2
+ FGAA (Feature Guided Activation Addition) steering method.
3
+
4
+ Implements the method from "Interpretable Steering of Large Language Models
5
+ with Feature Guided Activation Additions" (arXiv:2501.09929).
6
+
7
+ Uses Gemma Scope SAEs and pre-computed effect approximators.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ from pathlib import Path
14
+ from typing import TYPE_CHECKING
15
+
16
+ import torch
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ from wisent.comparison.utils import (
20
+ apply_steering_to_model,
21
+ remove_steering,
22
+ convert_to_lm_eval_format,
23
+ generate_contrastive_pairs,
24
+ load_model_and_tokenizer,
25
+ load_sae,
26
+ SAE_CONFIGS,
27
+ )
28
+
29
+ if TYPE_CHECKING:
30
+ from wisent.core.models.wisent_model import WisentModel
31
+
32
+ __all__ = ["generate_steering_vector", "apply_steering_to_model", "remove_steering", "convert_to_lm_eval_format"]
33
+
34
+
35
+ # BOS feature indices - these features activate most strongly on the BOS token
36
+ # Paper features from Appendix G (5 features)
37
+ BOS_FEATURES_PAPER = {
38
+ "google/gemma-2-2b": [11087, 3220, 11752, 12160, 11498],
39
+ "google/gemma-2-9b": [], # Not listed in paper
40
+ }
41
+
42
+ # Detected features from running detect_bos_features.py (top 12 by mean activation)
43
+ BOS_FEATURES_DETECTED = {
44
+ "google/gemma-2-2b": [1041, 7507, 11087, 3220, 11767, 11752, 14669, 6889, 12160, 13700, 2747, 11498],
45
+ "google/gemma-2-9b": [8032, 11906, 7768, 14845, 14483, 10562, 8892, 9151, 5721, 15738, 5285, 13895],
46
+ }
47
+
48
+ # FGAA-specific: effect approximator config (adapter files)
49
+ FGAA_ADAPTER_FILES = {
50
+ "google/gemma-2-2b": "adapter_2b_layer_12.pt",
51
+ "google/gemma-2-9b": "adapter_9b_layer_12.pt",
52
+ }
53
+
54
+
55
+ def load_effect_approximator(model_name: str, device: str = "cuda:0") -> tuple[torch.Tensor, torch.Tensor]:
56
+ """
57
+ Load the pre-trained effect approximator (adapter) from HuggingFace.
58
+
59
+ The adapter contains:
60
+ - W: [d_model, d_sae] - maps SAE feature space to model activation space
61
+ - b: [d_sae] - bias term
62
+
63
+ Args:
64
+ model_name: HuggingFace model name
65
+ device: Device to load on
66
+
67
+ Returns:
68
+ Tuple of (W, b) tensors
69
+ """
70
+ if model_name not in FGAA_ADAPTER_FILES:
71
+ raise ValueError(f"No effect approximator for model '{model_name}'")
72
+
73
+ adapter_file = FGAA_ADAPTER_FILES[model_name]
74
+
75
+ print(f" Loading adapter from schalnev/sae-ts-effects / {adapter_file}")
76
+ path = hf_hub_download(
77
+ repo_id="schalnev/sae-ts-effects",
78
+ filename=adapter_file,
79
+ repo_type="dataset",
80
+ )
81
+
82
+ adapter = torch.load(path, map_location=device, weights_only=False)
83
+
84
+ # Adapter is OrderedDict with 'W' and 'b'
85
+ W = adapter["W"].to(device) # [d_model, d_sae]
86
+ b = adapter["b"].to(device) # [d_sae]
87
+
88
+ print(f" Adapter W shape: {W.shape}, b shape: {b.shape}")
89
+
90
+ return W, b
91
+
92
+
93
+ def compute_v_diff(
94
+ model,
95
+ tokenizer,
96
+ sae,
97
+ pairs: list[dict],
98
+ layer_idx: int,
99
+ device: str,
100
+ ) -> torch.Tensor:
101
+ """
102
+ Compute v_diff: the difference vector between positive and negative examples in SAE space.
103
+
104
+ v_diff = mean(f(h_l(x+))) - mean(f(h_l(x-)))
105
+
106
+ Args:
107
+ model: HuggingFace model
108
+ tokenizer: Tokenizer
109
+ sae: SAE object from sae_lens
110
+ pairs: List of contrastive pairs
111
+ layer_idx: Layer to extract activations from
112
+ device: Device
113
+
114
+ Returns:
115
+ v_diff tensor of shape [d_sae]
116
+ """
117
+ pos_features_list = []
118
+ neg_features_list = []
119
+
120
+ print(f" Computing v_diff from {len(pairs)} pairs...")
121
+
122
+ for i, pair in enumerate(pairs):
123
+ prompt = pair["prompt"]
124
+ pos_response = pair["positive_response"]["model_response"]
125
+ neg_response = pair["negative_response"]["model_response"]
126
+
127
+ pos_text = f"{prompt} {pos_response}"
128
+ neg_text = f"{prompt} {neg_response}"
129
+
130
+ # Get activations and encode through SAE
131
+ pos_acts = _get_residual_stream_activations(model, tokenizer, pos_text, layer_idx, device)
132
+ pos_acts = pos_acts.to(device).to(sae.W_enc.dtype)
133
+ # SAE encode: latents = (x - b_dec) @ W_enc + b_enc
134
+ pos_latents = sae.encode(pos_acts)
135
+ # Mean over sequence dimension
136
+ pos_features_list.append(pos_latents.mean(dim=1).detach()) # [1, d_sae]
137
+
138
+ neg_acts = _get_residual_stream_activations(model, tokenizer, neg_text, layer_idx, device)
139
+ neg_acts = neg_acts.to(device).to(sae.W_enc.dtype)
140
+ neg_latents = sae.encode(neg_acts)
141
+ neg_features_list.append(neg_latents.mean(dim=1).detach())
142
+
143
+ if (i + 1) % 10 == 0:
144
+ print(f" Processed {i + 1}/{len(pairs)} pairs")
145
+
146
+ # Stack and compute mean
147
+ pos_features = torch.cat(pos_features_list, dim=0) # [num_pairs, d_sae]
148
+ neg_features = torch.cat(neg_features_list, dim=0)
149
+
150
+ v_diff = pos_features.mean(dim=0) - neg_features.mean(dim=0) # [d_sae]
151
+
152
+ print(f" v_diff computed, shape: {v_diff.shape}")
153
+ print(f" v_diff stats: mean={v_diff.mean():.6f}, std={v_diff.std():.6f}, "
154
+ f"min={v_diff.min():.6f}, max={v_diff.max():.6f}")
155
+
156
+ return v_diff
157
+
158
+
159
+ def compute_v_target(
160
+ v_diff: torch.Tensor,
161
+ sparsity: torch.Tensor,
162
+ model_name: str,
163
+ bos_features_source: str = "detected",
164
+ density_threshold: float = 0.01,
165
+ top_k_positive: int = 50,
166
+ top_k_negative: int = 0,
167
+ ) -> torch.Tensor:
168
+ """
169
+ Compute v_target by filtering v_diff.
170
+
171
+ Three filtering stages:
172
+ 1. Density filtering: zero out features with activation density > threshold
173
+ 2. BOS token filtering: zero out features that activate mainly on BOS token
174
+ 3. Top-k selection: keep top positive and negative features
175
+
176
+ Args:
177
+ v_diff: Difference vector in SAE space [d_sae]
178
+ sparsity: Feature sparsity/density values from SAE [d_sae]
179
+ model_name: Model name to look up BOS features
180
+ bos_features_source: Source of BOS features - "paper" (5 features), "detected" (12 features), or "none"
181
+ density_threshold: Zero out features with density above this (default 0.01)
182
+ top_k_positive: Number of top positive features to keep
183
+ top_k_negative: Number of top negative features to keep (paper uses 0)
184
+
185
+ Returns:
186
+ v_target tensor of shape [d_sae]
187
+ """
188
+ v_filtered = v_diff.clone()
189
+
190
+ # Stage 1: Density filtering
191
+ # Zero out features that are too commonly activated (not specific enough)
192
+ if sparsity is not None:
193
+ density_mask = sparsity > density_threshold
194
+ num_filtered = density_mask.sum().item()
195
+ v_filtered[density_mask] = 0
196
+ print(f" Density filtering: zeroed {num_filtered} features (density > {density_threshold})")
197
+
198
+ # Stage 2: BOS filtering
199
+ # Zero out features that activate mainly on BOS tokens
200
+ if bos_features_source == "paper":
201
+ bos_features = BOS_FEATURES_PAPER.get(model_name, [])
202
+ elif bos_features_source == "detected":
203
+ bos_features = BOS_FEATURES_DETECTED.get(model_name, [])
204
+ else: # "none"
205
+ bos_features = []
206
+ if bos_features:
207
+ for idx in bos_features:
208
+ v_filtered[idx] = 0
209
+ print(f" BOS filtering: zeroed {len(bos_features)} features {bos_features}")
210
+ else:
211
+ print(f" BOS filtering: no known BOS features for {model_name}")
212
+
213
+ # Stage 3: Top-k selection
214
+ v_target = torch.zeros_like(v_filtered)
215
+
216
+ # Get top positive features
217
+ if top_k_positive > 0:
218
+ pos_values = v_filtered.clone()
219
+ pos_values[pos_values < 0] = 0
220
+ top_pos_values, top_pos_indices = pos_values.topk(min(top_k_positive, (pos_values > 0).sum().item()))
221
+ v_target[top_pos_indices] = v_filtered[top_pos_indices]
222
+ print(f" Selected top {len(top_pos_indices)} positive features")
223
+
224
+ # Get top negative features (paper uses 0)
225
+ if top_k_negative > 0:
226
+ neg_values = -v_filtered.clone()
227
+ neg_values[neg_values < 0] = 0
228
+ top_neg_values, top_neg_indices = neg_values.topk(min(top_k_negative, (neg_values > 0).sum().item()))
229
+ v_target[top_neg_indices] = v_filtered[top_neg_indices]
230
+ print(f" Selected top {len(top_neg_indices)} negative features")
231
+
232
+ num_nonzero = (v_target != 0).sum().item()
233
+ print(f" v_target: {num_nonzero} non-zero features")
234
+
235
+ return v_target
236
+
237
+
238
+ def compute_v_opt(
239
+ v_target: torch.Tensor,
240
+ W: torch.Tensor,
241
+ b: torch.Tensor,
242
+ ) -> torch.Tensor:
243
+ """
244
+ Compute v_opt using the effect approximator.
245
+
246
+ From paper: v_opt = (W @ v_target_norm) / ||W @ v_target_norm|| - (W @ b) / ||W @ b||
247
+
248
+ Args:
249
+ v_target: Target vector in SAE space [d_sae]
250
+ W: Effect approximator weight matrix [d_model, d_sae]
251
+ b: Effect approximator bias [d_sae]
252
+
253
+ Returns:
254
+ v_opt tensor of shape [d_model]
255
+ """
256
+ # L1 normalize v_target (as specified in paper)
257
+ v_target_norm = v_target / (v_target.abs().sum() + 1e-8)
258
+
259
+ # W is [d_model, d_sae], v_target_norm is [d_sae]
260
+ # W @ v_target_norm -> [d_model]
261
+ Wv = W @ v_target_norm
262
+ Wv_normalized = Wv / (Wv.norm() + 1e-8)
263
+
264
+ # Bias term: W @ b -> [d_model]
265
+ Wb = W @ b
266
+ Wb_normalized = Wb / (Wb.norm() + 1e-8)
267
+
268
+ # Final v_opt (paper formula)
269
+ v_opt = Wv_normalized - Wb_normalized
270
+
271
+ print(f" v_opt computed, shape: {v_opt.shape}, norm: {v_opt.norm():.6f}")
272
+
273
+ return v_opt
274
+
275
+
276
+ def _get_residual_stream_activations(
277
+ model,
278
+ tokenizer,
279
+ text: str,
280
+ layer_idx: int,
281
+ device: str,
282
+ ) -> torch.Tensor:
283
+ """
284
+ Get residual stream activations from a specific layer.
285
+
286
+ Uses output_hidden_states=True (same as wisent's ActivationCollector).
287
+
288
+ Args:
289
+ model: HuggingFace model
290
+ tokenizer: Tokenizer
291
+ text: Input text
292
+ layer_idx: Layer index (0-indexed)
293
+ device: Device
294
+
295
+ Returns:
296
+ Tensor of shape (1, seq_len, d_model)
297
+ """
298
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
299
+ inputs = {k: v.to(device) for k, v in inputs.items()}
300
+
301
+ with torch.no_grad():
302
+ out = model(**inputs, output_hidden_states=True, use_cache=False)
303
+
304
+ # hidden_states is tuple: (embedding, layer0, layer1, ..., layerN)
305
+ # layer_idx=0 -> hs[1], layer_idx=12 -> hs[13]
306
+ hs = out.hidden_states
307
+ return hs[layer_idx + 1] # +1 because hs[0] is embedding layer
308
+
309
+
310
+ def generate_steering_vector(
311
+ task: str,
312
+ model_name: str,
313
+ output_path: str | Path,
314
+ trait_label: str = "correctness",
315
+ num_pairs: int = 50,
316
+ method: str = "fgaa",
317
+ layers: str | None = None,
318
+ device: str = "cuda:0",
319
+ keep_intermediate: bool = False,
320
+ density_threshold: float = 0.01,
321
+ top_k_positive: int = 50,
322
+ top_k_negative: int = 0,
323
+ bos_features_source: str = "detected",
324
+ **kwargs, # Accept additional kwargs for compatibility (e.g., extraction_strategy)
325
+ ) -> Path:
326
+ """
327
+ Generate a steering vector using the FGAA method.
328
+
329
+ Args:
330
+ task: lm-eval task name (e.g., 'boolq', 'cb')
331
+ model_name: HuggingFace model name (must be Gemma 2B or 9B)
332
+ output_path: Where to save the steering vector
333
+ trait_label: Label for the trait being steered
334
+ num_pairs: Number of contrastive pairs to use
335
+ method: Method name (should be 'fgaa')
336
+ layers: Layer(s) to use (e.g., '12' or '10,11,12')
337
+ device: Device to run on
338
+ keep_intermediate: Whether to keep intermediate files
339
+ density_threshold: Density threshold for filtering (default 0.01)
340
+ top_k_positive: Number of top positive features to keep
341
+ top_k_negative: Number of top negative features to keep
342
+ bos_features_source: Source of BOS features - "paper" (5), "detected" (12), or "none"
343
+
344
+ Returns:
345
+ Path to the saved steering vector
346
+ """
347
+ import gc
348
+
349
+ output_path = Path(output_path)
350
+
351
+ if model_name not in SAE_CONFIGS:
352
+ raise ValueError(
353
+ f"No SAE config for model '{model_name}'. "
354
+ f"Supported models: {list(SAE_CONFIGS.keys())}"
355
+ )
356
+
357
+ config = SAE_CONFIGS[model_name]
358
+
359
+ # Parse layers
360
+ if layers is None:
361
+ layer_indices = [config["default_layer"]]
362
+ elif layers == "all":
363
+ layer_indices = list(range(config["num_layers"]))
364
+ else:
365
+ layer_indices = [int(l.strip()) for l in layers.split(",")]
366
+
367
+ # Step 1: Generate contrastive pairs
368
+ print(f"Step 1: Generating contrastive pairs from task: {task}")
369
+ pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
370
+ print(f" Loaded {len(pairs)} contrastive pairs")
371
+
372
+ # Step 2: Load model
373
+ print(f"\nStep 2: Loading model {model_name}...")
374
+ model, tokenizer = load_model_and_tokenizer(model_name, device)
375
+
376
+ # Step 3: Load effect approximator (shared across layers)
377
+ print(f"\nStep 3: Loading effect approximator...")
378
+ W, b = load_effect_approximator(model_name, device=device)
379
+
380
+ steering_vectors = {}
381
+ feature_info = {}
382
+
383
+ for layer_idx in layer_indices:
384
+ print(f"\nStep 4: Processing layer {layer_idx}")
385
+
386
+ # Load SAE for this layer
387
+ sae, sparsity = load_sae(model_name, layer_idx, device=device)
388
+
389
+ # Compute v_diff
390
+ print(f"\nStep 5: Computing v_diff for layer {layer_idx}...")
391
+ v_diff = compute_v_diff(model, tokenizer, sae, pairs, layer_idx, device)
392
+
393
+ # Compute v_target
394
+ print(f"\nStep 6: Computing v_target for layer {layer_idx}...")
395
+ v_target = compute_v_target(
396
+ v_diff,
397
+ sparsity,
398
+ model_name,
399
+ bos_features_source=bos_features_source,
400
+ density_threshold=density_threshold,
401
+ top_k_positive=top_k_positive,
402
+ top_k_negative=top_k_negative,
403
+ )
404
+
405
+ # Compute v_opt
406
+ print(f"\nStep 7: Computing v_opt for layer {layer_idx}...")
407
+ v_opt = compute_v_opt(v_target, W, b)
408
+
409
+ steering_vectors[str(layer_idx)] = v_opt.cpu().float().tolist()
410
+
411
+ # Store feature info
412
+ nonzero_mask = v_target != 0
413
+ nonzero_indices = nonzero_mask.nonzero().squeeze(-1).tolist()
414
+ feature_info[str(layer_idx)] = {
415
+ "num_selected_features": len(nonzero_indices) if isinstance(nonzero_indices, list) else 1,
416
+ "selected_feature_indices": nonzero_indices[:20] if isinstance(nonzero_indices, list) else [nonzero_indices],
417
+ "v_diff_stats": {
418
+ "mean": v_diff.mean().item(),
419
+ "std": v_diff.std().item(),
420
+ "min": v_diff.min().item(),
421
+ "max": v_diff.max().item(),
422
+ },
423
+ }
424
+
425
+ # Cleanup SAE
426
+ del sae, sparsity, v_diff, v_target
427
+ gc.collect()
428
+ if torch.cuda.is_available():
429
+ torch.cuda.empty_cache()
430
+
431
+ # Cleanup
432
+ del model, W, b
433
+ gc.collect()
434
+ if torch.cuda.is_available():
435
+ torch.cuda.empty_cache()
436
+ torch.cuda.synchronize()
437
+
438
+ if not keep_intermediate:
439
+ import os
440
+ os.unlink(pairs_file)
441
+
442
+ # Save results
443
+ result = {
444
+ "steering_vectors": steering_vectors,
445
+ "layers": [str(l) for l in layer_indices],
446
+ "model": model_name,
447
+ "method": "fgaa",
448
+ "trait_label": trait_label,
449
+ "task": task,
450
+ "num_pairs": len(pairs),
451
+ "fgaa_params": {
452
+ "density_threshold": density_threshold,
453
+ "top_k_positive": top_k_positive,
454
+ "top_k_negative": top_k_negative,
455
+ "bos_features_source": bos_features_source,
456
+ },
457
+ "feature_info": feature_info,
458
+ }
459
+
460
+ output_path.parent.mkdir(parents=True, exist_ok=True)
461
+ with open(output_path, "w") as f:
462
+ json.dump(result, f, indent=2)
463
+
464
+ print(f"\nSaved FGAA steering vector to {output_path}")
465
+ return output_path