wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +663 -0
  6. wisent/comparison/lora_dpo.py +604 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/reft.py +690 -0
  10. wisent/comparison/sae.py +304 -0
  11. wisent/comparison/utils.py +381 -0
  12. wisent/core/activations/activations_collector.py +3 -2
  13. wisent/core/activations/extraction_strategy.py +8 -4
  14. wisent/core/cli/agent/apply_steering.py +7 -5
  15. wisent/core/cli/agent/train_classifier.py +4 -3
  16. wisent/core/cli/generate_vector_from_task.py +11 -20
  17. wisent/core/cli/get_activations.py +1 -1
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  20. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  21. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  22. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  23. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
  24. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
  25. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  26. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  27. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  28. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  29. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  30. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  31. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  32. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  33. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  34. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  35. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  36. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  37. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  38. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  39. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  40. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  41. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  42. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  43. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  44. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  45. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  46. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  47. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  48. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  49. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  50. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  51. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  52. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  53. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  54. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  55. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  56. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  57. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  58. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  59. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  60. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  61. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  62. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  63. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  64. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  65. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  66. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  67. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  68. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  69. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  71. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  73. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  75. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  76. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  77. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  78. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  79. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  80. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  81. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  82. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  83. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  84. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  85. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  86. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  87. wisent/examples/scripts/generate_paper_data.py +0 -384
  88. wisent/examples/scripts/intervention_validation.py +0 -626
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  90. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  92. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  94. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  95. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  96. wisent/examples/scripts/threshold_analysis.py +0 -434
  97. wisent/examples/scripts/visualization_gallery.py +0 -582
  98. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
  101. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,604 @@
