wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (391) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activation_cache.py +393 -0
  12. wisent/core/activations/activations.py +3 -3
  13. wisent/core/activations/activations_collector.py +12 -7
  14. wisent/core/activations/classifier_inference_strategy.py +12 -11
  15. wisent/core/activations/extraction_strategy.py +260 -84
  16. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  17. wisent/core/cli/__init__.py +2 -1
  18. wisent/core/cli/agent/train_classifier.py +16 -3
  19. wisent/core/cli/check_linearity.py +35 -3
  20. wisent/core/cli/cluster_benchmarks.py +4 -6
  21. wisent/core/cli/create_steering_vector.py +6 -4
  22. wisent/core/cli/diagnose_vectors.py +7 -4
  23. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  24. wisent/core/cli/generate_pairs_from_task.py +9 -56
  25. wisent/core/cli/generate_vector_from_task.py +11 -20
  26. wisent/core/cli/geometry_search.py +137 -0
  27. wisent/core/cli/get_activations.py +2 -2
  28. wisent/core/cli/method_optimizer.py +4 -3
  29. wisent/core/cli/modify_weights.py +3 -2
  30. wisent/core/cli/optimize_sample_size.py +1 -1
  31. wisent/core/cli/optimize_steering.py +14 -16
  32. wisent/core/cli/optimize_weights.py +2 -1
  33. wisent/core/cli/preview_pairs.py +203 -0
  34. wisent/core/cli/steering_method_trainer.py +3 -3
  35. wisent/core/cli/tasks.py +19 -76
  36. wisent/core/cli/train_unified_goodness.py +3 -3
  37. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  38. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  53. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  54. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  55. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  56. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  57. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  58. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  59. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  60. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  61. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  273. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  274. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  275. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  276. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  277. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  278. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  279. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  280. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  281. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  282. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  283. wisent/core/geometry_runner.py +995 -0
  284. wisent/core/geometry_search_space.py +237 -0
  285. wisent/core/hyperparameter_optimizer.py +1 -1
  286. wisent/core/main.py +3 -0
  287. wisent/core/models/core/atoms.py +5 -3
  288. wisent/core/models/wisent_model.py +1 -1
  289. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  290. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  291. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  292. wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
  293. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  294. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  295. wisent/core/parser_arguments/main_parser.py +8 -0
  296. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  297. wisent/core/steering.py +5 -3
  298. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  299. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  300. wisent/core/trainers/steering_trainer.py +2 -2
  301. wisent/core/utils/device.py +27 -27
  302. wisent/core/utils/layer_combinations.py +70 -0
  303. wisent/examples/__init__.py +1 -0
  304. wisent/examples/scripts/__init__.py +1 -0
  305. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  306. wisent/examples/scripts/discover_directions.py +469 -0
  307. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  308. wisent/examples/scripts/search_all_short_names.py +31 -0
  309. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  310. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  311. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  312. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  313. wisent/examples/scripts/test_one_benchmark.py +324 -0
  314. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  315. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  316. wisent/parameters/lm_eval/category_directions.json +137 -0
  317. wisent/parameters/lm_eval/repair_plan.json +282 -0
  318. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  319. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  320. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  321. wisent/tests/test_detector_accuracy.py +1 -1
  322. wisent/tests/visualize_geometry.py +1 -1
  323. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
  325. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  326. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  327. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  328. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  329. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  330. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  331. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  332. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  333. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  334. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  335. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  336. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  337. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  338. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  339. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  340. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  341. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  342. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  343. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  344. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  345. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  346. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  347. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  348. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  349. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  350. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  351. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  352. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  353. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  354. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  355. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  356. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  357. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  358. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  359. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  360. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  361. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  362. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  363. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  364. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  365. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  366. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  367. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  368. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  369. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  370. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  371. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  372. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  373. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  374. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  375. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  376. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  377. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  378. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  379. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  380. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  381. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  382. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  383. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  384. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  385. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  386. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  387. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  388. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  389. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  390. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  391. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,393 @@
