wisent 0.7.1045__py3-none-any.whl → 0.7.1116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wisent/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.7.1045"
1
+ __version__ = "0.7.1116"
2
2
 
3
3
  from wisent.core.diversity_processors import (
4
4
  OpenerPenaltyProcessor,
wisent/comparison/lora.py CHANGED
@@ -62,20 +62,17 @@ def get_target_modules(model_name: str) -> str | list[str]:
62
62
  return LORA_TARGET_MODULES["default"]
63
63
 
64
64
 
65
- def prepare_sft_dataset(
66
- pairs: list[dict],
67
- tokenizer,
68
- max_length: int = 512,
69
- ) -> Dataset:
65
+ def prepare_sft_dataset(pairs: list[dict], tokenizer) -> Dataset:
70
66
  """
71
67
  Prepare dataset for SFT from contrastive pairs.
72
68
 
73
69
  Uses only positive responses for training.
70
+ - Chat models: returns conversational format, SFTTrainer auto-applies chat template
71
+ - Base models: returns text format with simple formatting
74
72
 
75
73
  Args:
76
74
  pairs: List of contrastive pairs
77
- tokenizer: Tokenizer for formatting
78
- max_length: Maximum sequence length
75
+ tokenizer: Tokenizer to check for chat template support
79
76
 
80
77
  Returns:
81
78
  HuggingFace Dataset ready for SFTTrainer
@@ -84,24 +81,21 @@ def prepare_sft_dataset(
84
81
 
85
82
  for pair in pairs:
86
83
  prompt = pair["prompt"]
87
- positive_response = pair["positive_response"]["model_response"]
88
-
89
- # Format as chat if tokenizer supports it, otherwise simple format
90
- if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
91
- messages = [
92
- {"role": "user", "content": prompt},
93
- {"role": "assistant", "content": positive_response},
94
- ]
95
- text = tokenizer.apply_chat_template(
96
- messages,
97
- tokenize=False,
98
- add_generation_prompt=False,
99
- )
84
+ response = pair["positive_response"]["model_response"]
85
+
86
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
87
+ # Chat model: use conversational format, trainer applies template
88
+ formatted_examples.append({
89
+ "messages": [
90
+ {"role": "user", "content": prompt},
91
+ {"role": "assistant", "content": response},
92
+ ]
93
+ })
100
94
  else:
101
- # Simple format for base models
102
- text = f"Q: {prompt}\nA: {positive_response}"
103
-
104
- formatted_examples.append({"text": text})
95
+ # Base model: use simple text format
96
+ formatted_examples.append({
97
+ "text": f"{prompt}\n{response}"
98
+ })
105
99
 
106
100
  return Dataset.from_list(formatted_examples)
107
101
 
@@ -178,7 +172,7 @@ def train_lora_adapter(
178
172
 
179
173
  # Step 4: Prepare dataset
180
174
  print(f"\nStep 4: Preparing SFT dataset...")
181
- train_dataset = prepare_sft_dataset(pairs, tokenizer, max_length=max_length)
175
+ train_dataset = prepare_sft_dataset(pairs, tokenizer)
182
176
  print(f" Dataset size: {len(train_dataset)} examples")
183
177
 
184
178
  # Step 5: Training
@@ -203,7 +197,7 @@ def train_lora_adapter(
203
197
  bf16=(dtype == torch.bfloat16),
204
198
  fp16=(dtype == torch.float16),
205
199
  report_to="none", # Disable wandb/tensorboard
206
- dataset_text_field="text", # Field containing the text to train on
200
+ max_seq_length=max_length,
207
201
  )
208
202
 
209
203
  trainer = SFTTrainer(
@@ -35,14 +35,19 @@ if TYPE_CHECKING:
35
35
  from wisent.core.models.wisent_model import WisentModel
36
36
 
37
37
 
38
- def create_dpo_dataset(pairs: list[dict]) -> Dataset:
38
+ def create_dpo_dataset(pairs: list[dict], tokenizer) -> Dataset:
39
39
  """
40
40
  Convert contrastive pairs to DPO dataset format.
41
41
 
42
- DPO expects:
43
- - prompt: the input prompt
44
- - chosen: the preferred response
45
- - rejected: the non-preferred response
42
+ - Chat models: returns conversational format, DPOTrainer auto-applies chat template
43
+ - Base models: returns text format with simple formatting
44
+
45
+ Args:
46
+ pairs: List of contrastive pairs
47
+ tokenizer: Tokenizer to check for chat template support
48
+
49
+ Returns:
50
+ HuggingFace Dataset ready for DPOTrainer
46
51
  """
47
52
  data = {
48
53
  "prompt": [],
@@ -55,9 +60,16 @@ def create_dpo_dataset(pairs: list[dict]) -> Dataset:
55
60
  chosen = pair["positive_response"]["model_response"]
56
61
  rejected = pair["negative_response"]["model_response"]
57
62
 
58
- data["prompt"].append(prompt)
59
- data["chosen"].append(chosen)
60
- data["rejected"].append(rejected)
63
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
64
+ # Chat model: use conversational format, trainer applies template
65
+ data["prompt"].append([{"role": "user", "content": prompt}])
66
+ data["chosen"].append([{"role": "assistant", "content": chosen}])
67
+ data["rejected"].append([{"role": "assistant", "content": rejected}])
68
+ else:
69
+ # Base model: use simple text format
70
+ data["prompt"].append(f"{prompt}\n")
71
+ data["chosen"].append(chosen)
72
+ data["rejected"].append(rejected)
61
73
 
62
74
  return Dataset.from_dict(data)
63
75
 
@@ -112,17 +124,9 @@ def train_lora_dpo(
112
124
  pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
113
125
  print(f"Generated {len(pairs)} preference pairs")
114
126
 
115
- # Step 2: Create DPO dataset
116
- print(f"\n{'='*60}")
117
- print(f"Step 2: Creating DPO dataset")
118
- print(f"{'='*60}")
119
-
120
- dataset = create_dpo_dataset(pairs)
121
- print(f"Dataset size: {len(dataset)}")
122
-
123
- # Step 3: Load model
127
+ # Step 2: Load model
124
128
  print(f"\n{'='*60}")
125
- print(f"Step 3: Loading model {model_name}")
129
+ print(f"Step 2: Loading model {model_name}")
126
130
  print(f"{'='*60}")
127
131
 
128
132
  model, tokenizer = load_model_and_tokenizer(model_name, device, eval_mode=False)
@@ -132,6 +136,14 @@ def train_lora_dpo(
132
136
  tokenizer.pad_token = tokenizer.eos_token
133
137
  tokenizer.padding_side = "left" # DPO typically uses left padding
134
138
 
139
+ # Step 3: Create DPO dataset
140
+ print(f"\n{'='*60}")
141
+ print(f"Step 3: Creating DPO dataset")
142
+ print(f"{'='*60}")
143
+
144
+ dataset = create_dpo_dataset(pairs, tokenizer)
145
+ print(f"Dataset size: {len(dataset)}")
146
+
135
147
  # Step 4: Configure LoRA
136
148
  print(f"\n{'='*60}")
137
149
  print(f"Step 4: Configuring LoRA (r={lora_r}, alpha={lora_alpha})")
@@ -0,0 +1,690 @@
1
+ """
2
+ ReFT (Representation Fine-Tuning) method for comparison experiments.
3
+
4
+ Trains a LoReFT intervention on benchmark tasks using supervised fine-tuning (SFT)
5
+ on positive responses from contrastive pairs.
6
+
7
+ LoReFT operates on hidden representations rather than weights, making it
8
+ 10-50x more parameter-efficient than LoRA.
9
+
10
+ Based on: "ReFT: Representation Finetuning for Language Models" (arXiv:2404.03592)
11
+ Uses pyreft library from Stanford NLP.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import gc
17
+ import json
18
+ import tempfile
19
+ from pathlib import Path
20
+ from typing import TYPE_CHECKING
21
+
22
+ import torch
23
+ from datasets import Dataset
24
+
25
+ from wisent.comparison.utils import (
26
+ generate_contrastive_pairs,
27
+ create_test_only_task,
28
+ extract_accuracy,
29
+ run_lm_eval_evaluation,
30
+ run_ll_evaluation,
31
+ load_model_and_tokenizer,
32
+ apply_steering_to_model,
33
+ remove_steering,
34
+ )
35
+ from wisent.core.utils.device import preferred_dtype
36
+
37
+ if TYPE_CHECKING:
38
+ from wisent.core.models.wisent_model import WisentModel
39
+
40
+ __all__ = ["train_reft_adapter", "evaluate_reft", "apply_reft_to_model", "remove_reft"]
41
+
42
+
43
+ # Default intervention layers per model (middle layer)
44
+ DEFAULT_INTERVENTION_LAYERS = {
45
+ "gemma": 21, # gemma-2-9b has 42 layers
46
+ "llama": 16, # llama-3.1-8b has 32 layers
47
+ "mistral": 16, # mistral-7b has 32 layers
48
+ "phi": 16,
49
+ "default": 12,
50
+ }
51
+
52
+
53
+ def get_default_layer(model_name: str) -> int:
54
+ """Get default intervention layer based on model architecture."""
55
+ model_name_lower = model_name.lower()
56
+
57
+ for arch, layer in DEFAULT_INTERVENTION_LAYERS.items():
58
+ if arch in model_name_lower:
59
+ return layer
60
+
61
+ return DEFAULT_INTERVENTION_LAYERS["default"]
62
+
63
+
64
+ def prepare_reft_dataset(
65
+ pairs: list[dict],
66
+ tokenizer,
67
+ max_length: int = 512,
68
+ ) -> tuple[list[str], list[str]]:
69
+ """
70
+ Prepare dataset for ReFT training from contrastive pairs.
71
+
72
+ Uses only positive responses for training.
73
+
74
+ Args:
75
+ pairs: List of contrastive pairs
76
+ tokenizer: Tokenizer for formatting
77
+ max_length: Maximum sequence length
78
+
79
+ Returns:
80
+ Tuple of (prompts, responses) lists
81
+ """
82
+ prompts = []
83
+ responses = []
84
+
85
+ for pair in pairs:
86
+ prompt = pair["prompt"]
87
+ positive_response = pair["positive_response"]["model_response"]
88
+
89
+ # Format as chat if tokenizer supports it
90
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
91
+ # For chat models, format as conversation
92
+ messages = [{"role": "user", "content": prompt}]
93
+ formatted_prompt = tokenizer.apply_chat_template(
94
+ messages,
95
+ tokenize=False,
96
+ add_generation_prompt=True,
97
+ )
98
+ else:
99
+ # Simple format for base models
100
+ formatted_prompt = f"{prompt}\n"
101
+
102
+ prompts.append(formatted_prompt)
103
+ responses.append(positive_response)
104
+
105
+ return prompts, responses
106
+
107
+
108
+ def train_reft_adapter(
109
+ task: str,
110
+ model_name: str,
111
+ output_path: str | Path,
112
+ trait_label: str = "correctness",
113
+ num_pairs: int = 50,
114
+ device: str = "cuda:0",
115
+ keep_intermediate: bool = False,
116
+ # ReFT-specific parameters
117
+ low_rank_dimension: int = 4,
118
+ intervention_layers: str | None = None,
119
+ learning_rate: float = 5e-4,
120
+ num_epochs: int = 3,
121
+ batch_size: int = 2,
122
+ max_length: int = 512,
123
+ ) -> Path:
124
+ """
125
+ Train a LoReFT intervention using SFT on positive responses.
126
+
127
+ Args:
128
+ task: lm-eval task name (e.g., 'boolq', 'cb')
129
+ model_name: HuggingFace model name
130
+ output_path: Where to save the ReFT intervention
131
+ trait_label: Label for the trait being trained
132
+ num_pairs: Number of training examples to use
133
+ device: Device to train on
134
+ keep_intermediate: Whether to keep intermediate files
135
+ low_rank_dimension: Rank for LoReFT (default: 4, very small!)
136
+ intervention_layers: Comma-separated layers or None for default
137
+ learning_rate: Training learning rate
138
+ num_epochs: Number of training epochs
139
+ batch_size: Training batch size
140
+ max_length: Maximum sequence length
141
+
142
+ Returns:
143
+ Path to the saved ReFT intervention directory
144
+ """
145
+ import transformers
146
+ import pyreft
147
+
148
+ output_path = Path(output_path)
149
+
150
+ # Step 1: Generate contrastive pairs
151
+ print(f"Step 1: Generating training data from task: {task}")
152
+ pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
153
+ print(f" Loaded {len(pairs)} training examples")
154
+
155
+ # Step 2: Load model and tokenizer
156
+ print(f"\nStep 2: Loading model {model_name}...")
157
+ dtype = preferred_dtype(device)
158
+
159
+ model = transformers.AutoModelForCausalLM.from_pretrained(
160
+ model_name,
161
+ torch_dtype=dtype,
162
+ device_map=device,
163
+ trust_remote_code=True,
164
+ )
165
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
166
+ model_name,
167
+ trust_remote_code=True,
168
+ )
169
+ if tokenizer.pad_token is None:
170
+ tokenizer.pad_token = tokenizer.eos_token
171
+
172
+ # Step 3: Parse intervention layers
173
+ if intervention_layers is None:
174
+ layer_indices = [get_default_layer(model_name)]
175
+ else:
176
+ layer_indices = [int(l.strip()) for l in intervention_layers.split(",")]
177
+
178
+ print(f"\nStep 3: Configuring LoReFT (rank={low_rank_dimension}, layers={layer_indices})...")
179
+
180
+ # Step 4: Create ReFT config and model
181
+ # Get hidden size from model config
182
+ hidden_size = model.config.hidden_size
183
+
184
+ # Create interventions for each layer
185
+ representations = []
186
+ for layer_idx in layer_indices:
187
+ representations.append({
188
+ "layer": layer_idx,
189
+ "component": "block_output",
190
+ "low_rank_dimension": low_rank_dimension,
191
+ "intervention": pyreft.LoreftIntervention(
192
+ embed_dim=hidden_size,
193
+ low_rank_dimension=low_rank_dimension,
194
+ ),
195
+ })
196
+
197
+ reft_config = pyreft.ReftConfig(representations=representations)
198
+ reft_model = pyreft.get_reft_model(model, reft_config)
199
+ reft_model.set_device(device)
200
+ reft_model.print_trainable_parameters()
201
+
202
+ # Step 5: Prepare dataset
203
+ print(f"\nStep 5: Preparing ReFT dataset...")
204
+ prompts, responses = prepare_reft_dataset(pairs, tokenizer, max_length=max_length)
205
+ print(f" Dataset size: {len(prompts)} examples")
206
+
207
+ # Create data module for ReFT training
208
+ # ReFT expects data in specific format with intervention positions
209
+ data_module = pyreft.make_last_position_supervised_data_module(
210
+ tokenizer=tokenizer,
211
+ model=model,
212
+ inputs=prompts,
213
+ outputs=responses,
214
+ max_length=max_length,
215
+ )
216
+
217
+ # Step 6: Training
218
+ print(f"\nStep 6: Training LoReFT intervention...")
219
+
220
+ training_output_dir = tempfile.mkdtemp(prefix="reft_training_")
221
+
222
+ training_args = transformers.TrainingArguments(
223
+ output_dir=training_output_dir,
224
+ num_train_epochs=num_epochs,
225
+ per_device_train_batch_size=batch_size,
226
+ gradient_accumulation_steps=1,
227
+ learning_rate=learning_rate,
228
+ weight_decay=0.01,
229
+ warmup_ratio=0.1,
230
+ logging_steps=10,
231
+ save_strategy="no",
232
+ bf16=(dtype == torch.bfloat16),
233
+ fp16=(dtype == torch.float16),
234
+ report_to="none",
235
+ )
236
+
237
+ trainer = pyreft.ReftTrainerForCausalLM(
238
+ model=reft_model,
239
+ tokenizer=tokenizer,
240
+ args=training_args,
241
+ **data_module,
242
+ )
243
+
244
+ trainer.train()
245
+
246
+ # Step 7: Save ReFT intervention
247
+ print(f"\nStep 7: Saving ReFT intervention to {output_path}...")
248
+ output_path.mkdir(parents=True, exist_ok=True)
249
+ reft_model.save_pretrained(output_path)
250
+ tokenizer.save_pretrained(output_path)
251
+
252
+ # Save metadata
253
+ metadata = {
254
+ "method": "reft",
255
+ "model": model_name,
256
+ "task": task,
257
+ "trait_label": trait_label,
258
+ "num_pairs": len(pairs),
259
+ "reft_config": {
260
+ "low_rank_dimension": low_rank_dimension,
261
+ "intervention_layers": layer_indices,
262
+ "component": "block_output",
263
+ },
264
+ "training_config": {
265
+ "learning_rate": learning_rate,
266
+ "num_epochs": num_epochs,
267
+ "batch_size": batch_size,
268
+ "max_length": max_length,
269
+ },
270
+ }
271
+
272
+ with open(output_path / "metadata.json", "w") as f:
273
+ json.dump(metadata, f, indent=2)
274
+
275
+ # Cleanup
276
+ del reft_model, trainer, model
277
+ gc.collect()
278
+ if torch.cuda.is_available():
279
+ torch.cuda.empty_cache()
280
+ torch.cuda.synchronize()
281
+
282
+ if not keep_intermediate:
283
+ import os
284
+ os.unlink(pairs_file)
285
+ import shutil
286
+ shutil.rmtree(training_output_dir, ignore_errors=True)
287
+
288
+ print(f"\nReFT intervention saved to {output_path}")
289
+ return output_path
290
+
291
+
292
+ def apply_reft_to_model(wisent_model: "WisentModel", reft_path: str | Path) -> None:
293
+ """
294
+ Apply a trained ReFT intervention to a WisentModel.
295
+
296
+ Args:
297
+ wisent_model: WisentModel instance
298
+ reft_path: Path to the saved ReFT intervention
299
+ """
300
+ import pyreft
301
+
302
+ reft_path = Path(reft_path)
303
+
304
+ # Load ReFT model wrapping the existing model
305
+ reft_model = pyreft.ReftModel.load(
306
+ str(reft_path),
307
+ wisent_model.hf_model,
308
+ )
309
+ reft_model.set_device(wisent_model.device)
310
+
311
+ # Store original model and replace with ReFT model
312
+ wisent_model._original_model = wisent_model.hf_model
313
+ wisent_model.hf_model = reft_model
314
+
315
+ print(f"ReFT intervention loaded from {reft_path}")
316
+
317
+
318
+ def remove_reft(wisent_model: "WisentModel") -> None:
319
+ """
320
+ Remove/disable ReFT intervention from a WisentModel.
321
+
322
+ Args:
323
+ wisent_model: WisentModel instance with ReFT applied
324
+ """
325
+ if hasattr(wisent_model, '_original_model'):
326
+ wisent_model.hf_model = wisent_model._original_model
327
+ del wisent_model._original_model
328
+ print("ReFT intervention removed")
329
+ else:
330
+ print("No ReFT intervention to remove")
331
+
332
+
333
+ def evaluate_reft(
334
+ model_name: str,
335
+ reft_path: str | Path,
336
+ task: str,
337
+ train_ratio: float = 0.8,
338
+ device: str = "cuda:0",
339
+ batch_size: int = 1,
340
+ max_batch_size: int = 8,
341
+ limit: int | None = None,
342
+ output_dir: str | Path = None,
343
+ # Training metadata (for output)
344
+ num_train_pairs: int | None = None,
345
+ num_epochs: int | None = None,
346
+ low_rank_dimension: int | None = None,
347
+ intervention_layers: list[int] | None = None,
348
+ learning_rate: float | None = None,
349
+ # Steering parameters (optional)
350
+ with_steering: bool = False,
351
+ steering_method: str = "caa",
352
+ steering_layers: str = "12",
353
+ steering_num_pairs: int = 50,
354
+ steering_scales: list[float] | None = None,
355
+ extraction_strategy: str = "mc_completion",
356
+ ) -> dict:
357
+ """
358
+ Evaluate a trained ReFT intervention comparing base vs ReFT performance.
359
+
360
+ Optionally also evaluates ReFT + steering at multiple scales.
361
+
362
+ Args:
363
+ model_name: HuggingFace model name
364
+ reft_path: Path to trained ReFT intervention
365
+ task: lm-eval task name
366
+ train_ratio: Train/test split ratio
367
+ device: Device to run on
368
+ batch_size: Batch size for evaluation
369
+ max_batch_size: Max batch size
370
+ limit: Limit number of eval examples
371
+ output_dir: Where to save results
372
+ with_steering: Whether to also evaluate ReFT + steering
373
+ steering_method: Steering method (caa or fgaa)
374
+ steering_layers: Layers for steering vector
375
+ steering_num_pairs: Number of pairs for steering generation
376
+ steering_scales: List of steering scales to evaluate
377
+ extraction_strategy: Strategy for activation extraction
378
+
379
+ Returns:
380
+ Dict with evaluation results
381
+ """
382
+ import pyreft
383
+ import transformers
384
+
385
+ from wisent.core.models.wisent_model import WisentModel
386
+
387
+ reft_path = Path(reft_path)
388
+
389
+ if steering_scales is None:
390
+ steering_scales = [1.0, 2.0, 4.0]
391
+
392
+ # Create test task
393
+ print(f"\n{'='*60}")
394
+ print(f"Creating test task for: {task}")
395
+ print(f"{'='*60}")
396
+
397
+ task_dict = create_test_only_task(task, train_ratio=train_ratio)
398
+
399
+ # Load model
400
+ print(f"\n{'='*60}")
401
+ print(f"Loading model: {model_name}")
402
+ print(f"{'='*60}")
403
+ wisent_model = WisentModel(model_name=model_name, device=device)
404
+
405
+ # BASE evaluation
406
+ print(f"\n{'='*60}")
407
+ print(f"Running BASE evaluation (no ReFT)")
408
+ print(f"{'='*60}")
409
+
410
+ base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
411
+ base_acc_lm_eval = extract_accuracy(base_results, task)
412
+ print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
413
+
414
+ base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
415
+ print(f"Base accuracy (LL): {base_acc_ll:.4f}")
416
+
417
+ # Apply ReFT
418
+ print(f"\n{'='*60}")
419
+ print(f"Applying ReFT intervention from: {reft_path}")
420
+ print(f"{'='*60}")
421
+ apply_reft_to_model(wisent_model, reft_path)
422
+
423
+ # REFT evaluation
424
+ print(f"\n{'='*60}")
425
+ print(f"Running REFT evaluation")
426
+ print(f"{'='*60}")
427
+
428
+ reft_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
429
+ reft_acc_lm_eval = extract_accuracy(reft_results, task)
430
+ print(f"ReFT accuracy (lm-eval): {reft_acc_lm_eval:.4f}")
431
+
432
+ reft_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
433
+ print(f"ReFT accuracy (LL): {reft_acc_ll:.4f}")
434
+
435
+ # Results dict
436
+ results = {
437
+ "task": task,
438
+ "model": model_name,
439
+ "reft_path": str(reft_path),
440
+ # Training config
441
+ "num_train_pairs": num_train_pairs,
442
+ "num_epochs": num_epochs,
443
+ "low_rank_dimension": low_rank_dimension,
444
+ "intervention_layers": intervention_layers,
445
+ "learning_rate": learning_rate,
446
+ # Eval config
447
+ "train_ratio": train_ratio,
448
+ "eval_limit": limit,
449
+ # Results
450
+ "base_accuracy_lm_eval": base_acc_lm_eval,
451
+ "base_accuracy_ll": base_acc_ll,
452
+ "reft_accuracy_lm_eval": reft_acc_lm_eval,
453
+ "reft_accuracy_ll": reft_acc_ll,
454
+ "reft_diff_lm_eval": reft_acc_lm_eval - base_acc_lm_eval,
455
+ "reft_diff_ll": reft_acc_ll - base_acc_ll,
456
+ }
457
+
458
+ # ReFT + Steering evaluation (if enabled)
459
+ if with_steering:
460
+ from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
461
+ from wisent.core.steering_methods import get_steering_method
462
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
463
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
464
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
465
+ from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
466
+
467
+ # Generate contrastive pairs for steering
468
+ print(f"\n{'='*60}")
469
+ print(f"Generating {steering_num_pairs} contrastive pairs for steering")
470
+ print(f"{'='*60}")
471
+ pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
472
+
473
+ # Convert to ContrastivePairSet
474
+ pairs = []
475
+ for p in pairs_data:
476
+ pair = ContrastivePair(
477
+ prompt=p["prompt"],
478
+ positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
479
+ negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
480
+ )
481
+ pairs.append(pair)
482
+ pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_reft_steering")
483
+ print(f"Created {len(pair_set)} contrastive pairs")
484
+
485
+ # Generate steering vector on ReFT model
486
+ print(f"\n{'='*60}")
487
+ print(f"Generating {steering_method.upper()} steering vector on ReFT model")
488
+ print(f"Layers: {steering_layers}")
489
+ print(f"{'='*60}")
490
+
491
+ steering_method_obj = get_steering_method(steering_method, device=device)
492
+ strategy = ExtractionStrategy(extraction_strategy)
493
+
494
+ trainer = WisentSteeringTrainer(
495
+ model=wisent_model,
496
+ pair_set=pair_set,
497
+ steering_method=steering_method_obj,
498
+ )
499
+
500
+ result = trainer.run(
501
+ layers_spec=steering_layers,
502
+ strategy=strategy,
503
+ accept_low_quality_vector=True,
504
+ )
505
+
506
+ # Convert to dict format for apply_steering_to_model
507
+ steering_vectors = {}
508
+ for layer_name, tensor in result.steered_vectors.to_dict().items():
509
+ if tensor is not None:
510
+ steering_vectors[layer_name] = tensor.cpu().float().tolist()
511
+
512
+ steering_data = {
513
+ "steering_vectors": steering_vectors,
514
+ "layers": list(steering_vectors.keys()),
515
+ }
516
+
517
+ # Cleanup temp file
518
+ import os
519
+ os.unlink(pairs_file)
520
+
521
+ # Add steering info to results
522
+ results["steering"] = {
523
+ "method": steering_method,
524
+ "layers": list(steering_vectors.keys()),
525
+ "num_pairs": steering_num_pairs,
526
+ "extraction_strategy": extraction_strategy,
527
+ "scales": {},
528
+ }
529
+
530
+ # Evaluate at each scale
531
+ for scale in steering_scales:
532
+ print(f"\n{'='*60}")
533
+ print(f"Evaluating ReFT+{steering_method.upper()} at scale={scale}")
534
+ print(f"{'='*60}")
535
+
536
+ apply_steering_to_model(wisent_model, steering_data, scale=scale)
537
+
538
+ steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
539
+ steer_acc_lm_eval = extract_accuracy(steer_results, task)
540
+ print(f"ReFT+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
541
+
542
+ steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
543
+ print(f"ReFT+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
544
+
545
+ remove_steering(wisent_model)
546
+
547
+ results["steering"]["scales"][str(scale)] = {
548
+ "accuracy_lm_eval": steer_acc_lm_eval,
549
+ "accuracy_ll": steer_acc_ll,
550
+ "diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
551
+ "diff_from_base_ll": steer_acc_ll - base_acc_ll,
552
+ "diff_from_reft_lm_eval": steer_acc_lm_eval - reft_acc_lm_eval,
553
+ "diff_from_reft_ll": steer_acc_ll - reft_acc_ll,
554
+ }
555
+
556
+ # Cleanup
557
+ remove_reft(wisent_model)
558
+ del wisent_model
559
+ gc.collect()
560
+ if torch.cuda.is_available():
561
+ torch.cuda.empty_cache()
562
+
563
+ # Print summary
564
+ print(f"\n{'='*70}")
565
+ print(f"RESULTS SUMMARY")
566
+ print(f"{'='*70}")
567
+ print(f"Task: {task}")
568
+ print(f"Model: {model_name}")
569
+ print(f"ReFT: {reft_path}")
570
+ print(f"{'-'*70}")
571
+ print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
572
+ print(f"{'-'*70}")
573
+ print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
574
+ print(f"{'ReFT':<25} {reft_acc_lm_eval:<15.4f} {reft_acc_ll:<15.4f} {reft_acc_lm_eval - base_acc_lm_eval:+.4f}")
575
+
576
+ if with_steering:
577
+ for scale, res in results["steering"]["scales"].items():
578
+ label = f"ReFT+{steering_method.upper()}@{scale}"
579
+ print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
580
+
581
+ print(f"{'='*70}")
582
+
583
+ # Save results
584
+ if output_dir:
585
+ output_dir = Path(output_dir)
586
+ model_dir_name = model_name.replace("/", "_")
587
+ output_dir = output_dir / model_dir_name
588
+ output_dir.mkdir(parents=True, exist_ok=True)
589
+ results_file = output_dir / f"{task}_reft_eval_results.json"
590
+ with open(results_file, "w") as f:
591
+ json.dump(results, f, indent=2)
592
+ print(f"\nResults saved to: {results_file}")
593
+
594
+ return results
595
+
596
+
597
+ def main():
598
+ import argparse
599
+
600
+ parser = argparse.ArgumentParser(description="Train and evaluate ReFT intervention on benchmark task")
601
+ parser.add_argument("--model", required=True, help="HuggingFace model name")
602
+ parser.add_argument("--task", default="boolq", help="lm-eval task name")
603
+ parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
604
+ parser.add_argument("--num-pairs", type=int, default=50, help="Number of training examples")
605
+ parser.add_argument("--device", default="cuda:0", help="Device")
606
+ parser.add_argument("--low-rank-dimension", type=int, default=4, help="LoReFT rank (default: 4)")
607
+ parser.add_argument("--intervention-layers", default=None, help="Comma-separated intervention layers (default: auto)")
608
+ parser.add_argument("--learning-rate", type=float, default=5e-4, help="Learning rate")
609
+ parser.add_argument("--num-epochs", type=int, default=3, help="Number of epochs")
610
+ parser.add_argument("--batch-size", type=int, default=2, help="Training batch size")
611
+ parser.add_argument("--max-length", type=int, default=512, help="Max sequence length")
612
+ parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
613
+ # Eval args
614
+ parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
615
+ parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size (int or 'auto')")
616
+ parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size for auto")
617
+ parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
618
+ parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
619
+ # ReFT + Steering args
620
+ parser.add_argument("--with-steering", action="store_true", help="Also evaluate ReFT + steering")
621
+ parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
622
+ parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
623
+ parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
624
+ parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
625
+ parser.add_argument("--extraction-strategy", default="mc_completion", help="Extraction strategy for steering")
626
+
627
+ args = parser.parse_args()
628
+
629
+ output_path = Path(args.output_dir) / f"{args.task}_reft_intervention"
630
+
631
+ # Parse intervention layers for metadata
632
+ if args.intervention_layers:
633
+ intervention_layers = [int(l.strip()) for l in args.intervention_layers.split(",")]
634
+ else:
635
+ intervention_layers = [get_default_layer(args.model)]
636
+
637
+ # Train
638
+ train_reft_adapter(
639
+ task=args.task,
640
+ model_name=args.model,
641
+ output_path=output_path,
642
+ num_pairs=args.num_pairs,
643
+ device=args.device,
644
+ keep_intermediate=args.keep_intermediate,
645
+ low_rank_dimension=args.low_rank_dimension,
646
+ intervention_layers=args.intervention_layers,
647
+ learning_rate=args.learning_rate,
648
+ num_epochs=args.num_epochs,
649
+ batch_size=args.batch_size,
650
+ max_length=args.max_length,
651
+ )
652
+
653
+ # Evaluate base vs ReFT (and optionally ReFT + steering)
654
+ if not args.skip_eval:
655
+ # Parse eval batch size (can be "auto" or int)
656
+ eval_batch_size = args.eval_batch_size
657
+ if eval_batch_size != "auto":
658
+ eval_batch_size = int(eval_batch_size)
659
+
660
+ # Parse steering scales
661
+ steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
662
+
663
+ evaluate_reft(
664
+ model_name=args.model,
665
+ reft_path=output_path,
666
+ task=args.task,
667
+ train_ratio=args.train_ratio,
668
+ device=args.device,
669
+ batch_size=eval_batch_size,
670
+ max_batch_size=args.eval_max_batch_size,
671
+ limit=args.eval_limit,
672
+ output_dir=args.output_dir,
673
+ # Training metadata
674
+ num_train_pairs=args.num_pairs,
675
+ num_epochs=args.num_epochs,
676
+ low_rank_dimension=args.low_rank_dimension,
677
+ intervention_layers=intervention_layers,
678
+ learning_rate=args.learning_rate,
679
+ # Steering parameters
680
+ with_steering=args.with_steering,
681
+ steering_method=args.steering_method,
682
+ steering_layers=args.steering_layers,
683
+ steering_num_pairs=args.steering_num_pairs,
684
+ steering_scales=steering_scales,
685
+ extraction_strategy=args.extraction_strategy,
686
+ )
687
+
688
+
689
+ if __name__ == "__main__":
690
+ main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wisent
3
- Version: 0.7.1045
3
+ Version: 0.7.1116
4
4
  Summary: Monitor and influence AI Brains
5
5
  Home-page: https://github.com/wisent-ai/wisent
6
6
  Author: Lukasz Bartoszcze and the Wisent Team
@@ -1,12 +1,13 @@
1
- wisent/__init__.py,sha256=oKtv5teWo_CHXelrihedQJtVKHc28HLqOA7-Vsehafg,1230
1
+ wisent/__init__.py,sha256=JumWUOHz2LVTGDdRwwjT7GOTdbOxeisuqgQ1tZu7KuM,1230
2
2
  wisent/cli.py,sha256=XKzGIGstr38EowHYpr821c6YuV9Eaw3I1I3NvLztTO0,3960
3
3
  wisent/comparison/__init__.py,sha256=DD_QZfE8XrEEbVTd_l6D5kjxnkOJ-BTQ-mvlu8WPmew,56
4
4
  wisent/comparison/detect_bos_features.py,sha256=T5ewM_eY1Sqic9xr30fU0nmd_ZF6Kj477G4UxNo4w5Y,9799
5
5
  wisent/comparison/fgaa.py,sha256=la1Qs8GUfKB7FGI-WgaCMc24KOVEpDnD5fNdntp3-Q4,15576
6
- wisent/comparison/lora.py,sha256=j-m1ulhu_MA3YB9N-hcitixzXsSDsskL1kIjbzF0uRo,23702
7
- wisent/comparison/lora_dpo.py,sha256=zG3MB77kAC_vIo30OjwKlmC0zqcg_lnteSqauW4usBY,20885
6
+ wisent/comparison/lora.py,sha256=-p0C2jMpQbbitLI2at8qvW08mIhK_baAT1fziY2jnbM,23609
7
+ wisent/comparison/lora_dpo.py,sha256=8mAV114g-2lN22HljSJ6RC34cjTrYZ0tjuY4fA8FmzQ,21577
8
8
  wisent/comparison/main.py,sha256=7jWBXPfvLszDHcWHdCO4hV7v_jB8B9UfjE545pMyf4w,17625
9
9
  wisent/comparison/ours.py,sha256=aMwd4v5Gx-4fLzsA5JI-qHXDvOPBKeUk--5dpbHubfU,1951
10
+ wisent/comparison/reft.py,sha256=YdIsdSAWfxWg4hX4xeQqvWqb_BLhH_t_P0uaPt2BK5k,24511
10
11
  wisent/comparison/sae.py,sha256=3wU7NLkWm3FMlWV9dCdzc5EcpxecelizNyQh65yHE10,9663
11
12
  wisent/comparison/utils.py,sha256=7bundfls_zD1WnMjrLLbyf60WuO9nsV0hs5pPt9VvzY,10679
12
13
  wisent/core/__init__.py,sha256=x1MX4vKpKP3c2FuIHcFly-UkoZwGVnRPbzcFaxr_Jdo,1340
@@ -1050,9 +1051,9 @@ wisent/tests/nosense/__init__.py,sha256=sH3x4jRPzFM3YmQkdrwJoz-BdOQ1Bh6F95G5HWyI
1050
1051
  wisent/tests/nosense/base_nosense.py,sha256=a18dBv1378nHly7OCIuk-bCcLnubss3XXDC1ex0zCK8,2633
1051
1052
  wisent/tests/nosense/math500_nosense.py,sha256=My0dHsr4OFOiTxb_VDKmGzpoMyzAtqXlHhA0oPfaG7s,2389
1052
1053
  wisent/tests/nosense/test_robustness.py,sha256=eeKji-_ls6tx7tuXqUO4BXxFRK-giJVihENAJVOvzSs,12546
1053
- wisent-0.7.1045.dist-info/licenses/LICENSE,sha256=wy0iaw8b2tyqZAfKHib3lP3PJ9o88FDCg92oUHh3sDQ,1073
1054
- wisent-0.7.1045.dist-info/METADATA,sha256=67956g1w6g1tTTNWWXl8qtGrXSIbMqvoEbFG13qldOM,2260
1055
- wisent-0.7.1045.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
1056
- wisent-0.7.1045.dist-info/entry_points.txt,sha256=BM76j3xjtIcVZGk24iDf5w18s6SuqeOpaiAxfZhpnY8,49
1057
- wisent-0.7.1045.dist-info/top_level.txt,sha256=2Ts9Iyldnb3auIN2HBBaHPknRy7nSRDm2f6RGzYgr8A,7
1058
- wisent-0.7.1045.dist-info/RECORD,,
1054
+ wisent-0.7.1116.dist-info/licenses/LICENSE,sha256=wy0iaw8b2tyqZAfKHib3lP3PJ9o88FDCg92oUHh3sDQ,1073
1055
+ wisent-0.7.1116.dist-info/METADATA,sha256=cqzQybnciKNc9kdk5CANYsUeztRR75AX-AKOArfAN8g,2260
1056
+ wisent-0.7.1116.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
1057
+ wisent-0.7.1116.dist-info/entry_points.txt,sha256=BM76j3xjtIcVZGk24iDf5w18s6SuqeOpaiAxfZhpnY8,49
1058
+ wisent-0.7.1116.dist-info/top_level.txt,sha256=2Ts9Iyldnb3auIN2HBBaHPknRy7nSRDm2f6RGzYgr8A,7
1059
+ wisent-0.7.1116.dist-info/RECORD,,