1
+ """
2
+ LoRA fine-tuning using DPO (Direct Preference Optimization).
3
+
4
+ Unlike SFT which trains on positive examples only, DPO trains on
5
+ preference pairs (chosen vs rejected) to directly optimize for preferences.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import gc
12
+ import json
13
+ import tempfile
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ import torch
18
+ from datasets import Dataset
19
+ from peft import LoraConfig, TaskType, get_peft_model
20
+ from trl import DPOTrainer, DPOConfig
21
+
22
+ from wisent.comparison.utils import (
23
+ generate_contrastive_pairs,
24
+ create_test_only_task,
25
+ extract_accuracy,
26
+ run_lm_eval_evaluation,
27
+ run_ll_evaluation,
28
+ load_model_and_tokenizer,
29
+ apply_steering_to_model,
30
+ remove_steering,
31
+ )
32
+ from wisent.core.utils.device import preferred_dtype
33
+
34
+ if TYPE_CHECKING:
35
+ from wisent.core.models.wisent_model import WisentModel
36
+
37
+
38
+ def create_dpo_dataset(pairs: list[dict], tokenizer) -> Dataset:
39
+ """
40
+ Convert contrastive pairs to DPO dataset format.
41
+
42
+ - Chat models: returns conversational format, DPOTrainer auto-applies chat template
43
+ - Base models: returns text format with simple formatting
44
+
45
+ Args:
46
+ pairs: List of contrastive pairs
47
+ tokenizer: Tokenizer to check for chat template support
48
+
49
+ Returns:
50
+ HuggingFace Dataset ready for DPOTrainer
51
+ """
52
+ data = {
53
+ "prompt": [],
54
+ "chosen": [],
55
+ "rejected": [],
56
+ }
57
+
58
+ for pair in pairs:
59
+ prompt = pair["prompt"]
60
+ chosen = pair["positive_response"]["model_response"]
61
+ rejected = pair["negative_response"]["model_response"]
62
+
63
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
64
+ # Chat model: use conversational format, trainer applies template
65
+ data["prompt"].append([{"role": "user", "content": prompt}])
66
+ data["chosen"].append([{"role": "assistant", "content": chosen}])
67
+ data["rejected"].append([{"role": "assistant", "content": rejected}])
68
+ else:
69
+ # Base model: use simple text format
70
+ data["prompt"].append(f"{prompt}\n")
71
+ data["chosen"].append(chosen)
72
+ data["rejected"].append(rejected)
73
+
74
+ return Dataset.from_dict(data)
75
+
76
+
77
+ def train_lora_dpo(
78
+ task: str,
79
+ model_name: str,
80
+ output_path: str | Path,
81
+ num_pairs: int = 50,
82
+ device: str = "cuda:0",
83
+ keep_intermediate: bool = False,
84
+ lora_r: int = 16,
85
+ lora_alpha: int = 32,
86
+ lora_dropout: float = 0.05,
87
+ learning_rate: float = 5e-5,
88
+ num_epochs: int = 1,
89
+ batch_size: int = 1,
90
+ max_length: int = 512,
91
+ max_prompt_length: int = 256,
92
+ beta: float = 0.1,
93
+ ) -> Path:
94
+ """
95
+ Train a LoRA adapter using DPO on contrastive pairs from an lm-eval task.
96
+
97
+ Args:
98
+ task: lm-eval task name (e.g., 'boolq', 'cb')
99
+ model_name: HuggingFace model name
100
+ output_path: Where to save the trained adapter
101
+ num_pairs: Number of preference pairs to use
102
+ device: Device to run on
103
+ keep_intermediate: Whether to keep intermediate files
104
+ lora_r: LoRA rank
105
+ lora_alpha: LoRA alpha
106
+ lora_dropout: LoRA dropout
107
+ learning_rate: Learning rate
108
+ num_epochs: Number of training epochs
109
+ batch_size: Training batch size
110
+ max_length: Max total sequence length
111
+ max_prompt_length: Max prompt length
112
+ beta: DPO beta parameter (controls deviation from reference model)
113
+
114
+ Returns:
115
+ Path to saved adapter
116
+ """
117
+ output_path = Path(output_path)
118
+
119
+ # Step 1: Generate contrastive pairs
120
+ print(f"\n{'='*60}")
121
+ print(f"Step 1: Generating {num_pairs} preference pairs from {task}")
122
+ print(f"{'='*60}")
123
+
124
+ pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
125
+ print(f"Generated {len(pairs)} preference pairs")
126
+
127
+ # Step 2: Load model
128
+ print(f"\n{'='*60}")
129
+ print(f"Step 2: Loading model {model_name}")
130
+ print(f"{'='*60}")
131
+
132
+ model, tokenizer = load_model_and_tokenizer(model_name, device, eval_mode=False)
133
+
134
+ # Ensure tokenizer has padding
135
+ if tokenizer.pad_token is None:
136
+ tokenizer.pad_token = tokenizer.eos_token
137
+ tokenizer.padding_side = "left" # DPO typically uses left padding
138
+
139
+ # Step 3: Create DPO dataset
140
+ print(f"\n{'='*60}")
141
+ print(f"Step 3: Creating DPO dataset")
142
+ print(f"{'='*60}")
143
+
144
+ dataset = create_dpo_dataset(pairs, tokenizer)
145
+ print(f"Dataset size: {len(dataset)}")
146
+
147
+ # Step 4: Configure LoRA
148
+ print(f"\n{'='*60}")
149
+ print(f"Step 4: Configuring LoRA (r={lora_r}, alpha={lora_alpha})")
150
+ print(f"{'='*60}")
151
+
152
+ lora_config = LoraConfig(
153
+ task_type=TaskType.CAUSAL_LM,
154
+ r=lora_r,
155
+ lora_alpha=lora_alpha,
156
+ lora_dropout=lora_dropout,
157
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
158
+ bias="none",
159
+ )
160
+
161
+ model = get_peft_model(model, lora_config)
162
+ model.print_trainable_parameters()
163
+
164
+ # Step 5: Configure DPO training
165
+ print(f"\n{'='*60}")
166
+ print(f"Step 5: Configuring DPO training")
167
+ print(f"{'='*60}")
168
+
169
+ training_output_dir = tempfile.mkdtemp(prefix="lora_dpo_training_")
170
+
171
+ # Determine dtype
172
+ dtype = preferred_dtype(device)
173
+
174
+ training_args = DPOConfig(
175
+ output_dir=training_output_dir,
176
+ num_train_epochs=num_epochs,
177
+ per_device_train_batch_size=batch_size,
178
+ gradient_accumulation_steps=1,
179
+ learning_rate=learning_rate,
180
+ weight_decay=0.01,
181
+ warmup_ratio=0.1,
182
+ logging_steps=10,
183
+ save_strategy="no",
184
+ bf16=(dtype == torch.bfloat16),
185
+ fp16=(dtype == torch.float16),
186
+ report_to="none",
187
+ max_length=max_length,
188
+ max_prompt_length=max_prompt_length,
189
+ beta=beta,
190
+ loss_type="sigmoid", # Standard DPO loss
191
+ )
192
+
193
+ print(f"Beta: {beta}")
194
+ print(f"Max length: {max_length}")
195
+ print(f"Max prompt length: {max_prompt_length}")
196
+ print(f"Learning rate: {learning_rate}")
197
+ print(f"Epochs: {num_epochs}")
198
+ print(f"Batch size: {batch_size}")
199
+
200
+ # Step 6: Train with DPO
201
+ print(f"\n{'='*60}")
202
+ print(f"Step 6: Training with DPO")
203
+ print(f"{'='*60}")
204
+
205
+ trainer = DPOTrainer(
206
+ model=model,
207
+ args=training_args,
208
+ train_dataset=dataset,
209
+ processing_class=tokenizer,
210
+ )
211
+
212
+ trainer.train()
213
+
214
+ # Step 7: Save adapter
215
+ print(f"\n{'='*60}")
216
+ print(f"Step 7: Saving LoRA adapter")
217
+ print(f"{'='*60}")
218
+
219
+ output_path.mkdir(parents=True, exist_ok=True)
220
+ model.save_pretrained(output_path)
221
+ tokenizer.save_pretrained(output_path)
222
+
223
+ # Save metadata
224
+ metadata = {
225
+ "task": task,
226
+ "model": model_name,
227
+ "training_method": "dpo",
228
+ "num_pairs": len(pairs),
229
+ "lora_r": lora_r,
230
+ "lora_alpha": lora_alpha,
231
+ "lora_dropout": lora_dropout,
232
+ "learning_rate": learning_rate,
233
+ "num_epochs": num_epochs,
234
+ "batch_size": batch_size,
235
+ "max_length": max_length,
236
+ "max_prompt_length": max_prompt_length,
237
+ "beta": beta,
238
+ }
239
+ with open(output_path / "metadata.json", "w") as f:
240
+ json.dump(metadata, f, indent=2)
241
+
242
+ # Cleanup
243
+ del model, trainer
244
+ gc.collect()
245
+ if torch.cuda.is_available():
246
+ torch.cuda.empty_cache()
247
+
248
+ if not keep_intermediate:
249
+ import os
250
+ import shutil
251
+ os.unlink(pairs_file)
252
+ shutil.rmtree(training_output_dir, ignore_errors=True)
253
+
254
+ print(f"\nDPO LoRA adapter saved to {output_path}")
255
+ return output_path
256
+
257
+
258
+ def evaluate_lora_dpo(
259
+ model_name: str,
260
+ lora_path: str | Path,
261
+ task: str,
262
+ train_ratio: float = 0.8,
263
+ device: str = "cuda:0",
264
+ batch_size: int = 1,
265
+ max_batch_size: int = 8,
266
+ limit: int | None = None,
267
+ output_dir: str | Path = None,
268
+ # Training metadata (for output)
269
+ num_train_pairs: int | None = None,
270
+ num_epochs: int | None = None,
271
+ lora_r: int | None = None,
272
+ lora_alpha: int | None = None,
273
+ lora_dropout: float | None = None,
274
+ learning_rate: float | None = None,
275
+ beta: float | None = None,
276
+ max_length: int | None = None,
277
+ max_prompt_length: int | None = None,
278
+ # Steering parameters (optional)
279
+ with_steering: bool = False,
280
+ steering_method: str = "caa",
281
+ steering_layers: str = "12",
282
+ steering_num_pairs: int = 50,
283
+ steering_scales: list[float] | None = None,
284
+ extraction_strategy: str = "mc_completion",
285
+ ) -> dict:
286
+ """
287
+ Evaluate a trained DPO LoRA adapter.
288
+
289
+ Compares base model vs DPO-LoRA model accuracy.
290
+ Optionally also evaluates DPO-LoRA + steering at multiple scales.
291
+ """
292
+ from wisent.core.models.wisent_model import WisentModel
293
+ from wisent.comparison.lora import apply_lora_to_model, remove_lora
294
+
295
+ lora_path = Path(lora_path)
296
+
297
+ if steering_scales is None:
298
+ steering_scales = [1.0, 2.0, 4.0]
299
+
300
+ # Create test task
301
+ print(f"\n{'='*60}")
302
+ print(f"Creating test task for: {task}")
303
+ print(f"{'='*60}")
304
+
305
+ task_dict = create_test_only_task(task, train_ratio=train_ratio)
306
+
307
+ # Load model
308
+ print(f"\n{'='*60}")
309
+ print(f"Loading model: {model_name}")
310
+ print(f"{'='*60}")
311
+ wisent_model = WisentModel(model_name=model_name, device=device)
312
+
313
+ # Base evaluation
314
+ print(f"\n{'='*60}")
315
+ print(f"Running BASE evaluation")
316
+ print(f"{'='*60}")
317
+
318
+ base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
319
+ base_acc_lm_eval = extract_accuracy(base_results, task)
320
+ print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
321
+
322
+ base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
323
+ print(f"Base accuracy (LL): {base_acc_ll:.4f}")
324
+
325
+ # Apply DPO LoRA
326
+ print(f"\n{'='*60}")
327
+ print(f"Applying DPO LoRA adapter from: {lora_path}")
328
+ print(f"{'='*60}")
329
+ apply_lora_to_model(wisent_model, lora_path)
330
+
331
+ # LoRA evaluation
332
+ print(f"\n{'='*60}")
333
+ print(f"Running DPO-LORA evaluation")
334
+ print(f"{'='*60}")
335
+
336
+ lora_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
337
+ lora_acc_lm_eval = extract_accuracy(lora_results, task)
338
+ print(f"DPO-LoRA accuracy (lm-eval): {lora_acc_lm_eval:.4f}")
339
+
340
+ lora_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
341
+ print(f"DPO-LoRA accuracy (LL): {lora_acc_ll:.4f}")
342
+
343
+ # Results dict
344
+ results = {
345
+ "task": task,
346
+ "model": model_name,
347
+ "training_method": "dpo",
348
+ "lora_path": str(lora_path),
349
+ # Training config
350
+ "num_train_pairs": num_train_pairs,
351
+ "num_epochs": num_epochs,
352
+ "lora_r": lora_r,
353
+ "lora_alpha": lora_alpha,
354
+ "lora_dropout": lora_dropout,
355
+ "learning_rate": learning_rate,
356
+ "beta": beta,
357
+ "max_length": max_length,
358
+ "max_prompt_length": max_prompt_length,
359
+ # Eval config
360
+ "train_ratio": train_ratio,
361
+ "eval_limit": limit,
362
+ # Results
363
+ "base_accuracy_lm_eval": base_acc_lm_eval,
364
+ "base_accuracy_ll": base_acc_ll,
365
+ "lora_accuracy_lm_eval": lora_acc_lm_eval,
366
+ "lora_accuracy_ll": lora_acc_ll,
367
+ "lora_diff_lm_eval": lora_acc_lm_eval - base_acc_lm_eval,
368
+ "lora_diff_ll": lora_acc_ll - base_acc_ll,
369
+ }
370
+
371
+ # DPO-LoRA + Steering evaluation (if enabled)
372
+ if with_steering:
373
+ from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
374
+ from wisent.core.steering_methods import get_steering_method
375
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
376
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
377
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
378
+ from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
379
+
380
+ # Generate contrastive pairs for steering
381
+ print(f"\n{'='*60}")
382
+ print(f"Generating {steering_num_pairs} contrastive pairs for steering")
383
+ print(f"{'='*60}")
384
+ pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
385
+
386
+ # Convert to ContrastivePairSet
387
+ pairs = []
388
+ for p in pairs_data:
389
+ pair = ContrastivePair(
390
+ prompt=p["prompt"],
391
+ positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
392
+ negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
393
+ )
394
+ pairs.append(pair)
395
+ pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_dpo_lora_steering")
396
+ print(f"Created {len(pair_set)} contrastive pairs")
397
+
398
+ # Generate steering vector on DPO-LoRA model
399
+ print(f"\n{'='*60}")
400
+ print(f"Generating {steering_method.upper()} steering vector on DPO-LoRA model")
401
+ print(f"Layers: {steering_layers}")
402
+ print(f"{'='*60}")
403
+
404
+ steering_method_obj = get_steering_method(steering_method, device=device)
405
+ strategy = ExtractionStrategy(extraction_strategy)
406
+
407
+ trainer = WisentSteeringTrainer(
408
+ model=wisent_model,
409
+ pair_set=pair_set,
410
+ steering_method=steering_method_obj,
411
+ )
412
+
413
+ result = trainer.run(
414
+ layers_spec=steering_layers,
415
+ strategy=strategy,
416
+ accept_low_quality_vector=True,
417
+ )
418
+
419
+ # Convert to dict format for apply_steering_to_model
420
+ steering_vectors = {}
421
+ for layer_name, tensor in result.steered_vectors.to_dict().items():
422
+ if tensor is not None:
423
+ steering_vectors[layer_name] = tensor.cpu().float().tolist()
424
+
425
+ steering_data = {
426
+ "steering_vectors": steering_vectors,
427
+ "layers": list(steering_vectors.keys()),
428
+ }
429
+
430
+ # Cleanup temp file
431
+ import os
432
+ os.unlink(pairs_file)
433
+
434
+ # Add steering info to results
435
+ results["steering"] = {
436
+ "method": steering_method,
437
+ "layers": list(steering_vectors.keys()),
438
+ "num_pairs": steering_num_pairs,
439
+ "extraction_strategy": extraction_strategy,
440
+ "scales": {},
441
+ }
442
+
443
+ # Evaluate at each scale
444
+ for scale in steering_scales:
445
+ print(f"\n{'='*60}")
446
+ print(f"Evaluating DPO-LoRA+{steering_method.upper()} at scale={scale}")
447
+ print(f"{'='*60}")
448
+
449
+ apply_steering_to_model(wisent_model, steering_data, scale=scale)
450
+
451
+ steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
452
+ steer_acc_lm_eval = extract_accuracy(steer_results, task)
453
+ print(f"DPO-LoRA+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
454
+
455
+ steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
456
+ print(f"DPO-LoRA+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
457
+
458
+ remove_steering(wisent_model)
459
+
460
+ results["steering"]["scales"][str(scale)] = {
461
+ "accuracy_lm_eval": steer_acc_lm_eval,
462
+ "accuracy_ll": steer_acc_ll,
463
+ "diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
464
+ "diff_from_base_ll": steer_acc_ll - base_acc_ll,
465
+ "diff_from_lora_lm_eval": steer_acc_lm_eval - lora_acc_lm_eval,
466
+ "diff_from_lora_ll": steer_acc_ll - lora_acc_ll,
467
+ }
468
+
469
+ # Cleanup
470
+ remove_lora(wisent_model)
471
+ del wisent_model
472
+ gc.collect()
473
+ if torch.cuda.is_available():
474
+ torch.cuda.empty_cache()
475
+
476
+ # Print summary
477
+ print(f"\n{'='*70}")
478
+ print(f"RESULTS SUMMARY")
479
+ print(f"{'='*70}")
480
+ print(f"Task: {task}")
481
+ print(f"Model: {model_name}")
482
+ print(f"Training: DPO")
483
+ print(f"{'-'*70}")
484
+ print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
485
+ print(f"{'-'*70}")
486
+ print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
487
+ print(f"{'DPO-LoRA':<25} {lora_acc_lm_eval:<15.4f} {lora_acc_ll:<15.4f} {lora_acc_lm_eval - base_acc_lm_eval:+.4f}")
488
+
489
+ if with_steering:
490
+ for scale, res in results["steering"]["scales"].items():
491
+ label = f"DPO-LoRA+{steering_method.upper()}@{scale}"
492
+ print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
493
+
494
+ print(f"{'='*70}")
495
+
496
+ # Save results
497
+ if output_dir:
498
+ output_dir = Path(output_dir)
499
+ model_dir_name = model_name.replace("/", "_")
500
+ output_dir = output_dir / model_dir_name
501
+ output_dir.mkdir(parents=True, exist_ok=True)
502
+ results_file = output_dir / f"{task}_lora_dpo_eval_results.json"
503
+ with open(results_file, "w") as f:
504
+ json.dump(results, f, indent=2)
505
+ print(f"\nResults saved to: {results_file}")
506
+
507
+ return results
508
+
509
+
510
+ def main():
511
+ parser = argparse.ArgumentParser(description="Train and evaluate LoRA adapter using DPO")
512
+ parser.add_argument("--model", required=True, help="HuggingFace model name")
513
+ parser.add_argument("--task", default="boolq", help="lm-eval task name")
514
+ parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
515
+ parser.add_argument("--num-pairs", type=int, default=50, help="Number of preference pairs")
516
+ parser.add_argument("--device", default="cuda:0", help="Device")
517
+ parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
518
+ parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
519
+ parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
520
+ parser.add_argument("--learning-rate", type=float, default=5e-5, help="Learning rate")
521
+ parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs")
522
+ parser.add_argument("--batch-size", type=int, default=1, help="Training batch size")
523
+ parser.add_argument("--max-length", type=int, default=512, help="Max total sequence length")
524
+ parser.add_argument("--max-prompt-length", type=int, default=256, help="Max prompt length")
525
+ parser.add_argument("--beta", type=float, default=0.1, help="DPO beta (controls KL penalty)")
526
+ parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
527
+ # Eval args
528
+ parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
529
+ parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size")
530
+ parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size")
531
+ parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
532
+ parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
533
+ # DPO-LoRA + Steering args
534
+ parser.add_argument("--with-steering", action="store_true", help="Also evaluate DPO-LoRA + steering")
535
+ parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
536
+ parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
537
+ parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
538
+ parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
539
+ parser.add_argument("--extraction-strategy", default="mc_balanced", help="Extraction strategy for steering")
540
+
541
+ args = parser.parse_args()
542
+
543
+ output_path = Path(args.output_dir) / f"{args.task}_lora_dpo_adapter"
544
+
545
+ # Train
546
+ train_lora_dpo(
547
+ task=args.task,
548
+ model_name=args.model,
549
+ output_path=output_path,
550
+ num_pairs=args.num_pairs,
551
+ device=args.device,
552
+ keep_intermediate=args.keep_intermediate,
553
+ lora_r=args.lora_r,
554
+ lora_alpha=args.lora_alpha,
555
+ lora_dropout=args.lora_dropout,
556
+ learning_rate=args.learning_rate,
557
+ num_epochs=args.num_epochs,
558
+ batch_size=args.batch_size,
559
+ max_length=args.max_length,
560
+ max_prompt_length=args.max_prompt_length,
561
+ beta=args.beta,
562
+ )
563
+
564
+ # Evaluate
565
+ if not args.skip_eval:
566
+ eval_batch_size = args.eval_batch_size
567
+ if eval_batch_size != "auto":
568
+ eval_batch_size = int(eval_batch_size)
569
+
570
+ # Parse steering scales
571
+ steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
572
+
573
+ evaluate_lora_dpo(
574
+ model_name=args.model,
575
+ lora_path=output_path,
576
+ task=args.task,
577
+ train_ratio=args.train_ratio,
578
+ device=args.device,
579
+ batch_size=eval_batch_size,
580
+ max_batch_size=args.eval_max_batch_size,
581
+ limit=args.eval_limit,
582
+ output_dir=args.output_dir,
583
+ # Training metadata
584
+ num_train_pairs=args.num_pairs,
585
+ num_epochs=args.num_epochs,
586
+ lora_r=args.lora_r,
587
+ lora_alpha=args.lora_alpha,
588
+ lora_dropout=args.lora_dropout,
589
+ learning_rate=args.learning_rate,
590
+ beta=args.beta,
591
+ max_length=args.max_length,
592
+ max_prompt_length=args.max_prompt_length,
593
+ # Steering parameters
594
+ with_steering=args.with_steering,
595
+ steering_method=args.steering_method,
596
+ steering_layers=args.steering_layers,
597
+ steering_num_pairs=args.steering_num_pairs,
598
+ steering_scales=steering_scales,
599
+ extraction_strategy=args.extraction_strategy,
600
+ )
601
+
602
+
603
+ if __name__ == "__main__":
604
+ main()