wisent 0.7.701__py3-none-any.whl → 0.7.901__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (330) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/core/activations/activation_cache.py +393 -0
  3. wisent/core/activations/activations.py +3 -3
  4. wisent/core/activations/activations_collector.py +9 -5
  5. wisent/core/activations/classifier_inference_strategy.py +12 -11
  6. wisent/core/activations/extraction_strategy.py +256 -84
  7. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  8. wisent/core/cli/__init__.py +2 -1
  9. wisent/core/cli/agent/apply_steering.py +5 -7
  10. wisent/core/cli/agent/train_classifier.py +19 -7
  11. wisent/core/cli/check_linearity.py +35 -3
  12. wisent/core/cli/cluster_benchmarks.py +4 -6
  13. wisent/core/cli/create_steering_vector.py +6 -4
  14. wisent/core/cli/diagnose_vectors.py +7 -4
  15. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  16. wisent/core/cli/generate_pairs_from_task.py +9 -56
  17. wisent/core/cli/geometry_search.py +137 -0
  18. wisent/core/cli/get_activations.py +1 -1
  19. wisent/core/cli/method_optimizer.py +4 -3
  20. wisent/core/cli/modify_weights.py +3 -2
  21. wisent/core/cli/optimize_sample_size.py +1 -1
  22. wisent/core/cli/optimize_steering.py +14 -16
  23. wisent/core/cli/optimize_weights.py +2 -1
  24. wisent/core/cli/preview_pairs.py +203 -0
  25. wisent/core/cli/steering_method_trainer.py +3 -3
  26. wisent/core/cli/tasks.py +19 -76
  27. wisent/core/cli/train_unified_goodness.py +3 -3
  28. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  29. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  30. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  31. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  32. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  33. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  34. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  35. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  36. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  37. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  38. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  53. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  54. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  55. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  56. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  57. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  58. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  59. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  60. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  61. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +2 -2
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +2 -2
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  273. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  274. wisent/core/geometry_runner.py +995 -0
  275. wisent/core/geometry_search_space.py +237 -0
  276. wisent/core/hyperparameter_optimizer.py +1 -1
  277. wisent/core/main.py +3 -0
  278. wisent/core/models/core/atoms.py +5 -3
  279. wisent/core/models/wisent_model.py +1 -1
  280. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  281. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  282. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  283. wisent/core/parser_arguments/generate_vector_from_task_parser.py +2 -2
  284. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  285. wisent/core/parser_arguments/main_parser.py +8 -0
  286. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  287. wisent/core/steering.py +5 -3
  288. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  289. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  290. wisent/core/trainers/steering_trainer.py +2 -2
  291. wisent/core/utils/device.py +27 -27
  292. wisent/core/utils/layer_combinations.py +70 -0
  293. wisent/examples/__init__.py +1 -0
  294. wisent/examples/scripts/__init__.py +1 -0
  295. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  296. wisent/examples/scripts/discover_directions.py +469 -0
  297. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  298. wisent/examples/scripts/generate_paper_data.py +384 -0
  299. wisent/examples/scripts/intervention_validation.py +626 -0
  300. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +324 -0
  301. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +92 -0
  302. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +324 -0
  303. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +92 -0
  304. wisent/examples/scripts/results/test_afrimgsm_pairs.json +92 -0
  305. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +324 -0
  306. wisent/examples/scripts/results/test_afrimmlu_pairs.json +92 -0
  307. wisent/examples/scripts/search_all_short_names.py +31 -0
  308. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  309. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  310. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  311. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  312. wisent/examples/scripts/test_one_benchmark.py +324 -0
  313. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  314. wisent/examples/scripts/threshold_analysis.py +434 -0
  315. wisent/examples/scripts/visualization_gallery.py +582 -0
  316. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  317. wisent/parameters/lm_eval/category_directions.json +137 -0
  318. wisent/parameters/lm_eval/repair_plan.json +282 -0
  319. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  320. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  321. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  322. wisent/tests/test_detector_accuracy.py +1 -1
  323. wisent/tests/visualize_geometry.py +1 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/METADATA +1 -1
  325. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/RECORD +329 -295
  326. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  327. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/WHEEL +0 -0
  328. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/entry_points.txt +0 -0
  329. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/licenses/LICENSE +0 -0
  330. {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/top_level.txt +0 -0
@@ -4,15 +4,19 @@ Unified extraction strategies for activation collection.
4
4
  These strategies combine prompt construction and token extraction into a single
5
5
  unified approach, based on empirical testing of what actually works.
6
6
 
7
- The strategies are:
7
+ CHAT STRATEGIES (require chat template - for instruct models):
8
8
  - chat_mean: Chat template prompt, mean of answer tokens
9
9
  - chat_first: Chat template prompt, first answer token
10
10
  - chat_last: Chat template prompt, last token
11
- - chat_gen_point: Chat template prompt, token before answer (generation decision point)
12
11
  - chat_max_norm: Chat template prompt, token with max norm in answer
13
12
  - chat_weighted: Chat template prompt, position-weighted mean (earlier tokens weighted more)
14
13
  - role_play: "Behave like person who answers Q with A" format, last token
15
14
  - mc_balanced: Multiple choice with balanced A/B assignment, last token
15
+
16
+ BASE MODEL STRATEGIES (no chat template - for base models like gemma-2b, gemma-9b):
17
+ - completion_last: Direct Q+A completion, last token
18
+ - completion_mean: Direct Q+A completion, mean of answer tokens
19
+ - mc_completion: Multiple choice without chat template, A/B token
16
20
  """
17
21
 
18
22
  from enum import Enum
@@ -35,10 +39,7 @@ class ExtractionStrategy(str, Enum):
35
39
  """Chat template prompt with Q+A, extract first answer token."""
36
40
 
37
41
  CHAT_LAST = "chat_last"
38
- """Chat template prompt with Q+A, extract last token."""
39
-
40
- CHAT_GEN_POINT = "chat_gen_point"
41
- """Chat template prompt with Q+A, extract token before answer starts (decision point)."""
42
+ """Chat template prompt with Q+A, extract EOT token (has seen full answer)."""
42
43
 
43
44
  CHAT_MAX_NORM = "chat_max_norm"
44
45
  """Chat template prompt with Q+A, extract token with max norm in answer region."""
@@ -47,22 +48,34 @@ class ExtractionStrategy(str, Enum):
47
48
  """Chat template prompt with Q+A, position-weighted mean (earlier tokens weighted more)."""
48
49
 
49
50
  ROLE_PLAY = "role_play"
50
- """'Behave like person who answers Q with A' format, extract last token."""
51
+ """'Behave like person who answers Q with A' format, extract EOT token."""
51
52
 
52
53
  MC_BALANCED = "mc_balanced"
53
- """Multiple choice format with balanced A/B assignment, extract last token."""
54
+ """Multiple choice format with balanced A/B assignment, extract the A/B choice token."""
55
+
56
+ # Base model strategies (no chat template required)
57
+ COMPLETION_LAST = "completion_last"
58
+ """Direct Q+A completion without chat template, extract last token. For base models."""
59
+
60
+ COMPLETION_MEAN = "completion_mean"
61
+ """Direct Q+A completion without chat template, extract mean of answer tokens. For base models."""
62
+
63
+ MC_COMPLETION = "mc_completion"
64
+ """Multiple choice without chat template, extract A/B token. For base models."""
54
65
 
55
66
  @property
56
67
  def description(self) -> str:
57
68
  descriptions = {
58
69
  ExtractionStrategy.CHAT_MEAN: "Chat template with mean of answer tokens",
59
70
  ExtractionStrategy.CHAT_FIRST: "Chat template with first answer token",
60
- ExtractionStrategy.CHAT_LAST: "Chat template with last token",
61
- ExtractionStrategy.CHAT_GEN_POINT: "Chat template with generation decision point",
71
+ ExtractionStrategy.CHAT_LAST: "Chat template with EOT token",
62
72
  ExtractionStrategy.CHAT_MAX_NORM: "Chat template with max-norm answer token",
63
73
  ExtractionStrategy.CHAT_WEIGHTED: "Chat template with position-weighted mean",
64
- ExtractionStrategy.ROLE_PLAY: "Role-playing format with last token",
65
- ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with last token",
74
+ ExtractionStrategy.ROLE_PLAY: "Role-playing format with EOT token",
75
+ ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with A/B token",
76
+ ExtractionStrategy.COMPLETION_LAST: "Direct completion with last token (base models)",
77
+ ExtractionStrategy.COMPLETION_MEAN: "Direct completion with mean of answer tokens (base models)",
78
+ ExtractionStrategy.MC_COMPLETION: "Multiple choice completion with A/B token (base models)",
66
79
  }
67
80
  return descriptions.get(self, "Unknown strategy")
68
81
 
@@ -75,6 +88,77 @@ class ExtractionStrategy(str, Enum):
75
88
  def list_all(cls) -> list[str]:
76
89
  """List all strategy names."""
77
90
  return [s.value for s in cls]
91
+
92
+ @classmethod
93
+ def for_tokenizer(cls, tokenizer, prefer_mc: bool = False) -> "ExtractionStrategy":
94
+ """
95
+ Select the appropriate strategy based on whether tokenizer supports chat template.
96
+
97
+ Args:
98
+ tokenizer: The tokenizer to check
99
+ prefer_mc: If True, prefer multiple choice strategies (mc_balanced/mc_completion)
100
+
101
+ Returns:
102
+ Appropriate strategy for the tokenizer type
103
+ """
104
+ has_chat = hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
105
+
106
+ if has_chat:
107
+ return cls.MC_BALANCED if prefer_mc else cls.CHAT_LAST
108
+ else:
109
+ return cls.MC_COMPLETION if prefer_mc else cls.COMPLETION_LAST
110
+
111
+ @classmethod
112
+ def is_base_model_strategy(cls, strategy: "ExtractionStrategy") -> bool:
113
+ """Check if a strategy is designed for base models (no chat template)."""
114
+ return strategy in (cls.COMPLETION_LAST, cls.COMPLETION_MEAN, cls.MC_COMPLETION)
115
+
116
+ @classmethod
117
+ def get_equivalent_for_model_type(cls, strategy: "ExtractionStrategy", tokenizer) -> "ExtractionStrategy":
118
+ """
119
+ Get the equivalent strategy for the given tokenizer type.
120
+
121
+ If strategy requires chat template but tokenizer doesn't have it,
122
+ returns the base model equivalent. And vice versa.
123
+
124
+ Args:
125
+ strategy: The requested strategy
126
+ tokenizer: The tokenizer to check
127
+
128
+ Returns:
129
+ The appropriate strategy for the tokenizer
130
+ """
131
+ has_chat = hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
132
+ is_base_strategy = cls.is_base_model_strategy(strategy)
133
+
134
+ if has_chat and is_base_strategy:
135
+ # Tokenizer has chat but strategy is for base model - upgrade to chat version
136
+ mapping = {
137
+ cls.COMPLETION_LAST: cls.CHAT_LAST,
138
+ cls.COMPLETION_MEAN: cls.CHAT_MEAN,
139
+ cls.MC_COMPLETION: cls.MC_BALANCED,
140
+ }
141
+ return mapping.get(strategy, strategy)
142
+
143
+ elif not has_chat and not is_base_strategy:
144
+ # Tokenizer is base model but strategy requires chat - downgrade to base version
145
+ mapping = {
146
+ cls.CHAT_LAST: cls.COMPLETION_LAST,
147
+ cls.CHAT_FIRST: cls.COMPLETION_LAST,
148
+ cls.CHAT_MEAN: cls.COMPLETION_MEAN,
149
+ cls.CHAT_MAX_NORM: cls.COMPLETION_LAST,
150
+ cls.CHAT_WEIGHTED: cls.COMPLETION_MEAN,
151
+ cls.ROLE_PLAY: cls.COMPLETION_LAST,
152
+ cls.MC_BALANCED: cls.MC_COMPLETION,
153
+ }
154
+ return mapping.get(strategy, cls.COMPLETION_LAST)
155
+
156
+ return strategy
157
+
158
+
159
+ def tokenizer_has_chat_template(tokenizer) -> bool:
160
+ """Check if tokenizer supports chat template."""
161
+ return hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
78
162
 
79
163
 
80
164
  # Random tokens for role_play strategy (deterministic based on prompt hash)
@@ -88,6 +172,7 @@ def build_extraction_texts(
88
172
  tokenizer,
89
173
  other_response: Optional[str] = None,
90
174
  is_positive: bool = True,
175
+ auto_convert_strategy: bool = True,
91
176
  ) -> Tuple[str, str, Optional[str]]:
92
177
  """
93
178
  Build the full text for activation extraction based on strategy.
@@ -97,8 +182,9 @@ def build_extraction_texts(
97
182
  prompt: The user prompt/question
98
183
  response: The response to extract activations for
99
184
  tokenizer: The tokenizer (needs apply_chat_template for chat strategies)
100
- other_response: For mc_balanced, the other response option
101
- is_positive: For mc_balanced, whether 'response' is the positive option
185
+ other_response: For mc_balanced/mc_completion, the other response option
186
+ is_positive: For mc_balanced/mc_completion, whether 'response' is the positive option
187
+ auto_convert_strategy: If True, automatically convert strategy to match tokenizer type
102
188
 
103
189
  Returns:
104
190
  Tuple of (full_text, answer_text, prompt_only_text)
@@ -106,31 +192,40 @@ def build_extraction_texts(
106
192
  - answer_text: The answer portion (for strategies that need it)
107
193
  - prompt_only_text: Prompt without answer (for boundary detection)
108
194
  """
195
+ # Auto-convert strategy if needed
196
+ if auto_convert_strategy:
197
+ original_strategy = strategy
198
+ strategy = ExtractionStrategy.get_equivalent_for_model_type(strategy, tokenizer)
199
+ if strategy != original_strategy:
200
+ import warnings
201
+ warnings.warn(
202
+ f"Strategy {original_strategy.value} not compatible with tokenizer, "
203
+ f"using {strategy.value} instead.",
204
+ UserWarning
205
+ )
109
206
 
110
207
  if strategy in (ExtractionStrategy.CHAT_MEAN, ExtractionStrategy.CHAT_FIRST,
111
- ExtractionStrategy.CHAT_LAST, ExtractionStrategy.CHAT_GEN_POINT,
112
- ExtractionStrategy.CHAT_MAX_NORM, ExtractionStrategy.CHAT_WEIGHTED):
208
+ ExtractionStrategy.CHAT_LAST, ExtractionStrategy.CHAT_MAX_NORM,
209
+ ExtractionStrategy.CHAT_WEIGHTED):
113
210
  # All chat_* strategies use the same prompt construction
114
- if hasattr(tokenizer, "apply_chat_template"):
115
- try:
116
- prompt_only = tokenizer.apply_chat_template(
117
- [{"role": "user", "content": prompt}],
118
- tokenize=False,
119
- add_generation_prompt=True,
120
- )
121
- full_text = tokenizer.apply_chat_template(
122
- [{"role": "user", "content": prompt},
123
- {"role": "assistant", "content": response}],
124
- tokenize=False,
125
- add_generation_prompt=False,
126
- )
127
- except (ValueError, KeyError):
128
- # Fallback for models without chat template
129
- prompt_only = prompt
130
- full_text = f"{prompt} {response}"
131
- else:
132
- prompt_only = prompt
133
- full_text = f"{prompt} {response}"
211
+ if not hasattr(tokenizer, "apply_chat_template"):
212
+ raise ValueError(
213
+ f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
214
+ f"Base models without chat templates are not supported for chat_* strategies. "
215
+ f"Use completion_last, completion_mean, or mc_completion instead."
216
+ )
217
+
218
+ prompt_only = tokenizer.apply_chat_template(
219
+ [{"role": "user", "content": prompt}],
220
+ tokenize=False,
221
+ add_generation_prompt=True,
222
+ )
223
+ full_text = tokenizer.apply_chat_template(
224
+ [{"role": "user", "content": prompt},
225
+ {"role": "assistant", "content": response}],
226
+ tokenize=False,
227
+ add_generation_prompt=False,
228
+ )
134
229
 
135
230
  return full_text, response, prompt_only
136
231
 
@@ -139,25 +234,23 @@ def build_extraction_texts(
139
234
  random_token = ROLE_PLAY_TOKENS[hash(prompt) % len(ROLE_PLAY_TOKENS)]
140
235
  instruction = f"Behave like a person that would answer {prompt} with {response}"
141
236
 
142
- if hasattr(tokenizer, "apply_chat_template"):
143
- try:
144
- prompt_only = tokenizer.apply_chat_template(
145
- [{"role": "user", "content": instruction}],
146
- tokenize=False,
147
- add_generation_prompt=True,
148
- )
149
- full_text = tokenizer.apply_chat_template(
150
- [{"role": "user", "content": instruction},
151
- {"role": "assistant", "content": random_token}],
152
- tokenize=False,
153
- add_generation_prompt=False,
154
- )
155
- except (ValueError, KeyError):
156
- prompt_only = instruction
157
- full_text = f"{instruction} {random_token}"
158
- else:
159
- prompt_only = instruction
160
- full_text = f"{instruction} {random_token}"
237
+ if not hasattr(tokenizer, "apply_chat_template"):
238
+ raise ValueError(
239
+ f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
240
+ f"Use completion_last or mc_completion for base models."
241
+ )
242
+
243
+ prompt_only = tokenizer.apply_chat_template(
244
+ [{"role": "user", "content": instruction}],
245
+ tokenize=False,
246
+ add_generation_prompt=True,
247
+ )
248
+ full_text = tokenizer.apply_chat_template(
249
+ [{"role": "user", "content": instruction},
250
+ {"role": "assistant", "content": random_token}],
251
+ tokenize=False,
252
+ add_generation_prompt=False,
253
+ )
161
254
 
162
255
  return full_text, random_token, prompt_only
163
256
 
@@ -188,28 +281,66 @@ def build_extraction_texts(
188
281
  option_b = response[:200] # negative
189
282
  answer = "B"
190
283
 
191
- mc_prompt = f"Which is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
284
+ mc_prompt = f"Question: {prompt}\n\nWhich is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
285
+
286
+ if not hasattr(tokenizer, "apply_chat_template"):
287
+ raise ValueError(
288
+ f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
289
+ f"Use mc_completion for base models."
290
+ )
291
+
292
+ prompt_only = tokenizer.apply_chat_template(
293
+ [{"role": "user", "content": mc_prompt}],
294
+ tokenize=False,
295
+ add_generation_prompt=True,
296
+ )
297
+ full_text = tokenizer.apply_chat_template(
298
+ [{"role": "user", "content": mc_prompt},
299
+ {"role": "assistant", "content": answer}],
300
+ tokenize=False,
301
+ add_generation_prompt=False,
302
+ )
303
+
304
+ return full_text, answer, prompt_only
305
+
306
+ elif strategy in (ExtractionStrategy.COMPLETION_LAST, ExtractionStrategy.COMPLETION_MEAN):
307
+ # Base model strategies - direct Q+A without chat template
308
+ # Format: "Q: {prompt}\nA: {response}"
309
+ prompt_only = f"Q: {prompt}\nA:"
310
+ full_text = f"Q: {prompt}\nA: {response}"
311
+ return full_text, response, prompt_only
312
+
313
+ elif strategy == ExtractionStrategy.MC_COMPLETION:
314
+ # Multiple choice for base models - no chat template
315
+ if other_response is None:
316
+ raise ValueError("MC_COMPLETION strategy requires other_response")
317
+
318
+ # Deterministic "random" based on prompt - same for both pos and neg of a pair
319
+ pos_goes_in_b = hash(prompt) % 2 == 0
192
320
 
193
- if hasattr(tokenizer, "apply_chat_template"):
194
- try:
195
- prompt_only = tokenizer.apply_chat_template(
196
- [{"role": "user", "content": mc_prompt}],
197
- tokenize=False,
198
- add_generation_prompt=True,
199
- )
200
- full_text = tokenizer.apply_chat_template(
201
- [{"role": "user", "content": mc_prompt},
202
- {"role": "assistant", "content": answer}],
203
- tokenize=False,
204
- add_generation_prompt=False,
205
- )
206
- except (ValueError, KeyError):
207
- prompt_only = mc_prompt
208
- full_text = f"{mc_prompt} {answer}"
321
+ if is_positive:
322
+ if pos_goes_in_b:
323
+ option_a = other_response[:200]
324
+ option_b = response[:200]
325
+ answer = "B"
326
+ else:
327
+ option_a = response[:200]
328
+ option_b = other_response[:200]
329
+ answer = "A"
209
330
  else:
210
- prompt_only = mc_prompt
211
- full_text = f"{mc_prompt} {answer}"
331
+ if pos_goes_in_b:
332
+ option_a = response[:200]
333
+ option_b = other_response[:200]
334
+ answer = "A"
335
+ else:
336
+ option_a = other_response[:200]
337
+ option_b = response[:200]
338
+ answer = "B"
212
339
 
340
+ mc_prompt = f"Question: {prompt}\n\nWhich is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
341
+
342
+ prompt_only = mc_prompt
343
+ full_text = f"{mc_prompt} {answer}"
213
344
  return full_text, answer, prompt_only
214
345
 
215
346
  else:
@@ -243,6 +374,7 @@ def extract_activation(
243
374
  num_answer_tokens = len(answer_tokens)
244
375
 
245
376
  if strategy == ExtractionStrategy.CHAT_LAST:
377
+ # EOT token - has seen the entire answer, best performance
246
378
  return hidden_states[-1]
247
379
 
248
380
  elif strategy == ExtractionStrategy.CHAT_FIRST:
@@ -257,11 +389,6 @@ def extract_activation(
257
389
  return answer_hidden.mean(dim=0)
258
390
  return hidden_states[-1]
259
391
 
260
- elif strategy == ExtractionStrategy.CHAT_GEN_POINT:
261
- # Last token before answer starts (decision point)
262
- gen_point_idx = max(0, seq_len - num_answer_tokens - 2)
263
- return hidden_states[gen_point_idx]
264
-
265
392
  elif strategy == ExtractionStrategy.CHAT_MAX_NORM:
266
393
  # Token with max norm in answer region
267
394
  if num_answer_tokens > 0 and seq_len > num_answer_tokens:
@@ -275,18 +402,36 @@ def extract_activation(
275
402
  # Position-weighted mean (earlier tokens weighted more)
276
403
  if num_answer_tokens > 0 and seq_len > num_answer_tokens:
277
404
  answer_hidden = hidden_states[-num_answer_tokens-1:-1]
278
- weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=torch.float32, device=answer_hidden.device) * 0.5)
405
+ weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=answer_hidden.dtype, device=answer_hidden.device) * 0.5)
279
406
  weights = weights / weights.sum()
280
407
  return (answer_hidden * weights.unsqueeze(1)).sum(dim=0)
281
408
  return hidden_states[-1]
282
409
 
283
- elif strategy in (ExtractionStrategy.ROLE_PLAY, ExtractionStrategy.MC_BALANCED):
284
- # Both use last token
410
+ elif strategy == ExtractionStrategy.ROLE_PLAY:
411
+ # EOT token - slightly better than answer word (65% vs 64%)
285
412
  return hidden_states[-1]
286
413
 
287
- else:
288
- # Default fallback
414
+ elif strategy == ExtractionStrategy.MC_BALANCED:
415
+ # Answer token (A/B) - better than EOT (64% vs 56%)
416
+ return hidden_states[-2]
417
+
418
+ elif strategy == ExtractionStrategy.COMPLETION_LAST:
419
+ # Last token for base model completion
289
420
  return hidden_states[-1]
421
+
422
+ elif strategy == ExtractionStrategy.COMPLETION_MEAN:
423
+ # Mean of answer tokens for base model completion
424
+ if num_answer_tokens > 0 and seq_len > num_answer_tokens:
425
+ answer_hidden = hidden_states[-num_answer_tokens:]
426
+ return answer_hidden.mean(dim=0)
427
+ return hidden_states[-1]
428
+
429
+ elif strategy == ExtractionStrategy.MC_COMPLETION:
430
+ # A/B token for base model MC (last token is the answer)
431
+ return hidden_states[-1]
432
+
433
+ else:
434
+ raise ValueError(f"Unknown extraction strategy: {strategy}")
290
435
 
291
436
 
292
437
  def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
@@ -306,3 +451,30 @@ def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
306
451
  choices=ExtractionStrategy.list_all(),
307
452
  help=f"Extraction strategy for activations. Options: {', '.join(ExtractionStrategy.list_all())}. Default: {ExtractionStrategy.default().value}",
308
453
  )
454
+
455
+
456
+ def get_strategy_for_model(tokenizer, prefer_mc: bool = False) -> ExtractionStrategy:
457
+ """
458
+ Get the best extraction strategy for a given tokenizer.
459
+
460
+ Automatically detects if tokenizer has chat template and returns
461
+ the appropriate strategy.
462
+
463
+ Args:
464
+ tokenizer: The tokenizer to check
465
+ prefer_mc: If True, prefer multiple choice strategies
466
+
467
+ Returns:
468
+ ExtractionStrategy appropriate for the tokenizer
469
+
470
+ Example:
471
+ >>> from transformers import AutoTokenizer
472
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
473
+ >>> strategy = get_strategy_for_model(tokenizer)
474
+ >>> print(strategy) # completion_last (base model)
475
+
476
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
477
+ >>> strategy = get_strategy_for_model(tokenizer)
478
+ >>> print(strategy) # chat_last (instruct model)
479
+ """
480
+ return ExtractionStrategy.for_tokenizer(tokenizer, prefer_mc=prefer_mc)
@@ -14,6 +14,7 @@ import numpy as np
14
14
 
15
15
  from torch.nn.modules.loss import _Loss
16
16
  from wisent.core.errors import DuplicateNameError, InvalidRangeError, UnknownTypeError
17
+ from wisent.core.utils.device import preferred_dtype
17
18
 
18
19
  __all__ = [
19
20
  "ClassifierTrainConfig",
@@ -164,13 +165,13 @@ class BaseClassifier(ABC):
164
165
  self,
165
166
  threshold: float = 0.5,
166
167
  device: str | None = None,
167
- dtype: torch.dtype = torch.float32,
168
+ dtype: torch.dtype | None = None,
168
169
  ) -> None:
169
170
  if not 0.0 <= threshold <= 1.0:
170
171
  raise InvalidRangeError(param_name="threshold", actual=threshold, min_val=0.0, max_val=1.0)
171
172
  self.threshold = threshold
172
173
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
173
- self.dtype = torch.float32 if self.device == "mps" else dtype
174
+ self.dtype = dtype if dtype is not None else preferred_dtype(self.device)
174
175
  self.model = None
175
176
 
176
177
  @abstractmethod
@@ -22,5 +22,6 @@ from .inference_config_cli import execute_inference_config
22
22
  from .optimization_cache import execute_optimization_cache
23
23
  from .optimize_weights import execute_optimize_weights
24
24
  from .optimize import execute_optimize
25
+ from .geometry_search import execute_geometry_search
25
26
 
26
- __all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_diagnose_pairs', 'execute_get_activations', 'execute_diagnose_vectors', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_optimize_sample_size', 'execute_generate_responses', 'execute_evaluate_responses', 'execute_multi_steer', 'execute_agent', 'execute_modify_weights', 'execute_evaluate_refusal', 'execute_inference_config', 'execute_optimization_cache', 'execute_optimize_weights', 'execute_optimize']
27
+ __all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_diagnose_pairs', 'execute_get_activations', 'execute_diagnose_vectors', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_optimize_sample_size', 'execute_generate_responses', 'execute_evaluate_responses', 'execute_multi_steer', 'execute_agent', 'execute_modify_weights', 'execute_evaluate_refusal', 'execute_inference_config', 'execute_optimization_cache', 'execute_optimize_weights', 'execute_optimize', 'execute_geometry_search']
@@ -19,7 +19,7 @@ def _map_token_aggregation(aggregation_str: str):
19
19
 
20
20
  def _map_prompt_strategy(strategy_str: str):
21
21
  """Map string prompt strategy to ExtractionStrategy."""
22
-
22
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
23
23
 
24
24
  mapping = {
25
25
  "chat_template": ExtractionStrategy.CHAT_LAST,
@@ -111,9 +111,8 @@ def apply_steering_and_evaluate(
111
111
 
112
112
  updated_pair = collector.collect(
113
113
  pair, strategy=aggregation_strategy,
114
- return_full_sequence=return_full_sequence,
115
- normalize_layers=normalize_layers,
116
- prompt_strategy=prompt_construction_strategy
114
+ layers=target_layers,
115
+ normalize=normalize_layers
117
116
  )
118
117
  enriched_pairs.append(updated_pair)
119
118
 
@@ -174,9 +173,8 @@ def apply_steering_and_evaluate(
174
173
 
175
174
  steered_evaluated_pair = collector.collect(
176
175
  steered_dummy_pair, strategy=aggregation_strategy,
177
- return_full_sequence=return_full_sequence,
178
- normalize_layers=normalize_layers,
179
- prompt_strategy=prompt_construction_strategy
176
+ layers=target_layers,
177
+ normalize=normalize_layers
180
178
  )
181
179
 
182
180
  steered_quality = 0.0
@@ -1,8 +1,20 @@
1
1
  """Train classifier on contrastive pairs for agent."""
2
2
 
3
3
  import numpy as np
4
+ import torch
4
5
  from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainReport
5
6
  from wisent.core.errors import UnknownTypeError
7
+ from wisent.core.utils.device import preferred_dtype
8
+
9
+
10
+ def _torch_dtype_to_numpy(torch_dtype: torch.dtype):
11
+ """Convert torch dtype to numpy dtype."""
12
+ mapping = {
13
+ torch.float32: np.float32,
14
+ torch.float16: np.float16,
15
+ torch.bfloat16: np.float32, # numpy doesn't support bfloat16, use float32
16
+ }
17
+ return mapping.get(torch_dtype, np.float32)
6
18
 
7
19
 
8
20
  def _map_token_aggregation(aggregation_str: str):
@@ -21,7 +33,7 @@ def _map_token_aggregation(aggregation_str: str):
21
33
 
22
34
  def _map_prompt_strategy(strategy_str: str):
23
35
  """Map string prompt strategy to ExtractionStrategy."""
24
-
36
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
25
37
 
26
38
  mapping = {
27
39
  "chat_template": ExtractionStrategy.CHAT_LAST,
@@ -97,7 +109,7 @@ def train_classifier_on_pairs(
97
109
  prompt_construction_strategy = _map_prompt_strategy(prompt_strategy)
98
110
 
99
111
  # Collect activations for all pairs
100
- collector = ActivationCollector(model=model, store_device="cpu")
112
+ collector = ActivationCollector(model=model)
101
113
  target_layers = [str(target_layer)]
102
114
  layer_key = target_layers[0]
103
115
 
@@ -108,9 +120,8 @@ def train_classifier_on_pairs(
108
120
 
109
121
  updated_pair = collector.collect(
110
122
  pair, strategy=aggregation_strategy,
111
- return_full_sequence=return_full_sequence,
112
- normalize_layers=normalize_layers,
113
- prompt_strategy=prompt_construction_strategy
123
+ layers=[str(target_layer)],
124
+ normalize=normalize_layers
114
125
  )
115
126
  enriched_training_pairs.append(updated_pair)
116
127
 
@@ -133,8 +144,9 @@ def train_classifier_on_pairs(
133
144
  X_list.append(neg_act.cpu().numpy())
134
145
  y_list.append(0.0)
135
146
 
136
- X_train = np.array(X_list, dtype=np.float32)
137
- y_train = np.array(y_list, dtype=np.float32)
147
+ np_dtype = _torch_dtype_to_numpy(preferred_dtype())
148
+ X_train = np.array(X_list, dtype=np_dtype)
149
+ y_train = np.array(y_list, dtype=np_dtype)
138
150
 
139
151
  print(f" Training data: {X_train.shape[0]} samples, {X_train.shape[1]} features")
140
152
 
@@ -31,6 +31,7 @@ def execute_check_linearity(args):
31
31
  from wisent.core.models.wisent_model import WisentModel
32
32
  from wisent.core.contrastive_pairs.core.pair import ContrastivePair
33
33
  from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
34
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
34
35
 
35
36
  # Build ContrastivePair objects
36
37
  pairs = []
@@ -72,6 +73,10 @@ def execute_check_linearity(args):
72
73
  if args.layers:
73
74
  config.layers_to_test = [int(l) for l in args.layers.split(',')]
74
75
 
76
+ if args.extraction_strategy:
77
+ config.extraction_strategies = [ExtractionStrategy(args.extraction_strategy)]
78
+ print(f"Using extraction strategy: {args.extraction_strategy}")
79
+
75
80
  # Run check
76
81
  print("\nRunning linearity check...")
77
82
  result = check_linearity(pairs, model, config)
@@ -110,12 +115,39 @@ def execute_check_linearity(args):
110
115
 
111
116
  sorted_results = sorted(result.all_results, key=lambda x: x['linear_score'], reverse=True)
112
117
 
113
- print(f"{'Linear':<8} {'d':<8} {'Layer':<6} {'Prompt':<25} {'Aggregation':<15} {'Norm'}")
114
- print("-" * 80)
118
+ print(f"{'Linear':<8} {'d':<8} {'Layer':<6} {'Strategy':<20} {'Structure':<12} {'Norm'}")
119
+ print("-" * 70)
115
120
 
116
121
  for r in sorted_results[:20]:
117
122
  print(f"{r['linear_score']:<8.3f} {r['cohens_d']:<8.2f} {r['layer']:<6} "
118
- f"{r['prompt_strategy']:<25} {r['aggregation']:<15} {r['normalize']}")
123
+ f"{r['extraction_strategy']:<20} {r['best_structure']:<12} {r['normalize']}")
124
+
125
+ # Show best result for each structure type
126
+ if sorted_results and 'all_structure_scores' in sorted_results[0]:
127
+ print(f"\n{'='*60}")
128
+ print("BEST SCORE PER STRUCTURE TYPE")
129
+ print(f"{'='*60}")
130
+
131
+ # Collect best score for each structure across all configs
132
+ best_per_structure = {}
133
+ for r in result.all_results:
134
+ if 'all_structure_scores' not in r:
135
+ continue
136
+ for struct_name, data in r['all_structure_scores'].items():
137
+ score = data['score']
138
+ if struct_name not in best_per_structure or score > best_per_structure[struct_name]['score']:
139
+ best_per_structure[struct_name] = {
140
+ 'score': score,
141
+ 'confidence': data['confidence'],
142
+ 'layer': r['layer'],
143
+ 'strategy': r['extraction_strategy'],
144
+ }
145
+
146
+ print(f"{'Structure':<12} {'Score':<8} {'Conf':<8} {'Layer':<6} {'Strategy'}")
147
+ print("-" * 55)
148
+ sorted_structs = sorted(best_per_structure.items(), key=lambda x: x[1]['score'], reverse=True)
149
+ for name, data in sorted_structs:
150
+ print(f"{name:<12} {data['score']:<8.3f} {data['confidence']:<8.3f} {data['layer']:<6} {data['strategy']}")
119
151
 
120
152
  # Exit code based on verdict
121
153
  if result.verdict.value == "linear":