1
+ """
2
+ Activation cache for geometry search.
3
+
4
+ Caches activations for ALL layers once per (benchmark, strategy) pair.
5
+ Layer combinations are then tested from cache without re-extraction.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import hashlib
11
+ import json
12
+ import os
13
+ from dataclasses import dataclass, field
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional, Tuple, Any
16
+ import torch
17
+
18
+ from wisent.core.activations.core.atoms import LayerActivations, RawActivationMap
19
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
20
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
21
+ from wisent.core.utils.device import resolve_default_device
22
+
23
+
24
+ @dataclass
25
+ class CachedActivations:
26
+ """
27
+ Cached activations for a single (benchmark, strategy) pair.
28
+
29
+ Contains activations for ALL layers for all pairs.
30
+ Layer combinations can be extracted without re-running the model.
31
+ """
32
+ benchmark: str
33
+ strategy: ExtractionStrategy
34
+ model_name: str
35
+ num_layers: int
36
+
37
+ # List of (positive_activations, negative_activations) per pair
38
+ # Each activation is a dict: layer_name -> tensor [hidden_size]
39
+ pair_activations: List[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]] = field(default_factory=list)
40
+
41
+ # Metadata
42
+ num_pairs: int = 0
43
+ hidden_size: int = 0
44
+
45
+ def add_pair(self, positive: LayerActivations, negative: LayerActivations) -> None:
46
+ """Add activations for a contrastive pair."""
47
+ pos_dict = {k: v.clone() for k, v in positive.items() if v is not None}
48
+ neg_dict = {k: v.clone() for k, v in negative.items() if v is not None}
49
+ self.pair_activations.append((pos_dict, neg_dict))
50
+ self.num_pairs = len(self.pair_activations)
51
+
52
+ # Infer hidden size from first tensor
53
+ if self.hidden_size == 0 and pos_dict:
54
+ first_tensor = next(iter(pos_dict.values()))
55
+ self.hidden_size = first_tensor.shape[-1]
56
+
57
+ def get_layer_subset(self, layers: List[int]) -> "CachedActivations":
58
+ """
59
+ Get a new CachedActivations with only the specified layers.
60
+
61
+ Args:
62
+ layers: List of layer indices (0-based)
63
+
64
+ Returns:
65
+ New CachedActivations with only the specified layers
66
+ """
67
+ layer_names = [str(l) for l in layers]
68
+
69
+ new_pairs = []
70
+ for pos_dict, neg_dict in self.pair_activations:
71
+ new_pos = {k: v for k, v in pos_dict.items() if k in layer_names}
72
+ new_neg = {k: v for k, v in neg_dict.items() if k in layer_names}
73
+ new_pairs.append((new_pos, new_neg))
74
+
75
+ result = CachedActivations(
76
+ benchmark=self.benchmark,
77
+ strategy=self.strategy,
78
+ model_name=self.model_name,
79
+ num_layers=len(layers),
80
+ hidden_size=self.hidden_size,
81
+ )
82
+ result.pair_activations = new_pairs
83
+ result.num_pairs = len(new_pairs)
84
+ return result
85
+
86
+ def get_available_layers(self) -> List[str]:
87
+ """Get list of available layer names."""
88
+ if not self.pair_activations:
89
+ return []
90
+ return list(self.pair_activations[0][0].keys())
91
+
92
+ def get_positive_activations(self, layer: int | str) -> torch.Tensor:
93
+ """
94
+ Get stacked positive activations for a single layer.
95
+
96
+ Args:
97
+ layer: Layer index (int) or layer name (str)
98
+
99
+ Returns:
100
+ Tensor of shape [num_pairs, hidden_size]
101
+ """
102
+ layer_name = str(layer)
103
+ tensors = [pos[layer_name] for pos, _ in self.pair_activations if layer_name in pos]
104
+ if not tensors:
105
+ raise KeyError(f"Layer {layer_name} not found. Available: {self.get_available_layers()}")
106
+ return torch.stack(tensors, dim=0)
107
+
108
+ def get_negative_activations(self, layer: int | str) -> torch.Tensor:
109
+ """
110
+ Get stacked negative activations for a single layer.
111
+
112
+ Args:
113
+ layer: Layer index (int) or layer name (str)
114
+
115
+ Returns:
116
+ Tensor of shape [num_pairs, hidden_size]
117
+ """
118
+ layer_name = str(layer)
119
+ tensors = [neg[layer_name] for _, neg in self.pair_activations if layer_name in neg]
120
+ if not tensors:
121
+ raise KeyError(f"Layer {layer_name} not found. Available: {self.get_available_layers()}")
122
+ return torch.stack(tensors, dim=0)
123
+
124
+ def get_diff_activations(self, layer: int | str) -> torch.Tensor:
125
+ """
126
+ Get positive - negative activation differences for a layer.
127
+
128
+ Args:
129
+ layer: Layer index (int) or layer name (str)
130
+
131
+ Returns:
132
+ Tensor of shape [num_pairs, hidden_size]
133
+ """
134
+ return self.get_positive_activations(layer) - self.get_negative_activations(layer)
135
+
136
+ def get_all_layers_diff(self) -> Dict[str, torch.Tensor]:
137
+ """
138
+ Get activation differences for all layers.
139
+
140
+ Returns:
141
+ Dict mapping layer_name -> tensor [num_pairs, hidden_size]
142
+ """
143
+ result = {}
144
+ if not self.pair_activations:
145
+ return result
146
+
147
+ # Get layer names from first pair
148
+ layer_names = list(self.pair_activations[0][0].keys())
149
+ for layer_name in layer_names:
150
+ pos_tensors = []
151
+ neg_tensors = []
152
+ for pos, neg in self.pair_activations:
153
+ if layer_name in pos and layer_name in neg:
154
+ pos_tensors.append(pos[layer_name])
155
+ neg_tensors.append(neg[layer_name])
156
+ if pos_tensors:
157
+ result[layer_name] = torch.stack(pos_tensors) - torch.stack(neg_tensors)
158
+ return result
159
+
160
+ def to_device(self, device: str) -> "CachedActivations":
161
+ """Move all tensors to a device."""
162
+ new_pairs = []
163
+ for pos, neg in self.pair_activations:
164
+ new_pos = {k: v.to(device) for k, v in pos.items()}
165
+ new_neg = {k: v.to(device) for k, v in neg.items()}
166
+ new_pairs.append((new_pos, new_neg))
167
+
168
+ result = CachedActivations(
169
+ benchmark=self.benchmark,
170
+ strategy=self.strategy,
171
+ model_name=self.model_name,
172
+ num_layers=self.num_layers,
173
+ hidden_size=self.hidden_size,
174
+ )
175
+ result.pair_activations = new_pairs
176
+ result.num_pairs = self.num_pairs
177
+ return result
178
+
179
+
180
+ class ActivationCache:
181
+ """
182
+ Disk-backed cache for activations.
183
+
184
+ Saves/loads activations per (model, benchmark, strategy) tuple.
185
+ """
186
+
187
+ def __init__(self, cache_dir: str):
188
+ self.cache_dir = Path(cache_dir)
189
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
190
+ self._memory_cache: Dict[str, CachedActivations] = {}
191
+
192
+ def _get_cache_key(self, model_name: str, benchmark: str, strategy: ExtractionStrategy) -> str:
193
+ """Generate a unique cache key."""
194
+ key_str = f"{model_name}_{benchmark}_{strategy.value}"
195
+ return hashlib.md5(key_str.encode()).hexdigest()[:16]
196
+
197
+ def _get_cache_path(self, cache_key: str) -> Path:
198
+ """Get path for a cache file."""
199
+ return self.cache_dir / f"{cache_key}.pt"
200
+
201
+ def _get_metadata_path(self, cache_key: str) -> Path:
202
+ """Get path for cache metadata."""
203
+ return self.cache_dir / f"{cache_key}.json"
204
+
205
+ def has(self, model_name: str, benchmark: str, strategy: ExtractionStrategy) -> bool:
206
+ """Check if activations are cached."""
207
+ key = self._get_cache_key(model_name, benchmark, strategy)
208
+ if key in self._memory_cache:
209
+ return True
210
+ return self._get_cache_path(key).exists()
211
+
212
+ def get(
213
+ self,
214
+ model_name: str,
215
+ benchmark: str,
216
+ strategy: ExtractionStrategy,
217
+ load_to_memory: bool = True,
218
+ ) -> Optional[CachedActivations]:
219
+ """
220
+ Get cached activations if they exist.
221
+
222
+ Args:
223
+ model_name: Model identifier
224
+ benchmark: Benchmark name
225
+ strategy: Extraction strategy
226
+ load_to_memory: If True, keep in memory cache after loading
227
+
228
+ Returns:
229
+ CachedActivations or None if not cached
230
+ """
231
+ key = self._get_cache_key(model_name, benchmark, strategy)
232
+
233
+ # Check memory cache first
234
+ if key in self._memory_cache:
235
+ return self._memory_cache[key]
236
+
237
+ # Check disk cache
238
+ cache_path = self._get_cache_path(key)
239
+ if not cache_path.exists():
240
+ return None
241
+
242
+ # Load from disk
243
+ data = torch.load(cache_path, map_location=resolve_default_device(), weights_only=False)
244
+
245
+ cached = CachedActivations(
246
+ benchmark=data["benchmark"],
247
+ strategy=ExtractionStrategy(data["strategy"]),
248
+ model_name=data["model_name"],
249
+ num_layers=data["num_layers"],
250
+ hidden_size=data["hidden_size"],
251
+ )
252
+ cached.pair_activations = data["pair_activations"]
253
+ cached.num_pairs = data["num_pairs"]
254
+
255
+ if load_to_memory:
256
+ self._memory_cache[key] = cached
257
+
258
+ return cached
259
+
260
+ def put(
261
+ self,
262
+ cached: CachedActivations,
263
+ save_to_disk: bool = True,
264
+ ) -> None:
265
+ """
266
+ Store cached activations.
267
+
268
+ Args:
269
+ cached: CachedActivations to store
270
+ save_to_disk: If True, persist to disk
271
+ """
272
+ key = self._get_cache_key(cached.model_name, cached.benchmark, cached.strategy)
273
+
274
+ # Store in memory
275
+ self._memory_cache[key] = cached
276
+
277
+ if save_to_disk:
278
+ # Save to disk
279
+ data = {
280
+ "benchmark": cached.benchmark,
281
+ "strategy": cached.strategy.value,
282
+ "model_name": cached.model_name,
283
+ "num_layers": cached.num_layers,
284
+ "hidden_size": cached.hidden_size,
285
+ "num_pairs": cached.num_pairs,
286
+ "pair_activations": cached.pair_activations,
287
+ }
288
+ torch.save(data, self._get_cache_path(key))
289
+
290
+ # Save metadata as JSON
291
+ metadata = {
292
+ "benchmark": cached.benchmark,
293
+ "strategy": cached.strategy.value,
294
+ "model_name": cached.model_name,
295
+ "num_layers": cached.num_layers,
296
+ "hidden_size": cached.hidden_size,
297
+ "num_pairs": cached.num_pairs,
298
+ }
299
+ with open(self._get_metadata_path(key), "w") as f:
300
+ json.dump(metadata, f, indent=2)
301
+
302
+ def clear_memory(self) -> None:
303
+ """Clear the in-memory cache."""
304
+ self._memory_cache.clear()
305
+
306
+ def list_cached(self) -> List[Dict[str, Any]]:
307
+ """List all cached activations."""
308
+ result = []
309
+ for meta_path in self.cache_dir.glob("*.json"):
310
+ with open(meta_path) as f:
311
+ result.append(json.load(f))
312
+ return result
313
+
314
+ def get_cache_size_bytes(self) -> int:
315
+ """Get total size of cache on disk."""
316
+ total = 0
317
+ for path in self.cache_dir.glob("*.pt"):
318
+ total += path.stat().st_size
319
+ return total
320
+
321
+
322
+ def collect_and_cache_activations(
323
+ model: "WisentModel",
324
+ pairs: List[ContrastivePair],
325
+ benchmark: str,
326
+ strategy: ExtractionStrategy,
327
+ cache: Optional[ActivationCache] = None,
328
+ cache_dir: Optional[str] = None,
329
+ show_progress: bool = True,
330
+ ) -> CachedActivations:
331
+ """
332
+ Collect activations for all pairs and all layers, then cache.
333
+
334
+ Args:
335
+ model: WisentModel instance
336
+ pairs: List of contrastive pairs
337
+ benchmark: Benchmark name
338
+ strategy: Extraction strategy
339
+ cache: Optional existing cache to use
340
+ cache_dir: Cache directory (used if cache not provided)
341
+ show_progress: Print progress
342
+
343
+ Returns:
344
+ CachedActivations with all layers for all pairs
345
+ """
346
+ from wisent.core.activations.activations_collector import ActivationCollector
347
+
348
+ # Check cache first
349
+ if cache is None and cache_dir:
350
+ cache = ActivationCache(cache_dir)
351
+
352
+ if cache and cache.has(model.model_name, benchmark, strategy):
353
+ if show_progress:
354
+ print(f"Loading cached activations for {benchmark}/{strategy.value}")
355
+ return cache.get(model.model_name, benchmark, strategy)
356
+
357
+ # Collect activations for ALL layers (preserve model's native dtype)
358
+ collector = ActivationCollector(model=model)
359
+
360
+ cached = CachedActivations(
361
+ benchmark=benchmark,
362
+ strategy=strategy,
363
+ model_name=model.model_name,
364
+ num_layers=model.num_layers,
365
+ )
366
+
367
+ for i, pair in enumerate(pairs):
368
+ if show_progress and i % 10 == 0:
369
+ print(f"Collecting activations: {i+1}/{len(pairs)}", end="\r", flush=True)
370
+
371
+ # Collect ALL layers (layers=None)
372
+ updated = collector.collect(pair, strategy=strategy, layers=None)
373
+ cached.add_pair(
374
+ updated.positive_response.layers_activations,
375
+ updated.negative_response.layers_activations,
376
+ )
377
+
378
+ if show_progress:
379
+ print(f"Collected activations: {len(pairs)}/{len(pairs)} pairs, {cached.num_layers} layers")
380
+
381
+ # Cache the result
382
+ if cache:
383
+ cache.put(cached)
384
+ if show_progress:
385
+ print(f"Cached to {cache.cache_dir}")
386
+
387
+ return cached
388
+
389
+
390
+ # Type hint for WisentModel (avoid circular import)
391
+ from typing import TYPE_CHECKING
392
+ if TYPE_CHECKING:
393
+ from wisent.core.models.wisent_model import WisentModel
@@ -50,7 +50,7 @@ class Activations:
50
50
  features = tensor.mean(dim=1).squeeze(0)
51
51
  elif strategy in (ExtractionStrategy.CHAT_LAST, ExtractionStrategy.ROLE_PLAY, ExtractionStrategy.MC_BALANCED):
52
52
  features = tensor[:, -1, :].squeeze(0)
53
- elif strategy in (ExtractionStrategy.CHAT_FIRST, ExtractionStrategy.CHAT_GEN_POINT):
53
+ elif strategy == ExtractionStrategy.CHAT_FIRST:
54
54
  features = tensor[:, 0, :].squeeze(0)
55
55
  elif strategy == ExtractionStrategy.CHAT_MAX_NORM:
56
56
  norms = torch.norm(tensor, dim=2)
@@ -58,11 +58,11 @@ class Activations:
58
58
  features = tensor[0, max_idx[0], :]
59
59
  elif strategy == ExtractionStrategy.CHAT_WEIGHTED:
60
60
  seq_len = tensor.shape[1]
61
- weights = torch.exp(-torch.arange(seq_len, dtype=torch.float32, device=tensor.device) * 0.5)
61
+ weights = torch.exp(-torch.arange(seq_len, dtype=tensor.dtype, device=tensor.device) * 0.5)
62
62
  weights = weights / weights.sum()
63
63
  features = (tensor * weights.unsqueeze(0).unsqueeze(2)).sum(dim=1).squeeze(0)
64
64
  else:
65
- features = tensor.mean(dim=1).squeeze(0)
65
+ raise InvalidValueError(param="extraction_strategy", reason=f"Unknown extraction strategy: {strategy}")
66
66
 
67
67
  return features
68
68
 
@@ -25,11 +25,11 @@ class ActivationCollector:
25
25
 
26
26
  Args:
27
27
  model: WisentModel instance
28
- store_device: Device to store collected activations on (default "cpu")
29
- dtype: Optional torch.dtype to cast activations to (e.g., torch.float32)
28
+ store_device: Device to store collected activations on (default: "cpu" to avoid GPU OOM)
29
+ dtype: Optional torch.dtype to cast activations to
30
30
 
31
31
  Example:
32
- >>> collector = ActivationCollector(model=my_model, store_device="cpu", dtype=torch.float32)
32
+ >>> collector = ActivationCollector(model=my_model)
33
33
  >>> updated_pair = collector.collect(
34
34
  ... pair,
35
35
  ... strategy=ExtractionStrategy.CHAT_LAST,
@@ -37,7 +37,7 @@ class ActivationCollector:
37
37
  ... )
38
38
  >>> pos_acts = updated_pair.positive_response.layers_activations
39
39
  >>> pos_acts.summary()
40
- {'8': {'shape': (2048,), 'dtype': 'torch.float32', ...}, '12': {...}}
40
+ {'8': {'shape': (2048,), ...}, '12': {...}}
41
41
  """
