wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +663 -0
- wisent/comparison/lora_dpo.py +604 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/reft.py +690 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,663 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LoRA fine-tuning method for comparison experiments.
|
|
3
|
+
|
|
4
|
+
Trains a LoRA adapter on benchmark tasks using supervised fine-tuning (SFT)
|
|
5
|
+
on positive responses from contrastive pairs.
|
|
6
|
+
|
|
7
|
+
Optionally evaluates LoRA + steering by generating a steering vector on the
|
|
8
|
+
LoRA model and combining both methods.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import gc
|
|
14
|
+
import json
|
|
15
|
+
import tempfile
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import TYPE_CHECKING
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from datasets import Dataset
|
|
21
|
+
from peft import LoraConfig, TaskType, get_peft_model
|
|
22
|
+
from trl import SFTTrainer, SFTConfig
|
|
23
|
+
|
|
24
|
+
from wisent.comparison.utils import (
|
|
25
|
+
generate_contrastive_pairs,
|
|
26
|
+
create_test_only_task,
|
|
27
|
+
extract_accuracy,
|
|
28
|
+
run_lm_eval_evaluation,
|
|
29
|
+
run_ll_evaluation,
|
|
30
|
+
load_model_and_tokenizer,
|
|
31
|
+
apply_steering_to_model,
|
|
32
|
+
remove_steering,
|
|
33
|
+
)
|
|
34
|
+
from wisent.core.utils.device import preferred_dtype
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
38
|
+
|
|
39
|
+
__all__ = ["train_lora_adapter", "evaluate_lora", "apply_lora_to_model", "remove_lora"]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Default LoRA configurations per model architecture
|
|
43
|
+
LORA_TARGET_MODULES = {
|
|
44
|
+
"gemma": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
45
|
+
"llama": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
46
|
+
"mistral": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
47
|
+
"phi": ["q_proj", "k_proj", "v_proj", "dense"],
|
|
48
|
+
"gpt_neo": ["q_proj", "v_proj"],
|
|
49
|
+
"gpt2": ["c_attn"],
|
|
50
|
+
"default": "all-linear",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_target_modules(model_name: str) -> str | list[str]:
|
|
55
|
+
"""Get LoRA target modules based on model architecture."""
|
|
56
|
+
model_name_lower = model_name.lower()
|
|
57
|
+
|
|
58
|
+
for arch, modules in LORA_TARGET_MODULES.items():
|
|
59
|
+
if arch in model_name_lower:
|
|
60
|
+
return modules
|
|
61
|
+
|
|
62
|
+
return LORA_TARGET_MODULES["default"]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def prepare_sft_dataset(pairs: list[dict], tokenizer) -> Dataset:
|
|
66
|
+
"""
|
|
67
|
+
Prepare dataset for SFT from contrastive pairs.
|
|
68
|
+
|
|
69
|
+
Uses only positive responses for training.
|
|
70
|
+
- Chat models: returns conversational format, SFTTrainer auto-applies chat template
|
|
71
|
+
- Base models: returns text format with simple formatting
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
pairs: List of contrastive pairs
|
|
75
|
+
tokenizer: Tokenizer to check for chat template support
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
HuggingFace Dataset ready for SFTTrainer
|
|
79
|
+
"""
|
|
80
|
+
formatted_examples = []
|
|
81
|
+
|
|
82
|
+
for pair in pairs:
|
|
83
|
+
prompt = pair["prompt"]
|
|
84
|
+
response = pair["positive_response"]["model_response"]
|
|
85
|
+
|
|
86
|
+
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
|
|
87
|
+
# Chat model: use conversational format, trainer applies template
|
|
88
|
+
formatted_examples.append({
|
|
89
|
+
"messages": [
|
|
90
|
+
{"role": "user", "content": prompt},
|
|
91
|
+
{"role": "assistant", "content": response},
|
|
92
|
+
]
|
|
93
|
+
})
|
|
94
|
+
else:
|
|
95
|
+
# Base model: use simple text format
|
|
96
|
+
formatted_examples.append({
|
|
97
|
+
"text": f"{prompt}\n{response}"
|
|
98
|
+
})
|
|
99
|
+
|
|
100
|
+
return Dataset.from_list(formatted_examples)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def train_lora_adapter(
|
|
104
|
+
task: str,
|
|
105
|
+
model_name: str,
|
|
106
|
+
output_path: str | Path,
|
|
107
|
+
trait_label: str = "correctness",
|
|
108
|
+
num_pairs: int = 50,
|
|
109
|
+
device: str = "cuda:0",
|
|
110
|
+
keep_intermediate: bool = False,
|
|
111
|
+
# LoRA-specific parameters
|
|
112
|
+
lora_r: int = 16,
|
|
113
|
+
lora_alpha: int = 32,
|
|
114
|
+
lora_dropout: float = 0.05,
|
|
115
|
+
learning_rate: float = 2e-4,
|
|
116
|
+
num_epochs: int = 3,
|
|
117
|
+
batch_size: int = 2,
|
|
118
|
+
max_length: int = 512,
|
|
119
|
+
) -> Path:
|
|
120
|
+
"""
|
|
121
|
+
Train a LoRA adapter using SFT on positive responses.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
task: lm-eval task name (e.g., 'boolq', 'cb')
|
|
125
|
+
model_name: HuggingFace model name
|
|
126
|
+
output_path: Where to save the LoRA adapter
|
|
127
|
+
trait_label: Label for the trait being trained
|
|
128
|
+
num_pairs: Number of training examples to use
|
|
129
|
+
device: Device to train on
|
|
130
|
+
keep_intermediate: Whether to keep intermediate files
|
|
131
|
+
lora_r: LoRA rank
|
|
132
|
+
lora_alpha: LoRA alpha scaling factor
|
|
133
|
+
lora_dropout: LoRA dropout
|
|
134
|
+
learning_rate: Training learning rate
|
|
135
|
+
num_epochs: Number of training epochs
|
|
136
|
+
batch_size: Training batch size
|
|
137
|
+
max_length: Maximum sequence length
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Path to the saved LoRA adapter directory
|
|
141
|
+
"""
|
|
142
|
+
import gc
|
|
143
|
+
|
|
144
|
+
output_path = Path(output_path)
|
|
145
|
+
|
|
146
|
+
# Step 1: Generate contrastive pairs
|
|
147
|
+
print(f"Step 1: Generating training data from task: {task}")
|
|
148
|
+
pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
|
|
149
|
+
print(f" Loaded {len(pairs)} training examples")
|
|
150
|
+
|
|
151
|
+
# Step 2: Load model and tokenizer
|
|
152
|
+
print(f"\nStep 2: Loading model {model_name}...")
|
|
153
|
+
model, tokenizer = load_model_and_tokenizer(model_name, device, eval_mode=False)
|
|
154
|
+
|
|
155
|
+
# Step 3: Configure LoRA
|
|
156
|
+
print(f"\nStep 3: Configuring LoRA (r={lora_r}, alpha={lora_alpha})...")
|
|
157
|
+
|
|
158
|
+
target_modules = get_target_modules(model_name)
|
|
159
|
+
print(f" Target modules: {target_modules}")
|
|
160
|
+
|
|
161
|
+
lora_config = LoraConfig(
|
|
162
|
+
r=lora_r,
|
|
163
|
+
lora_alpha=lora_alpha,
|
|
164
|
+
target_modules=target_modules,
|
|
165
|
+
lora_dropout=lora_dropout,
|
|
166
|
+
bias="none",
|
|
167
|
+
task_type=TaskType.CAUSAL_LM,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
model = get_peft_model(model, lora_config)
|
|
171
|
+
model.print_trainable_parameters()
|
|
172
|
+
|
|
173
|
+
# Step 4: Prepare dataset
|
|
174
|
+
print(f"\nStep 4: Preparing SFT dataset...")
|
|
175
|
+
train_dataset = prepare_sft_dataset(pairs, tokenizer)
|
|
176
|
+
print(f" Dataset size: {len(train_dataset)} examples")
|
|
177
|
+
|
|
178
|
+
# Step 5: Training
|
|
179
|
+
print(f"\nStep 5: Training LoRA adapter...")
|
|
180
|
+
|
|
181
|
+
# Create temporary directory for training outputs
|
|
182
|
+
training_output_dir = tempfile.mkdtemp(prefix="lora_training_")
|
|
183
|
+
|
|
184
|
+
# Use device-optimized dtype (bfloat16 on CUDA, float16 on MPS, float32 on CPU)
|
|
185
|
+
dtype = preferred_dtype(device)
|
|
186
|
+
|
|
187
|
+
training_args = SFTConfig(
|
|
188
|
+
output_dir=training_output_dir,
|
|
189
|
+
num_train_epochs=num_epochs,
|
|
190
|
+
per_device_train_batch_size=batch_size,
|
|
191
|
+
gradient_accumulation_steps=1,
|
|
192
|
+
learning_rate=learning_rate,
|
|
193
|
+
weight_decay=0.01,
|
|
194
|
+
warmup_ratio=0.1,
|
|
195
|
+
logging_steps=10,
|
|
196
|
+
save_strategy="no", # Don't save checkpoints
|
|
197
|
+
bf16=(dtype == torch.bfloat16),
|
|
198
|
+
fp16=(dtype == torch.float16),
|
|
199
|
+
report_to="none", # Disable wandb/tensorboard
|
|
200
|
+
max_seq_length=max_length,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
trainer = SFTTrainer(
|
|
204
|
+
model=model,
|
|
205
|
+
args=training_args,
|
|
206
|
+
train_dataset=train_dataset,
|
|
207
|
+
processing_class=tokenizer,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
trainer.train()
|
|
211
|
+
|
|
212
|
+
# Step 6: Save LoRA adapter
|
|
213
|
+
print(f"\nStep 6: Saving LoRA adapter to {output_path}...")
|
|
214
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
215
|
+
model.save_pretrained(output_path)
|
|
216
|
+
tokenizer.save_pretrained(output_path)
|
|
217
|
+
|
|
218
|
+
# Save metadata
|
|
219
|
+
metadata = {
|
|
220
|
+
"method": "lora",
|
|
221
|
+
"model": model_name,
|
|
222
|
+
"task": task,
|
|
223
|
+
"trait_label": trait_label,
|
|
224
|
+
"num_pairs": len(pairs),
|
|
225
|
+
"lora_config": {
|
|
226
|
+
"r": lora_r,
|
|
227
|
+
"alpha": lora_alpha,
|
|
228
|
+
"dropout": lora_dropout,
|
|
229
|
+
"target_modules": target_modules if isinstance(target_modules, list) else [target_modules],
|
|
230
|
+
},
|
|
231
|
+
"training_config": {
|
|
232
|
+
"learning_rate": learning_rate,
|
|
233
|
+
"num_epochs": num_epochs,
|
|
234
|
+
"batch_size": batch_size,
|
|
235
|
+
"max_length": max_length,
|
|
236
|
+
},
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
with open(output_path / "metadata.json", "w") as f:
|
|
240
|
+
json.dump(metadata, f, indent=2)
|
|
241
|
+
|
|
242
|
+
# Cleanup
|
|
243
|
+
del model, trainer
|
|
244
|
+
gc.collect()
|
|
245
|
+
if torch.cuda.is_available():
|
|
246
|
+
torch.cuda.empty_cache()
|
|
247
|
+
torch.cuda.synchronize()
|
|
248
|
+
|
|
249
|
+
if not keep_intermediate:
|
|
250
|
+
import os
|
|
251
|
+
os.unlink(pairs_file)
|
|
252
|
+
import shutil
|
|
253
|
+
shutil.rmtree(training_output_dir, ignore_errors=True)
|
|
254
|
+
|
|
255
|
+
print(f"\nLoRA adapter saved to {output_path}")
|
|
256
|
+
return output_path
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def apply_lora_to_model(wisent_model: "WisentModel", lora_path: str | Path) -> None:
|
|
260
|
+
"""
|
|
261
|
+
Apply a trained LoRA adapter to a WisentModel.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
wisent_model: WisentModel instance
|
|
265
|
+
lora_path: Path to the saved LoRA adapter
|
|
266
|
+
"""
|
|
267
|
+
from peft import PeftModel
|
|
268
|
+
|
|
269
|
+
lora_path = Path(lora_path)
|
|
270
|
+
|
|
271
|
+
# Check if model already has adapters
|
|
272
|
+
if hasattr(wisent_model.hf_model, 'peft_config'):
|
|
273
|
+
# Model already has PEFT, just load new adapter
|
|
274
|
+
wisent_model.hf_model.load_adapter(str(lora_path), adapter_name="steering")
|
|
275
|
+
wisent_model.hf_model.set_adapter("steering")
|
|
276
|
+
else:
|
|
277
|
+
# Wrap model with PEFT
|
|
278
|
+
wisent_model.hf_model = PeftModel.from_pretrained(
|
|
279
|
+
wisent_model.hf_model,
|
|
280
|
+
str(lora_path),
|
|
281
|
+
adapter_name="steering",
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
print(f"LoRA adapter loaded from {lora_path}")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def remove_lora(wisent_model: "WisentModel") -> None:
|
|
288
|
+
"""
|
|
289
|
+
Remove/disable LoRA adapter from a WisentModel.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
wisent_model: WisentModel instance with LoRA applied
|
|
293
|
+
"""
|
|
294
|
+
if hasattr(wisent_model.hf_model, 'disable_adapters'):
|
|
295
|
+
try:
|
|
296
|
+
wisent_model.hf_model.disable_adapters()
|
|
297
|
+
print("LoRA adapter disabled")
|
|
298
|
+
except ValueError:
|
|
299
|
+
# No adapter was loaded
|
|
300
|
+
pass
|
|
301
|
+
elif hasattr(wisent_model.hf_model, 'base_model'):
|
|
302
|
+
# Unwrap the model
|
|
303
|
+
wisent_model.hf_model = wisent_model.hf_model.base_model.model
|
|
304
|
+
print("LoRA adapter removed")
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def evaluate_lora(
|
|
308
|
+
model_name: str,
|
|
309
|
+
lora_path: str | Path,
|
|
310
|
+
task: str,
|
|
311
|
+
train_ratio: float = 0.8,
|
|
312
|
+
device: str = "cuda:0",
|
|
313
|
+
batch_size: int = 1,
|
|
314
|
+
max_batch_size: int = 8,
|
|
315
|
+
limit: int | None = None,
|
|
316
|
+
output_dir: str | Path = None,
|
|
317
|
+
# Training metadata (for output)
|
|
318
|
+
num_train_pairs: int | None = None,
|
|
319
|
+
num_epochs: int | None = None,
|
|
320
|
+
lora_r: int | None = None,
|
|
321
|
+
lora_alpha: int | None = None,
|
|
322
|
+
lora_dropout: float | None = None,
|
|
323
|
+
learning_rate: float | None = None,
|
|
324
|
+
# Steering parameters (optional)
|
|
325
|
+
with_steering: bool = False,
|
|
326
|
+
steering_method: str = "caa",
|
|
327
|
+
steering_layers: str = "12",
|
|
328
|
+
steering_num_pairs: int = 50,
|
|
329
|
+
steering_scales: list[float] | None = None,
|
|
330
|
+
extraction_strategy: str = "mc_completion",
|
|
331
|
+
) -> dict:
|
|
332
|
+
"""
|
|
333
|
+
Evaluate a trained LoRA adapter comparing base vs LoRA performance.
|
|
334
|
+
|
|
335
|
+
Optionally also evaluates LoRA + steering at multiple scales.
|
|
336
|
+
All results are saved to a single output file.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
model_name: HuggingFace model name
|
|
340
|
+
lora_path: Path to trained LoRA adapter
|
|
341
|
+
task: lm-eval task name
|
|
342
|
+
train_ratio: Train/test split ratio
|
|
343
|
+
device: Device to run on
|
|
344
|
+
batch_size: Batch size for evaluation
|
|
345
|
+
max_batch_size: Max batch size
|
|
346
|
+
limit: Limit number of eval examples
|
|
347
|
+
output_dir: Where to save results
|
|
348
|
+
with_steering: Whether to also evaluate LoRA + steering
|
|
349
|
+
steering_method: Steering method (caa or fgaa)
|
|
350
|
+
steering_layers: Layers for steering vector
|
|
351
|
+
steering_num_pairs: Number of pairs for steering generation
|
|
352
|
+
steering_scales: List of steering scales to evaluate
|
|
353
|
+
extraction_strategy: Strategy for activation extraction
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Dict with evaluation results
|
|
357
|
+
"""
|
|
358
|
+
import gc
|
|
359
|
+
|
|
360
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
361
|
+
|
|
362
|
+
lora_path = Path(lora_path)
|
|
363
|
+
|
|
364
|
+
if steering_scales is None:
|
|
365
|
+
steering_scales = [1.0, 2.0, 4.0]
|
|
366
|
+
|
|
367
|
+
# Create test task
|
|
368
|
+
print(f"\n{'='*60}")
|
|
369
|
+
print(f"Creating test task for: {task}")
|
|
370
|
+
print(f"{'='*60}")
|
|
371
|
+
|
|
372
|
+
task_dict = create_test_only_task(task, train_ratio=train_ratio)
|
|
373
|
+
|
|
374
|
+
# Load model
|
|
375
|
+
print(f"\n{'='*60}")
|
|
376
|
+
print(f"Loading model: {model_name}")
|
|
377
|
+
print(f"{'='*60}")
|
|
378
|
+
wisent_model = WisentModel(model_name=model_name, device=device)
|
|
379
|
+
|
|
380
|
+
# BASE evaluation
|
|
381
|
+
print(f"\n{'='*60}")
|
|
382
|
+
print(f"Running BASE evaluation (no LoRA)")
|
|
383
|
+
print(f"{'='*60}")
|
|
384
|
+
|
|
385
|
+
base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
386
|
+
base_acc_lm_eval = extract_accuracy(base_results, task)
|
|
387
|
+
print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
|
|
388
|
+
|
|
389
|
+
base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
390
|
+
print(f"Base accuracy (LL): {base_acc_ll:.4f}")
|
|
391
|
+
|
|
392
|
+
# Apply LoRA
|
|
393
|
+
print(f"\n{'='*60}")
|
|
394
|
+
print(f"Applying LoRA adapter from: {lora_path}")
|
|
395
|
+
print(f"{'='*60}")
|
|
396
|
+
apply_lora_to_model(wisent_model, lora_path)
|
|
397
|
+
|
|
398
|
+
# LORA evaluation
|
|
399
|
+
print(f"\n{'='*60}")
|
|
400
|
+
print(f"Running LORA evaluation")
|
|
401
|
+
print(f"{'='*60}")
|
|
402
|
+
|
|
403
|
+
lora_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
404
|
+
lora_acc_lm_eval = extract_accuracy(lora_results, task)
|
|
405
|
+
print(f"LoRA accuracy (lm-eval): {lora_acc_lm_eval:.4f}")
|
|
406
|
+
|
|
407
|
+
lora_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
408
|
+
print(f"LoRA accuracy (LL): {lora_acc_ll:.4f}")
|
|
409
|
+
|
|
410
|
+
# Results dict
|
|
411
|
+
results = {
|
|
412
|
+
"task": task,
|
|
413
|
+
"model": model_name,
|
|
414
|
+
"lora_path": str(lora_path),
|
|
415
|
+
# Training config
|
|
416
|
+
"num_train_pairs": num_train_pairs,
|
|
417
|
+
"num_epochs": num_epochs,
|
|
418
|
+
"lora_r": lora_r,
|
|
419
|
+
"lora_alpha": lora_alpha,
|
|
420
|
+
"lora_dropout": lora_dropout,
|
|
421
|
+
"learning_rate": learning_rate,
|
|
422
|
+
# Eval config
|
|
423
|
+
"train_ratio": train_ratio,
|
|
424
|
+
"eval_limit": limit,
|
|
425
|
+
# Results
|
|
426
|
+
"base_accuracy_lm_eval": base_acc_lm_eval,
|
|
427
|
+
"base_accuracy_ll": base_acc_ll,
|
|
428
|
+
"lora_accuracy_lm_eval": lora_acc_lm_eval,
|
|
429
|
+
"lora_accuracy_ll": lora_acc_ll,
|
|
430
|
+
"lora_diff_lm_eval": lora_acc_lm_eval - base_acc_lm_eval,
|
|
431
|
+
"lora_diff_ll": lora_acc_ll - base_acc_ll,
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
# LoRA + Steering evaluation (if enabled)
|
|
435
|
+
if with_steering:
|
|
436
|
+
from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
|
|
437
|
+
from wisent.core.steering_methods import get_steering_method
|
|
438
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
439
|
+
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
440
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
441
|
+
from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
|
|
442
|
+
|
|
443
|
+
# Generate contrastive pairs for steering
|
|
444
|
+
print(f"\n{'='*60}")
|
|
445
|
+
print(f"Generating {steering_num_pairs} contrastive pairs for steering")
|
|
446
|
+
print(f"{'='*60}")
|
|
447
|
+
pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
|
|
448
|
+
|
|
449
|
+
# Convert to ContrastivePairSet
|
|
450
|
+
pairs = []
|
|
451
|
+
for p in pairs_data:
|
|
452
|
+
pair = ContrastivePair(
|
|
453
|
+
prompt=p["prompt"],
|
|
454
|
+
positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
|
|
455
|
+
negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
|
|
456
|
+
)
|
|
457
|
+
pairs.append(pair)
|
|
458
|
+
pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_lora_steering")
|
|
459
|
+
print(f"Created {len(pair_set)} contrastive pairs")
|
|
460
|
+
|
|
461
|
+
# Generate steering vector on LoRA model
|
|
462
|
+
print(f"\n{'='*60}")
|
|
463
|
+
print(f"Generating {steering_method.upper()} steering vector on LoRA model")
|
|
464
|
+
print(f"Layers: {steering_layers}")
|
|
465
|
+
print(f"{'='*60}")
|
|
466
|
+
|
|
467
|
+
steering_method_obj = get_steering_method(steering_method, device=device)
|
|
468
|
+
strategy = ExtractionStrategy(extraction_strategy)
|
|
469
|
+
|
|
470
|
+
trainer = WisentSteeringTrainer(
|
|
471
|
+
model=wisent_model,
|
|
472
|
+
pair_set=pair_set,
|
|
473
|
+
steering_method=steering_method_obj,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
result = trainer.run(
|
|
477
|
+
layers_spec=steering_layers,
|
|
478
|
+
strategy=strategy,
|
|
479
|
+
accept_low_quality_vector=True,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Convert to dict format for apply_steering_to_model
|
|
483
|
+
steering_vectors = {}
|
|
484
|
+
for layer_name, tensor in result.steered_vectors.to_dict().items():
|
|
485
|
+
if tensor is not None:
|
|
486
|
+
steering_vectors[layer_name] = tensor.cpu().float().tolist()
|
|
487
|
+
|
|
488
|
+
steering_data = {
|
|
489
|
+
"steering_vectors": steering_vectors,
|
|
490
|
+
"layers": list(steering_vectors.keys()),
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
# Cleanup temp file
|
|
494
|
+
import os
|
|
495
|
+
os.unlink(pairs_file)
|
|
496
|
+
|
|
497
|
+
# Add steering info to results
|
|
498
|
+
results["steering"] = {
|
|
499
|
+
"method": steering_method,
|
|
500
|
+
"layers": list(steering_vectors.keys()),
|
|
501
|
+
"num_pairs": steering_num_pairs,
|
|
502
|
+
"extraction_strategy": extraction_strategy,
|
|
503
|
+
"scales": {},
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
# Evaluate at each scale
|
|
507
|
+
for scale in steering_scales:
|
|
508
|
+
print(f"\n{'='*60}")
|
|
509
|
+
print(f"Evaluating LoRA+{steering_method.upper()} at scale={scale}")
|
|
510
|
+
print(f"{'='*60}")
|
|
511
|
+
|
|
512
|
+
apply_steering_to_model(wisent_model, steering_data, scale=scale)
|
|
513
|
+
|
|
514
|
+
steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
515
|
+
steer_acc_lm_eval = extract_accuracy(steer_results, task)
|
|
516
|
+
print(f"LoRA+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
|
|
517
|
+
|
|
518
|
+
steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
519
|
+
print(f"LoRA+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
|
|
520
|
+
|
|
521
|
+
remove_steering(wisent_model)
|
|
522
|
+
|
|
523
|
+
results["steering"]["scales"][str(scale)] = {
|
|
524
|
+
"accuracy_lm_eval": steer_acc_lm_eval,
|
|
525
|
+
"accuracy_ll": steer_acc_ll,
|
|
526
|
+
"diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
|
|
527
|
+
"diff_from_base_ll": steer_acc_ll - base_acc_ll,
|
|
528
|
+
"diff_from_lora_lm_eval": steer_acc_lm_eval - lora_acc_lm_eval,
|
|
529
|
+
"diff_from_lora_ll": steer_acc_ll - lora_acc_ll,
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
# Cleanup
|
|
533
|
+
remove_lora(wisent_model)
|
|
534
|
+
del wisent_model
|
|
535
|
+
gc.collect()
|
|
536
|
+
if torch.cuda.is_available():
|
|
537
|
+
torch.cuda.empty_cache()
|
|
538
|
+
|
|
539
|
+
# Print summary
|
|
540
|
+
print(f"\n{'='*70}")
|
|
541
|
+
print(f"RESULTS SUMMARY")
|
|
542
|
+
print(f"{'='*70}")
|
|
543
|
+
print(f"Task: {task}")
|
|
544
|
+
print(f"Model: {model_name}")
|
|
545
|
+
print(f"LoRA: {lora_path}")
|
|
546
|
+
print(f"{'-'*70}")
|
|
547
|
+
print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
|
|
548
|
+
print(f"{'-'*70}")
|
|
549
|
+
print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
|
|
550
|
+
print(f"{'LoRA':<25} {lora_acc_lm_eval:<15.4f} {lora_acc_ll:<15.4f} {lora_acc_lm_eval - base_acc_lm_eval:+.4f}")
|
|
551
|
+
|
|
552
|
+
if with_steering:
|
|
553
|
+
for scale, res in results["steering"]["scales"].items():
|
|
554
|
+
label = f"LoRA+{steering_method.upper()}@{scale}"
|
|
555
|
+
print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
|
|
556
|
+
|
|
557
|
+
print(f"{'='*70}")
|
|
558
|
+
|
|
559
|
+
# Save results
|
|
560
|
+
if output_dir:
|
|
561
|
+
output_dir = Path(output_dir)
|
|
562
|
+
model_dir_name = model_name.replace("/", "_")
|
|
563
|
+
output_dir = output_dir / model_dir_name
|
|
564
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
565
|
+
results_file = output_dir / f"{task}_lora_eval_results.json"
|
|
566
|
+
with open(results_file, "w") as f:
|
|
567
|
+
json.dump(results, f, indent=2)
|
|
568
|
+
print(f"\nResults saved to: {results_file}")
|
|
569
|
+
|
|
570
|
+
return results
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def main():
|
|
574
|
+
import argparse
|
|
575
|
+
|
|
576
|
+
parser = argparse.ArgumentParser(description="Train and evaluate LoRA adapter on benchmark task")
|
|
577
|
+
parser.add_argument("--model", required=True, help="HuggingFace model name")
|
|
578
|
+
parser.add_argument("--task", default="boolq", help="lm-eval task name")
|
|
579
|
+
parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
|
|
580
|
+
parser.add_argument("--num-pairs", type=int, default=50, help="Number of training examples")
|
|
581
|
+
parser.add_argument("--device", default="cuda:0", help="Device")
|
|
582
|
+
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
|
583
|
+
parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
|
|
584
|
+
parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
|
|
585
|
+
parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
|
|
586
|
+
parser.add_argument("--num-epochs", type=int, default=3, help="Number of epochs")
|
|
587
|
+
parser.add_argument("--batch-size", type=int, default=2, help="Training batch size")
|
|
588
|
+
parser.add_argument("--max-length", type=int, default=512, help="Max sequence length")
|
|
589
|
+
parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
|
|
590
|
+
# Eval args
|
|
591
|
+
parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
|
|
592
|
+
parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size (int or 'auto')")
|
|
593
|
+
parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size for auto")
|
|
594
|
+
parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
|
|
595
|
+
parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
|
|
596
|
+
# LoRA + Steering args
|
|
597
|
+
parser.add_argument("--with-steering", action="store_true", help="Also evaluate LoRA + steering")
|
|
598
|
+
parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
|
|
599
|
+
parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
|
|
600
|
+
parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
|
|
601
|
+
parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
|
|
602
|
+
parser.add_argument("--extraction-strategy", default="mc_balanced", help="Extraction strategy for steering")
|
|
603
|
+
|
|
604
|
+
args = parser.parse_args()
|
|
605
|
+
|
|
606
|
+
output_path = Path(args.output_dir) / f"{args.task}_lora_adapter"
|
|
607
|
+
|
|
608
|
+
# Train
|
|
609
|
+
train_lora_adapter(
|
|
610
|
+
task=args.task,
|
|
611
|
+
model_name=args.model,
|
|
612
|
+
output_path=output_path,
|
|
613
|
+
num_pairs=args.num_pairs,
|
|
614
|
+
device=args.device,
|
|
615
|
+
keep_intermediate=args.keep_intermediate,
|
|
616
|
+
lora_r=args.lora_r,
|
|
617
|
+
lora_alpha=args.lora_alpha,
|
|
618
|
+
lora_dropout=args.lora_dropout,
|
|
619
|
+
learning_rate=args.learning_rate,
|
|
620
|
+
num_epochs=args.num_epochs,
|
|
621
|
+
batch_size=args.batch_size,
|
|
622
|
+
max_length=args.max_length,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Evaluate base vs LoRA (and optionally LoRA + steering)
|
|
626
|
+
if not args.skip_eval:
|
|
627
|
+
# Parse eval batch size (can be "auto" or int)
|
|
628
|
+
eval_batch_size = args.eval_batch_size
|
|
629
|
+
if eval_batch_size != "auto":
|
|
630
|
+
eval_batch_size = int(eval_batch_size)
|
|
631
|
+
|
|
632
|
+
# Parse steering scales
|
|
633
|
+
steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
|
|
634
|
+
|
|
635
|
+
evaluate_lora(
|
|
636
|
+
model_name=args.model,
|
|
637
|
+
lora_path=output_path,
|
|
638
|
+
task=args.task,
|
|
639
|
+
train_ratio=args.train_ratio,
|
|
640
|
+
device=args.device,
|
|
641
|
+
batch_size=eval_batch_size,
|
|
642
|
+
max_batch_size=args.eval_max_batch_size,
|
|
643
|
+
limit=args.eval_limit,
|
|
644
|
+
output_dir=args.output_dir,
|
|
645
|
+
# Training metadata
|
|
646
|
+
num_train_pairs=args.num_pairs,
|
|
647
|
+
num_epochs=args.num_epochs,
|
|
648
|
+
lora_r=args.lora_r,
|
|
649
|
+
lora_alpha=args.lora_alpha,
|
|
650
|
+
lora_dropout=args.lora_dropout,
|
|
651
|
+
learning_rate=args.learning_rate,
|
|
652
|
+
# Steering parameters
|
|
653
|
+
with_steering=args.with_steering,
|
|
654
|
+
steering_method=args.steering_method,
|
|
655
|
+
steering_layers=args.steering_layers,
|
|
656
|
+
steering_num_pairs=args.steering_num_pairs,
|
|
657
|
+
steering_scales=steering_scales,
|
|
658
|
+
extraction_strategy=args.extraction_strategy,
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
if __name__ == "__main__":
|
|
663
|
+
main()
|