wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (391) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activation_cache.py +393 -0
  12. wisent/core/activations/activations.py +3 -3
  13. wisent/core/activations/activations_collector.py +12 -7
  14. wisent/core/activations/classifier_inference_strategy.py +12 -11
  15. wisent/core/activations/extraction_strategy.py +260 -84
  16. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  17. wisent/core/cli/__init__.py +2 -1
  18. wisent/core/cli/agent/train_classifier.py +16 -3
  19. wisent/core/cli/check_linearity.py +35 -3
  20. wisent/core/cli/cluster_benchmarks.py +4 -6
  21. wisent/core/cli/create_steering_vector.py +6 -4
  22. wisent/core/cli/diagnose_vectors.py +7 -4
  23. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  24. wisent/core/cli/generate_pairs_from_task.py +9 -56
  25. wisent/core/cli/generate_vector_from_task.py +11 -20
  26. wisent/core/cli/geometry_search.py +137 -0
  27. wisent/core/cli/get_activations.py +2 -2
  28. wisent/core/cli/method_optimizer.py +4 -3
  29. wisent/core/cli/modify_weights.py +3 -2
  30. wisent/core/cli/optimize_sample_size.py +1 -1
  31. wisent/core/cli/optimize_steering.py +14 -16
  32. wisent/core/cli/optimize_weights.py +2 -1
  33. wisent/core/cli/preview_pairs.py +203 -0
  34. wisent/core/cli/steering_method_trainer.py +3 -3
  35. wisent/core/cli/tasks.py +19 -76
  36. wisent/core/cli/train_unified_goodness.py +3 -3
  37. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  38. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  53. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  54. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  55. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  56. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  57. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  58. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  59. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  60. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  61. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  273. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  274. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  275. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  276. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  277. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  278. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  279. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  280. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  281. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  282. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  283. wisent/core/geometry_runner.py +995 -0
  284. wisent/core/geometry_search_space.py +237 -0
  285. wisent/core/hyperparameter_optimizer.py +1 -1
  286. wisent/core/main.py +3 -0
  287. wisent/core/models/core/atoms.py +5 -3
  288. wisent/core/models/wisent_model.py +1 -1
  289. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  290. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  291. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  292. wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
  293. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  294. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  295. wisent/core/parser_arguments/main_parser.py +8 -0
  296. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  297. wisent/core/steering.py +5 -3
  298. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  299. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  300. wisent/core/trainers/steering_trainer.py +2 -2
  301. wisent/core/utils/device.py +27 -27
  302. wisent/core/utils/layer_combinations.py +70 -0
  303. wisent/examples/__init__.py +1 -0
  304. wisent/examples/scripts/__init__.py +1 -0
  305. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  306. wisent/examples/scripts/discover_directions.py +469 -0
  307. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  308. wisent/examples/scripts/search_all_short_names.py +31 -0
  309. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  310. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  311. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  312. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  313. wisent/examples/scripts/test_one_benchmark.py +324 -0
  314. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  315. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  316. wisent/parameters/lm_eval/category_directions.json +137 -0
  317. wisent/parameters/lm_eval/repair_plan.json +282 -0
  318. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  319. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  320. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  321. wisent/tests/test_detector_accuracy.py +1 -1
  322. wisent/tests/visualize_geometry.py +1 -1
  323. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
  325. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  326. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  327. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  328. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  329. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  330. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  331. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  332. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  333. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  334. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  335. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  336. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  337. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  338. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  339. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  340. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  341. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  342. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  343. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  344. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  345. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  346. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  347. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  348. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  349. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  350. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  351. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  352. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  353. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  354. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  355. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  356. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  357. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  358. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  359. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  360. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  361. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  362. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  363. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  364. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  365. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  366. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  367. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  368. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  369. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  370. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  371. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  372. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  373. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  374. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  375. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  376. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  377. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  378. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  379. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  380. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  381. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  382. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  383. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  384. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  385. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  386. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  387. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  388. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  389. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  390. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  391. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,995 @@
