wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +669 -0
- wisent/comparison/lora_dpo.py +592 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,592 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LoRA fine-tuning using DPO (Direct Preference Optimization).
|
|
3
|
+
|
|
4
|
+
Unlike SFT which trains on positive examples only, DPO trains on
|
|
5
|
+
preference pairs (chosen vs rejected) to directly optimize for preferences.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import argparse
|
|
11
|
+
import gc
|
|
12
|
+
import json
|
|
13
|
+
import tempfile
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from datasets import Dataset
|
|
19
|
+
from peft import LoraConfig, TaskType, get_peft_model
|
|
20
|
+
from trl import DPOTrainer, DPOConfig
|
|
21
|
+
|
|
22
|
+
from wisent.comparison.utils import (
|
|
23
|
+
generate_contrastive_pairs,
|
|
24
|
+
create_test_only_task,
|
|
25
|
+
extract_accuracy,
|
|
26
|
+
run_lm_eval_evaluation,
|
|
27
|
+
run_ll_evaluation,
|
|
28
|
+
load_model_and_tokenizer,
|
|
29
|
+
apply_steering_to_model,
|
|
30
|
+
remove_steering,
|
|
31
|
+
)
|
|
32
|
+
from wisent.core.utils.device import preferred_dtype
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def create_dpo_dataset(pairs: list[dict]) -> Dataset:
|
|
39
|
+
"""
|
|
40
|
+
Convert contrastive pairs to DPO dataset format.
|
|
41
|
+
|
|
42
|
+
DPO expects:
|
|
43
|
+
- prompt: the input prompt
|
|
44
|
+
- chosen: the preferred response
|
|
45
|
+
- rejected: the non-preferred response
|
|
46
|
+
"""
|
|
47
|
+
data = {
|
|
48
|
+
"prompt": [],
|
|
49
|
+
"chosen": [],
|
|
50
|
+
"rejected": [],
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
for pair in pairs:
|
|
54
|
+
prompt = pair["prompt"]
|
|
55
|
+
chosen = pair["positive_response"]["model_response"]
|
|
56
|
+
rejected = pair["negative_response"]["model_response"]
|
|
57
|
+
|
|
58
|
+
data["prompt"].append(prompt)
|
|
59
|
+
data["chosen"].append(chosen)
|
|
60
|
+
data["rejected"].append(rejected)
|
|
61
|
+
|
|
62
|
+
return Dataset.from_dict(data)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def train_lora_dpo(
|
|
66
|
+
task: str,
|
|
67
|
+
model_name: str,
|
|
68
|
+
output_path: str | Path,
|
|
69
|
+
num_pairs: int = 50,
|
|
70
|
+
device: str = "cuda:0",
|
|
71
|
+
keep_intermediate: bool = False,
|
|
72
|
+
lora_r: int = 16,
|
|
73
|
+
lora_alpha: int = 32,
|
|
74
|
+
lora_dropout: float = 0.05,
|
|
75
|
+
learning_rate: float = 5e-5,
|
|
76
|
+
num_epochs: int = 1,
|
|
77
|
+
batch_size: int = 1,
|
|
78
|
+
max_length: int = 512,
|
|
79
|
+
max_prompt_length: int = 256,
|
|
80
|
+
beta: float = 0.1,
|
|
81
|
+
) -> Path:
|
|
82
|
+
"""
|
|
83
|
+
Train a LoRA adapter using DPO on contrastive pairs from an lm-eval task.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
task: lm-eval task name (e.g., 'boolq', 'cb')
|
|
87
|
+
model_name: HuggingFace model name
|
|
88
|
+
output_path: Where to save the trained adapter
|
|
89
|
+
num_pairs: Number of preference pairs to use
|
|
90
|
+
device: Device to run on
|
|
91
|
+
keep_intermediate: Whether to keep intermediate files
|
|
92
|
+
lora_r: LoRA rank
|
|
93
|
+
lora_alpha: LoRA alpha
|
|
94
|
+
lora_dropout: LoRA dropout
|
|
95
|
+
learning_rate: Learning rate
|
|
96
|
+
num_epochs: Number of training epochs
|
|
97
|
+
batch_size: Training batch size
|
|
98
|
+
max_length: Max total sequence length
|
|
99
|
+
max_prompt_length: Max prompt length
|
|
100
|
+
beta: DPO beta parameter (controls deviation from reference model)
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Path to saved adapter
|
|
104
|
+
"""
|
|
105
|
+
output_path = Path(output_path)
|
|
106
|
+
|
|
107
|
+
# Step 1: Generate contrastive pairs
|
|
108
|
+
print(f"\n{'='*60}")
|
|
109
|
+
print(f"Step 1: Generating {num_pairs} preference pairs from {task}")
|
|
110
|
+
print(f"{'='*60}")
|
|
111
|
+
|
|
112
|
+
pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
|
|
113
|
+
print(f"Generated {len(pairs)} preference pairs")
|
|
114
|
+
|
|
115
|
+
# Step 2: Create DPO dataset
|
|
116
|
+
print(f"\n{'='*60}")
|
|
117
|
+
print(f"Step 2: Creating DPO dataset")
|
|
118
|
+
print(f"{'='*60}")
|
|
119
|
+
|
|
120
|
+
dataset = create_dpo_dataset(pairs)
|
|
121
|
+
print(f"Dataset size: {len(dataset)}")
|
|
122
|
+
|
|
123
|
+
# Step 3: Load model
|
|
124
|
+
print(f"\n{'='*60}")
|
|
125
|
+
print(f"Step 3: Loading model {model_name}")
|
|
126
|
+
print(f"{'='*60}")
|
|
127
|
+
|
|
128
|
+
model, tokenizer = load_model_and_tokenizer(model_name, device, eval_mode=False)
|
|
129
|
+
|
|
130
|
+
# Ensure tokenizer has padding
|
|
131
|
+
if tokenizer.pad_token is None:
|
|
132
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
133
|
+
tokenizer.padding_side = "left" # DPO typically uses left padding
|
|
134
|
+
|
|
135
|
+
# Step 4: Configure LoRA
|
|
136
|
+
print(f"\n{'='*60}")
|
|
137
|
+
print(f"Step 4: Configuring LoRA (r={lora_r}, alpha={lora_alpha})")
|
|
138
|
+
print(f"{'='*60}")
|
|
139
|
+
|
|
140
|
+
lora_config = LoraConfig(
|
|
141
|
+
task_type=TaskType.CAUSAL_LM,
|
|
142
|
+
r=lora_r,
|
|
143
|
+
lora_alpha=lora_alpha,
|
|
144
|
+
lora_dropout=lora_dropout,
|
|
145
|
+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
|
146
|
+
bias="none",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
model = get_peft_model(model, lora_config)
|
|
150
|
+
model.print_trainable_parameters()
|
|
151
|
+
|
|
152
|
+
# Step 5: Configure DPO training
|
|
153
|
+
print(f"\n{'='*60}")
|
|
154
|
+
print(f"Step 5: Configuring DPO training")
|
|
155
|
+
print(f"{'='*60}")
|
|
156
|
+
|
|
157
|
+
training_output_dir = tempfile.mkdtemp(prefix="lora_dpo_training_")
|
|
158
|
+
|
|
159
|
+
# Determine dtype
|
|
160
|
+
dtype = preferred_dtype(device)
|
|
161
|
+
|
|
162
|
+
training_args = DPOConfig(
|
|
163
|
+
output_dir=training_output_dir,
|
|
164
|
+
num_train_epochs=num_epochs,
|
|
165
|
+
per_device_train_batch_size=batch_size,
|
|
166
|
+
gradient_accumulation_steps=1,
|
|
167
|
+
learning_rate=learning_rate,
|
|
168
|
+
weight_decay=0.01,
|
|
169
|
+
warmup_ratio=0.1,
|
|
170
|
+
logging_steps=10,
|
|
171
|
+
save_strategy="no",
|
|
172
|
+
bf16=(dtype == torch.bfloat16),
|
|
173
|
+
fp16=(dtype == torch.float16),
|
|
174
|
+
report_to="none",
|
|
175
|
+
max_length=max_length,
|
|
176
|
+
max_prompt_length=max_prompt_length,
|
|
177
|
+
beta=beta,
|
|
178
|
+
loss_type="sigmoid", # Standard DPO loss
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
print(f"Beta: {beta}")
|
|
182
|
+
print(f"Max length: {max_length}")
|
|
183
|
+
print(f"Max prompt length: {max_prompt_length}")
|
|
184
|
+
print(f"Learning rate: {learning_rate}")
|
|
185
|
+
print(f"Epochs: {num_epochs}")
|
|
186
|
+
print(f"Batch size: {batch_size}")
|
|
187
|
+
|
|
188
|
+
# Step 6: Train with DPO
|
|
189
|
+
print(f"\n{'='*60}")
|
|
190
|
+
print(f"Step 6: Training with DPO")
|
|
191
|
+
print(f"{'='*60}")
|
|
192
|
+
|
|
193
|
+
trainer = DPOTrainer(
|
|
194
|
+
model=model,
|
|
195
|
+
args=training_args,
|
|
196
|
+
train_dataset=dataset,
|
|
197
|
+
processing_class=tokenizer,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
trainer.train()
|
|
201
|
+
|
|
202
|
+
# Step 7: Save adapter
|
|
203
|
+
print(f"\n{'='*60}")
|
|
204
|
+
print(f"Step 7: Saving LoRA adapter")
|
|
205
|
+
print(f"{'='*60}")
|
|
206
|
+
|
|
207
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
208
|
+
model.save_pretrained(output_path)
|
|
209
|
+
tokenizer.save_pretrained(output_path)
|
|
210
|
+
|
|
211
|
+
# Save metadata
|
|
212
|
+
metadata = {
|
|
213
|
+
"task": task,
|
|
214
|
+
"model": model_name,
|
|
215
|
+
"training_method": "dpo",
|
|
216
|
+
"num_pairs": len(pairs),
|
|
217
|
+
"lora_r": lora_r,
|
|
218
|
+
"lora_alpha": lora_alpha,
|
|
219
|
+
"lora_dropout": lora_dropout,
|
|
220
|
+
"learning_rate": learning_rate,
|
|
221
|
+
"num_epochs": num_epochs,
|
|
222
|
+
"batch_size": batch_size,
|
|
223
|
+
"max_length": max_length,
|
|
224
|
+
"max_prompt_length": max_prompt_length,
|
|
225
|
+
"beta": beta,
|
|
226
|
+
}
|
|
227
|
+
with open(output_path / "metadata.json", "w") as f:
|
|
228
|
+
json.dump(metadata, f, indent=2)
|
|
229
|
+
|
|
230
|
+
# Cleanup
|
|
231
|
+
del model, trainer
|
|
232
|
+
gc.collect()
|
|
233
|
+
if torch.cuda.is_available():
|
|
234
|
+
torch.cuda.empty_cache()
|
|
235
|
+
|
|
236
|
+
if not keep_intermediate:
|
|
237
|
+
import os
|
|
238
|
+
import shutil
|
|
239
|
+
os.unlink(pairs_file)
|
|
240
|
+
shutil.rmtree(training_output_dir, ignore_errors=True)
|
|
241
|
+
|
|
242
|
+
print(f"\nDPO LoRA adapter saved to {output_path}")
|
|
243
|
+
return output_path
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def evaluate_lora_dpo(
|
|
247
|
+
model_name: str,
|
|
248
|
+
lora_path: str | Path,
|
|
249
|
+
task: str,
|
|
250
|
+
train_ratio: float = 0.8,
|
|
251
|
+
device: str = "cuda:0",
|
|
252
|
+
batch_size: int = 1,
|
|
253
|
+
max_batch_size: int = 8,
|
|
254
|
+
limit: int | None = None,
|
|
255
|
+
output_dir: str | Path = None,
|
|
256
|
+
# Training metadata (for output)
|
|
257
|
+
num_train_pairs: int | None = None,
|
|
258
|
+
num_epochs: int | None = None,
|
|
259
|
+
lora_r: int | None = None,
|
|
260
|
+
lora_alpha: int | None = None,
|
|
261
|
+
lora_dropout: float | None = None,
|
|
262
|
+
learning_rate: float | None = None,
|
|
263
|
+
beta: float | None = None,
|
|
264
|
+
max_length: int | None = None,
|
|
265
|
+
max_prompt_length: int | None = None,
|
|
266
|
+
# Steering parameters (optional)
|
|
267
|
+
with_steering: bool = False,
|
|
268
|
+
steering_method: str = "caa",
|
|
269
|
+
steering_layers: str = "12",
|
|
270
|
+
steering_num_pairs: int = 50,
|
|
271
|
+
steering_scales: list[float] | None = None,
|
|
272
|
+
extraction_strategy: str = "mc_completion",
|
|
273
|
+
) -> dict:
|
|
274
|
+
"""
|
|
275
|
+
Evaluate a trained DPO LoRA adapter.
|
|
276
|
+
|
|
277
|
+
Compares base model vs DPO-LoRA model accuracy.
|
|
278
|
+
Optionally also evaluates DPO-LoRA + steering at multiple scales.
|
|
279
|
+
"""
|
|
280
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
281
|
+
from wisent.comparison.lora import apply_lora_to_model, remove_lora
|
|
282
|
+
|
|
283
|
+
lora_path = Path(lora_path)
|
|
284
|
+
|
|
285
|
+
if steering_scales is None:
|
|
286
|
+
steering_scales = [1.0, 2.0, 4.0]
|
|
287
|
+
|
|
288
|
+
# Create test task
|
|
289
|
+
print(f"\n{'='*60}")
|
|
290
|
+
print(f"Creating test task for: {task}")
|
|
291
|
+
print(f"{'='*60}")
|
|
292
|
+
|
|
293
|
+
task_dict = create_test_only_task(task, train_ratio=train_ratio)
|
|
294
|
+
|
|
295
|
+
# Load model
|
|
296
|
+
print(f"\n{'='*60}")
|
|
297
|
+
print(f"Loading model: {model_name}")
|
|
298
|
+
print(f"{'='*60}")
|
|
299
|
+
wisent_model = WisentModel(model_name=model_name, device=device)
|
|
300
|
+
|
|
301
|
+
# Base evaluation
|
|
302
|
+
print(f"\n{'='*60}")
|
|
303
|
+
print(f"Running BASE evaluation")
|
|
304
|
+
print(f"{'='*60}")
|
|
305
|
+
|
|
306
|
+
base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
307
|
+
base_acc_lm_eval = extract_accuracy(base_results, task)
|
|
308
|
+
print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
|
|
309
|
+
|
|
310
|
+
base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
311
|
+
print(f"Base accuracy (LL): {base_acc_ll:.4f}")
|
|
312
|
+
|
|
313
|
+
# Apply DPO LoRA
|
|
314
|
+
print(f"\n{'='*60}")
|
|
315
|
+
print(f"Applying DPO LoRA adapter from: {lora_path}")
|
|
316
|
+
print(f"{'='*60}")
|
|
317
|
+
apply_lora_to_model(wisent_model, lora_path)
|
|
318
|
+
|
|
319
|
+
# LoRA evaluation
|
|
320
|
+
print(f"\n{'='*60}")
|
|
321
|
+
print(f"Running DPO-LORA evaluation")
|
|
322
|
+
print(f"{'='*60}")
|
|
323
|
+
|
|
324
|
+
lora_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
325
|
+
lora_acc_lm_eval = extract_accuracy(lora_results, task)
|
|
326
|
+
print(f"DPO-LoRA accuracy (lm-eval): {lora_acc_lm_eval:.4f}")
|
|
327
|
+
|
|
328
|
+
lora_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
329
|
+
print(f"DPO-LoRA accuracy (LL): {lora_acc_ll:.4f}")
|
|
330
|
+
|
|
331
|
+
# Results dict
|
|
332
|
+
results = {
|
|
333
|
+
"task": task,
|
|
334
|
+
"model": model_name,
|
|
335
|
+
"training_method": "dpo",
|
|
336
|
+
"lora_path": str(lora_path),
|
|
337
|
+
# Training config
|
|
338
|
+
"num_train_pairs": num_train_pairs,
|
|
339
|
+
"num_epochs": num_epochs,
|
|
340
|
+
"lora_r": lora_r,
|
|
341
|
+
"lora_alpha": lora_alpha,
|
|
342
|
+
"lora_dropout": lora_dropout,
|
|
343
|
+
"learning_rate": learning_rate,
|
|
344
|
+
"beta": beta,
|
|
345
|
+
"max_length": max_length,
|
|
346
|
+
"max_prompt_length": max_prompt_length,
|
|
347
|
+
# Eval config
|
|
348
|
+
"train_ratio": train_ratio,
|
|
349
|
+
"eval_limit": limit,
|
|
350
|
+
# Results
|
|
351
|
+
"base_accuracy_lm_eval": base_acc_lm_eval,
|
|
352
|
+
"base_accuracy_ll": base_acc_ll,
|
|
353
|
+
"lora_accuracy_lm_eval": lora_acc_lm_eval,
|
|
354
|
+
"lora_accuracy_ll": lora_acc_ll,
|
|
355
|
+
"lora_diff_lm_eval": lora_acc_lm_eval - base_acc_lm_eval,
|
|
356
|
+
"lora_diff_ll": lora_acc_ll - base_acc_ll,
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
# DPO-LoRA + Steering evaluation (if enabled)
|
|
360
|
+
if with_steering:
|
|
361
|
+
from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
|
|
362
|
+
from wisent.core.steering_methods import get_steering_method
|
|
363
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
364
|
+
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
365
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
366
|
+
from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
|
|
367
|
+
|
|
368
|
+
# Generate contrastive pairs for steering
|
|
369
|
+
print(f"\n{'='*60}")
|
|
370
|
+
print(f"Generating {steering_num_pairs} contrastive pairs for steering")
|
|
371
|
+
print(f"{'='*60}")
|
|
372
|
+
pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
|
|
373
|
+
|
|
374
|
+
# Convert to ContrastivePairSet
|
|
375
|
+
pairs = []
|
|
376
|
+
for p in pairs_data:
|
|
377
|
+
pair = ContrastivePair(
|
|
378
|
+
prompt=p["prompt"],
|
|
379
|
+
positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
|
|
380
|
+
negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
|
|
381
|
+
)
|
|
382
|
+
pairs.append(pair)
|
|
383
|
+
pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_dpo_lora_steering")
|
|
384
|
+
print(f"Created {len(pair_set)} contrastive pairs")
|
|
385
|
+
|
|
386
|
+
# Generate steering vector on DPO-LoRA model
|
|
387
|
+
print(f"\n{'='*60}")
|
|
388
|
+
print(f"Generating {steering_method.upper()} steering vector on DPO-LoRA model")
|
|
389
|
+
print(f"Layers: {steering_layers}")
|
|
390
|
+
print(f"{'='*60}")
|
|
391
|
+
|
|
392
|
+
steering_method_obj = get_steering_method(steering_method, device=device)
|
|
393
|
+
strategy = ExtractionStrategy(extraction_strategy)
|
|
394
|
+
|
|
395
|
+
trainer = WisentSteeringTrainer(
|
|
396
|
+
model=wisent_model,
|
|
397
|
+
pair_set=pair_set,
|
|
398
|
+
steering_method=steering_method_obj,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
result = trainer.run(
|
|
402
|
+
layers_spec=steering_layers,
|
|
403
|
+
strategy=strategy,
|
|
404
|
+
accept_low_quality_vector=True,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Convert to dict format for apply_steering_to_model
|
|
408
|
+
steering_vectors = {}
|
|
409
|
+
for layer_name, tensor in result.steered_vectors.to_dict().items():
|
|
410
|
+
if tensor is not None:
|
|
411
|
+
steering_vectors[layer_name] = tensor.cpu().float().tolist()
|
|
412
|
+
|
|
413
|
+
steering_data = {
|
|
414
|
+
"steering_vectors": steering_vectors,
|
|
415
|
+
"layers": list(steering_vectors.keys()),
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
# Cleanup temp file
|
|
419
|
+
import os
|
|
420
|
+
os.unlink(pairs_file)
|
|
421
|
+
|
|
422
|
+
# Add steering info to results
|
|
423
|
+
results["steering"] = {
|
|
424
|
+
"method": steering_method,
|
|
425
|
+
"layers": list(steering_vectors.keys()),
|
|
426
|
+
"num_pairs": steering_num_pairs,
|
|
427
|
+
"extraction_strategy": extraction_strategy,
|
|
428
|
+
"scales": {},
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
# Evaluate at each scale
|
|
432
|
+
for scale in steering_scales:
|
|
433
|
+
print(f"\n{'='*60}")
|
|
434
|
+
print(f"Evaluating DPO-LoRA+{steering_method.upper()} at scale={scale}")
|
|
435
|
+
print(f"{'='*60}")
|
|
436
|
+
|
|
437
|
+
apply_steering_to_model(wisent_model, steering_data, scale=scale)
|
|
438
|
+
|
|
439
|
+
steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
440
|
+
steer_acc_lm_eval = extract_accuracy(steer_results, task)
|
|
441
|
+
print(f"DPO-LoRA+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
|
|
442
|
+
|
|
443
|
+
steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
444
|
+
print(f"DPO-LoRA+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
|
|
445
|
+
|
|
446
|
+
remove_steering(wisent_model)
|
|
447
|
+
|
|
448
|
+
results["steering"]["scales"][str(scale)] = {
|
|
449
|
+
"accuracy_lm_eval": steer_acc_lm_eval,
|
|
450
|
+
"accuracy_ll": steer_acc_ll,
|
|
451
|
+
"diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
|
|
452
|
+
"diff_from_base_ll": steer_acc_ll - base_acc_ll,
|
|
453
|
+
"diff_from_lora_lm_eval": steer_acc_lm_eval - lora_acc_lm_eval,
|
|
454
|
+
"diff_from_lora_ll": steer_acc_ll - lora_acc_ll,
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
# Cleanup
|
|
458
|
+
remove_lora(wisent_model)
|
|
459
|
+
del wisent_model
|
|
460
|
+
gc.collect()
|
|
461
|
+
if torch.cuda.is_available():
|
|
462
|
+
torch.cuda.empty_cache()
|
|
463
|
+
|
|
464
|
+
# Print summary
|
|
465
|
+
print(f"\n{'='*70}")
|
|
466
|
+
print(f"RESULTS SUMMARY")
|
|
467
|
+
print(f"{'='*70}")
|
|
468
|
+
print(f"Task: {task}")
|
|
469
|
+
print(f"Model: {model_name}")
|
|
470
|
+
print(f"Training: DPO")
|
|
471
|
+
print(f"{'-'*70}")
|
|
472
|
+
print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
|
|
473
|
+
print(f"{'-'*70}")
|
|
474
|
+
print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
|
|
475
|
+
print(f"{'DPO-LoRA':<25} {lora_acc_lm_eval:<15.4f} {lora_acc_ll:<15.4f} {lora_acc_lm_eval - base_acc_lm_eval:+.4f}")
|
|
476
|
+
|
|
477
|
+
if with_steering:
|
|
478
|
+
for scale, res in results["steering"]["scales"].items():
|
|
479
|
+
label = f"DPO-LoRA+{steering_method.upper()}@{scale}"
|
|
480
|
+
print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
|
|
481
|
+
|
|
482
|
+
print(f"{'='*70}")
|
|
483
|
+
|
|
484
|
+
# Save results
|
|
485
|
+
if output_dir:
|
|
486
|
+
output_dir = Path(output_dir)
|
|
487
|
+
model_dir_name = model_name.replace("/", "_")
|
|
488
|
+
output_dir = output_dir / model_dir_name
|
|
489
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
490
|
+
results_file = output_dir / f"{task}_lora_dpo_eval_results.json"
|
|
491
|
+
with open(results_file, "w") as f:
|
|
492
|
+
json.dump(results, f, indent=2)
|
|
493
|
+
print(f"\nResults saved to: {results_file}")
|
|
494
|
+
|
|
495
|
+
return results
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def main():
|
|
499
|
+
parser = argparse.ArgumentParser(description="Train and evaluate LoRA adapter using DPO")
|
|
500
|
+
parser.add_argument("--model", required=True, help="HuggingFace model name")
|
|
501
|
+
parser.add_argument("--task", default="boolq", help="lm-eval task name")
|
|
502
|
+
parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
|
|
503
|
+
parser.add_argument("--num-pairs", type=int, default=50, help="Number of preference pairs")
|
|
504
|
+
parser.add_argument("--device", default="cuda:0", help="Device")
|
|
505
|
+
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
|
506
|
+
parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
|
|
507
|
+
parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
|
|
508
|
+
parser.add_argument("--learning-rate", type=float, default=5e-5, help="Learning rate")
|
|
509
|
+
parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs")
|
|
510
|
+
parser.add_argument("--batch-size", type=int, default=1, help="Training batch size")
|
|
511
|
+
parser.add_argument("--max-length", type=int, default=512, help="Max total sequence length")
|
|
512
|
+
parser.add_argument("--max-prompt-length", type=int, default=256, help="Max prompt length")
|
|
513
|
+
parser.add_argument("--beta", type=float, default=0.1, help="DPO beta (controls KL penalty)")
|
|
514
|
+
parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
|
|
515
|
+
# Eval args
|
|
516
|
+
parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
|
|
517
|
+
parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size")
|
|
518
|
+
parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size")
|
|
519
|
+
parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
|
|
520
|
+
parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
|
|
521
|
+
# DPO-LoRA + Steering args
|
|
522
|
+
parser.add_argument("--with-steering", action="store_true", help="Also evaluate DPO-LoRA + steering")
|
|
523
|
+
parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
|
|
524
|
+
parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
|
|
525
|
+
parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
|
|
526
|
+
parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
|
|
527
|
+
parser.add_argument("--extraction-strategy", default="mc_balanced", help="Extraction strategy for steering")
|
|
528
|
+
|
|
529
|
+
args = parser.parse_args()
|
|
530
|
+
|
|
531
|
+
output_path = Path(args.output_dir) / f"{args.task}_lora_dpo_adapter"
|
|
532
|
+
|
|
533
|
+
# Train
|
|
534
|
+
train_lora_dpo(
|
|
535
|
+
task=args.task,
|
|
536
|
+
model_name=args.model,
|
|
537
|
+
output_path=output_path,
|
|
538
|
+
num_pairs=args.num_pairs,
|
|
539
|
+
device=args.device,
|
|
540
|
+
keep_intermediate=args.keep_intermediate,
|
|
541
|
+
lora_r=args.lora_r,
|
|
542
|
+
lora_alpha=args.lora_alpha,
|
|
543
|
+
lora_dropout=args.lora_dropout,
|
|
544
|
+
learning_rate=args.learning_rate,
|
|
545
|
+
num_epochs=args.num_epochs,
|
|
546
|
+
batch_size=args.batch_size,
|
|
547
|
+
max_length=args.max_length,
|
|
548
|
+
max_prompt_length=args.max_prompt_length,
|
|
549
|
+
beta=args.beta,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# Evaluate
|
|
553
|
+
if not args.skip_eval:
|
|
554
|
+
eval_batch_size = args.eval_batch_size
|
|
555
|
+
if eval_batch_size != "auto":
|
|
556
|
+
eval_batch_size = int(eval_batch_size)
|
|
557
|
+
|
|
558
|
+
# Parse steering scales
|
|
559
|
+
steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
|
|
560
|
+
|
|
561
|
+
evaluate_lora_dpo(
|
|
562
|
+
model_name=args.model,
|
|
563
|
+
lora_path=output_path,
|
|
564
|
+
task=args.task,
|
|
565
|
+
train_ratio=args.train_ratio,
|
|
566
|
+
device=args.device,
|
|
567
|
+
batch_size=eval_batch_size,
|
|
568
|
+
max_batch_size=args.eval_max_batch_size,
|
|
569
|
+
limit=args.eval_limit,
|
|
570
|
+
output_dir=args.output_dir,
|
|
571
|
+
# Training metadata
|
|
572
|
+
num_train_pairs=args.num_pairs,
|
|
573
|
+
num_epochs=args.num_epochs,
|
|
574
|
+
lora_r=args.lora_r,
|
|
575
|
+
lora_alpha=args.lora_alpha,
|
|
576
|
+
lora_dropout=args.lora_dropout,
|
|
577
|
+
learning_rate=args.learning_rate,
|
|
578
|
+
beta=args.beta,
|
|
579
|
+
max_length=args.max_length,
|
|
580
|
+
max_prompt_length=args.max_prompt_length,
|
|
581
|
+
# Steering parameters
|
|
582
|
+
with_steering=args.with_steering,
|
|
583
|
+
steering_method=args.steering_method,
|
|
584
|
+
steering_layers=args.steering_layers,
|
|
585
|
+
steering_num_pairs=args.steering_num_pairs,
|
|
586
|
+
steering_scales=steering_scales,
|
|
587
|
+
extraction_strategy=args.extraction_strategy,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
if __name__ == "__main__":
|
|
592
|
+
main()
|