wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (391) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activation_cache.py +393 -0
  12. wisent/core/activations/activations.py +3 -3
  13. wisent/core/activations/activations_collector.py +12 -7
  14. wisent/core/activations/classifier_inference_strategy.py +12 -11
  15. wisent/core/activations/extraction_strategy.py +260 -84
  16. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  17. wisent/core/cli/__init__.py +2 -1
  18. wisent/core/cli/agent/train_classifier.py +16 -3
  19. wisent/core/cli/check_linearity.py +35 -3
  20. wisent/core/cli/cluster_benchmarks.py +4 -6
  21. wisent/core/cli/create_steering_vector.py +6 -4
  22. wisent/core/cli/diagnose_vectors.py +7 -4
  23. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  24. wisent/core/cli/generate_pairs_from_task.py +9 -56
  25. wisent/core/cli/generate_vector_from_task.py +11 -20
  26. wisent/core/cli/geometry_search.py +137 -0
  27. wisent/core/cli/get_activations.py +2 -2
  28. wisent/core/cli/method_optimizer.py +4 -3
  29. wisent/core/cli/modify_weights.py +3 -2
  30. wisent/core/cli/optimize_sample_size.py +1 -1
  31. wisent/core/cli/optimize_steering.py +14 -16
  32. wisent/core/cli/optimize_weights.py +2 -1
  33. wisent/core/cli/preview_pairs.py +203 -0
  34. wisent/core/cli/steering_method_trainer.py +3 -3
  35. wisent/core/cli/tasks.py +19 -76
  36. wisent/core/cli/train_unified_goodness.py +3 -3
  37. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  38. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  53. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  54. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  55. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  56. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  57. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  58. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  59. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  60. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  61. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  273. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  274. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  275. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  276. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  277. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  278. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  279. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  280. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  281. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  282. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  283. wisent/core/geometry_runner.py +995 -0
  284. wisent/core/geometry_search_space.py +237 -0
  285. wisent/core/hyperparameter_optimizer.py +1 -1
  286. wisent/core/main.py +3 -0
  287. wisent/core/models/core/atoms.py +5 -3
  288. wisent/core/models/wisent_model.py +1 -1
  289. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  290. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  291. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  292. wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
  293. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  294. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  295. wisent/core/parser_arguments/main_parser.py +8 -0
  296. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  297. wisent/core/steering.py +5 -3
  298. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  299. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  300. wisent/core/trainers/steering_trainer.py +2 -2
  301. wisent/core/utils/device.py +27 -27
  302. wisent/core/utils/layer_combinations.py +70 -0
  303. wisent/examples/__init__.py +1 -0
  304. wisent/examples/scripts/__init__.py +1 -0
  305. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  306. wisent/examples/scripts/discover_directions.py +469 -0
  307. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  308. wisent/examples/scripts/search_all_short_names.py +31 -0
  309. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  310. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  311. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  312. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  313. wisent/examples/scripts/test_one_benchmark.py +324 -0
  314. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  315. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  316. wisent/parameters/lm_eval/category_directions.json +137 -0
  317. wisent/parameters/lm_eval/repair_plan.json +282 -0
  318. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  319. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  320. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  321. wisent/tests/test_detector_accuracy.py +1 -1
  322. wisent/tests/visualize_geometry.py +1 -1
  323. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
  325. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  326. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  327. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  328. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  329. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  330. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  331. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  332. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  333. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  334. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  335. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  336. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  337. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  338. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  339. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  340. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  341. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  342. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  343. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  344. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  345. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  346. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  347. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  348. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  349. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  350. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  351. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  352. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  353. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  354. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  355. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  356. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  357. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  358. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  359. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  360. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  361. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  362. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  363. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  364. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  365. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  366. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  367. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  368. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  369. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  370. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  371. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  372. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  373. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  374. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  375. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  376. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  377. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  378. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  379. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  380. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  381. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  382. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  383. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  384. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  385. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  386. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  387. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  388. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  389. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  390. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  391. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -4,15 +4,19 @@ Unified extraction strategies for activation collection.
4
4
  These strategies combine prompt construction and token extraction into a single
5
5
  unified approach, based on empirical testing of what actually works.
6
6
 
7
- The strategies are:
7
+ CHAT STRATEGIES (require chat template - for instruct models):
8
8
  - chat_mean: Chat template prompt, mean of answer tokens
9
9
  - chat_first: Chat template prompt, first answer token
10
10
  - chat_last: Chat template prompt, last token
11
- - chat_gen_point: Chat template prompt, token before answer (generation decision point)
12
11
  - chat_max_norm: Chat template prompt, token with max norm in answer
13
12
  - chat_weighted: Chat template prompt, position-weighted mean (earlier tokens weighted more)
14
13
  - role_play: "Behave like person who answers Q with A" format, last token
15
14
  - mc_balanced: Multiple choice with balanced A/B assignment, last token
15
+
16
+ BASE MODEL STRATEGIES (no chat template - for base models like gemma-2b, gemma-9b):
17
+ - completion_last: Direct Q+A completion, last token
18
+ - completion_mean: Direct Q+A completion, mean of answer tokens
19
+ - mc_completion: Multiple choice without chat template, A/B token
16
20
  """
17
21
 
18
22
  from enum import Enum
@@ -35,10 +39,7 @@ class ExtractionStrategy(str, Enum):
35
39
  """Chat template prompt with Q+A, extract first answer token."""
36
40
 
37
41
  CHAT_LAST = "chat_last"
38
- """Chat template prompt with Q+A, extract last token."""
39
-
40
- CHAT_GEN_POINT = "chat_gen_point"
41
- """Chat template prompt with Q+A, extract token before answer starts (decision point)."""
42
+ """Chat template prompt with Q+A, extract EOT token (has seen full answer)."""
42
43
 
43
44
  CHAT_MAX_NORM = "chat_max_norm"
44
45
  """Chat template prompt with Q+A, extract token with max norm in answer region."""
@@ -47,22 +48,34 @@ class ExtractionStrategy(str, Enum):
47
48
  """Chat template prompt with Q+A, position-weighted mean (earlier tokens weighted more)."""
48
49
 
49
50
  ROLE_PLAY = "role_play"
50
- """'Behave like person who answers Q with A' format, extract last token."""
51
+ """'Behave like person who answers Q with A' format, extract EOT token."""
51
52
 
52
53
  MC_BALANCED = "mc_balanced"
53
- """Multiple choice format with balanced A/B assignment, extract last token."""
54
+ """Multiple choice format with balanced A/B assignment, extract the A/B choice token."""
55
+
56
+ # Base model strategies (no chat template required)
57
+ COMPLETION_LAST = "completion_last"
58
+ """Direct Q+A completion without chat template, extract last token. For base models."""
59
+
60
+ COMPLETION_MEAN = "completion_mean"
61
+ """Direct Q+A completion without chat template, extract mean of answer tokens. For base models."""
62
+
63
+ MC_COMPLETION = "mc_completion"
64
+ """Multiple choice without chat template, extract A/B token. For base models."""
54
65
 
55
66
  @property
56
67
  def description(self) -> str:
57
68
  descriptions = {
58
69
  ExtractionStrategy.CHAT_MEAN: "Chat template with mean of answer tokens",
59
70
  ExtractionStrategy.CHAT_FIRST: "Chat template with first answer token",
60
- ExtractionStrategy.CHAT_LAST: "Chat template with last token",
61
- ExtractionStrategy.CHAT_GEN_POINT: "Chat template with generation decision point",
71
+ ExtractionStrategy.CHAT_LAST: "Chat template with EOT token",
62
72
  ExtractionStrategy.CHAT_MAX_NORM: "Chat template with max-norm answer token",
63
73
  ExtractionStrategy.CHAT_WEIGHTED: "Chat template with position-weighted mean",
64
- ExtractionStrategy.ROLE_PLAY: "Role-playing format with last token",
65
- ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with last token",
74
+ ExtractionStrategy.ROLE_PLAY: "Role-playing format with EOT token",
75
+ ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with A/B token",
76
+ ExtractionStrategy.COMPLETION_LAST: "Direct completion with last token (base models)",
77
+ ExtractionStrategy.COMPLETION_MEAN: "Direct completion with mean of answer tokens (base models)",
78
+ ExtractionStrategy.MC_COMPLETION: "Multiple choice completion with A/B token (base models)",
66
79
  }
67
80
  return descriptions.get(self, "Unknown strategy")
68
81
 
@@ -75,6 +88,81 @@ class ExtractionStrategy(str, Enum):
75
88
  def list_all(cls) -> list[str]:
76
89
  """List all strategy names."""
77
90
  return [s.value for s in cls]
91
+
92
+ @classmethod
93
+ def for_tokenizer(cls, tokenizer, prefer_mc: bool = False) -> "ExtractionStrategy":
94
+ """
95
+ Select the appropriate strategy based on whether tokenizer supports chat template.
96
+
97
+ Args:
98
+ tokenizer: The tokenizer to check
99
+ prefer_mc: If True, prefer multiple choice strategies (mc_balanced/mc_completion)
100
+
101
+ Returns:
102
+ Appropriate strategy for the tokenizer type
103
+ """
104
+ has_chat = (hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
105
+ and hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None)
106
+
107
+ if has_chat:
108
+ return cls.MC_BALANCED if prefer_mc else cls.CHAT_LAST
109
+ else:
110
+ return cls.MC_COMPLETION if prefer_mc else cls.COMPLETION_LAST
111
+
112
+ @classmethod
113
+ def is_base_model_strategy(cls, strategy: "ExtractionStrategy") -> bool:
114
+ """Check if a strategy is designed for base models (no chat template)."""
115
+ return strategy in (cls.COMPLETION_LAST, cls.COMPLETION_MEAN, cls.MC_COMPLETION)
116
+
117
+ @classmethod
118
+ def get_equivalent_for_model_type(cls, strategy: "ExtractionStrategy", tokenizer) -> "ExtractionStrategy":
119
+ """
120
+ Get the equivalent strategy for the given tokenizer type.
121
+
122
+ If strategy requires chat template but tokenizer doesn't have it,
123
+ returns the base model equivalent. And vice versa.
124
+
125
+ Args:
126
+ strategy: The requested strategy
127
+ tokenizer: The tokenizer to check
128
+
129
+ Returns:
130
+ The appropriate strategy for the tokenizer
131
+ """
132
+ has_chat = (hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
133
+ and hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None)
134
+ is_base_strategy = cls.is_base_model_strategy(strategy)
135
+
136
+ if has_chat and is_base_strategy:
137
+ # Tokenizer has chat but strategy is for base model - upgrade to chat version
138
+ mapping = {
139
+ cls.COMPLETION_LAST: cls.CHAT_LAST,
140
+ cls.COMPLETION_MEAN: cls.CHAT_MEAN,
141
+ cls.MC_COMPLETION: cls.MC_BALANCED,
142
+ }
143
+ return mapping.get(strategy, strategy)
144
+
145
+ elif not has_chat and not is_base_strategy:
146
+ # Tokenizer is base model but strategy requires chat - downgrade to base version
147
+ mapping = {
148
+ cls.CHAT_LAST: cls.COMPLETION_LAST,
149
+ cls.CHAT_FIRST: cls.COMPLETION_LAST,
150
+ cls.CHAT_MEAN: cls.COMPLETION_MEAN,
151
+ cls.CHAT_MAX_NORM: cls.COMPLETION_LAST,
152
+ cls.CHAT_WEIGHTED: cls.COMPLETION_MEAN,
153
+ cls.ROLE_PLAY: cls.COMPLETION_LAST,
154
+ cls.MC_BALANCED: cls.MC_COMPLETION,
155
+ }
156
+ return mapping.get(strategy, cls.COMPLETION_LAST)
157
+
158
+ return strategy
159
+
160
+
161
+ def tokenizer_has_chat_template(tokenizer) -> bool:
162
+ """Check if tokenizer supports chat template."""
163
+ has_method = hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
164
+ has_template = hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
165
+ return has_method and has_template
78
166
 
79
167
 
80
168
  # Random tokens for role_play strategy (deterministic based on prompt hash)
@@ -88,6 +176,7 @@ def build_extraction_texts(
88
176
  tokenizer,
89
177
  other_response: Optional[str] = None,
90
178
  is_positive: bool = True,
179
+ auto_convert_strategy: bool = True,
91
180
  ) -> Tuple[str, str, Optional[str]]:
92
181
  """