1
+ """
2
+ Geometry search runner.
3
+
4
+ Runs geometry tests across the search space using cached activations.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import random
11
+ import time
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+ from typing import Dict, List, Optional, Any, Tuple
15
+ import torch
16
+
17
+ import numpy as np
18
+
19
+ from wisent.core.geometry_search_space import GeometrySearchSpace, GeometrySearchConfig
20
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
21
+ from wisent.core.activations.activation_cache import (
22
+ ActivationCache,
23
+ CachedActivations,
24
+ collect_and_cache_activations,
25
+ )
26
+ from wisent.core.utils.layer_combinations import get_layer_combinations
27
+
28
+
29
+ def compute_signal_strength(
30
+ pos_activations: torch.Tensor,
31
+ neg_activations: torch.Tensor,
32
+ n_folds: int = 5,
33
+ ) -> float:
34
+ """
35
+ Compute signal strength using MLP cross-validation accuracy.
36
+
37
+ This measures whether there is ANY extractable signal (linear or nonlinear)
38
+ that generalizes to unseen data. Random/nonsense data gives ~0.5.
39
+
40
+ Args:
41
+ pos_activations: [N, hidden_dim] positive class activations
42
+ neg_activations: [N, hidden_dim] negative class activations
43
+ n_folds: Number of CV folds
44
+
45
+ Returns:
46
+ Cross-validation accuracy (0.5 = no signal, >0.7 = signal exists)
47
+ """
48
+ try:
49
+ from sklearn.neural_network import MLPClassifier
50
+ from sklearn.model_selection import cross_val_score
51
+
52
+ n_pos = len(pos_activations)
53
+ n_neg = len(neg_activations)
54
+
55
+ if n_pos < 5 or n_neg < 5:
56
+ return 0.5 # Not enough data
57
+
58
+ X = torch.cat([pos_activations, neg_activations], dim=0).float().cpu().numpy()
59
+ y = np.array([1] * n_pos + [0] * n_neg)
60
+
61
+ n_folds = min(n_folds, min(n_pos, n_neg))
62
+ if n_folds < 2:
63
+ return 0.5
64
+
65
+ clf = MLPClassifier(
66
+ hidden_layer_sizes=(16,),
67
+ max_iter=500,
68
+ random_state=42,
69
+ )
70
+ scores = cross_val_score(clf, X, y, cv=n_folds, scoring='accuracy')
71
+ return float(scores.mean())
72
+ except Exception:
73
+ return 0.5
74
+
75
+
76
+ def compute_knn_accuracy(
77
+ pos_activations: torch.Tensor,
78
+ neg_activations: torch.Tensor,
79
+ k: int = 10,
80
+ n_folds: int = 5,
81
+ ) -> float:
82
+ """
83
+ Compute k-NN cross-validation accuracy.
84
+
85
+ Measures local separability without assuming linearity.
86
+
87
+ Args:
88
+ pos_activations: [N, hidden_dim] positive class activations
89
+ neg_activations: [N, hidden_dim] negative class activations
90
+ k: Number of neighbors
91
+ n_folds: Number of CV folds
92
+
93
+ Returns:
94
+ Cross-validation accuracy
95
+ """
96
+ try:
97
+ from sklearn.neighbors import KNeighborsClassifier
98
+ from sklearn.model_selection import cross_val_score
99
+
100
+ n_pos = len(pos_activations)
101
+ n_neg = len(neg_activations)
102
+
103
+ if n_pos < k + 1 or n_neg < k + 1:
104
+ return 0.5
105
+
106
+ X = torch.cat([pos_activations, neg_activations], dim=0).float().cpu().numpy()
107
+ y = np.array([1] * n_pos + [0] * n_neg)
108
+
109
+ n_folds = min(n_folds, min(n_pos, n_neg))
110
+ if n_folds < 2:
111
+ return 0.5
112
+
113
+ clf = KNeighborsClassifier(n_neighbors=k)
114
+ scores = cross_val_score(clf, X, y, cv=n_folds, scoring='accuracy')
115
+ return float(scores.mean())
116
+ except Exception:
117
+ return 0.5
118
+
119
+
120
+ def compute_mmd_rbf(
121
+ pos_activations: torch.Tensor,
122
+ neg_activations: torch.Tensor,
123
+ ) -> float:
124
+ """
125
+ Compute Maximum Mean Discrepancy with RBF kernel.
126
+
127
+ Measures distribution difference without assuming linearity.
128
+ Higher values indicate more separable distributions.
129
+
130
+ Args:
131
+ pos_activations: [N, hidden_dim] positive class activations
132
+ neg_activations: [N, hidden_dim] negative class activations
133
+
134
+ Returns:
135
+ MMD value (0 = identical distributions)
136
+ """
137
+ try:
138
+ from sklearn.metrics.pairwise import rbf_kernel
139
+ from scipy.spatial.distance import cdist
140
+
141
+ pos = pos_activations.float().cpu().numpy()
142
+ neg = neg_activations.float().cpu().numpy()
143
+
144
+ # Use median heuristic for gamma
145
+ all_data = np.vstack([pos, neg])
146
+ dists = cdist(all_data, all_data, 'euclidean')
147
+ gamma = 1.0 / (2 * np.median(dists[dists > 0]) ** 2 + 1e-10)
148
+
149
+ K_pp = rbf_kernel(pos, pos, gamma=gamma)
150
+ K_nn = rbf_kernel(neg, neg, gamma=gamma)
151
+ K_pn = rbf_kernel(pos, neg, gamma=gamma)
152
+
153
+ m = len(pos)
154
+ n = len(neg)
155
+
156
+ mmd = (K_pp.sum() / (m * m) +
157
+ K_nn.sum() / (n * n) -
158
+ 2 * K_pn.sum() / (m * n))
159
+
160
+ return float(max(0, mmd))
161
+ except Exception:
162
+ return 0.0
163
+
164
+
165
+ def estimate_local_intrinsic_dim(X: np.ndarray, k: int = 10) -> float:
166
+ """
167
+ Estimate local intrinsic dimensionality using MLE method.
168
+ Based on Levina & Bickel (2004).
169
+
170
+ Args:
171
+ X: [N, D] data matrix
172
+ k: Number of neighbors for estimation
173
+
174
+ Returns:
175
+ Estimated intrinsic dimension
176
+ """
177
+ from scipy.spatial.distance import cdist
178
+
179
+ if len(X) < k + 1:
180
+ return float(X.shape[1])
181
+
182
+ dists = cdist(X, X, 'euclidean')
183
+ np.fill_diagonal(dists, np.inf)
184
+
185
+ sorted_dists = np.sort(dists, axis=1)[:, :k]
186
+
187
+ dims = []
188
+ for i in range(len(X)):
189
+ T_k = sorted_dists[i, k-1]
190
+ if T_k < 1e-10:
191
+ continue
192
+ log_ratios = np.log(sorted_dists[i, :k-1] / T_k + 1e-10)
193
+ if len(log_ratios) > 0 and log_ratios.sum() < 0:
194
+ dim_est = -(k - 1) / log_ratios.sum()
195
+ dims.append(min(dim_est, X.shape[1]))
196
+
197
+ return float(np.median(dims)) if dims else float(X.shape[1])
198
+
199
+
200
+ def compute_local_intrinsic_dims(
201
+ pos_activations: torch.Tensor,
202
+ neg_activations: torch.Tensor,
203
+ k: int = 10,
204
+ ) -> tuple:
205
+ """
206
+ Compute local intrinsic dimension for pos and neg separately.
207
+
208
+ Different local dimensions suggest different geometric structures.
209
+
210
+ Args:
211
+ pos_activations: [N, hidden_dim] positive class activations
212
+ neg_activations: [N, hidden_dim] negative class activations
213
+ k: Number of neighbors
214
+
215
+ Returns:
216
+ (local_dim_pos, local_dim_neg, ratio)
217
+ """
218
+ try:
219
+ pos = pos_activations.float().cpu().numpy()
220
+ neg = neg_activations.float().cpu().numpy()
221
+
222
+ dim_pos = estimate_local_intrinsic_dim(pos, k)
223
+ dim_neg = estimate_local_intrinsic_dim(neg, k)
224
+ ratio = dim_pos / (dim_neg + 1e-10)
225
+
226
+ return dim_pos, dim_neg, ratio
227
+ except Exception:
228
+ return 0.0, 0.0, 1.0
229
+
230
+
231
+ def compute_fisher_per_dimension(
232
+ pos_activations: torch.Tensor,
233
+ neg_activations: torch.Tensor,
234
+ ) -> dict:
235
+ """
236
+ Compute Fisher ratio for each dimension and summary stats.
237
+
238
+ Args:
239
+ pos_activations: [N, hidden_dim] positive class activations
240
+ neg_activations: [N, hidden_dim] negative class activations
241
+
242
+ Returns:
243
+ Dict with fisher_max, fisher_gini, fisher_top10_ratio, num_dims_above_1
244
+ """
245
+ try:
246
+ pos = pos_activations.float().cpu().numpy()
247
+ neg = neg_activations.float().cpu().numpy()
248
+
249
+ n_dims = pos.shape[1]
250
+ fishers = np.zeros(n_dims)
251
+
252
+ for d in range(n_dims):
253
+ pos_d = pos[:, d]
254
+ neg_d = neg[:, d]
255
+
256
+ mean_pos = pos_d.mean()
257
+ mean_neg = neg_d.mean()
258
+ var_pos = pos_d.var()
259
+ var_neg = neg_d.var()
260
+
261
+ between_var = (mean_pos - mean_neg) ** 2
262
+ within_var = (var_pos + var_neg) / 2
263
+
264
+ if within_var > 1e-10:
265
+ fishers[d] = between_var / within_var
266
+
267
+ # Summary stats
268
+ fisher_max = float(fishers.max())
269
+
270
+ # Gini coefficient
271
+ values = np.abs(fishers)
272
+ if values.sum() > 1e-10:
273
+ values = np.sort(values)
274
+ n = len(values)
275
+ fisher_gini = (2 * np.sum((np.arange(1, n+1) * values)) / (n * values.sum())) - (n + 1) / n
276
+ else:
277
+ fisher_gini = 0.0
278
+
279
+ # Top 10 ratio
280
+ sorted_fishers = np.sort(fishers)[::-1]
281
+ top10_sum = sorted_fishers[:10].sum()
282
+ total_sum = fishers.sum() + 1e-10
283
+ fisher_top10_ratio = float(top10_sum / total_sum)
284
+
285
+ num_dims_above_1 = int((fishers > 1.0).sum())
286
+
287
+ return {
288
+ "fisher_max": fisher_max,
289
+ "fisher_gini": float(fisher_gini),
290
+ "fisher_top10_ratio": fisher_top10_ratio,
291
+ "num_dims_fisher_above_1": num_dims_above_1,
292
+ }
293
+ except Exception:
294
+ return {
295
+ "fisher_max": 0.0,
296
+ "fisher_gini": 0.0,
297
+ "fisher_top10_ratio": 0.0,
298
+ "num_dims_fisher_above_1": 0,
299
+ }
300
+
301
+
302
+ def compute_density_ratio(
303
+ pos_activations: torch.Tensor,
304
+ neg_activations: torch.Tensor,
305
+ ) -> float:
306
+ """
307
+ Compute ratio of average intra-class distances.
308
+
309
+ Values far from 1 suggest different local geometries.
310
+
311
+ Args:
312
+ pos_activations: [N, hidden_dim] positive class activations
313
+ neg_activations: [N, hidden_dim] negative class activations
314
+
315
+ Returns:
316
+ Density ratio (pos avg dist / neg avg dist)
317
+ """
318
+ try:
319
+ from scipy.spatial.distance import cdist
320
+
321
+ pos = pos_activations.float().cpu().numpy()
322
+ neg = neg_activations.float().cpu().numpy()
323
+
324
+ if len(pos) < 2 or len(neg) < 2:
325
+ return 1.0
326
+
327
+ pos_dists = cdist(pos, pos, 'euclidean')
328
+ neg_dists = cdist(neg, neg, 'euclidean')
329
+
330
+ np.fill_diagonal(pos_dists, np.nan)
331
+ np.fill_diagonal(neg_dists, np.nan)
332
+
333
+ avg_pos = np.nanmean(pos_dists)
334
+ avg_neg = np.nanmean(neg_dists)
335
+
336
+ if avg_neg < 1e-10:
337
+ return 1.0
338
+
339
+ return float(avg_pos / avg_neg)
340
+ except Exception:
341
+ return 1.0
342
+
343
+
344
+ def compute_linear_probe_accuracy(
345
+ pos_activations: torch.Tensor,
346
+ neg_activations: torch.Tensor,
347
+ n_folds: int = 5,
348
+ ) -> float:
349
+ """
350
+ Compute linear probe cross-validation accuracy.
351
+
352
+ If signal_strength is high but linear_probe is low, the signal is nonlinear.
353
+ If both are high, signal is linear and CAA should work.
354
+
355
+ Args:
356
+ pos_activations: [N, hidden_dim] positive class activations
357
+ neg_activations: [N, hidden_dim] negative class activations
358
+ n_folds: Number of CV folds
359
+
360
+ Returns:
361
+ Cross-validation accuracy (0.5 = no linear signal)
362
+ """
363
+ try:
364
+ from sklearn.linear_model import LogisticRegression
365
+ from sklearn.model_selection import cross_val_score
366
+
367
+ n_pos = len(pos_activations)
368
+ n_neg = len(neg_activations)
369
+
370
+ if n_pos < 5 or n_neg < 5:
371
+ return 0.5
372
+
373
+ X = torch.cat([pos_activations, neg_activations], dim=0).float().cpu().numpy()
374
+ y = np.array([1] * n_pos + [0] * n_neg)
375
+
376
+ n_folds = min(n_folds, min(n_pos, n_neg))
377
+ if n_folds < 2:
378
+ return 0.5
379
+
380
+ clf = LogisticRegression(max_iter=1000, solver='lbfgs')
381
+ scores = cross_val_score(clf, X, y, cv=n_folds, scoring='accuracy')
382
+ return float(scores.mean())
383
+ except Exception:
384
+ return 0.5
385
+
386
+
387
+ @dataclass
388
+ class GeometryTestResult:
389
+ """Result of a single geometry test."""
390
+ benchmark: str
391
+ strategy: str
392
+ layers: List[int]
393
+
394
+ # Step 1: Is there any signal? (MLP CV accuracy)
395
+ signal_strength: float # MLP CV accuracy, ~0.5 = no signal, >0.6 = signal exists
396
+ has_signal: bool # signal_strength > 0.6
397
+
398
+ # Step 2: Is signal linear? (Linear probe CV accuracy)
399
+ linear_probe_accuracy: float # Linear CV accuracy, high = linear, low = nonlinear
400
+ is_linear: bool # linear_probe_accuracy > 0.6 AND close to signal_strength
401
+
402
+ # NEW: Nonlinear signal metrics
403
+ knn_accuracy_k5: float # k-NN CV accuracy with k=5
404
+ knn_accuracy_k10: float # k-NN CV accuracy with k=10
405
+ knn_accuracy_k20: float # k-NN CV accuracy with k=20
406
+ mmd_rbf: float # Maximum Mean Discrepancy with RBF kernel
407
+ local_dim_pos: float # Local intrinsic dimension of positive class
408
+ local_dim_neg: float # Local intrinsic dimension of negative class
409
+ local_dim_ratio: float # Ratio of local dimensions
410
+ fisher_max: float # Max Fisher ratio across all dimensions
411
+ fisher_gini: float # Gini coefficient of Fisher ratios (concentration)
412
+ fisher_top10_ratio: float # Fraction of total Fisher in top 10 dims
413
+ num_dims_fisher_above_1: int # Number of dimensions with Fisher > 1
414
+ density_ratio: float # Ratio of avg intra-class distances
415
+
416
+ # Step 3: Geometry details (only meaningful if has_signal=True)
417
+ # Best structure detected
418
+ best_structure: str # 'linear', 'cone', 'cluster', 'manifold', 'sparse', 'bimodal', 'orthogonal'
419
+ best_score: float
420
+
421
+ # All structure scores
422
+ linear_score: float
423
+ cone_score: float
424
+ orthogonal_score: float
425
+ manifold_score: float
426
+ sparse_score: float
427
+ cluster_score: float
428
+ bimodal_score: float
429
+
430
+ # Detailed metrics per structure
431
+ # Linear
432
+ cohens_d: float # separation quality
433
+ variance_explained: float # by primary direction
434
+ within_class_consistency: float
435
+
436
+ # Cone
437
+ raw_mean_cosine_similarity: float # between diff vectors
438
+ positive_correlation_fraction: float # fraction in same half-space
439
+
440
+ # Orthogonal
441
+ near_zero_fraction: float # fraction of near-zero correlations
442
+
443
+ # Manifold
444
+ pca_top2_variance: float # variance by top 2 PCs
445
+ local_nonlinearity: float # curvature measure
446
+
447
+ # Sparse
448
+ gini_coefficient: float # inequality of activations
449
+ active_fraction: float # fraction of active neurons
450
+ top_10_contribution: float # contribution of top 10 neurons
451
+
452
+ # Cluster
453
+ best_silhouette: float # clustering quality
454
+ best_k: int # optimal number of clusters
455
+
456
+ # Recommendation
457
+ recommended_method: str
458
+
459
+ def to_dict(self) -> Dict[str, Any]:
460
+ return {
461
+ "benchmark": self.benchmark,
462
+ "strategy": self.strategy,
463
+ "layers": self.layers,
464
+ # Step 1: Signal detection
465
+ "signal_strength": self.signal_strength,
466
+ "has_signal": self.has_signal,
467
+ # Step 2: Linearity check
468
+ "linear_probe_accuracy": self.linear_probe_accuracy,
469
+ "is_linear": self.is_linear,
470
+ # NEW: Nonlinear signal metrics
471
+ "nonlinear_metrics": {
472
+ "knn_accuracy_k5": self.knn_accuracy_k5,
473
+ "knn_accuracy_k10": self.knn_accuracy_k10,
474
+ "knn_accuracy_k20": self.knn_accuracy_k20,
475
+ "mmd_rbf": self.mmd_rbf,
476
+ "local_dim_pos": self.local_dim_pos,
477
+ "local_dim_neg": self.local_dim_neg,
478
+ "local_dim_ratio": self.local_dim_ratio,
479
+ "fisher_max": self.fisher_max,
480
+ "fisher_gini": self.fisher_gini,
481
+ "fisher_top10_ratio": self.fisher_top10_ratio,
482
+ "num_dims_fisher_above_1": self.num_dims_fisher_above_1,
483
+ "density_ratio": self.density_ratio,
484
+ },
485
+ # Step 3: Geometry (only meaningful if has_signal)
486
+ "best_structure": self.best_structure,
487
+ "best_score": self.best_score,
488
+ "structure_scores": {
489
+ "linear": self.linear_score,
490
+ "cone": self.cone_score,
491
+ "orthogonal": self.orthogonal_score,
492
+ "manifold": self.manifold_score,
493
+ "sparse": self.sparse_score,
494
+ "cluster": self.cluster_score,
495
+ "bimodal": self.bimodal_score,
496
+ },
497
+ "linear_details": {
498
+ "cohens_d": self.cohens_d,
499
+ "variance_explained": self.variance_explained,
500
+ "within_class_consistency": self.within_class_consistency,
501
+ },
502
+ "cone_details": {
503
+ "raw_mean_cosine_similarity": self.raw_mean_cosine_similarity,
504
+ "positive_correlation_fraction": self.positive_correlation_fraction,
505
+ },
506
+ "orthogonal_details": {
507
+ "near_zero_fraction": self.near_zero_fraction,
508
+ },
509
+ "manifold_details": {
510
+ "pca_top2_variance": self.pca_top2_variance,
511
+ "local_nonlinearity": self.local_nonlinearity,
512
+ },
513
+ "sparse_details": {
514
+ "gini_coefficient": self.gini_coefficient,
515
+ "active_fraction": self.active_fraction,
516
+ "top_10_contribution": self.top_10_contribution,
517
+ },
518
+ "cluster_details": {
519
+ "best_silhouette": self.best_silhouette,
520
+ "best_k": self.best_k,
521
+ },
522
+ "recommended_method": self.recommended_method,
523
+ }
524
+
525
+
526
+ @dataclass
527
+ class GeometrySearchResults:
528
+ """Results from a full geometry search."""
529
+ model_name: str
530
+ config: GeometrySearchConfig
531
+ results: List[GeometryTestResult] = field(default_factory=list)
532
+
533
+ # Timing
534
+ total_time_seconds: float = 0.0
535
+ extraction_time_seconds: float = 0.0
536
+ test_time_seconds: float = 0.0
537
+
538
+ # Counts
539
+ benchmarks_tested: int = 0
540
+ strategies_tested: int = 0
541
+ layer_combos_tested: int = 0
542
+
543
+ def add_result(self, result: GeometryTestResult) -> None:
544
+ self.results.append(result)
545
+
546
+ def get_best_by_linear_score(self, n: int = 10) -> List[GeometryTestResult]:
547
+ """Get top N configurations by linear score."""
548
+ return sorted(self.results, key=lambda r: r.linear_score, reverse=True)[:n]
549
+
550
+ def get_best_by_structure(self, structure: str, n: int = 10) -> List[GeometryTestResult]:
551
+ """Get top N configurations by a specific structure score."""
552
+ score_attr = f"{structure}_score"
553
+ return sorted(
554
+ self.results,
555
+ key=lambda r: getattr(r, score_attr, 0.0),
556
+ reverse=True
557
+ )[:n]
558
+
559
+ def get_structure_distribution(self) -> Dict[str, int]:
560
+ """Count how many configurations have each structure as best."""
561
+ counts: Dict[str, int] = {}
562
+ for r in self.results:
563
+ s = r.best_structure
564
+ counts[s] = counts.get(s, 0) + 1
565
+ return counts
566
+
567
+ def get_summary_by_benchmark(self) -> Dict[str, Dict[str, float]]:
568
+ """Get summary statistics grouped by benchmark."""
569
+ by_bench: Dict[str, List[float]] = {}
570
+ for r in self.results:
571
+ if r.benchmark not in by_bench:
572
+ by_bench[r.benchmark] = []
573
+ by_bench[r.benchmark].append(r.linear_score)
574
+
575
+ return {
576
+ bench: {
577
+ "mean": sum(scores) / len(scores),
578
+ "max": max(scores),
579
+ "min": min(scores),
580
+ "count": len(scores),
581
+ }
582
+ for bench, scores in by_bench.items()
583
+ }
584
+
585
+ def to_dict(self) -> Dict[str, Any]:
586
+ return {
587
+ "model_name": self.model_name,
588
+ "config": self.config.to_dict(),
589
+ "total_time_seconds": self.total_time_seconds,
590
+ "extraction_time_seconds": self.extraction_time_seconds,
591
+ "test_time_seconds": self.test_time_seconds,
592
+ "benchmarks_tested": self.benchmarks_tested,
593
+ "strategies_tested": self.strategies_tested,
594
+ "layer_combos_tested": self.layer_combos_tested,
595
+ "results": [r.to_dict() for r in self.results],
596
+ }
597
+
598
+ def save(self, path: str) -> None:
599
+ with open(path, "w") as f:
600
+ json.dump(self.to_dict(), f, indent=2)
601
+
602
+
603
+ def compute_geometry_metrics(
604
+ cached: CachedActivations,
605
+ layers: List[int],
606
+ ) -> GeometryTestResult:
607
+ """
608
+ Compute geometry metrics for a layer combination from cached activations.
609
+
610
+ Uses the comprehensive detect_geometry_structure() to get scores for:
611
+ - linear, cone, cluster, manifold, sparse, bimodal, orthogonal
612
+
613
+ Args:
614
+ cached: Cached activations with all layers
615
+ layers: Layer indices (0-based) to analyze
616
+
617
+ Returns:
618
+ GeometryTestResult with all structure scores
619
+ """
620
+ from wisent.core.contrastive_pairs.diagnostics.control_vectors import (
621
+ detect_geometry_structure,
622
+ GeometryAnalysisConfig,
623
+ )
624
+
625
+ # Stack positive and negative activations for specified layers
626
+ # Convert 0-based indices to 1-based layer names used in cache
627
+ pos_acts_list = []
628
+ neg_acts_list = []
629
+
630
+ for layer_idx in layers:
631
+ layer_name = str(layer_idx + 1) # Convert 0-based to 1-based
632
+ try:
633
+ pos = cached.get_positive_activations(layer_name) # [num_pairs, hidden_size]
634
+ neg = cached.get_negative_activations(layer_name) # [num_pairs, hidden_size]
635
+ pos_acts_list.append(pos)
636
+ neg_acts_list.append(neg)
637
+ except (KeyError, IndexError):
638
+ continue
639
+
640
+ if not pos_acts_list:
641
+ return GeometryTestResult(
642
+ benchmark=cached.benchmark,
643
+ strategy=cached.strategy.value,
644
+ layers=layers,
645
+ signal_strength=0.5,
646
+ has_signal=False,
647
+ linear_probe_accuracy=0.5,
648
+ is_linear=False,
649
+ # Nonlinear metrics
650
+ knn_accuracy_k5=0.5,
651
+ knn_accuracy_k10=0.5,
652
+ knn_accuracy_k20=0.5,
653
+ mmd_rbf=0.0,
654
+ local_dim_pos=0.0,
655
+ local_dim_neg=0.0,
656
+ local_dim_ratio=1.0,
657
+ fisher_max=0.0,
658
+ fisher_gini=0.0,
659
+ fisher_top10_ratio=0.0,
660
+ num_dims_fisher_above_1=0,
661
+ density_ratio=1.0,
662
+ # Structure scores
663
+ best_structure="error",
664
+ best_score=0.0,
665
+ linear_score=0.0,
666
+ cone_score=0.0,
667
+ orthogonal_score=0.0,
668
+ manifold_score=0.0,
669
+ sparse_score=0.0,
670
+ cluster_score=0.0,
671
+ bimodal_score=0.0,
672
+ cohens_d=0.0,
673
+ variance_explained=0.0,
674
+ within_class_consistency=0.0,
675
+ raw_mean_cosine_similarity=0.0,
676
+ positive_correlation_fraction=0.0,
677
+ near_zero_fraction=0.0,
678
+ pca_top2_variance=0.0,
679
+ local_nonlinearity=0.0,
680
+ gini_coefficient=0.0,
681
+ active_fraction=0.0,
682
+ top_10_contribution=0.0,
683
+ best_silhouette=0.0,
684
+ best_k=0,
685
+ recommended_method="error: no activations",
686
+ )
687
+
688
+ # Concatenate across layers: [num_pairs, hidden_size * num_layers]
689
+ pos_activations = torch.cat(pos_acts_list, dim=-1)
690
+ neg_activations = torch.cat(neg_acts_list, dim=-1)
691
+
692
+ # Convert to float32 for geometry analysis (bf16/float16 can cause dtype mismatches)
693
+ pos_activations = pos_activations.float()
694
+ neg_activations = neg_activations.float()
695
+
696
+ # Run comprehensive geometry detection
697
+ config = GeometryAnalysisConfig(
698
+ num_components=5,
699
+ optimization_steps=50, # Reduced for speed since we're testing many combos
700
+ )
701
+
702
+ try:
703
+ result = detect_geometry_structure(pos_activations, neg_activations, config)
704
+
705
+ # Step 1: Compute signal strength (MLP CV accuracy)
706
+ signal_strength = compute_signal_strength(pos_activations, neg_activations)
707
+ has_signal = signal_strength > 0.6
708
+
709
+ # Step 2: Compute linear probe accuracy
710
+ linear_probe_accuracy = compute_linear_probe_accuracy(pos_activations, neg_activations)
711
+ # Signal is linear if: has signal AND linear probe is close to MLP (within 0.1)
712
+ is_linear = has_signal and linear_probe_accuracy > 0.6 and (signal_strength - linear_probe_accuracy) < 0.15
713
+
714
+ # Step 2b: Compute nonlinear signal metrics
715
+ knn_k5 = compute_knn_accuracy(pos_activations, neg_activations, k=5)
716
+ knn_k10 = compute_knn_accuracy(pos_activations, neg_activations, k=10)
717
+ knn_k20 = compute_knn_accuracy(pos_activations, neg_activations, k=20)
718
+ mmd = compute_mmd_rbf(pos_activations, neg_activations)
719
+ local_dim_pos, local_dim_neg, local_dim_ratio = compute_local_intrinsic_dims(pos_activations, neg_activations)
720
+ fisher_stats = compute_fisher_per_dimension(pos_activations, neg_activations)
721
+ density_rat = compute_density_ratio(pos_activations, neg_activations)
722
+
723
+ # Determine recommendation based on signal analysis
724
+ if not has_signal:
725
+ recommendation = "NO_SIGNAL"
726
+ elif is_linear:
727
+ recommendation = "CAA" # Linear signal -> use Contrastive Activation Addition
728
+ else:
729
+ recommendation = "NONLINEAR" # Nonlinear signal -> need different method
730
+
731
+ # Helper to safely get detail
732
+ def get_detail(struct_name: str, key: str, default=0.0):
733
+ if struct_name in result.all_scores:
734
+ return result.all_scores[struct_name].details.get(key, default)
735
+ return default
736
+
737
+ return GeometryTestResult(
738
+ benchmark=cached.benchmark,
739
+ strategy=cached.strategy.value,
740
+ layers=layers,
741
+ signal_strength=signal_strength,
742
+ has_signal=has_signal,
743
+ linear_probe_accuracy=linear_probe_accuracy,
744
+ is_linear=is_linear,
745
+ # Nonlinear metrics
746
+ knn_accuracy_k5=knn_k5,
747
+ knn_accuracy_k10=knn_k10,
748
+ knn_accuracy_k20=knn_k20,
749
+ mmd_rbf=mmd,
750
+ local_dim_pos=local_dim_pos,
751
+ local_dim_neg=local_dim_neg,
752
+ local_dim_ratio=local_dim_ratio,
753
+ fisher_max=fisher_stats["fisher_max"],
754
+ fisher_gini=fisher_stats["fisher_gini"],
755
+ fisher_top10_ratio=fisher_stats["fisher_top10_ratio"],
756
+ num_dims_fisher_above_1=fisher_stats["num_dims_fisher_above_1"],
757
+ density_ratio=density_rat,
758
+ # Structure scores
759
+ best_structure=result.best_structure.value,
760
+ best_score=result.best_score,
761
+ linear_score=result.all_scores.get("linear", type('', (), {'score': 0.0})()).score,
762
+ cone_score=result.all_scores.get("cone", type('', (), {'score': 0.0})()).score,
763
+ orthogonal_score=result.all_scores.get("orthogonal", type('', (), {'score': 0.0})()).score,
764
+ manifold_score=result.all_scores.get("manifold", type('', (), {'score': 0.0})()).score,
765
+ sparse_score=result.all_scores.get("sparse", type('', (), {'score': 0.0})()).score,
766
+ cluster_score=result.all_scores.get("cluster", type('', (), {'score': 0.0})()).score,
767
+ bimodal_score=result.all_scores.get("bimodal", type('', (), {'score': 0.0})()).score,
768
+ # Linear details
769
+ cohens_d=get_detail("linear", "cohens_d", 0.0),
770
+ variance_explained=get_detail("linear", "variance_explained", 0.0),
771
+ within_class_consistency=get_detail("linear", "within_class_consistency", 0.0),
772
+ # Cone details
773
+ raw_mean_cosine_similarity=get_detail("cone", "raw_mean_cosine_similarity", 0.0),
774
+ positive_correlation_fraction=get_detail("cone", "positive_correlation_fraction", 0.0),
775
+ # Orthogonal details
776
+ near_zero_fraction=get_detail("orthogonal", "near_zero_fraction", 0.0),
777
+ # Manifold details
778
+ pca_top2_variance=get_detail("manifold", "pca_top2_variance", 0.0),
779
+ local_nonlinearity=get_detail("manifold", "local_nonlinearity", 0.0),
780
+ # Sparse details
781
+ gini_coefficient=get_detail("sparse", "gini_coefficient", 0.0),
782
+ active_fraction=get_detail("sparse", "active_fraction", 0.0),
783
+ top_10_contribution=get_detail("sparse", "top_10_contribution", 0.0),
784
+ # Cluster details
785
+ best_silhouette=get_detail("cluster", "best_silhouette", 0.0),
786
+ best_k=int(get_detail("cluster", "best_k", 2)),
787
+ # Recommendation based on signal analysis
788
+ recommended_method=recommendation,
789
+ )
790
+ except Exception as e:
791
+ return GeometryTestResult(
792
+ benchmark=cached.benchmark,
793
+ strategy=cached.strategy.value,
794
+ layers=layers,
795
+ signal_strength=0.5,
796
+ has_signal=False,
797
+ linear_probe_accuracy=0.5,
798
+ is_linear=False,
799
+ # Nonlinear metrics
800
+ knn_accuracy_k5=0.5,
801
+ knn_accuracy_k10=0.5,
802
+ knn_accuracy_k20=0.5,
803
+ mmd_rbf=0.0,
804
+ local_dim_pos=0.0,
805
+ local_dim_neg=0.0,
806
+ local_dim_ratio=1.0,
807
+ fisher_max=0.0,
808
+ fisher_gini=0.0,
809
+ fisher_top10_ratio=0.0,
810
+ num_dims_fisher_above_1=0,
811
+ density_ratio=1.0,
812
+ # Structure scores
813
+ best_structure="error",
814
+ best_score=0.0,
815
+ linear_score=0.0,
816
+ cone_score=0.0,
817
+ orthogonal_score=0.0,
818
+ manifold_score=0.0,
819
+ sparse_score=0.0,
820
+ cluster_score=0.0,
821
+ bimodal_score=0.0,
822
+ cohens_d=0.0,
823
+ variance_explained=0.0,
824
+ within_class_consistency=0.0,
825
+ raw_mean_cosine_similarity=0.0,
826
+ positive_correlation_fraction=0.0,
827
+ near_zero_fraction=0.0,
828
+ pca_top2_variance=0.0,
829
+ local_nonlinearity=0.0,
830
+ gini_coefficient=0.0,
831
+ active_fraction=0.0,
832
+ top_10_contribution=0.0,
833
+ best_silhouette=0.0,
834
+ best_k=0,
835
+ recommended_method=f"error: {str(e)}",
836
+ )
837
+
838
+
839
+ class GeometryRunner:
840
+ """
841
+ Runs geometry search across the search space.
842
+
843
+ Uses activation caching for efficiency:
844
+ 1. Extract ALL layers once per (benchmark, strategy)
845
+ 2. Test all layer combinations from cache
846
+ """
847
+
848
+ def __init__(
849
+ self,
850
+ search_space: GeometrySearchSpace,
851
+ model: "WisentModel",
852
+ cache_dir: Optional[str] = None,
853
+ ):
854
+ self.search_space = search_space
855
+ self.model = model
856
+ self.cache_dir = cache_dir or f"/tmp/wisent_geometry_cache_{model.model_name.replace('/', '_')}"
857
+ self.cache = ActivationCache(self.cache_dir)
858
+
859
+ def run(
860
+ self,
861
+ benchmarks: Optional[List[str]] = None,
862
+ strategies: Optional[List[ExtractionStrategy]] = None,
863
+ max_layer_combo_size: Optional[int] = None,
864
+ show_progress: bool = True,
865
+ ) -> GeometrySearchResults:
866
+ """
867
+ Run the geometry search.
868
+
869
+ Args:
870
+ benchmarks: Benchmarks to test (default: all from search space)
871
+ strategies: Strategies to test (default: all from search space)
872
+ max_layer_combo_size: Override max layer combo size
873
+ show_progress: Print progress
874
+
875
+ Returns:
876
+ GeometrySearchResults with all test results
877
+ """
878
+ benchmarks = benchmarks or self.search_space.benchmarks
879
+ strategies = strategies or self.search_space.strategies
880
+ max_combo = max_layer_combo_size or self.search_space.config.max_layer_combo_size
881
+
882
+ # Get layer combinations
883
+ num_layers = self.model.num_layers
884
+ layer_combos = get_layer_combinations(num_layers, max_combo)
885
+
886
+ results = GeometrySearchResults(
887
+ model_name=self.model.model_name,
888
+ config=self.search_space.config,
889
+ )
890
+
891
+ start_time = time.time()
892
+ extraction_time = 0.0
893
+ test_time = 0.0
894
+
895
+ total_extractions = len(benchmarks) * len(strategies)
896
+ extraction_count = 0
897
+
898
+ for benchmark in benchmarks:
899
+ for strategy in strategies:
900
+ extraction_count += 1
901
+
902
+ if show_progress:
903
+ print(f"\n[{extraction_count}/{total_extractions}] {benchmark} / {strategy.value}")
904
+
905
+ # Get or create cached activations
906
+ extract_start = time.time()
907
+ try:
908
+ cached = self._get_cached_activations(benchmark, strategy, show_progress)
909
+ except Exception as e:
910
+ if show_progress:
911
+ print(f" SKIP: {e}")
912
+ continue
913
+ extraction_time += time.time() - extract_start
914
+
915
+ # Test all layer combinations
916
+ test_start = time.time()
917
+ for combo in layer_combos:
918
+ result = compute_geometry_metrics(cached, combo)
919
+ results.add_result(result)
920
+ test_time += time.time() - test_start
921
+
922
+ results.benchmarks_tested = len(set(r.benchmark for r in results.results))
923
+ results.strategies_tested = len(set(r.strategy for r in results.results))
924
+ results.layer_combos_tested = len(results.results)
925
+
926
+ if show_progress:
927
+ print(f" Tested {len(layer_combos)} layer combos")
928
+
929
+ results.total_time_seconds = time.time() - start_time
930
+ results.extraction_time_seconds = extraction_time
931
+ results.test_time_seconds = test_time
932
+
933
+ return results
934
+
935
+ def _get_cached_activations(
936
+ self,
937
+ benchmark: str,
938
+ strategy: ExtractionStrategy,
939
+ show_progress: bool = True,
940
+ ) -> CachedActivations:
941
+ """Get cached activations, extracting if necessary."""
942
+ # Check cache
943
+ if self.cache.has(self.model.model_name, benchmark, strategy):
944
+ if show_progress:
945
+ print(f" Loading from cache...")
946
+ return self.cache.get(self.model.model_name, benchmark, strategy)
947
+
948
+ # Need to extract - load pairs first
949
+ if show_progress:
950
+ print(f" Loading pairs...")
951
+
952
+ pairs = self._load_pairs(benchmark)
953
+
954
+ if show_progress:
955
+ print(f" Extracting activations for {len(pairs)} pairs...")
956
+
957
+ return collect_and_cache_activations(
958
+ model=self.model,
959
+ pairs=pairs,
960
+ benchmark=benchmark,
961
+ strategy=strategy,
962
+ cache=self.cache,
963
+ show_progress=show_progress,
964
+ )
965
+
966
+ def _load_pairs(self, benchmark: str) -> List:
967
+ """Load contrastive pairs for a benchmark."""
968
+ from lm_eval.tasks import TaskManager
969
+ from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import lm_build_contrastive_pairs
970
+
971
+ tm = TaskManager()
972
+ try:
973
+ task_dict = tm.load_task_or_group([benchmark])
974
+ task = list(task_dict.values())[0]
975
+ except Exception:
976
+ task = None
977
+
978
+ pairs = lm_build_contrastive_pairs(
979
+ benchmark,
980
+ task,
981
+ limit=self.search_space.config.pairs_per_benchmark
982
+ )
983
+
984
+ # Random sample if we have more pairs than needed
985
+ if len(pairs) > self.search_space.config.pairs_per_benchmark:
986
+ random.seed(self.search_space.config.random_seed)
987
+ pairs = random.sample(pairs, self.search_space.config.pairs_per_benchmark)
988
+
989
+ return pairs
990
+
991
+
992
+ # Type hints
993
+ from typing import TYPE_CHECKING
994
+ if TYPE_CHECKING:
995
+ from wisent.core.models.wisent_model import WisentModel