wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activations_collector.py +3 -2
  12. wisent/core/activations/extraction_strategy.py +8 -4
  13. wisent/core/cli/agent/apply_steering.py +7 -5
  14. wisent/core/cli/agent/train_classifier.py +4 -3
  15. wisent/core/cli/generate_vector_from_task.py +11 -20
  16. wisent/core/cli/get_activations.py +1 -1
  17. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  20. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  21. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  22. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  23. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
  24. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  25. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  26. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  27. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  28. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  29. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  30. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  31. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  32. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  33. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  34. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  35. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  36. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  37. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  38. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  39. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  40. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  41. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  42. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  43. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  44. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  45. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  46. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  47. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  48. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  49. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  50. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  51. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  52. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  53. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  54. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  55. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  56. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  57. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  58. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  59. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  60. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  61. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  62. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  63. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  64. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  65. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  66. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  67. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  68. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  69. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  71. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  73. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  75. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  76. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  77. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  78. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  79. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  80. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  81. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  82. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  83. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  84. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  85. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  86. wisent/examples/scripts/generate_paper_data.py +0 -384
  87. wisent/examples/scripts/intervention_validation.py +0 -626
  88. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  90. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  92. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  94. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  95. wisent/examples/scripts/threshold_analysis.py +0 -434
  96. wisent/examples/scripts/visualization_gallery.py +0 -582
  97. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  98. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,592 @@