93
182
  Build the full text for activation extraction based on strategy.
@@ -97,8 +186,9 @@ def build_extraction_texts(
97
186
  prompt: The user prompt/question
98
187
  response: The response to extract activations for
99
188
  tokenizer: The tokenizer (needs apply_chat_template for chat strategies)
100
- other_response: For mc_balanced, the other response option
101
- is_positive: For mc_balanced, whether 'response' is the positive option
189
+ other_response: For mc_balanced/mc_completion, the other response option
190
+ is_positive: For mc_balanced/mc_completion, whether 'response' is the positive option
191
+ auto_convert_strategy: If True, automatically convert strategy to match tokenizer type
102
192
 
103
193
  Returns:
104
194
  Tuple of (full_text, answer_text, prompt_only_text)
@@ -106,31 +196,40 @@ def build_extraction_texts(
106
196
  - answer_text: The answer portion (for strategies that need it)
107
197
  - prompt_only_text: Prompt without answer (for boundary detection)
108
198
  """
199
+ # Auto-convert strategy if needed
200
+ if auto_convert_strategy:
201
+ original_strategy = strategy
202
+ strategy = ExtractionStrategy.get_equivalent_for_model_type(strategy, tokenizer)
203
+ if strategy != original_strategy:
204
+ import warnings
205
+ warnings.warn(
206
+ f"Strategy {original_strategy.value} not compatible with tokenizer, "
207
+ f"using {strategy.value} instead.",
208
+ UserWarning
209
+ )
109
210
 
110
211
  if strategy in (ExtractionStrategy.CHAT_MEAN, ExtractionStrategy.CHAT_FIRST,
111
- ExtractionStrategy.CHAT_LAST, ExtractionStrategy.CHAT_GEN_POINT,
112
- ExtractionStrategy.CHAT_MAX_NORM, ExtractionStrategy.CHAT_WEIGHTED):
212
+ ExtractionStrategy.CHAT_LAST, ExtractionStrategy.CHAT_MAX_NORM,
213
+ ExtractionStrategy.CHAT_WEIGHTED):
113
214
  # All chat_* strategies use the same prompt construction
114
- if hasattr(tokenizer, "apply_chat_template"):
115
- try:
116
- prompt_only = tokenizer.apply_chat_template(
117
- [{"role": "user", "content": prompt}],
118
- tokenize=False,
119
- add_generation_prompt=True,
120
- )
121
- full_text = tokenizer.apply_chat_template(
122
- [{"role": "user", "content": prompt},
123
- {"role": "assistant", "content": response}],
124
- tokenize=False,
125
- add_generation_prompt=False,
126
- )
127
- except (ValueError, KeyError):
128
- # Fallback for models without chat template
129
- prompt_only = prompt
130
- full_text = f"{prompt} {response}"
131
- else:
132
- prompt_only = prompt
133
- full_text = f"{prompt} {response}"
215
+ if not hasattr(tokenizer, "apply_chat_template"):
216
+ raise ValueError(
217
+ f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
218
+ f"Base models without chat templates are not supported for chat_* strategies. "
219
+ f"Use completion_last, completion_mean, or mc_completion instead."
220
+ )
221
+
222
+ prompt_only = tokenizer.apply_chat_template(
223
+ [{"role": "user", "content": prompt}],
224
+ tokenize=False,
225
+ add_generation_prompt=True,
226
+ )
227
+ full_text = tokenizer.apply_chat_template(
228
+ [{"role": "user", "content": prompt},
229
+ {"role": "assistant", "content": response}],
230
+ tokenize=False,
231
+ add_generation_prompt=False,
232
+ )
134
233
 
135
234
  return full_text, response, prompt_only
136
235
 
@@ -139,25 +238,23 @@ def build_extraction_texts(
139
238
  random_token = ROLE_PLAY_TOKENS[hash(prompt) % len(ROLE_PLAY_TOKENS)]
140
239
  instruction = f"Behave like a person that would answer {prompt} with {response}"
141
240
 
142
- if hasattr(tokenizer, "apply_chat_template"):
143
- try:
144
- prompt_only = tokenizer.apply_chat_template(
145
- [{"role": "user", "content": instruction}],
146
- tokenize=False,
147
- add_generation_prompt=True,
148
- )
149
- full_text = tokenizer.apply_chat_template(
150
- [{"role": "user", "content": instruction},
151
- {"role": "assistant", "content": random_token}],
152
- tokenize=False,
153
- add_generation_prompt=False,
154
- )
155
- except (ValueError, KeyError):
156
- prompt_only = instruction
157
- full_text = f"{instruction} {random_token}"
158
- else:
159
- prompt_only = instruction
160
- full_text = f"{instruction} {random_token}"
241
+ if not hasattr(tokenizer, "apply_chat_template"):
242
+ raise ValueError(
243
+ f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
244
+ f"Use completion_last or mc_completion for base models."
245
+ )
246
+
247
+ prompt_only = tokenizer.apply_chat_template(
248
+ [{"role": "user", "content": instruction}],
249
+ tokenize=False,
250
+ add_generation_prompt=True,
251
+ )
252
+ full_text = tokenizer.apply_chat_template(
253
+ [{"role": "user", "content": instruction},
254
+ {"role": "assistant", "content": random_token}],
255
+ tokenize=False,
256
+ add_generation_prompt=False,
257
+ )
161
258
 
162
259
  return full_text, random_token, prompt_only
163
260
 
@@ -188,28 +285,66 @@ def build_extraction_texts(
188
285
  option_b = response[:200] # negative
189
286
  answer = "B"
190
287
 
191
- mc_prompt = f"Which is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
288
+ mc_prompt = f"Question: {prompt}\n\nWhich is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
289
+
290
+ if not hasattr(tokenizer, "apply_chat_template"):
291
+ raise ValueError(
292
+ f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
293
+ f"Use mc_completion for base models."
294
+ )
295
+
296
+ prompt_only = tokenizer.apply_chat_template(
297
+ [{"role": "user", "content": mc_prompt}],
298
+ tokenize=False,
299
+ add_generation_prompt=True,
300
+ )
301
+ full_text = tokenizer.apply_chat_template(
302
+ [{"role": "user", "content": mc_prompt},
303
+ {"role": "assistant", "content": answer}],
304
+ tokenize=False,
305
+ add_generation_prompt=False,
306
+ )
307
+
308
+ return full_text, answer, prompt_only
309
+
310
+ elif strategy in (ExtractionStrategy.COMPLETION_LAST, ExtractionStrategy.COMPLETION_MEAN):
311
+ # Base model strategies - direct Q+A without chat template
312
+ # Format: "Q: {prompt}\nA: {response}"
313
+ prompt_only = f"Q: {prompt}\nA:"
314
+ full_text = f"Q: {prompt}\nA: {response}"
315
+ return full_text, response, prompt_only
316
+
317
+ elif strategy == ExtractionStrategy.MC_COMPLETION:
318
+ # Multiple choice for base models - no chat template
319
+ if other_response is None:
320
+ raise ValueError("MC_COMPLETION strategy requires other_response")
321
+
322
+ # Deterministic "random" based on prompt - same for both pos and neg of a pair
323
+ pos_goes_in_b = hash(prompt) % 2 == 0
192
324
 
193
- if hasattr(tokenizer, "apply_chat_template"):
194
- try:
195
- prompt_only = tokenizer.apply_chat_template(
196
- [{"role": "user", "content": mc_prompt}],
197
- tokenize=False,
198
- add_generation_prompt=True,
199
- )
200
- full_text = tokenizer.apply_chat_template(
201
- [{"role": "user", "content": mc_prompt},
202
- {"role": "assistant", "content": answer}],
203
- tokenize=False,
204
- add_generation_prompt=False,
205
- )
206
- except (ValueError, KeyError):
207
- prompt_only = mc_prompt
208
- full_text = f"{mc_prompt} {answer}"
325
+ if is_positive:
326
+ if pos_goes_in_b:
327
+ option_a = other_response[:200]
328
+ option_b = response[:200]
329
+ answer = "B"
330
+ else:
331
+ option_a = response[:200]
332
+ option_b = other_response[:200]
333
+ answer = "A"
209
334
  else:
210
- prompt_only = mc_prompt
211
- full_text = f"{mc_prompt} {answer}"
335
+ if pos_goes_in_b:
336
+ option_a = response[:200]
337
+ option_b = other_response[:200]
338
+ answer = "A"
339
+ else:
340
+ option_a = other_response[:200]
341
+ option_b = response[:200]
342
+ answer = "B"
212
343
 
344
+ mc_prompt = f"Question: {prompt}\n\nWhich is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
345
+
346
+ prompt_only = mc_prompt
347
+ full_text = f"{mc_prompt} {answer}"
213
348
  return full_text, answer, prompt_only
214
349
 
215
350
  else:
@@ -243,6 +378,7 @@ def extract_activation(
243
378
  num_answer_tokens = len(answer_tokens)
244
379
 
245
380
  if strategy == ExtractionStrategy.CHAT_LAST:
381
+ # EOT token - has seen the entire answer, best performance
246
382
  return hidden_states[-1]
247
383
 
248
384
  elif strategy == ExtractionStrategy.CHAT_FIRST:
@@ -257,11 +393,6 @@ def extract_activation(
257
393
  return answer_hidden.mean(dim=0)
258
394
  return hidden_states[-1]
259
395
 
260
- elif strategy == ExtractionStrategy.CHAT_GEN_POINT:
261
- # Last token before answer starts (decision point)
262
- gen_point_idx = max(0, seq_len - num_answer_tokens - 2)
263
- return hidden_states[gen_point_idx]
264
-
265
396
  elif strategy == ExtractionStrategy.CHAT_MAX_NORM:
266
397
  # Token with max norm in answer region
267
398
  if num_answer_tokens > 0 and seq_len > num_answer_tokens:
@@ -275,18 +406,36 @@ def extract_activation(
275
406
  # Position-weighted mean (earlier tokens weighted more)
276
407
  if num_answer_tokens > 0 and seq_len > num_answer_tokens:
277
408
  answer_hidden = hidden_states[-num_answer_tokens-1:-1]
278
- weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=torch.float32, device=answer_hidden.device) * 0.5)
409
+ weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=answer_hidden.dtype, device=answer_hidden.device) * 0.5)
279
410
  weights = weights / weights.sum()
280
411
  return (answer_hidden * weights.unsqueeze(1)).sum(dim=0)
281
412
  return hidden_states[-1]
282
413
 
283
- elif strategy in (ExtractionStrategy.ROLE_PLAY, ExtractionStrategy.MC_BALANCED):
284
- # Both use last token
414
+ elif strategy == ExtractionStrategy.ROLE_PLAY:
415
+ # EOT token - slightly better than answer word (65% vs 64%)
285
416
  return hidden_states[-1]
286
417
 
287
- else:
288
- # Default fallback
418
+ elif strategy == ExtractionStrategy.MC_BALANCED:
419
+ # Answer token (A/B) - better than EOT (64% vs 56%)
420
+ return hidden_states[-2]
421
+
422
+ elif strategy == ExtractionStrategy.COMPLETION_LAST:
423
+ # Last token for base model completion
289
424
  return hidden_states[-1]
425
+
426
+ elif strategy == ExtractionStrategy.COMPLETION_MEAN:
427
+ # Mean of answer tokens for base model completion
428
+ if num_answer_tokens > 0 and seq_len > num_answer_tokens:
429
+ answer_hidden = hidden_states[-num_answer_tokens:]
430
+ return answer_hidden.mean(dim=0)
431
+ return hidden_states[-1]
432
+
433
+ elif strategy == ExtractionStrategy.MC_COMPLETION:
434
+ # A/B token for base model MC (last token is the answer)
435
+ return hidden_states[-1]
436
+
437
+ else:
438
+ raise ValueError(f"Unknown extraction strategy: {strategy}")
290
439
 
291
440
 
292
441
  def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
@@ -306,3 +455,30 @@ def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
306
455
  choices=ExtractionStrategy.list_all(),
307
456
  help=f"Extraction strategy for activations. Options: {', '.join(ExtractionStrategy.list_all())}. Default: {ExtractionStrategy.default().value}",
308
457
  )
458
+
459
+
460
+ def get_strategy_for_model(tokenizer, prefer_mc: bool = False) -> ExtractionStrategy:
461
+ """
462
+ Get the best extraction strategy for a given tokenizer.
463
+
464
+ Automatically detects if tokenizer has chat template and returns
465
+ the appropriate strategy.
466
+
467
+ Args:
468
+ tokenizer: The tokenizer to check
469
+ prefer_mc: If True, prefer multiple choice strategies
470
+
471
+ Returns:
472
+ ExtractionStrategy appropriate for the tokenizer
473
+
474
+ Example:
475
+ >>> from transformers import AutoTokenizer
476
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
477
+ >>> strategy = get_strategy_for_model(tokenizer)
478
+ >>> print(strategy) # completion_last (base model)
479
+
480
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
481
+ >>> strategy = get_strategy_for_model(tokenizer)
482
+ >>> print(strategy) # chat_last (instruct model)
483
+ """
484
+ return ExtractionStrategy.for_tokenizer(tokenizer, prefer_mc=prefer_mc)
@@ -14,6 +14,7 @@ import numpy as np
14
14
 
15
15
  from torch.nn.modules.loss import _Loss
16
16
  from wisent.core.errors import DuplicateNameError, InvalidRangeError, UnknownTypeError
17
+ from wisent.core.utils.device import preferred_dtype
17
18
 
18
19
  __all__ = [
19
20
  "ClassifierTrainConfig",
@@ -164,13 +165,13 @@ class BaseClassifier(ABC):
164
165
  self,
165
166
  threshold: float = 0.5,
166
167
  device: str | None = None,
167
- dtype: torch.dtype = torch.float32,
168
+ dtype: torch.dtype | None = None,
168
169
  ) -> None:
169
170
  if not 0.0 <= threshold <= 1.0:
170
171
  raise InvalidRangeError(param_name="threshold", actual=threshold, min_val=0.0, max_val=1.0)
171
172
  self.threshold = threshold
172
173
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
173
- self.dtype = torch.float32 if self.device == "mps" else dtype
174
+ self.dtype = dtype if dtype is not None else preferred_dtype(self.device)
174
175
  self.model = None
175
176
 
176
177
  @abstractmethod
@@ -22,5 +22,6 @@ from .inference_config_cli import execute_inference_config
22
22
  from .optimization_cache import execute_optimization_cache
23
23
  from .optimize_weights import execute_optimize_weights
24
24
  from .optimize import execute_optimize
25
+ from .geometry_search import execute_geometry_search
25
26
 
26
- __all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_diagnose_pairs', 'execute_get_activations', 'execute_diagnose_vectors', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_optimize_sample_size', 'execute_generate_responses', 'execute_evaluate_responses', 'execute_multi_steer', 'execute_agent', 'execute_modify_weights', 'execute_evaluate_refusal', 'execute_inference_config', 'execute_optimization_cache', 'execute_optimize_weights', 'execute_optimize']
27
+ __all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_diagnose_pairs', 'execute_get_activations', 'execute_diagnose_vectors', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_optimize_sample_size', 'execute_generate_responses', 'execute_evaluate_responses', 'execute_multi_steer', 'execute_agent', 'execute_modify_weights', 'execute_evaluate_refusal', 'execute_inference_config', 'execute_optimization_cache', 'execute_optimize_weights', 'execute_optimize', 'execute_geometry_search']
@@ -1,8 +1,20 @@
1
1
  """Train classifier on contrastive pairs for agent."""
2
2
 
3
3
  import numpy as np
4
+ import torch
4
5
  from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainReport
5
6
  from wisent.core.errors import UnknownTypeError
7
+ from wisent.core.utils.device import preferred_dtype
8
+
9
+
10
+ def _torch_dtype_to_numpy(torch_dtype: torch.dtype):
11
+ """Convert torch dtype to numpy dtype."""
12
+ mapping = {
13
+ torch.float32: np.float32,
14
+ torch.float16: np.float16,
15
+ torch.bfloat16: np.float32, # numpy doesn't support bfloat16, use float32
16
+ }
17
+ return mapping.get(torch_dtype, np.float32)
6
18
 
7
19
 
8
20
  def _map_token_aggregation(aggregation_str: str):
@@ -97,7 +109,7 @@ def train_classifier_on_pairs(
97
109
  prompt_construction_strategy = _map_prompt_strategy(prompt_strategy)
98
110
 
99
111
  # Collect activations for all pairs
100
- collector = ActivationCollector(model=model, store_device="cpu")
112
+ collector = ActivationCollector(model=model)
101
113
  target_layers = [str(target_layer)]
102
114
  layer_key = target_layers[0]
103
115
 
@@ -133,8 +145,9 @@ def train_classifier_on_pairs(
133
145
  X_list.append(neg_act.cpu().numpy())
134
146
  y_list.append(0.0)
135
147
 
136
- X_train = np.array(X_list, dtype=np.float32)
137
- y_train = np.array(y_list, dtype=np.float32)
148
+ np_dtype = _torch_dtype_to_numpy(preferred_dtype())
149
+ X_train = np.array(X_list, dtype=np_dtype)
150
+ y_train = np.array(y_list, dtype=np_dtype)
138
151
 
139
152
  print(f" Training data: {X_train.shape[0]} samples, {X_train.shape[1]} features")
140
153
 
@@ -31,6 +31,7 @@ def execute_check_linearity(args):
31
31
  from wisent.core.models.wisent_model import WisentModel
32
32
  from wisent.core.contrastive_pairs.core.pair import ContrastivePair
33
33
  from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
34
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
34
35
 
35
36
  # Build ContrastivePair objects
36
37
  pairs = []
@@ -72,6 +73,10 @@ def execute_check_linearity(args):
72
73
  if args.layers:
73
74
  config.layers_to_test = [int(l) for l in args.layers.split(',')]
74
75
 
76
+ if args.extraction_strategy:
77
+ config.extraction_strategies = [ExtractionStrategy(args.extraction_strategy)]
78
+ print(f"Using extraction strategy: {args.extraction_strategy}")
79
+
75
80
  # Run check
76
81
  print("\nRunning linearity check...")
77
82
  result = check_linearity(pairs, model, config)
@@ -110,12 +115,39 @@ def execute_check_linearity(args):
110
115
 
111
116
  sorted_results = sorted(result.all_results, key=lambda x: x['linear_score'], reverse=True)
112
117
 
113
- print(f"{'Linear':<8} {'d':<8} {'Layer':<6} {'Prompt':<25} {'Aggregation':<15} {'Norm'}")
114
- print("-" * 80)
118
+ print(f"{'Linear':<8} {'d':<8} {'Layer':<6} {'Strategy':<20} {'Structure':<12} {'Norm'}")
119
+ print("-" * 70)
115
120
 
116
121
  for r in sorted_results[:20]:
117
122
  print(f"{r['linear_score']:<8.3f} {r['cohens_d']:<8.2f} {r['layer']:<6} "
118
- f"{r['prompt_strategy']:<25} {r['aggregation']:<15} {r['normalize']}")
123
+ f"{r['extraction_strategy']:<20} {r['best_structure']:<12} {r['normalize']}")
124
+
125
+ # Show best result for each structure type
126
+ if sorted_results and 'all_structure_scores' in sorted_results[0]:
127
+ print(f"\n{'='*60}")
128
+ print("BEST SCORE PER STRUCTURE TYPE")
129
+ print(f"{'='*60}")
130
+
131
+ # Collect best score for each structure across all configs
132
+ best_per_structure = {}
133
+ for r in result.all_results:
134
+ if 'all_structure_scores' not in r:
135
+ continue
136
+ for struct_name, data in r['all_structure_scores'].items():
137
+ score = data['score']
138
+ if struct_name not in best_per_structure or score > best_per_structure[struct_name]['score']:
139
+ best_per_structure[struct_name] = {
140
+ 'score': score,
141
+ 'confidence': data['confidence'],
142
+ 'layer': r['layer'],
143
+ 'strategy': r['extraction_strategy'],
144
+ }
145
+
146
+ print(f"{'Structure':<12} {'Score':<8} {'Conf':<8} {'Layer':<6} {'Strategy'}")
147
+ print("-" * 55)
148
+ sorted_structs = sorted(best_per_structure.items(), key=lambda x: x[1]['score'], reverse=True)
149
+ for name, data in sorted_structs:
150
+ print(f"{name:<12} {data['score']:<8.3f} {data['confidence']:<8.3f} {data['layer']:<6} {data['strategy']}")
119
151
 
120
152
  # Exit code based on verdict
121
153
  if result.verdict.value == "linear":
@@ -33,7 +33,6 @@ STRATEGIES = [
33
33
  "chat_mean",
34
34
  "chat_first",
35
35
  "chat_last",
36
- "chat_gen_point",
37
36
  "chat_max_norm",
38
37
  "chat_weighted",
39
38
  "role_play",
@@ -134,9 +133,9 @@ def get_weighted_mean_answer_act(model, tokenizer, text: str, answer: str, layer
134
133
  hidden = outputs.hidden_states[layer][0]
135
134
  if num_answer_tokens > 0 and num_answer_tokens < hidden.shape[0]:
136
135
  answer_hidden = hidden[-num_answer_tokens-1:-1, :]
137
- weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=torch.float32) * 0.5)
136
+ weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=answer_hidden.dtype, device=answer_hidden.device) * 0.5)
138
137
  weights = weights / weights.sum()
139
- weighted_mean = (answer_hidden * weights.unsqueeze(1).to(answer_hidden.device)).sum(dim=0)
138
+ weighted_mean = (answer_hidden * weights.unsqueeze(1)).sum(dim=0)
140
139
  return weighted_mean.cpu().float()
141
140
  return hidden[-1].cpu().float()
142
141
 
@@ -156,8 +155,6 @@ def get_activation(model, tokenizer, prompt: str, response: str, layer: int, dev
156
155
  return get_first_answer_token_act(model, tokenizer, text, response, layer, device)
157
156
  elif strategy == "chat_last":
158
157
  return get_last_token_act(model, tokenizer, text, layer, device)
159
- elif strategy == "chat_gen_point":
160
- return get_generation_point_act(model, tokenizer, text, response, layer, device)
161
158
  elif strategy == "chat_max_norm":
162
159
  return get_max_norm_answer_act(model, tokenizer, text, response, layer, device)
163
160
  elif strategy == "chat_weighted":
@@ -348,7 +345,8 @@ def execute_cluster_benchmarks(args):
348
345
 
349
346
  logger.info(f"Loading {model}...")
350
347
  tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
351
- dtype = torch.bfloat16 if device == 'cuda' else torch.float16
348
+ from wisent.core.utils.device import device_optimized_dtype
349
+ dtype = device_optimized_dtype(device)
352
350
  llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=dtype, device_map=device, trust_remote_code=True)
353
351
 
354
352
  layers = get_layers_to_test(llm)