wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +663 -0
- wisent/comparison/lora_dpo.py +604 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/reft.py +690 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ReFT (Representation Fine-Tuning) method for comparison experiments.
|
|
3
|
+
|
|
4
|
+
Trains a LoReFT intervention on benchmark tasks using supervised fine-tuning (SFT)
|
|
5
|
+
on positive responses from contrastive pairs.
|
|
6
|
+
|
|
7
|
+
LoReFT operates on hidden representations rather than weights, making it
|
|
8
|
+
10-50x more parameter-efficient than LoRA.
|
|
9
|
+
|
|
10
|
+
Based on: "ReFT: Representation Finetuning for Language Models" (arXiv:2404.03592)
|
|
11
|
+
Uses pyreft library from Stanford NLP.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import gc
|
|
17
|
+
import json
|
|
18
|
+
import tempfile
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from datasets import Dataset
|
|
24
|
+
|
|
25
|
+
from wisent.comparison.utils import (
|
|
26
|
+
generate_contrastive_pairs,
|
|
27
|
+
create_test_only_task,
|
|
28
|
+
extract_accuracy,
|
|
29
|
+
run_lm_eval_evaluation,
|
|
30
|
+
run_ll_evaluation,
|
|
31
|
+
load_model_and_tokenizer,
|
|
32
|
+
apply_steering_to_model,
|
|
33
|
+
remove_steering,
|
|
34
|
+
)
|
|
35
|
+
from wisent.core.utils.device import preferred_dtype
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
39
|
+
|
|
40
|
+
__all__ = ["train_reft_adapter", "evaluate_reft", "apply_reft_to_model", "remove_reft"]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Default intervention layers per model (middle layer)
|
|
44
|
+
DEFAULT_INTERVENTION_LAYERS = {
|
|
45
|
+
"gemma": 21, # gemma-2-9b has 42 layers
|
|
46
|
+
"llama": 16, # llama-3.1-8b has 32 layers
|
|
47
|
+
"mistral": 16, # mistral-7b has 32 layers
|
|
48
|
+
"phi": 16,
|
|
49
|
+
"default": 12,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_default_layer(model_name: str) -> int:
|
|
54
|
+
"""Get default intervention layer based on model architecture."""
|
|
55
|
+
model_name_lower = model_name.lower()
|
|
56
|
+
|
|
57
|
+
for arch, layer in DEFAULT_INTERVENTION_LAYERS.items():
|
|
58
|
+
if arch in model_name_lower:
|
|
59
|
+
return layer
|
|
60
|
+
|
|
61
|
+
return DEFAULT_INTERVENTION_LAYERS["default"]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def prepare_reft_dataset(
|
|
65
|
+
pairs: list[dict],
|
|
66
|
+
tokenizer,
|
|
67
|
+
max_length: int = 512,
|
|
68
|
+
) -> tuple[list[str], list[str]]:
|
|
69
|
+
"""
|
|
70
|
+
Prepare dataset for ReFT training from contrastive pairs.
|
|
71
|
+
|
|
72
|
+
Uses only positive responses for training.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
pairs: List of contrastive pairs
|
|
76
|
+
tokenizer: Tokenizer for formatting
|
|
77
|
+
max_length: Maximum sequence length
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Tuple of (prompts, responses) lists
|
|
81
|
+
"""
|
|
82
|
+
prompts = []
|
|
83
|
+
responses = []
|
|
84
|
+
|
|
85
|
+
for pair in pairs:
|
|
86
|
+
prompt = pair["prompt"]
|
|
87
|
+
positive_response = pair["positive_response"]["model_response"]
|
|
88
|
+
|
|
89
|
+
# Format as chat if tokenizer supports it
|
|
90
|
+
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
|
|
91
|
+
# For chat models, format as conversation
|
|
92
|
+
messages = [{"role": "user", "content": prompt}]
|
|
93
|
+
formatted_prompt = tokenizer.apply_chat_template(
|
|
94
|
+
messages,
|
|
95
|
+
tokenize=False,
|
|
96
|
+
add_generation_prompt=True,
|
|
97
|
+
)
|
|
98
|
+
else:
|
|
99
|
+
# Simple format for base models
|
|
100
|
+
formatted_prompt = f"{prompt}\n"
|
|
101
|
+
|
|
102
|
+
prompts.append(formatted_prompt)
|
|
103
|
+
responses.append(positive_response)
|
|
104
|
+
|
|
105
|
+
return prompts, responses
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def train_reft_adapter(
|
|
109
|
+
task: str,
|
|
110
|
+
model_name: str,
|
|
111
|
+
output_path: str | Path,
|
|
112
|
+
trait_label: str = "correctness",
|
|
113
|
+
num_pairs: int = 50,
|
|
114
|
+
device: str = "cuda:0",
|
|
115
|
+
keep_intermediate: bool = False,
|
|
116
|
+
# ReFT-specific parameters
|
|
117
|
+
low_rank_dimension: int = 4,
|
|
118
|
+
intervention_layers: str | None = None,
|
|
119
|
+
learning_rate: float = 5e-4,
|
|
120
|
+
num_epochs: int = 3,
|
|
121
|
+
batch_size: int = 2,
|
|
122
|
+
max_length: int = 512,
|
|
123
|
+
) -> Path:
|
|
124
|
+
"""
|
|
125
|
+
Train a LoReFT intervention using SFT on positive responses.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
task: lm-eval task name (e.g., 'boolq', 'cb')
|
|
129
|
+
model_name: HuggingFace model name
|
|
130
|
+
output_path: Where to save the ReFT intervention
|
|
131
|
+
trait_label: Label for the trait being trained
|
|
132
|
+
num_pairs: Number of training examples to use
|
|
133
|
+
device: Device to train on
|
|
134
|
+
keep_intermediate: Whether to keep intermediate files
|
|
135
|
+
low_rank_dimension: Rank for LoReFT (default: 4, very small!)
|
|
136
|
+
intervention_layers: Comma-separated layers or None for default
|
|
137
|
+
learning_rate: Training learning rate
|
|
138
|
+
num_epochs: Number of training epochs
|
|
139
|
+
batch_size: Training batch size
|
|
140
|
+
max_length: Maximum sequence length
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Path to the saved ReFT intervention directory
|
|
144
|
+
"""
|
|
145
|
+
import transformers
|
|
146
|
+
import pyreft
|
|
147
|
+
|
|
148
|
+
output_path = Path(output_path)
|
|
149
|
+
|
|
150
|
+
# Step 1: Generate contrastive pairs
|
|
151
|
+
print(f"Step 1: Generating training data from task: {task}")
|
|
152
|
+
pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
|
|
153
|
+
print(f" Loaded {len(pairs)} training examples")
|
|
154
|
+
|
|
155
|
+
# Step 2: Load model and tokenizer
|
|
156
|
+
print(f"\nStep 2: Loading model {model_name}...")
|
|
157
|
+
dtype = preferred_dtype(device)
|
|
158
|
+
|
|
159
|
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
160
|
+
model_name,
|
|
161
|
+
torch_dtype=dtype,
|
|
162
|
+
device_map=device,
|
|
163
|
+
trust_remote_code=True,
|
|
164
|
+
)
|
|
165
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
166
|
+
model_name,
|
|
167
|
+
trust_remote_code=True,
|
|
168
|
+
)
|
|
169
|
+
if tokenizer.pad_token is None:
|
|
170
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
171
|
+
|
|
172
|
+
# Step 3: Parse intervention layers
|
|
173
|
+
if intervention_layers is None:
|
|
174
|
+
layer_indices = [get_default_layer(model_name)]
|
|
175
|
+
else:
|
|
176
|
+
layer_indices = [int(l.strip()) for l in intervention_layers.split(",")]
|
|
177
|
+
|
|
178
|
+
print(f"\nStep 3: Configuring LoReFT (rank={low_rank_dimension}, layers={layer_indices})...")
|
|
179
|
+
|
|
180
|
+
# Step 4: Create ReFT config and model
|
|
181
|
+
# Get hidden size from model config
|
|
182
|
+
hidden_size = model.config.hidden_size
|
|
183
|
+
|
|
184
|
+
# Create interventions for each layer
|
|
185
|
+
representations = []
|
|
186
|
+
for layer_idx in layer_indices:
|
|
187
|
+
representations.append({
|
|
188
|
+
"layer": layer_idx,
|
|
189
|
+
"component": "block_output",
|
|
190
|
+
"low_rank_dimension": low_rank_dimension,
|
|
191
|
+
"intervention": pyreft.LoreftIntervention(
|
|
192
|
+
embed_dim=hidden_size,
|
|
193
|
+
low_rank_dimension=low_rank_dimension,
|
|
194
|
+
),
|
|
195
|
+
})
|
|
196
|
+
|
|
197
|
+
reft_config = pyreft.ReftConfig(representations=representations)
|
|
198
|
+
reft_model = pyreft.get_reft_model(model, reft_config)
|
|
199
|
+
reft_model.set_device(device)
|
|
200
|
+
reft_model.print_trainable_parameters()
|
|
201
|
+
|
|
202
|
+
# Step 5: Prepare dataset
|
|
203
|
+
print(f"\nStep 5: Preparing ReFT dataset...")
|
|
204
|
+
prompts, responses = prepare_reft_dataset(pairs, tokenizer, max_length=max_length)
|
|
205
|
+
print(f" Dataset size: {len(prompts)} examples")
|
|
206
|
+
|
|
207
|
+
# Create data module for ReFT training
|
|
208
|
+
# ReFT expects data in specific format with intervention positions
|
|
209
|
+
data_module = pyreft.make_last_position_supervised_data_module(
|
|
210
|
+
tokenizer=tokenizer,
|
|
211
|
+
model=model,
|
|
212
|
+
inputs=prompts,
|
|
213
|
+
outputs=responses,
|
|
214
|
+
max_length=max_length,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Step 6: Training
|
|
218
|
+
print(f"\nStep 6: Training LoReFT intervention...")
|
|
219
|
+
|
|
220
|
+
training_output_dir = tempfile.mkdtemp(prefix="reft_training_")
|
|
221
|
+
|
|
222
|
+
training_args = transformers.TrainingArguments(
|
|
223
|
+
output_dir=training_output_dir,
|
|
224
|
+
num_train_epochs=num_epochs,
|
|
225
|
+
per_device_train_batch_size=batch_size,
|
|
226
|
+
gradient_accumulation_steps=1,
|
|
227
|
+
learning_rate=learning_rate,
|
|
228
|
+
weight_decay=0.01,
|
|
229
|
+
warmup_ratio=0.1,
|
|
230
|
+
logging_steps=10,
|
|
231
|
+
save_strategy="no",
|
|
232
|
+
bf16=(dtype == torch.bfloat16),
|
|
233
|
+
fp16=(dtype == torch.float16),
|
|
234
|
+
report_to="none",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
trainer = pyreft.ReftTrainerForCausalLM(
|
|
238
|
+
model=reft_model,
|
|
239
|
+
tokenizer=tokenizer,
|
|
240
|
+
args=training_args,
|
|
241
|
+
**data_module,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
trainer.train()
|
|
245
|
+
|
|
246
|
+
# Step 7: Save ReFT intervention
|
|
247
|
+
print(f"\nStep 7: Saving ReFT intervention to {output_path}...")
|
|
248
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
249
|
+
reft_model.save_pretrained(output_path)
|
|
250
|
+
tokenizer.save_pretrained(output_path)
|
|
251
|
+
|
|
252
|
+
# Save metadata
|
|
253
|
+
metadata = {
|
|
254
|
+
"method": "reft",
|
|
255
|
+
"model": model_name,
|
|
256
|
+
"task": task,
|
|
257
|
+
"trait_label": trait_label,
|
|
258
|
+
"num_pairs": len(pairs),
|
|
259
|
+
"reft_config": {
|
|
260
|
+
"low_rank_dimension": low_rank_dimension,
|
|
261
|
+
"intervention_layers": layer_indices,
|
|
262
|
+
"component": "block_output",
|
|
263
|
+
},
|
|
264
|
+
"training_config": {
|
|
265
|
+
"learning_rate": learning_rate,
|
|
266
|
+
"num_epochs": num_epochs,
|
|
267
|
+
"batch_size": batch_size,
|
|
268
|
+
"max_length": max_length,
|
|
269
|
+
},
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
with open(output_path / "metadata.json", "w") as f:
|
|
273
|
+
json.dump(metadata, f, indent=2)
|
|
274
|
+
|
|
275
|
+
# Cleanup
|
|
276
|
+
del reft_model, trainer, model
|
|
277
|
+
gc.collect()
|
|
278
|
+
if torch.cuda.is_available():
|
|
279
|
+
torch.cuda.empty_cache()
|
|
280
|
+
torch.cuda.synchronize()
|
|
281
|
+
|
|
282
|
+
if not keep_intermediate:
|
|
283
|
+
import os
|
|
284
|
+
os.unlink(pairs_file)
|
|
285
|
+
import shutil
|
|
286
|
+
shutil.rmtree(training_output_dir, ignore_errors=True)
|
|
287
|
+
|
|
288
|
+
print(f"\nReFT intervention saved to {output_path}")
|
|
289
|
+
return output_path
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def apply_reft_to_model(wisent_model: "WisentModel", reft_path: str | Path) -> None:
|
|
293
|
+
"""
|
|
294
|
+
Apply a trained ReFT intervention to a WisentModel.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
wisent_model: WisentModel instance
|
|
298
|
+
reft_path: Path to the saved ReFT intervention
|
|
299
|
+
"""
|
|
300
|
+
import pyreft
|
|
301
|
+
|
|
302
|
+
reft_path = Path(reft_path)
|
|
303
|
+
|
|
304
|
+
# Load ReFT model wrapping the existing model
|
|
305
|
+
reft_model = pyreft.ReftModel.load(
|
|
306
|
+
str(reft_path),
|
|
307
|
+
wisent_model.hf_model,
|
|
308
|
+
)
|
|
309
|
+
reft_model.set_device(wisent_model.device)
|
|
310
|
+
|
|
311
|
+
# Store original model and replace with ReFT model
|
|
312
|
+
wisent_model._original_model = wisent_model.hf_model
|
|
313
|
+
wisent_model.hf_model = reft_model
|
|
314
|
+
|
|
315
|
+
print(f"ReFT intervention loaded from {reft_path}")
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def remove_reft(wisent_model: "WisentModel") -> None:
|
|
319
|
+
"""
|
|
320
|
+
Remove/disable ReFT intervention from a WisentModel.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
wisent_model: WisentModel instance with ReFT applied
|
|
324
|
+
"""
|
|
325
|
+
if hasattr(wisent_model, '_original_model'):
|
|
326
|
+
wisent_model.hf_model = wisent_model._original_model
|
|
327
|
+
del wisent_model._original_model
|
|
328
|
+
print("ReFT intervention removed")
|
|
329
|
+
else:
|
|
330
|
+
print("No ReFT intervention to remove")
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def evaluate_reft(
|
|
334
|
+
model_name: str,
|
|
335
|
+
reft_path: str | Path,
|
|
336
|
+
task: str,
|
|
337
|
+
train_ratio: float = 0.8,
|
|
338
|
+
device: str = "cuda:0",
|
|
339
|
+
batch_size: int = 1,
|
|
340
|
+
max_batch_size: int = 8,
|
|
341
|
+
limit: int | None = None,
|
|
342
|
+
output_dir: str | Path = None,
|
|
343
|
+
# Training metadata (for output)
|
|
344
|
+
num_train_pairs: int | None = None,
|
|
345
|
+
num_epochs: int | None = None,
|
|
346
|
+
low_rank_dimension: int | None = None,
|
|
347
|
+
intervention_layers: list[int] | None = None,
|
|
348
|
+
learning_rate: float | None = None,
|
|
349
|
+
# Steering parameters (optional)
|
|
350
|
+
with_steering: bool = False,
|
|
351
|
+
steering_method: str = "caa",
|
|
352
|
+
steering_layers: str = "12",
|
|
353
|
+
steering_num_pairs: int = 50,
|
|
354
|
+
steering_scales: list[float] | None = None,
|
|
355
|
+
extraction_strategy: str = "mc_completion",
|
|
356
|
+
) -> dict:
|
|
357
|
+
"""
|
|
358
|
+
Evaluate a trained ReFT intervention comparing base vs ReFT performance.
|
|
359
|
+
|
|
360
|
+
Optionally also evaluates ReFT + steering at multiple scales.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
model_name: HuggingFace model name
|
|
364
|
+
reft_path: Path to trained ReFT intervention
|
|
365
|
+
task: lm-eval task name
|
|
366
|
+
train_ratio: Train/test split ratio
|
|
367
|
+
device: Device to run on
|
|
368
|
+
batch_size: Batch size for evaluation
|
|
369
|
+
max_batch_size: Max batch size
|
|
370
|
+
limit: Limit number of eval examples
|
|
371
|
+
output_dir: Where to save results
|
|
372
|
+
with_steering: Whether to also evaluate ReFT + steering
|
|
373
|
+
steering_method: Steering method (caa or fgaa)
|
|
374
|
+
steering_layers: Layers for steering vector
|
|
375
|
+
steering_num_pairs: Number of pairs for steering generation
|
|
376
|
+
steering_scales: List of steering scales to evaluate
|
|
377
|
+
extraction_strategy: Strategy for activation extraction
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Dict with evaluation results
|
|
381
|
+
"""
|
|
382
|
+
import pyreft
|
|
383
|
+
import transformers
|
|
384
|
+
|
|
385
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
386
|
+
|
|
387
|
+
reft_path = Path(reft_path)
|
|
388
|
+
|
|
389
|
+
if steering_scales is None:
|
|
390
|
+
steering_scales = [1.0, 2.0, 4.0]
|
|
391
|
+
|
|
392
|
+
# Create test task
|
|
393
|
+
print(f"\n{'='*60}")
|
|
394
|
+
print(f"Creating test task for: {task}")
|
|
395
|
+
print(f"{'='*60}")
|
|
396
|
+
|
|
397
|
+
task_dict = create_test_only_task(task, train_ratio=train_ratio)
|
|
398
|
+
|
|
399
|
+
# Load model
|
|
400
|
+
print(f"\n{'='*60}")
|
|
401
|
+
print(f"Loading model: {model_name}")
|
|
402
|
+
print(f"{'='*60}")
|
|
403
|
+
wisent_model = WisentModel(model_name=model_name, device=device)
|
|
404
|
+
|
|
405
|
+
# BASE evaluation
|
|
406
|
+
print(f"\n{'='*60}")
|
|
407
|
+
print(f"Running BASE evaluation (no ReFT)")
|
|
408
|
+
print(f"{'='*60}")
|
|
409
|
+
|
|
410
|
+
base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
411
|
+
base_acc_lm_eval = extract_accuracy(base_results, task)
|
|
412
|
+
print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
|
|
413
|
+
|
|
414
|
+
base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
415
|
+
print(f"Base accuracy (LL): {base_acc_ll:.4f}")
|
|
416
|
+
|
|
417
|
+
# Apply ReFT
|
|
418
|
+
print(f"\n{'='*60}")
|
|
419
|
+
print(f"Applying ReFT intervention from: {reft_path}")
|
|
420
|
+
print(f"{'='*60}")
|
|
421
|
+
apply_reft_to_model(wisent_model, reft_path)
|
|
422
|
+
|
|
423
|
+
# REFT evaluation
|
|
424
|
+
print(f"\n{'='*60}")
|
|
425
|
+
print(f"Running REFT evaluation")
|
|
426
|
+
print(f"{'='*60}")
|
|
427
|
+
|
|
428
|
+
reft_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
429
|
+
reft_acc_lm_eval = extract_accuracy(reft_results, task)
|
|
430
|
+
print(f"ReFT accuracy (lm-eval): {reft_acc_lm_eval:.4f}")
|
|
431
|
+
|
|
432
|
+
reft_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
433
|
+
print(f"ReFT accuracy (LL): {reft_acc_ll:.4f}")
|
|
434
|
+
|
|
435
|
+
# Results dict
|
|
436
|
+
results = {
|
|
437
|
+
"task": task,
|
|
438
|
+
"model": model_name,
|
|
439
|
+
"reft_path": str(reft_path),
|
|
440
|
+
# Training config
|
|
441
|
+
"num_train_pairs": num_train_pairs,
|
|
442
|
+
"num_epochs": num_epochs,
|
|
443
|
+
"low_rank_dimension": low_rank_dimension,
|
|
444
|
+
"intervention_layers": intervention_layers,
|
|
445
|
+
"learning_rate": learning_rate,
|
|
446
|
+
# Eval config
|
|
447
|
+
"train_ratio": train_ratio,
|
|
448
|
+
"eval_limit": limit,
|
|
449
|
+
# Results
|
|
450
|
+
"base_accuracy_lm_eval": base_acc_lm_eval,
|
|
451
|
+
"base_accuracy_ll": base_acc_ll,
|
|
452
|
+
"reft_accuracy_lm_eval": reft_acc_lm_eval,
|
|
453
|
+
"reft_accuracy_ll": reft_acc_ll,
|
|
454
|
+
"reft_diff_lm_eval": reft_acc_lm_eval - base_acc_lm_eval,
|
|
455
|
+
"reft_diff_ll": reft_acc_ll - base_acc_ll,
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
# ReFT + Steering evaluation (if enabled)
|
|
459
|
+
if with_steering:
|
|
460
|
+
from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
|
|
461
|
+
from wisent.core.steering_methods import get_steering_method
|
|
462
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
463
|
+
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
464
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
465
|
+
from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
|
|
466
|
+
|
|
467
|
+
# Generate contrastive pairs for steering
|
|
468
|
+
print(f"\n{'='*60}")
|
|
469
|
+
print(f"Generating {steering_num_pairs} contrastive pairs for steering")
|
|
470
|
+
print(f"{'='*60}")
|
|
471
|
+
pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
|
|
472
|
+
|
|
473
|
+
# Convert to ContrastivePairSet
|
|
474
|
+
pairs = []
|
|
475
|
+
for p in pairs_data:
|
|
476
|
+
pair = ContrastivePair(
|
|
477
|
+
prompt=p["prompt"],
|
|
478
|
+
positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
|
|
479
|
+
negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
|
|
480
|
+
)
|
|
481
|
+
pairs.append(pair)
|
|
482
|
+
pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_reft_steering")
|
|
483
|
+
print(f"Created {len(pair_set)} contrastive pairs")
|
|
484
|
+
|
|
485
|
+
# Generate steering vector on ReFT model
|
|
486
|
+
print(f"\n{'='*60}")
|
|
487
|
+
print(f"Generating {steering_method.upper()} steering vector on ReFT model")
|
|
488
|
+
print(f"Layers: {steering_layers}")
|
|
489
|
+
print(f"{'='*60}")
|
|
490
|
+
|
|
491
|
+
steering_method_obj = get_steering_method(steering_method, device=device)
|
|
492
|
+
strategy = ExtractionStrategy(extraction_strategy)
|
|
493
|
+
|
|
494
|
+
trainer = WisentSteeringTrainer(
|
|
495
|
+
model=wisent_model,
|
|
496
|
+
pair_set=pair_set,
|
|
497
|
+
steering_method=steering_method_obj,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
result = trainer.run(
|
|
501
|
+
layers_spec=steering_layers,
|
|
502
|
+
strategy=strategy,
|
|
503
|
+
accept_low_quality_vector=True,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Convert to dict format for apply_steering_to_model
|
|
507
|
+
steering_vectors = {}
|
|
508
|
+
for layer_name, tensor in result.steered_vectors.to_dict().items():
|
|
509
|
+
if tensor is not None:
|
|
510
|
+
steering_vectors[layer_name] = tensor.cpu().float().tolist()
|
|
511
|
+
|
|
512
|
+
steering_data = {
|
|
513
|
+
"steering_vectors": steering_vectors,
|
|
514
|
+
"layers": list(steering_vectors.keys()),
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
# Cleanup temp file
|
|
518
|
+
import os
|
|
519
|
+
os.unlink(pairs_file)
|
|
520
|
+
|
|
521
|
+
# Add steering info to results
|
|
522
|
+
results["steering"] = {
|
|
523
|
+
"method": steering_method,
|
|
524
|
+
"layers": list(steering_vectors.keys()),
|
|
525
|
+
"num_pairs": steering_num_pairs,
|
|
526
|
+
"extraction_strategy": extraction_strategy,
|
|
527
|
+
"scales": {},
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
# Evaluate at each scale
|
|
531
|
+
for scale in steering_scales:
|
|
532
|
+
print(f"\n{'='*60}")
|
|
533
|
+
print(f"Evaluating ReFT+{steering_method.upper()} at scale={scale}")
|
|
534
|
+
print(f"{'='*60}")
|
|
535
|
+
|
|
536
|
+
apply_steering_to_model(wisent_model, steering_data, scale=scale)
|
|
537
|
+
|
|
538
|
+
steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
|
|
539
|
+
steer_acc_lm_eval = extract_accuracy(steer_results, task)
|
|
540
|
+
print(f"ReFT+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
|
|
541
|
+
|
|
542
|
+
steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
|
|
543
|
+
print(f"ReFT+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
|
|
544
|
+
|
|
545
|
+
remove_steering(wisent_model)
|
|
546
|
+
|
|
547
|
+
results["steering"]["scales"][str(scale)] = {
|
|
548
|
+
"accuracy_lm_eval": steer_acc_lm_eval,
|
|
549
|
+
"accuracy_ll": steer_acc_ll,
|
|
550
|
+
"diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
|
|
551
|
+
"diff_from_base_ll": steer_acc_ll - base_acc_ll,
|
|
552
|
+
"diff_from_reft_lm_eval": steer_acc_lm_eval - reft_acc_lm_eval,
|
|
553
|
+
"diff_from_reft_ll": steer_acc_ll - reft_acc_ll,
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
# Cleanup
|
|
557
|
+
remove_reft(wisent_model)
|
|
558
|
+
del wisent_model
|
|
559
|
+
gc.collect()
|
|
560
|
+
if torch.cuda.is_available():
|
|
561
|
+
torch.cuda.empty_cache()
|
|
562
|
+
|
|
563
|
+
# Print summary
|
|
564
|
+
print(f"\n{'='*70}")
|
|
565
|
+
print(f"RESULTS SUMMARY")
|
|
566
|
+
print(f"{'='*70}")
|
|
567
|
+
print(f"Task: {task}")
|
|
568
|
+
print(f"Model: {model_name}")
|
|
569
|
+
print(f"ReFT: {reft_path}")
|
|
570
|
+
print(f"{'-'*70}")
|
|
571
|
+
print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
|
|
572
|
+
print(f"{'-'*70}")
|
|
573
|
+
print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
|
|
574
|
+
print(f"{'ReFT':<25} {reft_acc_lm_eval:<15.4f} {reft_acc_ll:<15.4f} {reft_acc_lm_eval - base_acc_lm_eval:+.4f}")
|
|
575
|
+
|
|
576
|
+
if with_steering:
|
|
577
|
+
for scale, res in results["steering"]["scales"].items():
|
|
578
|
+
label = f"ReFT+{steering_method.upper()}@{scale}"
|
|
579
|
+
print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
|
|
580
|
+
|
|
581
|
+
print(f"{'='*70}")
|
|
582
|
+
|
|
583
|
+
# Save results
|
|
584
|
+
if output_dir:
|
|
585
|
+
output_dir = Path(output_dir)
|
|
586
|
+
model_dir_name = model_name.replace("/", "_")
|
|
587
|
+
output_dir = output_dir / model_dir_name
|
|
588
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
589
|
+
results_file = output_dir / f"{task}_reft_eval_results.json"
|
|
590
|
+
with open(results_file, "w") as f:
|
|
591
|
+
json.dump(results, f, indent=2)
|
|
592
|
+
print(f"\nResults saved to: {results_file}")
|
|
593
|
+
|
|
594
|
+
return results
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def main():
|
|
598
|
+
import argparse
|
|
599
|
+
|
|
600
|
+
parser = argparse.ArgumentParser(description="Train and evaluate ReFT intervention on benchmark task")
|
|
601
|
+
parser.add_argument("--model", required=True, help="HuggingFace model name")
|
|
602
|
+
parser.add_argument("--task", default="boolq", help="lm-eval task name")
|
|
603
|
+
parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
|
|
604
|
+
parser.add_argument("--num-pairs", type=int, default=50, help="Number of training examples")
|
|
605
|
+
parser.add_argument("--device", default="cuda:0", help="Device")
|
|
606
|
+
parser.add_argument("--low-rank-dimension", type=int, default=4, help="LoReFT rank (default: 4)")
|
|
607
|
+
parser.add_argument("--intervention-layers", default=None, help="Comma-separated intervention layers (default: auto)")
|
|
608
|
+
parser.add_argument("--learning-rate", type=float, default=5e-4, help="Learning rate")
|
|
609
|
+
parser.add_argument("--num-epochs", type=int, default=3, help="Number of epochs")
|
|
610
|
+
parser.add_argument("--batch-size", type=int, default=2, help="Training batch size")
|
|
611
|
+
parser.add_argument("--max-length", type=int, default=512, help="Max sequence length")
|
|
612
|
+
parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
|
|
613
|
+
# Eval args
|
|
614
|
+
parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
|
|
615
|
+
parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size (int or 'auto')")
|
|
616
|
+
parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size for auto")
|
|
617
|
+
parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
|
|
618
|
+
parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
|
|
619
|
+
# ReFT + Steering args
|
|
620
|
+
parser.add_argument("--with-steering", action="store_true", help="Also evaluate ReFT + steering")
|
|
621
|
+
parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
|
|
622
|
+
parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
|
|
623
|
+
parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
|
|
624
|
+
parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
|
|
625
|
+
parser.add_argument("--extraction-strategy", default="mc_completion", help="Extraction strategy for steering")
|
|
626
|
+
|
|
627
|
+
args = parser.parse_args()
|
|
628
|
+
|
|
629
|
+
output_path = Path(args.output_dir) / f"{args.task}_reft_intervention"
|
|
630
|
+
|
|
631
|
+
# Parse intervention layers for metadata
|
|
632
|
+
if args.intervention_layers:
|
|
633
|
+
intervention_layers = [int(l.strip()) for l in args.intervention_layers.split(",")]
|
|
634
|
+
else:
|
|
635
|
+
intervention_layers = [get_default_layer(args.model)]
|
|
636
|
+
|
|
637
|
+
# Train
|
|
638
|
+
train_reft_adapter(
|
|
639
|
+
task=args.task,
|
|
640
|
+
model_name=args.model,
|
|
641
|
+
output_path=output_path,
|
|
642
|
+
num_pairs=args.num_pairs,
|
|
643
|
+
device=args.device,
|
|
644
|
+
keep_intermediate=args.keep_intermediate,
|
|
645
|
+
low_rank_dimension=args.low_rank_dimension,
|
|
646
|
+
intervention_layers=args.intervention_layers,
|
|
647
|
+
learning_rate=args.learning_rate,
|
|
648
|
+
num_epochs=args.num_epochs,
|
|
649
|
+
batch_size=args.batch_size,
|
|
650
|
+
max_length=args.max_length,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Evaluate base vs ReFT (and optionally ReFT + steering)
|
|
654
|
+
if not args.skip_eval:
|
|
655
|
+
# Parse eval batch size (can be "auto" or int)
|
|
656
|
+
eval_batch_size = args.eval_batch_size
|
|
657
|
+
if eval_batch_size != "auto":
|
|
658
|
+
eval_batch_size = int(eval_batch_size)
|
|
659
|
+
|
|
660
|
+
# Parse steering scales
|
|
661
|
+
steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
|
|
662
|
+
|
|
663
|
+
evaluate_reft(
|
|
664
|
+
model_name=args.model,
|
|
665
|
+
reft_path=output_path,
|
|
666
|
+
task=args.task,
|
|
667
|
+
train_ratio=args.train_ratio,
|
|
668
|
+
device=args.device,
|
|
669
|
+
batch_size=eval_batch_size,
|
|
670
|
+
max_batch_size=args.eval_max_batch_size,
|
|
671
|
+
limit=args.eval_limit,
|
|
672
|
+
output_dir=args.output_dir,
|
|
673
|
+
# Training metadata
|
|
674
|
+
num_train_pairs=args.num_pairs,
|
|
675
|
+
num_epochs=args.num_epochs,
|
|
676
|
+
low_rank_dimension=args.low_rank_dimension,
|
|
677
|
+
intervention_layers=intervention_layers,
|
|
678
|
+
learning_rate=args.learning_rate,
|
|
679
|
+
# Steering parameters
|
|
680
|
+
with_steering=args.with_steering,
|
|
681
|
+
steering_method=args.steering_method,
|
|
682
|
+
steering_layers=args.steering_layers,
|
|
683
|
+
steering_num_pairs=args.steering_num_pairs,
|
|
684
|
+
steering_scales=steering_scales,
|
|
685
|
+
extraction_strategy=args.extraction_strategy,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
if __name__ == "__main__":
|
|
690
|
+
main()
|