wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (391) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activation_cache.py +393 -0
  12. wisent/core/activations/activations.py +3 -3
  13. wisent/core/activations/activations_collector.py +12 -7
  14. wisent/core/activations/classifier_inference_strategy.py +12 -11
  15. wisent/core/activations/extraction_strategy.py +260 -84
  16. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  17. wisent/core/cli/__init__.py +2 -1
  18. wisent/core/cli/agent/train_classifier.py +16 -3
  19. wisent/core/cli/check_linearity.py +35 -3
  20. wisent/core/cli/cluster_benchmarks.py +4 -6
  21. wisent/core/cli/create_steering_vector.py +6 -4
  22. wisent/core/cli/diagnose_vectors.py +7 -4
  23. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  24. wisent/core/cli/generate_pairs_from_task.py +9 -56
  25. wisent/core/cli/generate_vector_from_task.py +11 -20
  26. wisent/core/cli/geometry_search.py +137 -0
  27. wisent/core/cli/get_activations.py +2 -2
  28. wisent/core/cli/method_optimizer.py +4 -3
  29. wisent/core/cli/modify_weights.py +3 -2
  30. wisent/core/cli/optimize_sample_size.py +1 -1
  31. wisent/core/cli/optimize_steering.py +14 -16
  32. wisent/core/cli/optimize_weights.py +2 -1
  33. wisent/core/cli/preview_pairs.py +203 -0
  34. wisent/core/cli/steering_method_trainer.py +3 -3
  35. wisent/core/cli/tasks.py +19 -76
  36. wisent/core/cli/train_unified_goodness.py +3 -3
  37. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  38. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  53. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  54. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  55. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  56. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  57. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  58. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  59. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  60. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  61. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  273. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  274. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  275. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  276. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  277. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  278. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  279. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  280. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  281. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  282. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  283. wisent/core/geometry_runner.py +995 -0
  284. wisent/core/geometry_search_space.py +237 -0
  285. wisent/core/hyperparameter_optimizer.py +1 -1
  286. wisent/core/main.py +3 -0
  287. wisent/core/models/core/atoms.py +5 -3
  288. wisent/core/models/wisent_model.py +1 -1
  289. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  290. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  291. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  292. wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
  293. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  294. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  295. wisent/core/parser_arguments/main_parser.py +8 -0
  296. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  297. wisent/core/steering.py +5 -3
  298. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  299. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  300. wisent/core/trainers/steering_trainer.py +2 -2
  301. wisent/core/utils/device.py +27 -27
  302. wisent/core/utils/layer_combinations.py +70 -0
  303. wisent/examples/__init__.py +1 -0
  304. wisent/examples/scripts/__init__.py +1 -0
  305. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  306. wisent/examples/scripts/discover_directions.py +469 -0
  307. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  308. wisent/examples/scripts/search_all_short_names.py +31 -0
  309. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  310. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  311. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  312. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  313. wisent/examples/scripts/test_one_benchmark.py +324 -0
  314. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  315. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  316. wisent/parameters/lm_eval/category_directions.json +137 -0
  317. wisent/parameters/lm_eval/repair_plan.json +282 -0
  318. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  319. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  320. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  321. wisent/tests/test_detector_accuracy.py +1 -1
  322. wisent/tests/visualize_geometry.py +1 -1
  323. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
  325. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  326. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  327. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  328. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  329. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  330. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  331. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  332. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  333. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  334. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  335. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  336. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  337. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  338. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  339. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  340. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  341. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  342. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  343. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  344. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  345. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  346. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  347. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  348. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  349. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  350. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  351. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  352. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  353. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  354. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  355. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  356. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  357. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  358. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  359. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  360. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  361. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  362. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  363. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  364. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  365. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  366. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  367. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  368. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  369. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  370. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  371. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  372. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  373. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  374. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  375. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  376. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  377. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  378. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  379. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  380. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  381. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  382. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  383. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  384. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  385. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  386. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  387. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  388. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  389. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  390. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  391. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,203 @@
