wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (391) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activation_cache.py +393 -0
  12. wisent/core/activations/activations.py +3 -3
  13. wisent/core/activations/activations_collector.py +12 -7
  14. wisent/core/activations/classifier_inference_strategy.py +12 -11
  15. wisent/core/activations/extraction_strategy.py +260 -84
  16. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  17. wisent/core/cli/__init__.py +2 -1
  18. wisent/core/cli/agent/train_classifier.py +16 -3
  19. wisent/core/cli/check_linearity.py +35 -3
  20. wisent/core/cli/cluster_benchmarks.py +4 -6
  21. wisent/core/cli/create_steering_vector.py +6 -4
  22. wisent/core/cli/diagnose_vectors.py +7 -4
  23. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  24. wisent/core/cli/generate_pairs_from_task.py +9 -56
  25. wisent/core/cli/generate_vector_from_task.py +11 -20
  26. wisent/core/cli/geometry_search.py +137 -0
  27. wisent/core/cli/get_activations.py +2 -2
  28. wisent/core/cli/method_optimizer.py +4 -3
  29. wisent/core/cli/modify_weights.py +3 -2
  30. wisent/core/cli/optimize_sample_size.py +1 -1
  31. wisent/core/cli/optimize_steering.py +14 -16
  32. wisent/core/cli/optimize_weights.py +2 -1
  33. wisent/core/cli/preview_pairs.py +203 -0
  34. wisent/core/cli/steering_method_trainer.py +3 -3
  35. wisent/core/cli/tasks.py +19 -76
  36. wisent/core/cli/train_unified_goodness.py +3 -3
  37. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  38. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  53. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  54. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  55. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  56. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  57. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  58. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  59. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  60. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  61. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  273. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  274. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  275. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  276. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  277. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  278. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  279. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  280. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  281. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  282. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  283. wisent/core/geometry_runner.py +995 -0
  284. wisent/core/geometry_search_space.py +237 -0
  285. wisent/core/hyperparameter_optimizer.py +1 -1
  286. wisent/core/main.py +3 -0
  287. wisent/core/models/core/atoms.py +5 -3
  288. wisent/core/models/wisent_model.py +1 -1
  289. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  290. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  291. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  292. wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
  293. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  294. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  295. wisent/core/parser_arguments/main_parser.py +8 -0
  296. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  297. wisent/core/steering.py +5 -3
  298. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  299. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  300. wisent/core/trainers/steering_trainer.py +2 -2
  301. wisent/core/utils/device.py +27 -27
  302. wisent/core/utils/layer_combinations.py +70 -0
  303. wisent/examples/__init__.py +1 -0
  304. wisent/examples/scripts/__init__.py +1 -0
  305. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  306. wisent/examples/scripts/discover_directions.py +469 -0
  307. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  308. wisent/examples/scripts/search_all_short_names.py +31 -0
  309. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  310. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  311. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  312. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  313. wisent/examples/scripts/test_one_benchmark.py +324 -0
  314. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  315. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  316. wisent/parameters/lm_eval/category_directions.json +137 -0
  317. wisent/parameters/lm_eval/repair_plan.json +282 -0
  318. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  319. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  320. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  321. wisent/tests/test_detector_accuracy.py +1 -1
  322. wisent/tests/visualize_geometry.py +1 -1
  323. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
  325. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  326. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  327. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  328. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  329. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  330. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  331. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  332. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  333. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  334. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  335. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  336. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  337. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  338. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  339. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  340. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  341. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  342. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  343. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  344. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  345. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  346. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  347. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  348. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  349. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  350. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  351. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  352. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  353. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  354. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  355. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  356. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  357. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  358. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  359. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  360. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  361. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  362. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  363. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  364. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  365. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  366. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  367. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  368. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  369. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  370. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  371. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  372. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  373. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  374. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  375. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  376. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  377. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  378. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  379. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  380. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  381. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  382. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  383. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  384. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  385. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  386. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  387. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  388. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  389. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  390. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  391. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,669 @@
