wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (391) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activation_cache.py +393 -0
  12. wisent/core/activations/activations.py +3 -3
  13. wisent/core/activations/activations_collector.py +12 -7
  14. wisent/core/activations/classifier_inference_strategy.py +12 -11
  15. wisent/core/activations/extraction_strategy.py +260 -84
  16. wisent/core/classifiers/classifiers/core/atoms.py +3 -2
  17. wisent/core/cli/__init__.py +2 -1
  18. wisent/core/cli/agent/train_classifier.py +16 -3
  19. wisent/core/cli/check_linearity.py +35 -3
  20. wisent/core/cli/cluster_benchmarks.py +4 -6
  21. wisent/core/cli/create_steering_vector.py +6 -4
  22. wisent/core/cli/diagnose_vectors.py +7 -4
  23. wisent/core/cli/estimate_unified_goodness_time.py +6 -4
  24. wisent/core/cli/generate_pairs_from_task.py +9 -56
  25. wisent/core/cli/generate_vector_from_task.py +11 -20
  26. wisent/core/cli/geometry_search.py +137 -0
  27. wisent/core/cli/get_activations.py +2 -2
  28. wisent/core/cli/method_optimizer.py +4 -3
  29. wisent/core/cli/modify_weights.py +3 -2
  30. wisent/core/cli/optimize_sample_size.py +1 -1
  31. wisent/core/cli/optimize_steering.py +14 -16
  32. wisent/core/cli/optimize_weights.py +2 -1
  33. wisent/core/cli/preview_pairs.py +203 -0
  34. wisent/core/cli/steering_method_trainer.py +3 -3
  35. wisent/core/cli/tasks.py +19 -76
  36. wisent/core/cli/train_unified_goodness.py +3 -3
  37. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
  38. wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
  39. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
  40. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
  41. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
  42. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
  43. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
  44. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
  45. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
  46. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
  47. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
  48. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
  49. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
  50. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
  51. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
  52. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
  53. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
  54. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
  55. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
  56. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
  57. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
  58. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
  59. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
  60. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
  61. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
  82. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
  83. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
  84. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
  85. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
  86. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
  87. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
  88. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
  89. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
  90. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
  91. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
  92. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
  93. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
  94. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
  95. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
  96. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
  102. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
  103. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
  104. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
  105. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
  106. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
  107. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
  108. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
  109. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
  110. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
  111. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
  112. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
  113. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
  114. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
  115. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
  116. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
  117. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
  118. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
  119. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
  120. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
  121. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
  122. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
  123. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
  124. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
  125. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
  126. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
  127. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
  128. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
  129. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
  130. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
  131. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
  132. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
  133. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
  134. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
  135. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
  136. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
  137. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
  138. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
  139. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
  140. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
  141. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
  142. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
  143. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
  144. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
  145. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
  146. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
  147. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
  148. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
  149. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
  150. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
  151. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
  152. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
  153. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
  154. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
  155. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
  156. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
  157. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
  158. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
  159. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
  160. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
  161. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
  162. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
  163. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
  164. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
  165. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
  166. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
  167. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
  168. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
  169. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
  170. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
  171. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
  172. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
  173. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
  174. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
  175. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
  176. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
  177. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
  178. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
  179. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
  180. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
  181. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
  182. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
  183. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
  184. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
  185. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
  186. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
  187. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
  188. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
  189. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
  190. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
  191. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
  192. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
  193. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
  194. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
  195. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
  196. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
  197. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
  198. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
  199. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
  200. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
  201. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
  202. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
  203. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
  204. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
  205. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
  206. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
  207. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
  208. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
  209. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
  210. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
  211. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
  212. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
  213. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
  214. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
  215. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
  216. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
  217. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
  218. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
  219. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
  220. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
  221. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
  222. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
  223. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
  224. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
  225. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
  226. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
  227. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
  228. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
  229. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
  230. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
  231. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
  232. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
  233. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
  234. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
  235. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
  236. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
  237. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
  238. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
  239. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
  240. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
  241. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
  242. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
  243. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
  244. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
  245. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
  246. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
  247. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
  248. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
  249. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
  250. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
  251. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
  252. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
  253. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
  254. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
  255. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
  256. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
  257. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
  258. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
  259. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
  260. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
  261. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
  262. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
  263. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
  264. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
  265. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
  266. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
  267. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
  268. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
  269. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
  270. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
  271. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
  272. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
  273. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
  274. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
  275. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
  276. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
  277. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
  278. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
  279. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
  280. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
  281. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
  282. wisent/core/data_loaders/loaders/lm_loader.py +12 -1
  283. wisent/core/geometry_runner.py +995 -0
  284. wisent/core/geometry_search_space.py +237 -0
  285. wisent/core/hyperparameter_optimizer.py +1 -1
  286. wisent/core/main.py +3 -0
  287. wisent/core/models/core/atoms.py +5 -3
  288. wisent/core/models/wisent_model.py +1 -1
  289. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  290. wisent/core/parser_arguments/check_linearity_parser.py +12 -2
  291. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
  292. wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
  293. wisent/core/parser_arguments/geometry_search_parser.py +61 -0
  294. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  295. wisent/core/parser_arguments/main_parser.py +8 -0
  296. wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
  297. wisent/core/steering.py +5 -3
  298. wisent/core/steering_methods/methods/hyperplane.py +2 -1
  299. wisent/core/synthetic/generators/nonsense_generator.py +30 -18
  300. wisent/core/trainers/steering_trainer.py +2 -2
  301. wisent/core/utils/device.py +27 -27
  302. wisent/core/utils/layer_combinations.py +70 -0
  303. wisent/examples/__init__.py +1 -0
  304. wisent/examples/scripts/__init__.py +1 -0
  305. wisent/examples/scripts/count_all_benchmarks.py +121 -0
  306. wisent/examples/scripts/discover_directions.py +469 -0
  307. wisent/examples/scripts/extract_benchmark_info.py +71 -0
  308. wisent/examples/scripts/search_all_short_names.py +31 -0
  309. wisent/examples/scripts/test_all_benchmarks.py +138 -0
  310. wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
  311. wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
  312. wisent/examples/scripts/test_nonsense_baseline.py +261 -0
  313. wisent/examples/scripts/test_one_benchmark.py +324 -0
  314. wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
  315. wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
  316. wisent/parameters/lm_eval/category_directions.json +137 -0
  317. wisent/parameters/lm_eval/repair_plan.json +282 -0
  318. wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
  319. wisent/parameters/lm_eval/working_benchmarks.json +206 -0
  320. wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
  321. wisent/tests/test_detector_accuracy.py +1 -1
  322. wisent/tests/visualize_geometry.py +1 -1
  323. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  324. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
  325. wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
  326. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  327. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  328. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  329. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  330. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  331. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  332. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  333. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  334. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  335. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  336. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  337. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  338. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  339. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  340. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  341. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  342. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  343. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  344. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  345. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  346. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  347. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  348. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  349. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  350. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  351. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  352. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  353. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  354. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  355. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  356. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  357. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  358. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  359. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  360. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  361. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  362. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  363. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  364. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  365. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  366. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  367. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  368. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  369. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  370. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  371. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  372. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  373. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  374. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  375. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  376. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  377. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  378. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  379. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  380. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  381. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  382. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  383. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  384. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  385. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  386. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  387. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  388. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  389. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  390. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  391. {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,592 @@
1
+ """
2
+ LoRA fine-tuning using DPO (Direct Preference Optimization).
3
+
4
+ Unlike SFT which trains on positive examples only, DPO trains on
5
+ preference pairs (chosen vs rejected) to directly optimize for preferences.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import gc
12
+ import json
13
+ import tempfile
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ import torch
18
+ from datasets import Dataset
19
+ from peft import LoraConfig, TaskType, get_peft_model
20
+ from trl import DPOTrainer, DPOConfig
21
+
22
+ from wisent.comparison.utils import (
23
+ generate_contrastive_pairs,
24
+ create_test_only_task,
25
+ extract_accuracy,
26
+ run_lm_eval_evaluation,
27
+ run_ll_evaluation,
28
+ load_model_and_tokenizer,
29
+ apply_steering_to_model,
30
+ remove_steering,
31
+ )
32
+ from wisent.core.utils.device import preferred_dtype
33
+
34
+ if TYPE_CHECKING:
35
+ from wisent.core.models.wisent_model import WisentModel
36
+
37
+
38
+ def create_dpo_dataset(pairs: list[dict]) -> Dataset:
39
+ """
40
+ Convert contrastive pairs to DPO dataset format.
41
+
42
+ DPO expects:
43
+ - prompt: the input prompt
44
+ - chosen: the preferred response
45
+ - rejected: the non-preferred response
46
+ """
47
+ data = {
48
+ "prompt": [],
49
+ "chosen": [],
50
+ "rejected": [],
51
+ }
52
+
53
+ for pair in pairs:
54
+ prompt = pair["prompt"]
55
+ chosen = pair["positive_response"]["model_response"]
56
+ rejected = pair["negative_response"]["model_response"]
57
+
58
+ data["prompt"].append(prompt)
59
+ data["chosen"].append(chosen)
60
+ data["rejected"].append(rejected)
61
+
62
+ return Dataset.from_dict(data)
63
+
64
+
65
+ def train_lora_dpo(
66
+ task: str,
67
+ model_name: str,
68
+ output_path: str | Path,
69
+ num_pairs: int = 50,
70
+ device: str = "cuda:0",
71
+ keep_intermediate: bool = False,
72
+ lora_r: int = 16,
73
+ lora_alpha: int = 32,
74
+ lora_dropout: float = 0.05,
75
+ learning_rate: float = 5e-5,
76
+ num_epochs: int = 1,
77
+ batch_size: int = 1,
78
+ max_length: int = 512,
79
+ max_prompt_length: int = 256,
80
+ beta: float = 0.1,
81
+ ) -> Path:
82
+ """
83
+ Train a LoRA adapter using DPO on contrastive pairs from an lm-eval task.
84
+
85
+ Args:
86
+ task: lm-eval task name (e.g., 'boolq', 'cb')
87
+ model_name: HuggingFace model name
88
+ output_path: Where to save the trained adapter
89
+ num_pairs: Number of preference pairs to use
90
+ device: Device to run on
91
+ keep_intermediate: Whether to keep intermediate files
92
+ lora_r: LoRA rank
93
+ lora_alpha: LoRA alpha
94
+ lora_dropout: LoRA dropout
95
+ learning_rate: Learning rate
96
+ num_epochs: Number of training epochs
97
+ batch_size: Training batch size
98
+ max_length: Max total sequence length
99
+ max_prompt_length: Max prompt length
100
+ beta: DPO beta parameter (controls deviation from reference model)
101
+
102
+ Returns:
103
+ Path to saved adapter
104
+ """
105
+ output_path = Path(output_path)
106
+
107
+ # Step 1: Generate contrastive pairs
108
+ print(f"\n{'='*60}")
109
+ print(f"Step 1: Generating {num_pairs} preference pairs from {task}")
110
+ print(f"{'='*60}")
111
+
112
+ pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
113
+ print(f"Generated {len(pairs)} preference pairs")
114
+
115
+ # Step 2: Create DPO dataset
116
+ print(f"\n{'='*60}")
117
+ print(f"Step 2: Creating DPO dataset")
118
+ print(f"{'='*60}")
119
+
120
+ dataset = create_dpo_dataset(pairs)
121
+ print(f"Dataset size: {len(dataset)}")
122
+
123
+ # Step 3: Load model
124
+ print(f"\n{'='*60}")
125
+ print(f"Step 3: Loading model {model_name}")
126
+ print(f"{'='*60}")
127
+
128
+ model, tokenizer = load_model_and_tokenizer(model_name, device, eval_mode=False)
129
+
130
+ # Ensure tokenizer has padding
131
+ if tokenizer.pad_token is None:
132
+ tokenizer.pad_token = tokenizer.eos_token
133
+ tokenizer.padding_side = "left" # DPO typically uses left padding
134
+
135
+ # Step 4: Configure LoRA
136
+ print(f"\n{'='*60}")
137
+ print(f"Step 4: Configuring LoRA (r={lora_r}, alpha={lora_alpha})")
138
+ print(f"{'='*60}")
139
+
140
+ lora_config = LoraConfig(
141
+ task_type=TaskType.CAUSAL_LM,
142
+ r=lora_r,
143
+ lora_alpha=lora_alpha,
144
+ lora_dropout=lora_dropout,
145
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
146
+ bias="none",
147
+ )
148
+
149
+ model = get_peft_model(model, lora_config)
150
+ model.print_trainable_parameters()
151
+
152
+ # Step 5: Configure DPO training
153
+ print(f"\n{'='*60}")
154
+ print(f"Step 5: Configuring DPO training")
155
+ print(f"{'='*60}")
156
+
157
+ training_output_dir = tempfile.mkdtemp(prefix="lora_dpo_training_")
158
+
159
+ # Determine dtype
160
+ dtype = preferred_dtype(device)
161
+
162
+ training_args = DPOConfig(
163
+ output_dir=training_output_dir,
164
+ num_train_epochs=num_epochs,
165
+ per_device_train_batch_size=batch_size,
166
+ gradient_accumulation_steps=1,
167
+ learning_rate=learning_rate,
168
+ weight_decay=0.01,
169
+ warmup_ratio=0.1,
170
+ logging_steps=10,
171
+ save_strategy="no",
172
+ bf16=(dtype == torch.bfloat16),
173
+ fp16=(dtype == torch.float16),
174
+ report_to="none",
175
+ max_length=max_length,
176
+ max_prompt_length=max_prompt_length,
177
+ beta=beta,
178
+ loss_type="sigmoid", # Standard DPO loss
179
+ )
180
+
181
+ print(f"Beta: {beta}")
182
+ print(f"Max length: {max_length}")
183
+ print(f"Max prompt length: {max_prompt_length}")
184
+ print(f"Learning rate: {learning_rate}")
185
+ print(f"Epochs: {num_epochs}")
186
+ print(f"Batch size: {batch_size}")
187
+
188
+ # Step 6: Train with DPO
189
+ print(f"\n{'='*60}")
190
+ print(f"Step 6: Training with DPO")
191
+ print(f"{'='*60}")
192
+
193
+ trainer = DPOTrainer(
194
+ model=model,
195
+ args=training_args,
196
+ train_dataset=dataset,
197
+ processing_class=tokenizer,
198
+ )
199
+
200
+ trainer.train()
201
+
202
+ # Step 7: Save adapter
203
+ print(f"\n{'='*60}")
204
+ print(f"Step 7: Saving LoRA adapter")
205
+ print(f"{'='*60}")
206
+
207
+ output_path.mkdir(parents=True, exist_ok=True)
208
+ model.save_pretrained(output_path)
209
+ tokenizer.save_pretrained(output_path)
210
+
211
+ # Save metadata
212
+ metadata = {
213
+ "task": task,
214
+ "model": model_name,
215
+ "training_method": "dpo",
216
+ "num_pairs": len(pairs),
217
+ "lora_r": lora_r,
218
+ "lora_alpha": lora_alpha,
219
+ "lora_dropout": lora_dropout,
220
+ "learning_rate": learning_rate,
221
+ "num_epochs": num_epochs,
222
+ "batch_size": batch_size,
223
+ "max_length": max_length,
224
+ "max_prompt_length": max_prompt_length,
225
+ "beta": beta,
226
+ }
227
+ with open(output_path / "metadata.json", "w") as f:
228
+ json.dump(metadata, f, indent=2)
229
+
230
+ # Cleanup
231
+ del model, trainer
232
+ gc.collect()
233
+ if torch.cuda.is_available():
234
+ torch.cuda.empty_cache()
235
+
236
+ if not keep_intermediate:
237
+ import os
238
+ import shutil
239
+ os.unlink(pairs_file)
240
+ shutil.rmtree(training_output_dir, ignore_errors=True)
241
+
242
+ print(f"\nDPO LoRA adapter saved to {output_path}")
243
+ return output_path
244
+
245
+
246
+ def evaluate_lora_dpo(
247
+ model_name: str,
248
+ lora_path: str | Path,
249
+ task: str,
250
+ train_ratio: float = 0.8,
251
+ device: str = "cuda:0",
252
+ batch_size: int = 1,
253
+ max_batch_size: int = 8,
254
+ limit: int | None = None,
255
+ output_dir: str | Path = None,
256
+ # Training metadata (for output)
257
+ num_train_pairs: int | None = None,
258
+ num_epochs: int | None = None,
259
+ lora_r: int | None = None,
260
+ lora_alpha: int | None = None,
261
+ lora_dropout: float | None = None,
262
+ learning_rate: float | None = None,
263
+ beta: float | None = None,
264
+ max_length: int | None = None,
265
+ max_prompt_length: int | None = None,
266
+ # Steering parameters (optional)
267
+ with_steering: bool = False,
268
+ steering_method: str = "caa",
269
+ steering_layers: str = "12",
270
+ steering_num_pairs: int = 50,
271
+ steering_scales: list[float] | None = None,
272
+ extraction_strategy: str = "mc_completion",
273
+ ) -> dict:
274
+ """
275
+ Evaluate a trained DPO LoRA adapter.
276
+
277
+ Compares base model vs DPO-LoRA model accuracy.
278
+ Optionally also evaluates DPO-LoRA + steering at multiple scales.
279
+ """
280
+ from wisent.core.models.wisent_model import WisentModel
281
+ from wisent.comparison.lora import apply_lora_to_model, remove_lora
282
+
283
+ lora_path = Path(lora_path)
284
+
285
+ if steering_scales is None:
286
+ steering_scales = [1.0, 2.0, 4.0]
287
+
288
+ # Create test task
289
+ print(f"\n{'='*60}")
290
+ print(f"Creating test task for: {task}")
291
+ print(f"{'='*60}")
292
+
293
+ task_dict = create_test_only_task(task, train_ratio=train_ratio)
294
+
295
+ # Load model
296
+ print(f"\n{'='*60}")
297
+ print(f"Loading model: {model_name}")
298
+ print(f"{'='*60}")
299
+ wisent_model = WisentModel(model_name=model_name, device=device)
300
+
301
+ # Base evaluation
302
+ print(f"\n{'='*60}")
303
+ print(f"Running BASE evaluation")
304
+ print(f"{'='*60}")
305
+
306
+ base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
307
+ base_acc_lm_eval = extract_accuracy(base_results, task)
308
+ print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
309
+
310
+ base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
311
+ print(f"Base accuracy (LL): {base_acc_ll:.4f}")
312
+
313
+ # Apply DPO LoRA
314
+ print(f"\n{'='*60}")
315
+ print(f"Applying DPO LoRA adapter from: {lora_path}")
316
+ print(f"{'='*60}")
317
+ apply_lora_to_model(wisent_model, lora_path)
318
+
319
+ # LoRA evaluation
320
+ print(f"\n{'='*60}")
321
+ print(f"Running DPO-LORA evaluation")
322
+ print(f"{'='*60}")
323
+
324
+ lora_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
325
+ lora_acc_lm_eval = extract_accuracy(lora_results, task)
326
+ print(f"DPO-LoRA accuracy (lm-eval): {lora_acc_lm_eval:.4f}")
327
+
328
+ lora_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
329
+ print(f"DPO-LoRA accuracy (LL): {lora_acc_ll:.4f}")
330
+
331
+ # Results dict
332
+ results = {
333
+ "task": task,
334
+ "model": model_name,
335
+ "training_method": "dpo",
336
+ "lora_path": str(lora_path),
337
+ # Training config
338
+ "num_train_pairs": num_train_pairs,
339
+ "num_epochs": num_epochs,
340
+ "lora_r": lora_r,
341
+ "lora_alpha": lora_alpha,
342
+ "lora_dropout": lora_dropout,
343
+ "learning_rate": learning_rate,
344
+ "beta": beta,
345
+ "max_length": max_length,
346
+ "max_prompt_length": max_prompt_length,
347
+ # Eval config
348
+ "train_ratio": train_ratio,
349
+ "eval_limit": limit,
350
+ # Results
351
+ "base_accuracy_lm_eval": base_acc_lm_eval,
352
+ "base_accuracy_ll": base_acc_ll,
353
+ "lora_accuracy_lm_eval": lora_acc_lm_eval,
354
+ "lora_accuracy_ll": lora_acc_ll,
355
+ "lora_diff_lm_eval": lora_acc_lm_eval - base_acc_lm_eval,
356
+ "lora_diff_ll": lora_acc_ll - base_acc_ll,
357
+ }
358
+
359
+ # DPO-LoRA + Steering evaluation (if enabled)
360
+ if with_steering:
361
+ from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
362
+ from wisent.core.steering_methods import get_steering_method
363
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
364
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
365
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
366
+ from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
367
+
368
+ # Generate contrastive pairs for steering
369
+ print(f"\n{'='*60}")
370
+ print(f"Generating {steering_num_pairs} contrastive pairs for steering")
371
+ print(f"{'='*60}")
372
+ pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
373
+
374
+ # Convert to ContrastivePairSet
375
+ pairs = []
376
+ for p in pairs_data:
377
+ pair = ContrastivePair(
378
+ prompt=p["prompt"],
379
+ positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
380
+ negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
381
+ )
382
+ pairs.append(pair)
383
+ pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_dpo_lora_steering")
384
+ print(f"Created {len(pair_set)} contrastive pairs")
385
+
386
+ # Generate steering vector on DPO-LoRA model
387
+ print(f"\n{'='*60}")
388
+ print(f"Generating {steering_method.upper()} steering vector on DPO-LoRA model")
389
+ print(f"Layers: {steering_layers}")
390
+ print(f"{'='*60}")
391
+
392
+ steering_method_obj = get_steering_method(steering_method, device=device)
393
+ strategy = ExtractionStrategy(extraction_strategy)
394
+
395
+ trainer = WisentSteeringTrainer(
396
+ model=wisent_model,
397
+ pair_set=pair_set,
398
+ steering_method=steering_method_obj,
399
+ )
400
+
401
+ result = trainer.run(
402
+ layers_spec=steering_layers,
403
+ strategy=strategy,
404
+ accept_low_quality_vector=True,
405
+ )
406
+
407
+ # Convert to dict format for apply_steering_to_model
408
+ steering_vectors = {}
409
+ for layer_name, tensor in result.steered_vectors.to_dict().items():
410
+ if tensor is not None:
411
+ steering_vectors[layer_name] = tensor.cpu().float().tolist()
412
+
413
+ steering_data = {
414
+ "steering_vectors": steering_vectors,
415
+ "layers": list(steering_vectors.keys()),
416
+ }
417
+
418
+ # Cleanup temp file
419
+ import os
420
+ os.unlink(pairs_file)
421
+
422
+ # Add steering info to results
423
+ results["steering"] = {
424
+ "method": steering_method,
425
+ "layers": list(steering_vectors.keys()),
426
+ "num_pairs": steering_num_pairs,
427
+ "extraction_strategy": extraction_strategy,
428
+ "scales": {},
429
+ }
430
+
431
+ # Evaluate at each scale
432
+ for scale in steering_scales:
433
+ print(f"\n{'='*60}")
434
+ print(f"Evaluating DPO-LoRA+{steering_method.upper()} at scale={scale}")
435
+ print(f"{'='*60}")
436
+
437
+ apply_steering_to_model(wisent_model, steering_data, scale=scale)
438
+
439
+ steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
440
+ steer_acc_lm_eval = extract_accuracy(steer_results, task)
441
+ print(f"DPO-LoRA+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
442
+
443
+ steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
444
+ print(f"DPO-LoRA+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
445
+
446
+ remove_steering(wisent_model)
447
+
448
+ results["steering"]["scales"][str(scale)] = {
449
+ "accuracy_lm_eval": steer_acc_lm_eval,
450
+ "accuracy_ll": steer_acc_ll,
451
+ "diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
452
+ "diff_from_base_ll": steer_acc_ll - base_acc_ll,
453
+ "diff_from_lora_lm_eval": steer_acc_lm_eval - lora_acc_lm_eval,
454
+ "diff_from_lora_ll": steer_acc_ll - lora_acc_ll,
455
+ }
456
+
457
+ # Cleanup
458
+ remove_lora(wisent_model)
459
+ del wisent_model
460
+ gc.collect()
461
+ if torch.cuda.is_available():
462
+ torch.cuda.empty_cache()
463
+
464
+ # Print summary
465
+ print(f"\n{'='*70}")
466
+ print(f"RESULTS SUMMARY")
467
+ print(f"{'='*70}")
468
+ print(f"Task: {task}")
469
+ print(f"Model: {model_name}")
470
+ print(f"Training: DPO")
471
+ print(f"{'-'*70}")
472
+ print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
473
+ print(f"{'-'*70}")
474
+ print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
475
+ print(f"{'DPO-LoRA':<25} {lora_acc_lm_eval:<15.4f} {lora_acc_ll:<15.4f} {lora_acc_lm_eval - base_acc_lm_eval:+.4f}")
476
+
477
+ if with_steering:
478
+ for scale, res in results["steering"]["scales"].items():
479
+ label = f"DPO-LoRA+{steering_method.upper()}@{scale}"
480
+ print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
481
+
482
+ print(f"{'='*70}")
483
+
484
+ # Save results
485
+ if output_dir:
486
+ output_dir = Path(output_dir)
487
+ model_dir_name = model_name.replace("/", "_")
488
+ output_dir = output_dir / model_dir_name
489
+ output_dir.mkdir(parents=True, exist_ok=True)
490
+ results_file = output_dir / f"{task}_lora_dpo_eval_results.json"
491
+ with open(results_file, "w") as f:
492
+ json.dump(results, f, indent=2)
493
+ print(f"\nResults saved to: {results_file}")
494
+
495
+ return results
496
+
497
+
498
+ def main():
499
+ parser = argparse.ArgumentParser(description="Train and evaluate LoRA adapter using DPO")
500
+ parser.add_argument("--model", required=True, help="HuggingFace model name")
501
+ parser.add_argument("--task", default="boolq", help="lm-eval task name")
502
+ parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
503
+ parser.add_argument("--num-pairs", type=int, default=50, help="Number of preference pairs")
504
+ parser.add_argument("--device", default="cuda:0", help="Device")
505
+ parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
506
+ parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
507
+ parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
508
+ parser.add_argument("--learning-rate", type=float, default=5e-5, help="Learning rate")
509
+ parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs")
510
+ parser.add_argument("--batch-size", type=int, default=1, help="Training batch size")
511
+ parser.add_argument("--max-length", type=int, default=512, help="Max total sequence length")
512
+ parser.add_argument("--max-prompt-length", type=int, default=256, help="Max prompt length")
513
+ parser.add_argument("--beta", type=float, default=0.1, help="DPO beta (controls KL penalty)")
514
+ parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
515
+ # Eval args
516
+ parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
517
+ parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size")
518
+ parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size")
519
+ parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
520
+ parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
521
+ # DPO-LoRA + Steering args
522
+ parser.add_argument("--with-steering", action="store_true", help="Also evaluate DPO-LoRA + steering")
523
+ parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
524
+ parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
525
+ parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
526
+ parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
527
+ parser.add_argument("--extraction-strategy", default="mc_balanced", help="Extraction strategy for steering")
528
+
529
+ args = parser.parse_args()
530
+
531
+ output_path = Path(args.output_dir) / f"{args.task}_lora_dpo_adapter"
532
+
533
+ # Train
534
+ train_lora_dpo(
535
+ task=args.task,
536
+ model_name=args.model,
537
+ output_path=output_path,
538
+ num_pairs=args.num_pairs,
539
+ device=args.device,
540
+ keep_intermediate=args.keep_intermediate,
541
+ lora_r=args.lora_r,
542
+ lora_alpha=args.lora_alpha,
543
+ lora_dropout=args.lora_dropout,
544
+ learning_rate=args.learning_rate,
545
+ num_epochs=args.num_epochs,
546
+ batch_size=args.batch_size,
547
+ max_length=args.max_length,
548
+ max_prompt_length=args.max_prompt_length,
549
+ beta=args.beta,
550
+ )
551
+
552
+ # Evaluate
553
+ if not args.skip_eval:
554
+ eval_batch_size = args.eval_batch_size
555
+ if eval_batch_size != "auto":
556
+ eval_batch_size = int(eval_batch_size)
557
+
558
+ # Parse steering scales
559
+ steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
560
+
561
+ evaluate_lora_dpo(
562
+ model_name=args.model,
563
+ lora_path=output_path,
564
+ task=args.task,
565
+ train_ratio=args.train_ratio,
566
+ device=args.device,
567
+ batch_size=eval_batch_size,
568
+ max_batch_size=args.eval_max_batch_size,
569
+ limit=args.eval_limit,
570
+ output_dir=args.output_dir,
571
+ # Training metadata
572
+ num_train_pairs=args.num_pairs,
573
+ num_epochs=args.num_epochs,
574
+ lora_r=args.lora_r,
575
+ lora_alpha=args.lora_alpha,
576
+ lora_dropout=args.lora_dropout,
577
+ learning_rate=args.learning_rate,
578
+ beta=args.beta,
579
+ max_length=args.max_length,
580
+ max_prompt_length=args.max_prompt_length,
581
+ # Steering parameters
582
+ with_steering=args.with_steering,
583
+ steering_method=args.steering_method,
584
+ steering_layers=args.steering_layers,
585
+ steering_num_pairs=args.steering_num_pairs,
586
+ steering_scales=steering_scales,
587
+ extraction_strategy=args.extraction_strategy,
588
+ )
589
+
590
+
591
+ if __name__ == "__main__":
592
+ main()