1
+ """
2
+ LoRA fine-tuning using DPO (Direct Preference Optimization).
3
+
4
+ Unlike SFT which trains on positive examples only, DPO trains on
5
+ preference pairs (chosen vs rejected) to directly optimize for preferences.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import gc
12
+ import json
13
+ import tempfile
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ import torch
18
+ from datasets import Dataset
19
+ from peft import LoraConfig, TaskType, get_peft_model
20
+ from trl import DPOTrainer, DPOConfig
21
+
22
+ from wisent.comparison.utils import (
23
+ generate_contrastive_pairs,
24
+ create_test_only_task,
25
+ extract_accuracy,
26
+ run_lm_eval_evaluation,
27
+ run_ll_evaluation,
28
+ load_model_and_tokenizer,
29
+ apply_steering_to_model,
30
+ remove_steering,
31
+ )
32
+ from wisent.core.utils.device import preferred_dtype
33
+
34
+ if TYPE_CHECKING:
35
+ from wisent.core.models.wisent_model import WisentModel
36
+
37
+
38
+ def create_dpo_dataset(pairs: list[dict]) -> Dataset:
39
+ """
40
+ Convert contrastive pairs to DPO dataset format.
41
+
42
+ DPO expects:
43
+ - prompt: the input prompt
44
+ - chosen: the preferred response
45
+ - rejected: the non-preferred response
46
+ """
47
+ data = {
48
+ "prompt": [],
49
+ "chosen": [],
50
+ "rejected": [],
51
+ }
52
+
53
+ for pair in pairs:
54
+ prompt = pair["prompt"]
55
+ chosen = pair["positive_response"]["model_response"]
56
+ rejected = pair["negative_response"]["model_response"]
57
+
58
+ data["prompt"].append(prompt)
59
+ data["chosen"].append(chosen)
60
+ data["rejected"].append(rejected)
61
+
62
+ return Dataset.from_dict(data)
63
+
64
+
65
+ def train_lora_dpo(
66
+ task: str,
67
+ model_name: str,
68
+ output_path: str | Path,
69
+ num_pairs: int = 50,
70
+ device: str = "cuda:0",
71
+ keep_intermediate: bool = False,
72
+ lora_r: int = 16,
73
+ lora_alpha: int = 32,
74
+ lora_dropout: float = 0.05,
75
+ learning_rate: float = 5e-5,
76
+ num_epochs: int = 1,
77
+ batch_size: int = 1,
78
+ max_length: int = 512,
79
+ max_prompt_length: int = 256,
80
+ beta: float = 0.1,
81
+ ) -> Path:
82
+ """
83
+ Train a LoRA adapter using DPO on contrastive pairs from an lm-eval task.
84
+
85
+ Args:
86
+ task: lm-eval task name (e.g., 'boolq', 'cb')
87
+ model_name: HuggingFace model name
88
+ output_path: Where to save the trained adapter
89
+ num_pairs: Number of preference pairs to use
90
+ device: Device to run on
91
+ keep_intermediate: Whether to keep intermediate files
92
+ lora_r: LoRA rank
93
+ lora_alpha: LoRA alpha
94
+ lora_dropout: LoRA dropout
95
+ learning_rate: Learning rate
96
+ num_epochs: Number of training epochs
97
+ batch_size: Training batch size
98
+ max_length: Max total sequence length
99
+ max_prompt_length: Max prompt length
100
+ beta: DPO beta parameter (controls deviation from reference model)
101
+
102
+ Returns:
103
+ Path to saved adapter
104
+ """
105
+ output_path = Path(output_path)
106
+
107
+ # Step 1: Generate contrastive pairs
108
+ print(f"\n{'='*60}")
109
+ print(f"Step 1: Generating {num_pairs} preference pairs from {task}")
110
+ print(f"{'='*60}")
111
+
112
+ pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
113
+ print(f"Generated {len(pairs)} preference pairs")
114
+
115
+ # Step 2: Create DPO dataset
116
+ print(f"\n{'='*60}")
117
+ print(f"Step 2: Creating DPO dataset")
118
+ print(f"{'='*60}")
119
+
120
+ dataset = create_dpo_dataset(pairs)
121
+ print(f"Dataset size: {len(dataset)}")
122
+
123
+ # Step 3: Load model
124
+ print(f"\n{'='*60}")
125
+ print(f"Step 3: Loading model {model_name}")
126
+ print(f"{'='*60}")
127
+
128
+ model, tokenizer = load_model_and_tokenizer(model_name, device, eval_mode=False)
129
+
130
+ # Ensure tokenizer has padding
131
+ if tokenizer.pad_token is None:
132
+ tokenizer.pad_token = tokenizer.eos_token
133
+ tokenizer.padding_side = "left" # DPO typically uses left padding
134
+
135
+ # Step 4: Configure LoRA
136
+ print(f"\n{'='*60}")
137
+ print(f"Step 4: Configuring LoRA (r={lora_r}, alpha={lora_alpha})")
138
+ print(f"{'='*60}")
139
+
140
+ lora_config = LoraConfig(
141
+ task_type=TaskType.CAUSAL_LM,
142
+ r=lora_r,
143
+ lora_alpha=lora_alpha,
144
+ lora_dropout=lora_dropout,
145
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
146
+ bias="none",
147
+ )
148
+
149
+ model = get_peft_model(model, lora_config)
150
+ model.print_trainable_parameters()
151
+
152
+ # Step 5: Configure DPO training
153
+ print(f"\n{'='*60}")
154
+ print(f"Step 5: Configuring DPO training")
155
+ print(f"{'='*60}")
156
+
157
+ training_output_dir = tempfile.mkdtemp(prefix="lora_dpo_training_")
158
+
159
+ # Determine dtype
160
+ dtype = preferred_dtype(device)
161
+
162
+ training_args = DPOConfig(
163
+ output_dir=training_output_dir,
164
+ num_train_epochs=num_epochs,
165
+ per_device_train_batch_size=batch_size,
166
+ gradient_accumulation_steps=1,
167
+ learning_rate=learning_rate,
168
+ weight_decay=0.01,
169
+ warmup_ratio=0.1,
170
+ logging_steps=10,
171
+ save_strategy="no",
172
+ bf16=(dtype == torch.bfloat16),
173
+ fp16=(dtype == torch.float16),
174
+ report_to="none",
175
+ max_length=max_length,
176
+ max_prompt_length=max_prompt_length,
177
+ beta=beta,
178
+ loss_type="sigmoid", # Standard DPO loss
179
+ )
180
+
181
+ print(f"Beta: {beta}")
182
+ print(f"Max length: {max_length}")
183
+ print(f"Max prompt length: {max_prompt_length}")
184
+ print(f"Learning rate: {learning_rate}")
185
+ print(f"Epochs: {num_epochs}")
186
+ print(f"Batch size: {batch_size}")
187
+
188
+ # Step 6: Train with DPO
189
+ print(f"\n{'='*60}")
190
+ print(f"Step 6: Training with DPO")
191
+ print(f"{'='*60}")
192
+
193
+ trainer = DPOTrainer(
194
+ model=model,
195
+ args=training_args,
196
+ train_dataset=dataset,
197
+ processing_class=tokenizer,
198
+ )
199
+
200
+ trainer.train()
201
+
202
+ # Step 7: Save adapter
203
+ print(f"\n{'='*60}")
204
+ print(f"Step 7: Saving LoRA adapter")
205
+ print(f"{'='*60}")
206
+
207
+ output_path.mkdir(parents=True, exist_ok=True)
208
+ model.save_pretrained(output_path)
209
+ tokenizer.save_pretrained(output_path)
210
+
211
+ # Save metadata
212
+ metadata = {
213
+ "task": task,
214
+ "model": model_name,
215
+ "training_method": "dpo",
216
+ "num_pairs": len(pairs),
217
+ "lora_r": lora_r,
218
+ "lora_alpha": lora_alpha,
219
+ "lora_dropout": lora_dropout,
220
+ "learning_rate": learning_rate,
221
+ "num_epochs": num_epochs,
222
+ "batch_size": batch_size,
223
+ "max_length": max_length,
224
+ "max_prompt_length": max_prompt_length,
225
+ "beta": beta,
226
+ }
227
+ with open(output_path / "metadata.json", "w") as f:
228
+ json.dump(metadata, f, indent=2)
229
+
230
+ # Cleanup
231
+ del model, trainer
232
+ gc.collect()
233
+ if torch.cuda.is_available():
234
+ torch.cuda.empty_cache()
235
+
236
+ if not keep_intermediate:
237
+ import os
238
+ import shutil
239
+ os.unlink(pairs_file)
240
+ shutil.rmtree(training_output_dir, ignore_errors=True)
241
+
242
+ print(f"\nDPO LoRA adapter saved to {output_path}")
243
+ return output_path
244
+
245
+
246
+ def evaluate_lora_dpo(
247
+ model_name: str,
248
+ lora_path: str | Path,
249
+ task: str,
250
+ train_ratio: float = 0.8,
251
+ device: str = "cuda:0",
252
+ batch_size: int = 1,
253
+ max_batch_size: int = 8,
254
+ limit: int | None = None,
255
+ output_dir: str | Path = None,
256
+ # Training metadata (for output)
257
+ num_train_pairs: int | None = None,
258
+ num_epochs: int | None = None,
259
+ lora_r: int | None = None,
260
+ lora_alpha: int | None = None,
261
+ lora_dropout: float | None = None,
262
+ learning_rate: float | None = None,
263
+ beta: float | None = None,
264
+ max_length: int | None = None,
265
+ max_prompt_length: int | None = None,
266
+ # Steering parameters (optional)
267
+ with_steering: bool = False,
268
+ steering_method: str = "caa",
269
+ steering_layers: str = "12",
270
+ steering_num_pairs: int = 50,
271
+ steering_scales: list[float] | None = None,
272
+ extraction_strategy: str = "mc_completion",
273
+ ) -> dict:
274
+ """
275
+ Evaluate a trained DPO LoRA adapter.
276
+
277
+ Compares base model vs DPO-LoRA model accuracy.
278
+ Optionally also evaluates DPO-LoRA + steering at multiple scales.
279
+ """
280
+ from wisent.core.models.wisent_model import WisentModel
281
+ from wisent.comparison.lora import apply_lora_to_model, remove_lora
282
+
283
+ lora_path = Path(lora_path)
284
+
285
+ if steering_scales is None:
286
+ steering_scales = [1.0, 2.0, 4.0]
287
+
288
+ # Create test task
289
+ print(f"\n{'='*60}")
290
+ print(f"Creating test task for: {task}")
291
+ print(f"{'='*60}")
292
+
293
+ task_dict = create_test_only_task(task, train_ratio=train_ratio)
294
+
295
+ # Load model
296
+ print(f"\n{'='*60}")
297
+ print(f"Loading model: {model_name}")
298
+ print(f"{'='*60}")
299
+ wisent_model = WisentModel(model_name=model_name, device=device)
300
+
301
+ # Base evaluation
302
+ print(f"\n{'='*60}")
303
+ print(f"Running BASE evaluation")
304
+ print(f"{'='*60}")
305
+
306
+ base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
307
+ base_acc_lm_eval = extract_accuracy(base_results, task)
308
+ print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
309
+
310
+ base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
311
+ print(f"Base accuracy (LL): {base_acc_ll:.4f}")
312
+
313
+ # Apply DPO LoRA
314
+ print(f"\n{'='*60}")
315
+ print(f"Applying DPO LoRA adapter from: {lora_path}")
316
+ print(f"{'='*60}")
317
+ apply_lora_to_model(wisent_model, lora_path)
318
+
319
+ # LoRA evaluation
320
+ print(f"\n{'='*60}")
321
+ print(f"Running DPO-LORA evaluation")
322
+ print(f"{'='*60}")
323
+
324
+ lora_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
325
+ lora_acc_lm_eval = extract_accuracy(lora_results, task)
326
+ print(f"DPO-LoRA accuracy (lm-eval): {lora_acc_lm_eval:.4f}")
327
+
328
+ lora_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
329
+ print(f"DPO-LoRA accuracy (LL): {lora_acc_ll:.4f}")
330
+
331
+ # Results dict
332
+ results = {
333
+ "task": task,
334
+ "model": model_name,
335
+ "training_method": "dpo",
336
+ "lora_path": str(lora_path),
337
+ # Training config
338
+ "num_train_pairs": num_train_pairs,
339
+ "num_epochs": num_epochs,
340
+ "lora_r": lora_r,
341
+ "lora_alpha": lora_alpha,
342
+ "lora_dropout": lora_dropout,
343
+ "learning_rate": learning_rate,
344
+ "beta": beta,
345
+ "max_length": max_length,
346
+ "max_prompt_length": max_prompt_length,
347
+ # Eval config
348
+ "train_ratio": train_ratio,
349
+ "eval_limit": limit,
350
+ # Results
351
+ "base_accuracy_lm_eval": base_acc_lm_eval,
352
+ "base_accuracy_ll": base_acc_ll,
353
+ "lora_accuracy_lm_eval": lora_acc_lm_eval,
354
+ "lora_accuracy_ll": lora_acc_ll,
355
+ "lora_diff_lm_eval": lora_acc_lm_eval - base_acc_lm_eval,
356
+ "lora_diff_ll": lora_acc_ll - base_acc_ll,
357
+ }
358
+
359
+ # DPO-LoRA + Steering evaluation (if enabled)
360
+ if with_steering:
361
+ from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
362
+ from wisent.core.steering_methods import get_steering_method
363
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
364
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
365
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
366
+ from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
367
+
368
+ # Generate contrastive pairs for steering
369
+ print(f"\n{'='*60}")
370
+ print(f"Generating {steering_num_pairs} contrastive pairs for steering")
371
+ print(f"{'='*60}")
372
+ pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
373
+
374
+ # Convert to ContrastivePairSet
375
+ pairs = []
376
+ for p in pairs_data:
377
+ pair = ContrastivePair(
378
+ prompt=p["prompt"],
379
+ positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
380
+ negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
381
+ )
382
+ pairs.append(pair)
383
+ pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_dpo_lora_steering")
384
+ print(f"Created {len(pair_set)} contrastive pairs")
385
+
386
+ # Generate steering vector on DPO-LoRA model
387
+ print(f"\n{'='*60}")
388
+ print(f"Generating {steering_method.upper()} steering vector on DPO-LoRA model")
389
+ print(f"Layers: {steering_layers}")
390
+ print(f"{'='*60}")
391
+
392
+ steering_method_obj = get_steering_method(steering_method, device=device)
393
+ strategy = ExtractionStrategy(extraction_strategy)
394
+
395
+ trainer = WisentSteeringTrainer(
396
+ model=wisent_model,
397
+ pair_set=pair_set,
398
+ steering_method=steering_method_obj,
399
+ )
400
+
401
+ result = trainer.run(
402
+ layers_spec=steering_layers,
403
+ strategy=strategy,
404
+ accept_low_quality_vector=True,
405
+ )
406
+
407
+ # Convert to dict format for apply_steering_to_model
408
+ steering_vectors = {}
409
+ for layer_name, tensor in result.steered_vectors.to_dict().items():
410
+ if tensor is not None:
411
+ steering_vectors[layer_name] = tensor.cpu().float().tolist()
412
+
413
+ steering_data = {
414
+ "steering_vectors": steering_vectors,
415
+ "layers": list(steering_vectors.keys()),
416
+ }
417
+
418
+ # Cleanup temp file
419
+ import os
420
+ os.unlink(pairs_file)
421
+
422
+ # Add steering info to results
423
+ results["steering"] = {
424
+ "method": steering_method,
425
+ "layers": list(steering_vectors.keys()),
426
+ "num_pairs": steering_num_pairs,
427
+ "extraction_strategy": extraction_strategy,
428
+ "scales": {},
429
+ }
430
+
431
+ # Evaluate at each scale
432
+ for scale in steering_scales:
433
+ print(f"\n{'='*60}")
434
+ print(f"Evaluating DPO-LoRA+{steering_method.upper()} at scale={scale}")
435
+ print(f"{'='*60}")
436
+
437
+ apply_steering_to_model(wisent_model, steering_data, scale=scale)
438
+
439
+ steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
440
+ steer_acc_lm_eval = extract_accuracy(steer_results, task)
441
+ print(f"DPO-LoRA+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
442
+
443
+ steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
444
+ print(f"DPO-LoRA+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
445
+
446
+ remove_steering(wisent_model)
447
+
448
+ results["steering"]["scales"][str(scale)] = {
449
+ "accuracy_lm_eval": steer_acc_lm_eval,
450
+ "accuracy_ll": steer_acc_ll,
451
+ "diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
452
+ "diff_from_base_ll": steer_acc_ll - base_acc_ll,
453
+ "diff_from_lora_lm_eval": steer_acc_lm_eval - lora_acc_lm_eval,
454
+ "diff_from_lora_ll": steer_acc_ll - lora_acc_ll,
455
+ }
456
+
457
+ # Cleanup
458
+ remove_lora(wisent_model)
459
+ del wisent_model
460
+ gc.collect()
461
+ if torch.cuda.is_available():
462
+ torch.cuda.empty_cache()
463
+
464
+ # Print summary
465
+ print(f"\n{'='*70}")
466
+ print(f"RESULTS SUMMARY")
467
+ print(f"{'='*70}")
468
+ print(f"Task: {task}")
469
+ print(f"Model: {model_name}")
470
+ print(f"Training: DPO")
471
+ print(f"{'-'*70}")
472
+ print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
473
+ print(f"{'-'*70}")
474
+ print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
475
+ print(f"{'DPO-LoRA':<25} {lora_acc_lm_eval:<15.4f} {lora_acc_ll:<15.4f} {lora_acc_lm_eval - base_acc_lm_eval:+.4f}")
476
+
477
+ if with_steering:
478
+ for scale, res in results["steering"]["scales"].items():
479
+ label = f"DPO-LoRA+{steering_method.upper()}@{scale}"
480
+ print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
481
+
482
+ print(f"{'='*70}")
483
+
484
+ # Save results
485
+ if output_dir:
486
+ output_dir = Path(output_dir)
487
+ model_dir_name = model_name.replace("/", "_")
488
+ output_dir = output_dir / model_dir_name
489
+ output_dir.mkdir(parents=True, exist_ok=True)
490
+ results_file = output_dir / f"{task}_lora_dpo_eval_results.json"
491
+ with open(results_file, "w") as f:
492
+ json.dump(results, f, indent=2)
493
+ print(f"\nResults saved to: {results_file}")
494
+
495
+ return results
496
+
497
+
498
+ def main():
499
+ parser = argparse.ArgumentParser(description="Train and evaluate LoRA adapter using DPO")
500
+ parser.add_argument("--model", required=True, help="HuggingFace model name")
501
+ parser.add_argument("--task", default="boolq", help="lm-eval task name")
502
+ parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
503
+ parser.add_argument("--num-pairs", type=int, default=50, help="Number of preference pairs")
504
+ parser.add_argument("--device", default="cuda:0", help="Device")
505
+ parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
506
+ parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
507
+ parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
508
+ parser.add_argument("--learning-rate", type=float, default=5e-5, help="Learning rate")
509
+ parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs")
510
+ parser.add_argument("--batch-size", type=int, default=1, help="Training batch size")
511
+ parser.add_argument("--max-length", type=int, default=512, help="Max total sequence length")
512
+ parser.add_argument("--max-prompt-length", type=int, default=256, help="Max prompt length")
513
+ parser.add_argument("--beta", type=float, default=0.1, help="DPO beta (controls KL penalty)")
514
+ parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
515
+ # Eval args
516
+ parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
517
+ parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size")
518
+ parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size")
519
+ parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
520
+ parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
521
+ # DPO-LoRA + Steering args
522
+ parser.add_argument("--with-steering", action="store_true", help="Also evaluate DPO-LoRA + steering")
523
+ parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
524
+ parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
525
+ parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
526
+ parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
527
+ parser.add_argument("--extraction-strategy", default="mc_balanced", help="Extraction strategy for steering")
528
+
529
+ args = parser.parse_args()
530
+
531
+ output_path = Path(args.output_dir) / f"{args.task}_lora_dpo_adapter"
532
+
533
+ # Train
534
+ train_lora_dpo(
535
+ task=args.task,
536
+ model_name=args.model,
537
+ output_path=output_path,
538
+ num_pairs=args.num_pairs,
539
+ device=args.device,
540
+ keep_intermediate=args.keep_intermediate,
541
+ lora_r=args.lora_r,
542
+ lora_alpha=args.lora_alpha,
543
+ lora_dropout=args.lora_dropout,
544
+ learning_rate=args.learning_rate,
545
+ num_epochs=args.num_epochs,
546
+ batch_size=args.batch_size,
547
+ max_length=args.max_length,
548
+ max_prompt_length=args.max_prompt_length,
549
+ beta=args.beta,
550
+ )
551
+
552
+ # Evaluate
553
+ if not args.skip_eval:
554
+ eval_batch_size = args.eval_batch_size
555
+ if eval_batch_size != "auto":
556
+ eval_batch_size = int(eval_batch_size)
557
+
558
+ # Parse steering scales
559
+ steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
560
+
561
+ evaluate_lora_dpo(
562
+ model_name=args.model,
563
+ lora_path=output_path,
564
+ task=args.task,
565
+ train_ratio=args.train_ratio,
566
+ device=args.device,
567
+ batch_size=eval_batch_size,
568
+ max_batch_size=args.eval_max_batch_size,
569
+ limit=args.eval_limit,
570
+ output_dir=args.output_dir,
571
+ # Training metadata
572
+ num_train_pairs=args.num_pairs,
573
+ num_epochs=args.num_epochs,
574
+ lora_r=args.lora_r,
575
+ lora_alpha=args.lora_alpha,
576
+ lora_dropout=args.lora_dropout,
577
+ learning_rate=args.learning_rate,
578
+ beta=args.beta,
579
+ max_length=args.max_length,
580
+ max_prompt_length=args.max_prompt_length,
581
+ # Steering parameters
582
+ with_steering=args.with_steering,
583
+ steering_method=args.steering_method,
584
+ steering_layers=args.steering_layers,
585
+ steering_num_pairs=args.steering_num_pairs,
586
+ steering_scales=steering_scales,
587
+ extraction_strategy=args.extraction_strategy,
588
+ )
589
+
590
+
591
+ if __name__ == "__main__":
592
+ main()