42
42
 
43
43
  model: "WisentModel"
@@ -66,8 +66,9 @@ class ActivationCollector:
66
66
  pos_text = _resp_text(pair.positive_response)
67
67
  neg_text = _resp_text(pair.negative_response)
68
68
 
69
- other_for_pos = neg_text if strategy == ExtractionStrategy.MC_BALANCED else None
70
- other_for_neg = pos_text if strategy == ExtractionStrategy.MC_BALANCED else None
69
+ needs_other = strategy in (ExtractionStrategy.MC_BALANCED, ExtractionStrategy.MC_COMPLETION)
70
+ other_for_pos = neg_text if needs_other else None
71
+ other_for_neg = pos_text if needs_other else None
71
72
 
72
73
  pos = self._collect_single(
73
74
  pair.prompt, pos_text, strategy, layers, normalize,
@@ -220,8 +221,12 @@ class ActivationCollector:
220
221
  value = h.mean(dim=0)
221
222
  elif strategy == ExtractionStrategy.CHAT_FIRST:
222
223
  value = h[0]
223
- else:
224
+ elif strategy in (ExtractionStrategy.CHAT_LAST, ExtractionStrategy.ROLE_PLAY,
225
+ ExtractionStrategy.MC_BALANCED,
226
+ ExtractionStrategy.CHAT_MAX_NORM, ExtractionStrategy.CHAT_WEIGHTED):
224
227
  value = h[-1]
228
+ else:
229
+ raise ValueError(f"Unsupported strategy for batched collection: {strategy}")
225
230
 
226
231
  collected[name] = value.to(self.store_device)
227
232
 
@@ -8,13 +8,16 @@ Based on empirical testing across 3 models (Llama-3.2-1B, Llama-2-7b, Qwen3-8B)
8
8
  and 4 tasks (truthfulqa, happy, left_wing, livecodebench):
9
9
 
10
10
  Results:
11
- - last_token: 66.3% avg accuracy (94.4% when paired with chat_last training)
12
- - all_mean: 65.9% avg accuracy
13
- - all_min: 53.5% avg accuracy
14
- - all_max: 53.3% avg accuracy
15
- - first_token: 50.0% avg accuracy (completely useless - BOS token is identical for all inputs)
11
+ - last_token: Best performer (77% with chat_last training on truthfulqa)
12
+ - all_mean: Poor (~50%) - dominated by shared prompt tokens
13
+ - all_max/all_min: Poor (~50%)
14
+ - first_token: BROKEN (50%) - BOS token is identical for all inputs
16
15
 
17
16
  Recommendation: Use LAST_TOKEN (default) - it works best with chat_last training strategy.
17
+
18
+ IMPORTANT: These strategies operate on the FULL sequence (prompt + response).
19
+ At inference time, we typically don't know where the answer starts, so we
20
+ can only use strategies that work on the whole sequence.
18
21
  """
19
22
 
20
23
  from enum import Enum
@@ -102,8 +105,7 @@ def extract_inference_activation(
102
105
  return hidden_states[torch.argmin(norms)]
103
106
 
104
107
  else:
105
- # Default fallback
106
- return hidden_states[-1]
108
+ raise ValueError(f"Unknown classifier inference strategy: {strategy}")
107
109
 
108
110
 
109
111
  def get_inference_score(
@@ -152,8 +154,7 @@ def get_inference_score(
152
154
  elif strategy == ClassifierInferenceStrategy.ALL_MIN:
153
155
  return float(np.min(all_scores))
154
156
 
155
- # Default fallback
156
- return float(classifier.predict_proba([hidden_np[-1]])[0, 1])
157
+ raise ValueError(f"Unknown classifier inference strategy: {strategy}")
157
158
 
158
159
 
159
160
  def get_recommended_inference_strategy(train_strategy) -> ClassifierInferenceStrategy:
@@ -161,8 +162,8 @@ def get_recommended_inference_strategy(train_strategy) -> ClassifierInferenceStr
161
162
  Get the recommended inference strategy for a given training strategy.
162
163
 
163
164
  Based on empirical testing:
164
- - chat_last, role_play, mc_balanced -> last_token (94.4%, 72.4%, 60.2%)
165
- - chat_mean, chat_weighted, chat_max_norm, chat_first, chat_gen_point -> all_mean
165
+ - chat_last, role_play, mc_balanced -> last_token
166
+ - chat_mean, chat_weighted, chat_max_norm, chat_first -> all_mean
166
167
 
167
168
  Args:
168
169
  train_strategy: ExtractionStrategy used for training