1
+ """
2
+ LoRA fine-tuning method for comparison experiments.
3
+
4
+ Trains a LoRA adapter on benchmark tasks using supervised fine-tuning (SFT)
5
+ on positive responses from contrastive pairs.
6
+
7
+ Optionally evaluates LoRA + steering by generating a steering vector on the
8
+ LoRA model and combining both methods.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import gc
14
+ import json
15
+ import tempfile
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING
18
+
19
+ import torch
20
+ from datasets import Dataset
21
+ from peft import LoraConfig, TaskType, get_peft_model
22
+ from trl import SFTTrainer, SFTConfig
23
+
24
+ from wisent.comparison.utils import (
25
+ generate_contrastive_pairs,
26
+ create_test_only_task,
27
+ extract_accuracy,
28
+ run_lm_eval_evaluation,
29
+ run_ll_evaluation,
30
+ load_model_and_tokenizer,
31
+ apply_steering_to_model,
32
+ remove_steering,
33
+ )
34
+ from wisent.core.utils.device import preferred_dtype
35
+
36
+ if TYPE_CHECKING:
37
+ from wisent.core.models.wisent_model import WisentModel
38
+
39
+ __all__ = ["train_lora_adapter", "evaluate_lora", "apply_lora_to_model", "remove_lora"]
40
+
41
+
42
+ # Default LoRA configurations per model architecture
43
+ LORA_TARGET_MODULES = {
44
+ "gemma": ["q_proj", "k_proj", "v_proj", "o_proj"],
45
+ "llama": ["q_proj", "k_proj", "v_proj", "o_proj"],
46
+ "mistral": ["q_proj", "k_proj", "v_proj", "o_proj"],
47
+ "phi": ["q_proj", "k_proj", "v_proj", "dense"],
48
+ "gpt_neo": ["q_proj", "v_proj"],
49
+ "gpt2": ["c_attn"],
50
+ "default": "all-linear",
51
+ }
52
+
53
+
54
+ def get_target_modules(model_name: str) -> str | list[str]:
55
+ """Get LoRA target modules based on model architecture."""
56
+ model_name_lower = model_name.lower()
57
+
58
+ for arch, modules in LORA_TARGET_MODULES.items():
59
+ if arch in model_name_lower:
60
+ return modules
61
+
62
+ return LORA_TARGET_MODULES["default"]
63
+
64
+
65
+ def prepare_sft_dataset(
66
+ pairs: list[dict],
67
+ tokenizer,
68
+ max_length: int = 512,
69
+ ) -> Dataset:
70
+ """
71
+ Prepare dataset for SFT from contrastive pairs.
72
+
73
+ Uses only positive responses for training.
74
+
75
+ Args:
76
+ pairs: List of contrastive pairs
77
+ tokenizer: Tokenizer for formatting
78
+ max_length: Maximum sequence length
79
+
80
+ Returns:
81
+ HuggingFace Dataset ready for SFTTrainer
82
+ """
83
+ formatted_examples = []
84
+
85
+ for pair in pairs:
86
+ prompt = pair["prompt"]
87
+ positive_response = pair["positive_response"]["model_response"]
88
+
89
+ # Format as chat if tokenizer supports it, otherwise simple format
90
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
91
+ messages = [
92
+ {"role": "user", "content": prompt},
93
+ {"role": "assistant", "content": positive_response},
94
+ ]
95
+ text = tokenizer.apply_chat_template(
96
+ messages,
97
+ tokenize=False,
98
+ add_generation_prompt=False,
99
+ )
100
+ else:
101
+ # Simple format for base models
102
+ text = f"Q: {prompt}\nA: {positive_response}"
103
+
104
+ formatted_examples.append({"text": text})
105
+
106
+ return Dataset.from_list(formatted_examples)
107
+
108
+
109
+ def train_lora_adapter(
110
+ task: str,
111
+ model_name: str,
112
+ output_path: str | Path,
113
+ trait_label: str = "correctness",
114
+ num_pairs: int = 50,
115
+ device: str = "cuda:0",
116
+ keep_intermediate: bool = False,
117
+ # LoRA-specific parameters
118
+ lora_r: int = 16,
119
+ lora_alpha: int = 32,
120
+ lora_dropout: float = 0.05,
121
+ learning_rate: float = 2e-4,
122
+ num_epochs: int = 3,
123
+ batch_size: int = 2,
124
+ max_length: int = 512,
125
+ ) -> Path:
126
+ """
127
+ Train a LoRA adapter using SFT on positive responses.
128
+
129
+ Args:
130
+ task: lm-eval task name (e.g., 'boolq', 'cb')
131
+ model_name: HuggingFace model name
132
+ output_path: Where to save the LoRA adapter
133
+ trait_label: Label for the trait being trained
134
+ num_pairs: Number of training examples to use
135
+ device: Device to train on
136
+ keep_intermediate: Whether to keep intermediate files
137
+ lora_r: LoRA rank
138
+ lora_alpha: LoRA alpha scaling factor
139
+ lora_dropout: LoRA dropout
140
+ learning_rate: Training learning rate
141
+ num_epochs: Number of training epochs
142
+ batch_size: Training batch size
143
+ max_length: Maximum sequence length
144
+
145
+ Returns:
146
+ Path to the saved LoRA adapter directory
147
+ """
148
+ import gc
149
+
150
+ output_path = Path(output_path)
151
+
152
+ # Step 1: Generate contrastive pairs
153
+ print(f"Step 1: Generating training data from task: {task}")
154
+ pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
155
+ print(f" Loaded {len(pairs)} training examples")
156
+
157
+ # Step 2: Load model and tokenizer
158
+ print(f"\nStep 2: Loading model {model_name}...")
159
+ model, tokenizer = load_model_and_tokenizer(model_name, device, eval_mode=False)
160
+
161
+ # Step 3: Configure LoRA
162
+ print(f"\nStep 3: Configuring LoRA (r={lora_r}, alpha={lora_alpha})...")
163
+
164
+ target_modules = get_target_modules(model_name)
165
+ print(f" Target modules: {target_modules}")
166
+
167
+ lora_config = LoraConfig(
168
+ r=lora_r,
169
+ lora_alpha=lora_alpha,
170
+ target_modules=target_modules,
171
+ lora_dropout=lora_dropout,
172
+ bias="none",
173
+ task_type=TaskType.CAUSAL_LM,
174
+ )
175
+
176
+ model = get_peft_model(model, lora_config)
177
+ model.print_trainable_parameters()
178
+
179
+ # Step 4: Prepare dataset
180
+ print(f"\nStep 4: Preparing SFT dataset...")
181
+ train_dataset = prepare_sft_dataset(pairs, tokenizer, max_length=max_length)
182
+ print(f" Dataset size: {len(train_dataset)} examples")
183
+
184
+ # Step 5: Training
185
+ print(f"\nStep 5: Training LoRA adapter...")
186
+
187
+ # Create temporary directory for training outputs
188
+ training_output_dir = tempfile.mkdtemp(prefix="lora_training_")
189
+
190
+ # Use device-optimized dtype (bfloat16 on CUDA, float16 on MPS, float32 on CPU)
191
+ dtype = preferred_dtype(device)
192
+
193
+ training_args = SFTConfig(
194
+ output_dir=training_output_dir,
195
+ num_train_epochs=num_epochs,
196
+ per_device_train_batch_size=batch_size,
197
+ gradient_accumulation_steps=1,
198
+ learning_rate=learning_rate,
199
+ weight_decay=0.01,
200
+ warmup_ratio=0.1,
201
+ logging_steps=10,
202
+ save_strategy="no", # Don't save checkpoints
203
+ bf16=(dtype == torch.bfloat16),
204
+ fp16=(dtype == torch.float16),
205
+ report_to="none", # Disable wandb/tensorboard
206
+ dataset_text_field="text", # Field containing the text to train on
207
+ )
208
+
209
+ trainer = SFTTrainer(
210
+ model=model,
211
+ args=training_args,
212
+ train_dataset=train_dataset,
213
+ processing_class=tokenizer,
214
+ )
215
+
216
+ trainer.train()
217
+
218
+ # Step 6: Save LoRA adapter
219
+ print(f"\nStep 6: Saving LoRA adapter to {output_path}...")
220
+ output_path.mkdir(parents=True, exist_ok=True)
221
+ model.save_pretrained(output_path)
222
+ tokenizer.save_pretrained(output_path)
223
+
224
+ # Save metadata
225
+ metadata = {
226
+ "method": "lora",
227
+ "model": model_name,
228
+ "task": task,
229
+ "trait_label": trait_label,
230
+ "num_pairs": len(pairs),
231
+ "lora_config": {
232
+ "r": lora_r,
233
+ "alpha": lora_alpha,
234
+ "dropout": lora_dropout,
235
+ "target_modules": target_modules if isinstance(target_modules, list) else [target_modules],
236
+ },
237
+ "training_config": {
238
+ "learning_rate": learning_rate,
239
+ "num_epochs": num_epochs,
240
+ "batch_size": batch_size,
241
+ "max_length": max_length,
242
+ },
243
+ }
244
+
245
+ with open(output_path / "metadata.json", "w") as f:
246
+ json.dump(metadata, f, indent=2)
247
+
248
+ # Cleanup
249
+ del model, trainer
250
+ gc.collect()
251
+ if torch.cuda.is_available():
252
+ torch.cuda.empty_cache()
253
+ torch.cuda.synchronize()
254
+
255
+ if not keep_intermediate:
256
+ import os
257
+ os.unlink(pairs_file)
258
+ import shutil
259
+ shutil.rmtree(training_output_dir, ignore_errors=True)
260
+
261
+ print(f"\nLoRA adapter saved to {output_path}")
262
+ return output_path
263
+
264
+
265
+ def apply_lora_to_model(wisent_model: "WisentModel", lora_path: str | Path) -> None:
266
+ """
267
+ Apply a trained LoRA adapter to a WisentModel.
268
+
269
+ Args:
270
+ wisent_model: WisentModel instance
271
+ lora_path: Path to the saved LoRA adapter
272
+ """
273
+ from peft import PeftModel
274
+
275
+ lora_path = Path(lora_path)
276
+
277
+ # Check if model already has adapters
278
+ if hasattr(wisent_model.hf_model, 'peft_config'):
279
+ # Model already has PEFT, just load new adapter
280
+ wisent_model.hf_model.load_adapter(str(lora_path), adapter_name="steering")
281
+ wisent_model.hf_model.set_adapter("steering")
282
+ else:
283
+ # Wrap model with PEFT
284
+ wisent_model.hf_model = PeftModel.from_pretrained(
285
+ wisent_model.hf_model,
286
+ str(lora_path),
287
+ adapter_name="steering",
288
+ )
289
+
290
+ print(f"LoRA adapter loaded from {lora_path}")
291
+
292
+
293
+ def remove_lora(wisent_model: "WisentModel") -> None:
294
+ """
295
+ Remove/disable LoRA adapter from a WisentModel.
296
+
297
+ Args:
298
+ wisent_model: WisentModel instance with LoRA applied
299
+ """
300
+ if hasattr(wisent_model.hf_model, 'disable_adapters'):
301
+ try:
302
+ wisent_model.hf_model.disable_adapters()
303
+ print("LoRA adapter disabled")
304
+ except ValueError:
305
+ # No adapter was loaded
306
+ pass
307
+ elif hasattr(wisent_model.hf_model, 'base_model'):
308
+ # Unwrap the model
309
+ wisent_model.hf_model = wisent_model.hf_model.base_model.model
310
+ print("LoRA adapter removed")
311
+
312
+
313
+ def evaluate_lora(
314
+ model_name: str,
315
+ lora_path: str | Path,
316
+ task: str,
317
+ train_ratio: float = 0.8,
318
+ device: str = "cuda:0",
319
+ batch_size: int = 1,
320
+ max_batch_size: int = 8,
321
+ limit: int | None = None,
322
+ output_dir: str | Path = None,
323
+ # Training metadata (for output)
324
+ num_train_pairs: int | None = None,
325
+ num_epochs: int | None = None,
326
+ lora_r: int | None = None,
327
+ lora_alpha: int | None = None,
328
+ lora_dropout: float | None = None,
329
+ learning_rate: float | None = None,
330
+ # Steering parameters (optional)
331
+ with_steering: bool = False,
332
+ steering_method: str = "caa",
333
+ steering_layers: str = "12",
334
+ steering_num_pairs: int = 50,
335
+ steering_scales: list[float] | None = None,
336
+ extraction_strategy: str = "mc_completion",
337
+ ) -> dict:
338
+ """
339
+ Evaluate a trained LoRA adapter comparing base vs LoRA performance.
340
+
341
+ Optionally also evaluates LoRA + steering at multiple scales.
342
+ All results are saved to a single output file.
343
+
344
+ Args:
345
+ model_name: HuggingFace model name
346
+ lora_path: Path to trained LoRA adapter
347
+ task: lm-eval task name
348
+ train_ratio: Train/test split ratio
349
+ device: Device to run on
350
+ batch_size: Batch size for evaluation
351
+ max_batch_size: Max batch size
352
+ limit: Limit number of eval examples
353
+ output_dir: Where to save results
354
+ with_steering: Whether to also evaluate LoRA + steering
355
+ steering_method: Steering method (caa or fgaa)
356
+ steering_layers: Layers for steering vector
357
+ steering_num_pairs: Number of pairs for steering generation
358
+ steering_scales: List of steering scales to evaluate
359
+ extraction_strategy: Strategy for activation extraction
360
+
361
+ Returns:
362
+ Dict with evaluation results
363
+ """
364
+ import gc
365
+
366
+ from wisent.core.models.wisent_model import WisentModel
367
+
368
+ lora_path = Path(lora_path)
369
+
370
+ if steering_scales is None:
371
+ steering_scales = [1.0, 2.0, 4.0]
372
+
373
+ # Create test task
374
+ print(f"\n{'='*60}")
375
+ print(f"Creating test task for: {task}")
376
+ print(f"{'='*60}")
377
+
378
+ task_dict = create_test_only_task(task, train_ratio=train_ratio)
379
+
380
+ # Load model
381
+ print(f"\n{'='*60}")
382
+ print(f"Loading model: {model_name}")
383
+ print(f"{'='*60}")
384
+ wisent_model = WisentModel(model_name=model_name, device=device)
385
+
386
+ # BASE evaluation
387
+ print(f"\n{'='*60}")
388
+ print(f"Running BASE evaluation (no LoRA)")
389
+ print(f"{'='*60}")
390
+
391
+ base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
392
+ base_acc_lm_eval = extract_accuracy(base_results, task)
393
+ print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
394
+
395
+ base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
396
+ print(f"Base accuracy (LL): {base_acc_ll:.4f}")
397
+
398
+ # Apply LoRA
399
+ print(f"\n{'='*60}")
400
+ print(f"Applying LoRA adapter from: {lora_path}")
401
+ print(f"{'='*60}")
402
+ apply_lora_to_model(wisent_model, lora_path)
403
+
404
+ # LORA evaluation
405
+ print(f"\n{'='*60}")
406
+ print(f"Running LORA evaluation")
407
+ print(f"{'='*60}")
408
+
409
+ lora_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
410
+ lora_acc_lm_eval = extract_accuracy(lora_results, task)
411
+ print(f"LoRA accuracy (lm-eval): {lora_acc_lm_eval:.4f}")
412
+
413
+ lora_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
414
+ print(f"LoRA accuracy (LL): {lora_acc_ll:.4f}")
415
+
416
+ # Results dict
417
+ results = {
418
+ "task": task,
419
+ "model": model_name,
420
+ "lora_path": str(lora_path),
421
+ # Training config
422
+ "num_train_pairs": num_train_pairs,
423
+ "num_epochs": num_epochs,
424
+ "lora_r": lora_r,
425
+ "lora_alpha": lora_alpha,
426
+ "lora_dropout": lora_dropout,
427
+ "learning_rate": learning_rate,
428
+ # Eval config
429
+ "train_ratio": train_ratio,
430
+ "eval_limit": limit,
431
+ # Results
432
+ "base_accuracy_lm_eval": base_acc_lm_eval,
433
+ "base_accuracy_ll": base_acc_ll,
434
+ "lora_accuracy_lm_eval": lora_acc_lm_eval,
435
+ "lora_accuracy_ll": lora_acc_ll,
436
+ "lora_diff_lm_eval": lora_acc_lm_eval - base_acc_lm_eval,
437
+ "lora_diff_ll": lora_acc_ll - base_acc_ll,
438
+ }
439
+
440
+ # LoRA + Steering evaluation (if enabled)
441
+ if with_steering:
442
+ from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
443
+ from wisent.core.steering_methods import get_steering_method
444
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
445
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
446
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
447
+ from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
448
+
449
+ # Generate contrastive pairs for steering
450
+ print(f"\n{'='*60}")
451
+ print(f"Generating {steering_num_pairs} contrastive pairs for steering")
452
+ print(f"{'='*60}")
453
+ pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
454
+
455
+ # Convert to ContrastivePairSet
456
+ pairs = []
457
+ for p in pairs_data:
458
+ pair = ContrastivePair(
459
+ prompt=p["prompt"],
460
+ positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
461
+ negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
462
+ )
463
+ pairs.append(pair)
464
+ pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_lora_steering")
465
+ print(f"Created {len(pair_set)} contrastive pairs")
466
+
467
+ # Generate steering vector on LoRA model
468
+ print(f"\n{'='*60}")
469
+ print(f"Generating {steering_method.upper()} steering vector on LoRA model")
470
+ print(f"Layers: {steering_layers}")
471
+ print(f"{'='*60}")
472
+
473
+ steering_method_obj = get_steering_method(steering_method, device=device)
474
+ strategy = ExtractionStrategy(extraction_strategy)
475
+
476
+ trainer = WisentSteeringTrainer(
477
+ model=wisent_model,
478
+ pair_set=pair_set,
479
+ steering_method=steering_method_obj,
480
+ )
481
+
482
+ result = trainer.run(
483
+ layers_spec=steering_layers,
484
+ strategy=strategy,
485
+ accept_low_quality_vector=True,
486
+ )
487
+
488
+ # Convert to dict format for apply_steering_to_model
489
+ steering_vectors = {}
490
+ for layer_name, tensor in result.steered_vectors.to_dict().items():
491
+ if tensor is not None:
492
+ steering_vectors[layer_name] = tensor.cpu().float().tolist()
493
+
494
+ steering_data = {
495
+ "steering_vectors": steering_vectors,
496
+ "layers": list(steering_vectors.keys()),
497
+ }
498
+
499
+ # Cleanup temp file
500
+ import os
501
+ os.unlink(pairs_file)
502
+
503
+ # Add steering info to results
504
+ results["steering"] = {
505
+ "method": steering_method,
506
+ "layers": list(steering_vectors.keys()),
507
+ "num_pairs": steering_num_pairs,
508
+ "extraction_strategy": extraction_strategy,
509
+ "scales": {},
510
+ }
511
+
512
+ # Evaluate at each scale
513
+ for scale in steering_scales:
514
+ print(f"\n{'='*60}")
515
+ print(f"Evaluating LoRA+{steering_method.upper()} at scale={scale}")
516
+ print(f"{'='*60}")
517
+
518
+ apply_steering_to_model(wisent_model, steering_data, scale=scale)
519
+
520
+ steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
521
+ steer_acc_lm_eval = extract_accuracy(steer_results, task)
522
+ print(f"LoRA+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
523
+
524
+ steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
525
+ print(f"LoRA+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
526
+
527
+ remove_steering(wisent_model)
528
+
529
+ results["steering"]["scales"][str(scale)] = {
530
+ "accuracy_lm_eval": steer_acc_lm_eval,
531
+ "accuracy_ll": steer_acc_ll,
532
+ "diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
533
+ "diff_from_base_ll": steer_acc_ll - base_acc_ll,
534
+ "diff_from_lora_lm_eval": steer_acc_lm_eval - lora_acc_lm_eval,
535
+ "diff_from_lora_ll": steer_acc_ll - lora_acc_ll,
536
+ }
537
+
538
+ # Cleanup
539
+ remove_lora(wisent_model)
540
+ del wisent_model
541
+ gc.collect()
542
+ if torch.cuda.is_available():
543
+ torch.cuda.empty_cache()
544
+
545
+ # Print summary
546
+ print(f"\n{'='*70}")
547
+ print(f"RESULTS SUMMARY")
548
+ print(f"{'='*70}")
549
+ print(f"Task: {task}")
550
+ print(f"Model: {model_name}")
551
+ print(f"LoRA: {lora_path}")
552
+ print(f"{'-'*70}")
553
+ print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
554
+ print(f"{'-'*70}")
555
+ print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
556
+ print(f"{'LoRA':<25} {lora_acc_lm_eval:<15.4f} {lora_acc_ll:<15.4f} {lora_acc_lm_eval - base_acc_lm_eval:+.4f}")
557
+
558
+ if with_steering:
559
+ for scale, res in results["steering"]["scales"].items():
560
+ label = f"LoRA+{steering_method.upper()}@{scale}"
561
+ print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
562
+
563
+ print(f"{'='*70}")
564
+
565
+ # Save results
566
+ if output_dir:
567
+ output_dir = Path(output_dir)
568
+ model_dir_name = model_name.replace("/", "_")
569
+ output_dir = output_dir / model_dir_name
570
+ output_dir.mkdir(parents=True, exist_ok=True)
571
+ results_file = output_dir / f"{task}_lora_eval_results.json"
572
+ with open(results_file, "w") as f:
573
+ json.dump(results, f, indent=2)
574
+ print(f"\nResults saved to: {results_file}")
575
+
576
+ return results
577
+
578
+
579
+ def main():
580
+ import argparse
581
+
582
+ parser = argparse.ArgumentParser(description="Train and evaluate LoRA adapter on benchmark task")
583
+ parser.add_argument("--model", required=True, help="HuggingFace model name")
584
+ parser.add_argument("--task", default="boolq", help="lm-eval task name")
585
+ parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
586
+ parser.add_argument("--num-pairs", type=int, default=50, help="Number of training examples")
587
+ parser.add_argument("--device", default="cuda:0", help="Device")
588
+ parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
589
+ parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
590
+ parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
591
+ parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
592
+ parser.add_argument("--num-epochs", type=int, default=3, help="Number of epochs")
593
+ parser.add_argument("--batch-size", type=int, default=2, help="Training batch size")
594
+ parser.add_argument("--max-length", type=int, default=512, help="Max sequence length")
595
+ parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
596
+ # Eval args
597
+ parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
598
+ parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size (int or 'auto')")
599
+ parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size for auto")
600
+ parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
601
+ parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
602
+ # LoRA + Steering args
603
+ parser.add_argument("--with-steering", action="store_true", help="Also evaluate LoRA + steering")
604
+ parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
605
+ parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
606
+ parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
607
+ parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
608
+ parser.add_argument("--extraction-strategy", default="mc_balanced", help="Extraction strategy for steering")
609
+
610
+ args = parser.parse_args()
611
+
612
+ output_path = Path(args.output_dir) / f"{args.task}_lora_adapter"
613
+
614
+ # Train
615
+ train_lora_adapter(
616
+ task=args.task,
617
+ model_name=args.model,
618
+ output_path=output_path,
619
+ num_pairs=args.num_pairs,
620
+ device=args.device,
621
+ keep_intermediate=args.keep_intermediate,
622
+ lora_r=args.lora_r,
623
+ lora_alpha=args.lora_alpha,
624
+ lora_dropout=args.lora_dropout,
625
+ learning_rate=args.learning_rate,
626
+ num_epochs=args.num_epochs,
627
+ batch_size=args.batch_size,
628
+ max_length=args.max_length,
629
+ )
630
+
631
+ # Evaluate base vs LoRA (and optionally LoRA + steering)
632
+ if not args.skip_eval:
633
+ # Parse eval batch size (can be "auto" or int)
634
+ eval_batch_size = args.eval_batch_size
635
+ if eval_batch_size != "auto":
636
+ eval_batch_size = int(eval_batch_size)
637
+
638
+ # Parse steering scales
639
+ steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
640
+
641
+ evaluate_lora(
642
+ model_name=args.model,
643
+ lora_path=output_path,
644
+ task=args.task,
645
+ train_ratio=args.train_ratio,
646
+ device=args.device,
647
+ batch_size=eval_batch_size,
648
+ max_batch_size=args.eval_max_batch_size,
649
+ limit=args.eval_limit,
650
+ output_dir=args.output_dir,
651
+ # Training metadata
652
+ num_train_pairs=args.num_pairs,
653
+ num_epochs=args.num_epochs,
654
+ lora_r=args.lora_r,
655
+ lora_alpha=args.lora_alpha,
656
+ lora_dropout=args.lora_dropout,
657
+ learning_rate=args.learning_rate,
658
+ # Steering parameters
659
+ with_steering=args.with_steering,
660
+ steering_method=args.steering_method,
661
+ steering_layers=args.steering_layers,
662
+ steering_num_pairs=args.steering_num_pairs,
663
+ steering_scales=steering_scales,
664
+ extraction_strategy=args.extraction_strategy,
665
+ )
666
+
667
+
668
+ if __name__ == "__main__":
669
+ main()