1
+ """Preview contrastive pairs from benchmarks with different extraction strategies."""
2
+
3
+ import sys
4
+ import json
5
+ import argparse
6
+ from typing import Optional
7
+
8
+
9
+ def execute_preview_pairs(args):
10
+ """Preview contrastive pairs from a benchmark with different strategies applied."""
11
+ from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import (
12
+ lm_build_contrastive_pairs,
13
+ )
14
+ from wisent.core.contrastive_pairs.huggingface_pairs.hf_extractor_manifest import HF_EXTRACTORS
15
+ from wisent.core.activations.extraction_strategy import (
16
+ ExtractionStrategy,
17
+ build_extraction_texts,
18
+ get_strategy_for_model,
19
+ )
20
+
21
+ task_name = args.task_name
22
+ limit = args.limit or 5
23
+ strategies = args.strategies or ['chat_last', 'mc_balanced', 'completion_last']
24
+
25
+ print(f"\n{'='*80}")
26
+ print(f"Preview Contrastive Pairs: {task_name}")
27
+ print(f"{'='*80}")
28
+
29
+ # Load pairs
30
+ print(f"\nLoading {limit} pairs from '{task_name}'...")
31
+
32
+ try:
33
+ task_name_lower = task_name.lower()
34
+ is_hf_task = task_name_lower in {k.lower() for k in HF_EXTRACTORS.keys()}
35
+
36
+ if is_hf_task:
37
+ pairs = lm_build_contrastive_pairs(
38
+ task_name=task_name,
39
+ lm_eval_task=None,
40
+ limit=limit,
41
+ )
42
+ else:
43
+ from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
44
+ loader = LMEvalDataLoader()
45
+ task_obj = loader.load_lm_eval_task(task_name)
46
+
47
+ if isinstance(task_obj, dict):
48
+ if len(task_obj) != 1:
49
+ keys = ", ".join(sorted(task_obj.keys()))
50
+ print(f"Task '{task_name}' has subtasks: {keys}")
51
+ print("Please specify a subtask.")
52
+ sys.exit(1)
53
+ (subname, task), = task_obj.items()
54
+ task_name = subname
55
+ else:
56
+ task = task_obj
57
+
58
+ pairs = lm_build_contrastive_pairs(
59
+ task_name=task_name,
60
+ lm_eval_task=task,
61
+ limit=limit,
62
+ )
63
+
64
+ print(f"Loaded {len(pairs)} pairs\n")
65
+
66
+ except Exception as e:
67
+ print(f"Error loading task: {e}")
68
+ sys.exit(1)
69
+
70
+ # Mock tokenizer for preview
71
+ class PreviewTokenizer:
72
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):
73
+ if len(messages) == 1:
74
+ return f"<|user|>\n{messages[0]['content']}\n<|assistant|>\n"
75
+ elif len(messages) == 2:
76
+ return f"<|user|>\n{messages[0]['content']}\n<|assistant|>\n{messages[1]['content']}<|end|>"
77
+ return str(messages)
78
+
79
+ def __call__(self, text, add_special_tokens=False):
80
+ return {"input_ids": text.split()}
81
+
82
+ tokenizer = PreviewTokenizer()
83
+
84
+ # Show pairs with strategies
85
+ for i, pair in enumerate(pairs):
86
+ print(f"\n{'='*80}")
87
+ print(f"PAIR {i+1}/{len(pairs)}")
88
+ print(f"{'='*80}")
89
+
90
+ print(f"\n--- RAW DATA (from extractor) ---")
91
+ print(f"Prompt: {pair.prompt[:300]}{'...' if len(pair.prompt) > 300 else ''}")
92
+ print(f"Correct: {pair.positive_response.model_response[:100]}{'...' if len(pair.positive_response.model_response) > 100 else ''}")
93
+ print(f"Incorrect: {pair.negative_response.model_response[:100]}{'...' if len(pair.negative_response.model_response) > 100 else ''}")
94
+
95
+ for strategy_name in strategies:
96
+ try:
97
+ strategy = ExtractionStrategy(strategy_name)
98
+ except ValueError:
99
+ print(f"\n--- {strategy_name.upper()} --- (invalid strategy)")
100
+ continue
101
+
102
+ print(f"\n--- {strategy_name.upper()} ---")
103
+
104
+ try:
105
+ # Build texts for positive response
106
+ if strategy in (ExtractionStrategy.MC_BALANCED, ExtractionStrategy.MC_COMPLETION):
107
+ full_text, answer, prompt_only = build_extraction_texts(
108
+ strategy,
109
+ pair.prompt,
110
+ pair.positive_response.model_response,
111
+ tokenizer,
112
+ other_response=pair.negative_response.model_response,
113
+ is_positive=True,
114
+ auto_convert_strategy=False,
115
+ )
116
+ else:
117
+ full_text, answer, prompt_only = build_extraction_texts(
118
+ strategy,
119
+ pair.prompt,
120
+ pair.positive_response.model_response,
121
+ tokenizer,
122
+ auto_convert_strategy=False,
123
+ )
124
+
125
+ print(f"Full text (positive):")
126
+ print(f" {full_text[:400]}{'...' if len(full_text) > 400 else ''}")
127
+ print(f"Answer token: {answer}")
128
+
129
+ except Exception as e:
130
+ print(f" Error: {e}")
131
+
132
+ # Summary
133
+ print(f"\n{'='*80}")
134
+ print("SUMMARY")
135
+ print(f"{'='*80}")
136
+ print(f"Task: {task_name}")
137
+ print(f"Pairs shown: {len(pairs)}")
138
+ print(f"Strategies: {', '.join(strategies)}")
139
+ print()
140
+
141
+ # Save to JSON if requested
142
+ if args.output:
143
+ output_data = {
144
+ "task_name": task_name,
145
+ "num_pairs": len(pairs),
146
+ "strategies": strategies,
147
+ "pairs": []
148
+ }
149
+
150
+ for pair in pairs:
151
+ pair_data = {
152
+ "raw": {
153
+ "prompt": pair.prompt,
154
+ "correct": pair.positive_response.model_response,
155
+ "incorrect": pair.negative_response.model_response,
156
+ },
157
+ "formatted": {}
158
+ }
159
+
160
+ for strategy_name in strategies:
161
+ try:
162
+ strategy = ExtractionStrategy(strategy_name)
163
+ if strategy in (ExtractionStrategy.MC_BALANCED, ExtractionStrategy.MC_COMPLETION):
164
+ full_text, answer, _ = build_extraction_texts(
165
+ strategy, pair.prompt, pair.positive_response.model_response,
166
+ tokenizer, other_response=pair.negative_response.model_response,
167
+ is_positive=True, auto_convert_strategy=False,
168
+ )
169
+ else:
170
+ full_text, answer, _ = build_extraction_texts(
171
+ strategy, pair.prompt, pair.positive_response.model_response,
172
+ tokenizer, auto_convert_strategy=False,
173
+ )
174
+ pair_data["formatted"][strategy_name] = {
175
+ "full_text": full_text,
176
+ "answer": answer,
177
+ }
178
+ except Exception as e:
179
+ pair_data["formatted"][strategy_name] = {"error": str(e)}
180
+
181
+ output_data["pairs"].append(pair_data)
182
+
183
+ with open(args.output, 'w') as f:
184
+ json.dump(output_data, f, indent=2)
185
+ print(f"Saved to: {args.output}")
186
+
187
+
188
+ def main():
189
+ parser = argparse.ArgumentParser(description="Preview contrastive pairs with different strategies")
190
+ parser.add_argument("task_name", help="Task/benchmark name (e.g., boolq, mmlu, hellaswag)")
191
+ parser.add_argument("--limit", "-n", type=int, default=5, help="Number of pairs to show (default: 5)")
192
+ parser.add_argument("--strategies", "-s", nargs="+",
193
+ default=["chat_last", "mc_balanced", "completion_last"],
194
+ help="Strategies to preview")
195
+ parser.add_argument("--output", "-o", help="Save to JSON file")
196
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
197
+
198
+ args = parser.parse_args()
199
+ execute_preview_pairs(args)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
@@ -156,7 +156,7 @@ def collect_activations_for_pair_set(
156
156
  Returns:
157
157
  Updated ContrastivePairSet with activations attached
158
158
  """
159
- collector = ActivationCollector(model=model, store_device="cpu")
159
+ collector = ActivationCollector(model=model)
160
160
 
161
161
  updated_pairs = []
162
162
  for pair in pair_set.pairs:
@@ -320,7 +320,7 @@ class UnifiedSteeringTrainer:
320
320
  @property
321
321
  def collector(self) -> ActivationCollector:
322
322
  if self._collector is None:
323
- self._collector = ActivationCollector(model=self.model, store_device="cpu")
323
+ self._collector = ActivationCollector(model=self.model)
324
324
  return self._collector
325
325
 
326
326
  def train_for_layer(
@@ -595,7 +595,7 @@ def get_optimal_steering_plan(
595
595
  method_name = config["method"]
596
596
 
597
597
  # Collect activations for the optimal layer
598
- collector = ActivationCollector(model=model, store_device="cpu")
598
+ collector = ActivationCollector(model=model)
599
599
  layer_str = str(layer)
600
600
 
601
601
  pos_acts = []
wisent/core/cli/tasks.py CHANGED
@@ -414,7 +414,7 @@ def execute_tasks(args):
414
414
  print(f"\n🧠 Extracting activations from layer {layer}...")
415
415
 
416
416
  # 5. Collect activations for all pairs
417
- collector = ActivationCollector(model=model, store_device="cpu")
417
+ collector = ActivationCollector(model=model)
418
418
 
419
419
  # Get extraction strategy from args (already an ExtractionStrategy value string)
420
420
  extraction_strategy = ExtractionStrategy(getattr(args, 'extraction_strategy', 'chat_last'))
@@ -581,13 +581,6 @@ def execute_tasks(args):
581
581
  expected = pair.positive_response.model_response
582
582
  choices = [pair.negative_response.model_response, pair.positive_response.model_response]
583
583
 
584
- # Extract test_code from pair metadata for coding tasks
585
- test_code = None
586
- starter_code = None
587
- if hasattr(pair, 'metadata') and pair.metadata:
588
- test_code = pair.metadata.get('test_code')
589
- starter_code = pair.metadata.get('starter_code')
590
-
591
584
  # Generate response from unsteered model
592
585
  messages = [{"role": "user", "content": question}]
593
586
 
@@ -597,6 +590,7 @@ def execute_tasks(args):
597
590
  )[0]
598
591
 
599
592
  # Evaluate the response using Wisent evaluator
593
+ # Pass all pair metadata to evaluator - each evaluator uses what it needs
600
594
  eval_kwargs = {
601
595
  'response': response,
602
596
  'expected': expected,
@@ -605,16 +599,16 @@ def execute_tasks(args):
605
599
  'choices': choices,
606
600
  'task_name': task_name,
607
601
  }
608
- # Add test_code for coding tasks (livecodebench, humaneval, mbpp, etc.)
609
- if test_code:
610
- eval_kwargs['test_code'] = test_code
611
- if starter_code:
612
- eval_kwargs['starter_code'] = starter_code
602
+ # Add all pair metadata to eval_kwargs (test_code, correct_answers, etc.)
603
+ if hasattr(pair, 'metadata') and pair.metadata:
604
+ for key, value in pair.metadata.items():
605
+ if value is not None and key not in eval_kwargs:
606
+ eval_kwargs[key] = value
613
607
  eval_result = evaluator.evaluate(**eval_kwargs)
614
608
 
615
609
  # Get activation for this generation
616
610
  # Use ActivationCollector to collect activations from the generated text
617
- gen_collector = ActivationCollector(model=model, store_device="cpu")
611
+ gen_collector = ActivationCollector(model=model)
618
612
  # Create a pair with the generated response
619
613
  from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
620
614
  from wisent.core.contrastive_pairs.core.pair import ContrastivePair
@@ -631,56 +625,20 @@ def execute_tasks(args):
631
625
  # Collect activation - ActivationCollector will re-run the model with prompt+response
632
626
  # First, collect with full sequence to get token-by-token activations
633
627
  collected_full = gen_collector.collect(
634
- temp_pair, strategy=aggregation_strategy,
635
- return_full_sequence=True,
636
- normalize_layers=False,
637
- prompt_strategy=prompt_strategy
628
+ temp_pair, strategy=extraction_strategy,
638
629
  )
639
630
 
640
631
  # Access the collected activations
641
632
  import torch
642
633
  if collected_full.positive_response.layers_activations:
643
- layer_activations_full = collected_full.positive_response.layers_activations
644
- if layer_str in layer_activations_full:
645
- activation_full_seq = layer_activations_full[layer_str]
646
- if activation_full_seq is not None and isinstance(activation_full_seq, torch.Tensor):
647
- # activation_full_seq shape: (num_tokens, hidden_dim)
648
-
649
- # Apply aggregation manually to get single vector for classifier
650
- if aggregation_strategy.name == 'MEAN_POOLING':
651
- activation_agg = activation_full_seq.mean(dim=0)
652
- elif aggregation_strategy.name == 'LAST_TOKEN':
653
- activation_agg = activation_full_seq[-1]
654
- elif aggregation_strategy.name == 'FIRST_TOKEN':
655
- activation_agg = activation_full_seq[0]
656
- elif aggregation_strategy.name == 'MAX_POOLING':
657
- activation_agg = activation_full_seq.max(dim=0)[0]
658
- elif aggregation_strategy.name == 'MIN_POOLING':
659
- activation_agg = activation_full_seq.min(dim=0)[0]
660
- else:
661
- # Default to mean
662
- activation_agg = activation_full_seq.mean(dim=0)
663
-
664
- # Compute per-token classifier scores first (needed for max_score aggregation)
665
- token_scores = []
666
- for token_idx in range(activation_full_seq.shape[0]):
667
- token_act = activation_full_seq[token_idx].unsqueeze(0).float()
668
- token_proba_result = classifier.predict_proba(token_act)
669
- token_proba = token_proba_result if isinstance(token_proba_result, float) else token_proba_result[0]
670
- token_scores.append(float(token_proba))
671
-
672
- # Get classifier prediction - either from aggregated vector or min token score
673
- if use_max_token_score:
674
- # Use MINIMUM token score as response-level classification
675
- # Score = P(TRUTHFUL), so min score = most suspicious token
676
- # If ANY token has low truthfulness probability, flag the response
677
- pred_proba = min(token_scores) if token_scores else 0.5
678
- else:
679
- # Standard: classify aggregated activation vector
680
- act_tensor = activation_agg.unsqueeze(0).float()
681
- pred_proba_result = classifier.predict_proba(act_tensor)
682
- pred_proba = pred_proba_result if isinstance(pred_proba_result, float) else pred_proba_result[0]
683
-
634
+ layer_activations = collected_full.positive_response.layers_activations
635
+ if layer_str in layer_activations:
636
+ activation = layer_activations[layer_str]
637
+ if activation is not None and isinstance(activation, torch.Tensor):
638
+ # activation shape: (hidden_dim,) - already aggregated by extraction strategy
639
+ act_tensor = activation.unsqueeze(0).float()
640
+ pred_proba_result = classifier.predict_proba(act_tensor)
641
+ pred_proba = pred_proba_result if isinstance(pred_proba_result, float) else pred_proba_result[0]
684
642
  pred_label = int(pred_proba > args.detection_threshold)
685
643
 
686
644
  # Update detection stats
@@ -753,14 +711,6 @@ def execute_tasks(args):
753
711
  # Ground truth from evaluator
754
712
  ground_truth = 1 if eval_result.ground_truth == "TRUTHFUL" else 0
755
713
 
756
- # token_scores = P(TRUTHFUL) for each token
757
- # min_token_score = most suspicious token (lowest P(TRUTHFUL))
758
- # max_token_score = most confident token (highest P(TRUTHFUL))
759
- min_token_score = min(token_scores) if token_scores else 0.0
760
- min_token_idx = token_scores.index(min_token_score) if token_scores else -1
761
- max_token_score = max(token_scores) if token_scores else 0.0
762
- max_token_idx = token_scores.index(max_token_score) if token_scores else -1
763
-
764
714
  generation_results.append({
765
715
  'question': question,
766
716
  'response': response,
@@ -770,13 +720,6 @@ def execute_tasks(args):
770
720
  'classifier_pred': pred_label,
771
721
  'classifier_proba': float(pred_proba),
772
722
  'correct': pred_label == ground_truth,
773
- 'token_scores': token_scores, # Per-token P(TRUTHFUL) probabilities
774
- 'min_token_score': min_token_score, # Most suspicious token - lowest P(TRUTHFUL)
775
- 'min_token_idx': min_token_idx, # Index of most suspicious token
776
- 'max_token_score': max_token_score, # Most confident token - highest P(TRUTHFUL) (kept for backward compat)
777
- 'max_token_idx': max_token_idx, # Index of most confident token
778
- 'num_tokens': len(token_scores),
779
- 'aggregation_method': 'max_score' if use_max_token_score else args.token_aggregation,
780
723
  'quality_score': quality_score,
781
724
  'issue_detected': issue_detected,
782
725
  'detection_type': detection_type,
@@ -852,7 +795,7 @@ def execute_tasks(args):
852
795
  classifier_type=args.classifier_type,
853
796
  training_accuracy=report.final.accuracy,
854
797
  training_samples=len(X),
855
- token_aggregation=args.token_aggregation,
798
+ token_aggregation=extraction_strategy.value,
856
799
  detection_threshold=args.detection_threshold
857
800
  )
858
801
 
@@ -884,7 +827,7 @@ def execute_tasks(args):
884
827
  'task': args.task_names,
885
828
  'model': args.model,
886
829
  'layer': layer,
887
- 'aggregation': args.token_aggregation,
830
+ 'aggregation': extraction_strategy.value,
888
831
  'threshold': args.detection_threshold,
889
832
  'num_generations': len(generation_results),
890
833
  'detection_stats': detection_stats,
@@ -325,11 +325,11 @@ def execute_train_unified_goodness(args):
325
325
  'final': ExtractionStrategy.CHAT_LAST,
326
326
  'first': ExtractionStrategy.CHAT_FIRST,
327
327
  'max': ExtractionStrategy.CHAT_MAX_NORM,
328
- 'continuation': ExtractionStrategy.CHAT_GEN_POINT,
328
+ 'continuation': ExtractionStrategy.CHAT_FIRST, # First answer token
329
329
  }
330
330
  aggregation_strategy = aggregation_map.get(
331
331
  args.token_aggregation,
332
- ExtractionStrategy.CHAT_GEN_POINT
332
+ ExtractionStrategy.CHAT_LAST
333
333
  )
334
334
 
335
335
  # Map prompt strategy
@@ -353,7 +353,7 @@ def execute_train_unified_goodness(args):
353
353
  negative_activations = activations_checkpoint['negative_activations']
354
354
  print(f" ✓ Loaded activations from checkpoint ({len(positive_activations[layers[0]])} pairs)")
355
355
  else:
356
- collector = ActivationCollector(model=model, store_device="cpu")
356
+ collector = ActivationCollector(model=model)
357
357
 
358
358
  # Collect activations for all training pairs using batched processing
359
359
  positive_activations = {layer: [] for layer in layers}
@@ -95,7 +95,7 @@ def run_control_vector_diagnostics(
95
95
  )
96
96
  continue
97
97
 
98
- flat = detached.to(dtype=torch.float32, device="cpu").reshape(-1)
98
+ flat = detached.to(device="cpu").reshape(-1)
99
99
 
100
100
  if not torch.isfinite(flat).all():
101
101
  non_finite = (~torch.isfinite(flat)).sum().item()
@@ -1549,7 +1549,7 @@ def _detect_sparse_structure(
1549
1549
  sorted_abs = abs_diff.sort().values
1550
1550
  n = len(sorted_abs)
1551
1551
  cumsum = sorted_abs.cumsum(0)
1552
- gini = (2 * torch.arange(1, n + 1, dtype=torch.float32) @ sorted_abs - (n + 1) * sorted_abs.sum()) / (n * sorted_abs.sum() + 1e-10)
1552
+ gini = (2 * torch.arange(1, n + 1, dtype=sorted_abs.dtype, device=sorted_abs.device) @ sorted_abs - (n + 1) * sorted_abs.sum()) / (n * sorted_abs.sum() + 1e-10)
1553
1553
 
1554
1554
  # Sparse score: high if few dimensions are active
1555
1555
  sparse_score = 0.4 * (1 - float(l1_l2_ratio)) + 0.3 * (1 - float(active_fraction)) + 0.3 * float(gini)
@@ -1632,11 +1632,11 @@ def _compute_dip_statistic(data: torch.Tensor) -> float:
1632
1632
  return 0.0
1633
1633
 
1634
1634
  # Empirical CDF
1635
- ecdf = torch.arange(1, n + 1, dtype=torch.float32) / n
1635
+ ecdf = torch.arange(1, n + 1, dtype=sorted_data.dtype, device=sorted_data.device) / n
1636
1636
 
1637
1637
  # Greatest convex minorant and least concave majorant
1638
1638
  # Simplified: measure deviation from uniform
1639
- uniform = torch.linspace(0, 1, n)
1639
+ uniform = torch.linspace(0, 1, n, dtype=sorted_data.dtype, device=sorted_data.device)
1640
1640
 
1641
1641
  # Kolmogorov-Smirnov like statistic
1642
1642
  ks_stat = (ecdf - uniform).abs().max()
@@ -188,6 +188,12 @@ def check_linearity(
188
188
  linear_score = result.all_scores["linear"].score
189
189
  linear_details = result.all_scores["linear"].details
190
190
 
191
+ # Include all structure scores
192
+ structure_scores = {
193
+ name: {"score": score.score, "confidence": score.confidence}
194
+ for name, score in result.all_scores.items()
195
+ }
196
+
191
197
  all_results.append({
192
198
  "extraction_strategy": strategy.value,
193
199
  "normalize": normalize,
@@ -196,6 +202,7 @@ def check_linearity(
196
202
  "cohens_d": linear_details.get("cohens_d", 0),
197
203
  "variance_explained": linear_details.get("variance_explained", 0),
198
204
  "best_structure": result.best_structure.value,
205
+ "all_structure_scores": structure_scores,
199
206
  })
200
207
 
201
208
  if not all_results: