wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +663 -0
  6. wisent/comparison/lora_dpo.py +604 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/reft.py +690 -0
  10. wisent/comparison/sae.py +304 -0
  11. wisent/comparison/utils.py +381 -0
  12. wisent/core/activations/activations_collector.py +3 -2
  13. wisent/core/activations/extraction_strategy.py +8 -4
  14. wisent/core/cli/agent/apply_steering.py +7 -5
  15. wisent/core/cli/agent/train_classifier.py +4 -3
  16. wisent/core/cli/generate_vector_from_task.py +11 -20
  17. wisent/core/cli/get_activations.py +1 -1
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  20. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  21. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  22. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  23. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
  24. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
  25. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  26. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  27. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  28. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  29. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  30. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  31. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  32. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  33. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  34. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  35. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  36. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  37. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  38. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  39. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  40. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  41. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  42. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  43. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  44. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  45. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  46. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  47. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  48. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  49. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  50. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  51. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  52. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  53. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  54. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  55. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  56. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  57. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  58. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  59. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  60. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  61. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  62. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  63. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  64. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  65. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  66. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  67. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  68. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  69. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  71. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  73. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  75. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  76. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  77. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  78. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  79. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  80. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  81. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  82. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  83. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  84. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  85. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  86. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  87. wisent/examples/scripts/generate_paper_data.py +0 -384
  88. wisent/examples/scripts/intervention_validation.py +0 -626
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  90. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  92. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  94. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  95. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  96. wisent/examples/scripts/threshold_analysis.py +0 -434
  97. wisent/examples/scripts/visualization_gallery.py +0 -582
  98. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
  101. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,663 @@
1
+ """
2
+ LoRA fine-tuning method for comparison experiments.
3
+
4
+ Trains a LoRA adapter on benchmark tasks using supervised fine-tuning (SFT)
5
+ on positive responses from contrastive pairs.
6
+
7
+ Optionally evaluates LoRA + steering by generating a steering vector on the
8
+ LoRA model and combining both methods.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import gc
14
+ import json
15
+ import tempfile
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING
18
+
19
+ import torch
20
+ from datasets import Dataset
21
+ from peft import LoraConfig, TaskType, get_peft_model
22
+ from trl import SFTTrainer, SFTConfig
23
+
24
+ from wisent.comparison.utils import (
25
+ generate_contrastive_pairs,
26
+ create_test_only_task,
27
+ extract_accuracy,
28
+ run_lm_eval_evaluation,
29
+ run_ll_evaluation,
30
+ load_model_and_tokenizer,
31
+ apply_steering_to_model,
32
+ remove_steering,
33
+ )
34
+ from wisent.core.utils.device import preferred_dtype
35
+
36
+ if TYPE_CHECKING:
37
+ from wisent.core.models.wisent_model import WisentModel
38
+
39
+ __all__ = ["train_lora_adapter", "evaluate_lora", "apply_lora_to_model", "remove_lora"]
40
+
41
+
42
+ # Default LoRA configurations per model architecture
43
+ LORA_TARGET_MODULES = {
44
+ "gemma": ["q_proj", "k_proj", "v_proj", "o_proj"],
45
+ "llama": ["q_proj", "k_proj", "v_proj", "o_proj"],
46
+ "mistral": ["q_proj", "k_proj", "v_proj", "o_proj"],
47
+ "phi": ["q_proj", "k_proj", "v_proj", "dense"],
48
+ "gpt_neo": ["q_proj", "v_proj"],
49
+ "gpt2": ["c_attn"],
50
+ "default": "all-linear",
51
+ }
52
+
53
+
54
+ def get_target_modules(model_name: str) -> str | list[str]:
55
+ """Get LoRA target modules based on model architecture."""
56
+ model_name_lower = model_name.lower()
57
+
58
+ for arch, modules in LORA_TARGET_MODULES.items():
59
+ if arch in model_name_lower:
60
+ return modules
61
+
62
+ return LORA_TARGET_MODULES["default"]
63
+
64
+
65
+ def prepare_sft_dataset(pairs: list[dict], tokenizer) -> Dataset:
66
+ """
67
+ Prepare dataset for SFT from contrastive pairs.
68
+
69
+ Uses only positive responses for training.
70
+ - Chat models: returns conversational format, SFTTrainer auto-applies chat template
71
+ - Base models: returns text format with simple formatting
72
+
73
+ Args:
74
+ pairs: List of contrastive pairs
75
+ tokenizer: Tokenizer to check for chat template support
76
+
77
+ Returns:
78
+ HuggingFace Dataset ready for SFTTrainer
79
+ """
80
+ formatted_examples = []
81
+
82
+ for pair in pairs:
83
+ prompt = pair["prompt"]
84
+ response = pair["positive_response"]["model_response"]
85
+
86
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
87
+ # Chat model: use conversational format, trainer applies template
88
+ formatted_examples.append({
89
+ "messages": [
90
+ {"role": "user", "content": prompt},
91
+ {"role": "assistant", "content": response},
92
+ ]
93
+ })
94
+ else:
95
+ # Base model: use simple text format
96
+ formatted_examples.append({
97
+ "text": f"{prompt}\n{response}"
98
+ })
99
+
100
+ return Dataset.from_list(formatted_examples)
101
+
102
+
103
+ def train_lora_adapter(
104
+ task: str,
105
+ model_name: str,
106
+ output_path: str | Path,
107
+ trait_label: str = "correctness",
108
+ num_pairs: int = 50,
109
+ device: str = "cuda:0",
110
+ keep_intermediate: bool = False,
111
+ # LoRA-specific parameters
112
+ lora_r: int = 16,
113
+ lora_alpha: int = 32,
114
+ lora_dropout: float = 0.05,
115
+ learning_rate: float = 2e-4,
116
+ num_epochs: int = 3,
117
+ batch_size: int = 2,
118
+ max_length: int = 512,
119
+ ) -> Path:
120
+ """
121
+ Train a LoRA adapter using SFT on positive responses.
122
+
123
+ Args:
124
+ task: lm-eval task name (e.g., 'boolq', 'cb')
125
+ model_name: HuggingFace model name
126
+ output_path: Where to save the LoRA adapter
127
+ trait_label: Label for the trait being trained
128
+ num_pairs: Number of training examples to use
129
+ device: Device to train on
130
+ keep_intermediate: Whether to keep intermediate files
131
+ lora_r: LoRA rank
132
+ lora_alpha: LoRA alpha scaling factor
133
+ lora_dropout: LoRA dropout
134
+ learning_rate: Training learning rate
135
+ num_epochs: Number of training epochs
136
+ batch_size: Training batch size
137
+ max_length: Maximum sequence length
138
+
139
+ Returns:
140
+ Path to the saved LoRA adapter directory
141
+ """
142
+ import gc
143
+
144
+ output_path = Path(output_path)
145
+
146
+ # Step 1: Generate contrastive pairs
147
+ print(f"Step 1: Generating training data from task: {task}")
148
+ pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
149
+ print(f" Loaded {len(pairs)} training examples")
150
+
151
+ # Step 2: Load model and tokenizer
152
+ print(f"\nStep 2: Loading model {model_name}...")
153
+ model, tokenizer = load_model_and_tokenizer(model_name, device, eval_mode=False)
154
+
155
+ # Step 3: Configure LoRA
156
+ print(f"\nStep 3: Configuring LoRA (r={lora_r}, alpha={lora_alpha})...")
157
+
158
+ target_modules = get_target_modules(model_name)
159
+ print(f" Target modules: {target_modules}")
160
+
161
+ lora_config = LoraConfig(
162
+ r=lora_r,
163
+ lora_alpha=lora_alpha,
164
+ target_modules=target_modules,
165
+ lora_dropout=lora_dropout,
166
+ bias="none",
167
+ task_type=TaskType.CAUSAL_LM,
168
+ )
169
+
170
+ model = get_peft_model(model, lora_config)
171
+ model.print_trainable_parameters()
172
+
173
+ # Step 4: Prepare dataset
174
+ print(f"\nStep 4: Preparing SFT dataset...")
175
+ train_dataset = prepare_sft_dataset(pairs, tokenizer)
176
+ print(f" Dataset size: {len(train_dataset)} examples")
177
+
178
+ # Step 5: Training
179
+ print(f"\nStep 5: Training LoRA adapter...")
180
+
181
+ # Create temporary directory for training outputs
182
+ training_output_dir = tempfile.mkdtemp(prefix="lora_training_")
183
+
184
+ # Use device-optimized dtype (bfloat16 on CUDA, float16 on MPS, float32 on CPU)
185
+ dtype = preferred_dtype(device)
186
+
187
+ training_args = SFTConfig(
188
+ output_dir=training_output_dir,
189
+ num_train_epochs=num_epochs,
190
+ per_device_train_batch_size=batch_size,
191
+ gradient_accumulation_steps=1,
192
+ learning_rate=learning_rate,
193
+ weight_decay=0.01,
194
+ warmup_ratio=0.1,
195
+ logging_steps=10,
196
+ save_strategy="no", # Don't save checkpoints
197
+ bf16=(dtype == torch.bfloat16),
198
+ fp16=(dtype == torch.float16),
199
+ report_to="none", # Disable wandb/tensorboard
200
+ max_seq_length=max_length,
201
+ )
202
+
203
+ trainer = SFTTrainer(
204
+ model=model,
205
+ args=training_args,
206
+ train_dataset=train_dataset,
207
+ processing_class=tokenizer,
208
+ )
209
+
210
+ trainer.train()
211
+
212
+ # Step 6: Save LoRA adapter
213
+ print(f"\nStep 6: Saving LoRA adapter to {output_path}...")
214
+ output_path.mkdir(parents=True, exist_ok=True)
215
+ model.save_pretrained(output_path)
216
+ tokenizer.save_pretrained(output_path)
217
+
218
+ # Save metadata
219
+ metadata = {
220
+ "method": "lora",
221
+ "model": model_name,
222
+ "task": task,
223
+ "trait_label": trait_label,
224
+ "num_pairs": len(pairs),
225
+ "lora_config": {
226
+ "r": lora_r,
227
+ "alpha": lora_alpha,
228
+ "dropout": lora_dropout,
229
+ "target_modules": target_modules if isinstance(target_modules, list) else [target_modules],
230
+ },
231
+ "training_config": {
232
+ "learning_rate": learning_rate,
233
+ "num_epochs": num_epochs,
234
+ "batch_size": batch_size,
235
+ "max_length": max_length,
236
+ },
237
+ }
238
+
239
+ with open(output_path / "metadata.json", "w") as f:
240
+ json.dump(metadata, f, indent=2)
241
+
242
+ # Cleanup
243
+ del model, trainer
244
+ gc.collect()
245
+ if torch.cuda.is_available():
246
+ torch.cuda.empty_cache()
247
+ torch.cuda.synchronize()
248
+
249
+ if not keep_intermediate:
250
+ import os
251
+ os.unlink(pairs_file)
252
+ import shutil
253
+ shutil.rmtree(training_output_dir, ignore_errors=True)
254
+
255
+ print(f"\nLoRA adapter saved to {output_path}")
256
+ return output_path
257
+
258
+
259
+ def apply_lora_to_model(wisent_model: "WisentModel", lora_path: str | Path) -> None:
260
+ """
261
+ Apply a trained LoRA adapter to a WisentModel.
262
+
263
+ Args:
264
+ wisent_model: WisentModel instance
265
+ lora_path: Path to the saved LoRA adapter
266
+ """
267
+ from peft import PeftModel
268
+
269
+ lora_path = Path(lora_path)
270
+
271
+ # Check if model already has adapters
272
+ if hasattr(wisent_model.hf_model, 'peft_config'):
273
+ # Model already has PEFT, just load new adapter
274
+ wisent_model.hf_model.load_adapter(str(lora_path), adapter_name="steering")
275
+ wisent_model.hf_model.set_adapter("steering")
276
+ else:
277
+ # Wrap model with PEFT
278
+ wisent_model.hf_model = PeftModel.from_pretrained(
279
+ wisent_model.hf_model,
280
+ str(lora_path),
281
+ adapter_name="steering",
282
+ )
283
+
284
+ print(f"LoRA adapter loaded from {lora_path}")
285
+
286
+
287
+ def remove_lora(wisent_model: "WisentModel") -> None:
288
+ """
289
+ Remove/disable LoRA adapter from a WisentModel.
290
+
291
+ Args:
292
+ wisent_model: WisentModel instance with LoRA applied
293
+ """
294
+ if hasattr(wisent_model.hf_model, 'disable_adapters'):
295
+ try:
296
+ wisent_model.hf_model.disable_adapters()
297
+ print("LoRA adapter disabled")
298
+ except ValueError:
299
+ # No adapter was loaded
300
+ pass
301
+ elif hasattr(wisent_model.hf_model, 'base_model'):
302
+ # Unwrap the model
303
+ wisent_model.hf_model = wisent_model.hf_model.base_model.model
304
+ print("LoRA adapter removed")
305
+
306
+
307
+ def evaluate_lora(
308
+ model_name: str,
309
+ lora_path: str | Path,
310
+ task: str,
311
+ train_ratio: float = 0.8,
312
+ device: str = "cuda:0",
313
+ batch_size: int = 1,
314
+ max_batch_size: int = 8,
315
+ limit: int | None = None,
316
+ output_dir: str | Path = None,
317
+ # Training metadata (for output)
318
+ num_train_pairs: int | None = None,
319
+ num_epochs: int | None = None,
320
+ lora_r: int | None = None,
321
+ lora_alpha: int | None = None,
322
+ lora_dropout: float | None = None,
323
+ learning_rate: float | None = None,
324
+ # Steering parameters (optional)
325
+ with_steering: bool = False,
326
+ steering_method: str = "caa",
327
+ steering_layers: str = "12",
328
+ steering_num_pairs: int = 50,
329
+ steering_scales: list[float] | None = None,
330
+ extraction_strategy: str = "mc_completion",
331
+ ) -> dict:
332
+ """
333
+ Evaluate a trained LoRA adapter comparing base vs LoRA performance.
334
+
335
+ Optionally also evaluates LoRA + steering at multiple scales.
336
+ All results are saved to a single output file.
337
+
338
+ Args:
339
+ model_name: HuggingFace model name
340
+ lora_path: Path to trained LoRA adapter
341
+ task: lm-eval task name
342
+ train_ratio: Train/test split ratio
343
+ device: Device to run on
344
+ batch_size: Batch size for evaluation
345
+ max_batch_size: Max batch size
346
+ limit: Limit number of eval examples
347
+ output_dir: Where to save results
348
+ with_steering: Whether to also evaluate LoRA + steering
349
+ steering_method: Steering method (caa or fgaa)
350
+ steering_layers: Layers for steering vector
351
+ steering_num_pairs: Number of pairs for steering generation
352
+ steering_scales: List of steering scales to evaluate
353
+ extraction_strategy: Strategy for activation extraction
354
+
355
+ Returns:
356
+ Dict with evaluation results
357
+ """
358
+ import gc
359
+
360
+ from wisent.core.models.wisent_model import WisentModel
361
+
362
+ lora_path = Path(lora_path)
363
+
364
+ if steering_scales is None:
365
+ steering_scales = [1.0, 2.0, 4.0]
366
+
367
+ # Create test task
368
+ print(f"\n{'='*60}")
369
+ print(f"Creating test task for: {task}")
370
+ print(f"{'='*60}")
371
+
372
+ task_dict = create_test_only_task(task, train_ratio=train_ratio)
373
+
374
+ # Load model
375
+ print(f"\n{'='*60}")
376
+ print(f"Loading model: {model_name}")
377
+ print(f"{'='*60}")
378
+ wisent_model = WisentModel(model_name=model_name, device=device)
379
+
380
+ # BASE evaluation
381
+ print(f"\n{'='*60}")
382
+ print(f"Running BASE evaluation (no LoRA)")
383
+ print(f"{'='*60}")
384
+
385
+ base_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
386
+ base_acc_lm_eval = extract_accuracy(base_results, task)
387
+ print(f"Base accuracy (lm-eval): {base_acc_lm_eval:.4f}")
388
+
389
+ base_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
390
+ print(f"Base accuracy (LL): {base_acc_ll:.4f}")
391
+
392
+ # Apply LoRA
393
+ print(f"\n{'='*60}")
394
+ print(f"Applying LoRA adapter from: {lora_path}")
395
+ print(f"{'='*60}")
396
+ apply_lora_to_model(wisent_model, lora_path)
397
+
398
+ # LORA evaluation
399
+ print(f"\n{'='*60}")
400
+ print(f"Running LORA evaluation")
401
+ print(f"{'='*60}")
402
+
403
+ lora_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
404
+ lora_acc_lm_eval = extract_accuracy(lora_results, task)
405
+ print(f"LoRA accuracy (lm-eval): {lora_acc_lm_eval:.4f}")
406
+
407
+ lora_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
408
+ print(f"LoRA accuracy (LL): {lora_acc_ll:.4f}")
409
+
410
+ # Results dict
411
+ results = {
412
+ "task": task,
413
+ "model": model_name,
414
+ "lora_path": str(lora_path),
415
+ # Training config
416
+ "num_train_pairs": num_train_pairs,
417
+ "num_epochs": num_epochs,
418
+ "lora_r": lora_r,
419
+ "lora_alpha": lora_alpha,
420
+ "lora_dropout": lora_dropout,
421
+ "learning_rate": learning_rate,
422
+ # Eval config
423
+ "train_ratio": train_ratio,
424
+ "eval_limit": limit,
425
+ # Results
426
+ "base_accuracy_lm_eval": base_acc_lm_eval,
427
+ "base_accuracy_ll": base_acc_ll,
428
+ "lora_accuracy_lm_eval": lora_acc_lm_eval,
429
+ "lora_accuracy_ll": lora_acc_ll,
430
+ "lora_diff_lm_eval": lora_acc_lm_eval - base_acc_lm_eval,
431
+ "lora_diff_ll": lora_acc_ll - base_acc_ll,
432
+ }
433
+
434
+ # LoRA + Steering evaluation (if enabled)
435
+ if with_steering:
436
+ from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
437
+ from wisent.core.steering_methods import get_steering_method
438
+ from wisent.core.activations.extraction_strategy import ExtractionStrategy
439
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
440
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
441
+ from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
442
+
443
+ # Generate contrastive pairs for steering
444
+ print(f"\n{'='*60}")
445
+ print(f"Generating {steering_num_pairs} contrastive pairs for steering")
446
+ print(f"{'='*60}")
447
+ pairs_data, pairs_file = generate_contrastive_pairs(task, steering_num_pairs)
448
+
449
+ # Convert to ContrastivePairSet
450
+ pairs = []
451
+ for p in pairs_data:
452
+ pair = ContrastivePair(
453
+ prompt=p["prompt"],
454
+ positive_response=PositiveResponse(model_response=p["positive_response"]["model_response"]),
455
+ negative_response=NegativeResponse(model_response=p["negative_response"]["model_response"]),
456
+ )
457
+ pairs.append(pair)
458
+ pair_set = ContrastivePairSet(pairs=pairs, name=f"{task}_lora_steering")
459
+ print(f"Created {len(pair_set)} contrastive pairs")
460
+
461
+ # Generate steering vector on LoRA model
462
+ print(f"\n{'='*60}")
463
+ print(f"Generating {steering_method.upper()} steering vector on LoRA model")
464
+ print(f"Layers: {steering_layers}")
465
+ print(f"{'='*60}")
466
+
467
+ steering_method_obj = get_steering_method(steering_method, device=device)
468
+ strategy = ExtractionStrategy(extraction_strategy)
469
+
470
+ trainer = WisentSteeringTrainer(
471
+ model=wisent_model,
472
+ pair_set=pair_set,
473
+ steering_method=steering_method_obj,
474
+ )
475
+
476
+ result = trainer.run(
477
+ layers_spec=steering_layers,
478
+ strategy=strategy,
479
+ accept_low_quality_vector=True,
480
+ )
481
+
482
+ # Convert to dict format for apply_steering_to_model
483
+ steering_vectors = {}
484
+ for layer_name, tensor in result.steered_vectors.to_dict().items():
485
+ if tensor is not None:
486
+ steering_vectors[layer_name] = tensor.cpu().float().tolist()
487
+
488
+ steering_data = {
489
+ "steering_vectors": steering_vectors,
490
+ "layers": list(steering_vectors.keys()),
491
+ }
492
+
493
+ # Cleanup temp file
494
+ import os
495
+ os.unlink(pairs_file)
496
+
497
+ # Add steering info to results
498
+ results["steering"] = {
499
+ "method": steering_method,
500
+ "layers": list(steering_vectors.keys()),
501
+ "num_pairs": steering_num_pairs,
502
+ "extraction_strategy": extraction_strategy,
503
+ "scales": {},
504
+ }
505
+
506
+ # Evaluate at each scale
507
+ for scale in steering_scales:
508
+ print(f"\n{'='*60}")
509
+ print(f"Evaluating LoRA+{steering_method.upper()} at scale={scale}")
510
+ print(f"{'='*60}")
511
+
512
+ apply_steering_to_model(wisent_model, steering_data, scale=scale)
513
+
514
+ steer_results = run_lm_eval_evaluation(wisent_model, task_dict, task, batch_size, max_batch_size, limit)
515
+ steer_acc_lm_eval = extract_accuracy(steer_results, task)
516
+ print(f"LoRA+{steering_method.upper()} accuracy (lm-eval): {steer_acc_lm_eval:.4f}")
517
+
518
+ steer_acc_ll = run_ll_evaluation(wisent_model, task_dict, task, limit)
519
+ print(f"LoRA+{steering_method.upper()} accuracy (LL): {steer_acc_ll:.4f}")
520
+
521
+ remove_steering(wisent_model)
522
+
523
+ results["steering"]["scales"][str(scale)] = {
524
+ "accuracy_lm_eval": steer_acc_lm_eval,
525
+ "accuracy_ll": steer_acc_ll,
526
+ "diff_from_base_lm_eval": steer_acc_lm_eval - base_acc_lm_eval,
527
+ "diff_from_base_ll": steer_acc_ll - base_acc_ll,
528
+ "diff_from_lora_lm_eval": steer_acc_lm_eval - lora_acc_lm_eval,
529
+ "diff_from_lora_ll": steer_acc_ll - lora_acc_ll,
530
+ }
531
+
532
+ # Cleanup
533
+ remove_lora(wisent_model)
534
+ del wisent_model
535
+ gc.collect()
536
+ if torch.cuda.is_available():
537
+ torch.cuda.empty_cache()
538
+
539
+ # Print summary
540
+ print(f"\n{'='*70}")
541
+ print(f"RESULTS SUMMARY")
542
+ print(f"{'='*70}")
543
+ print(f"Task: {task}")
544
+ print(f"Model: {model_name}")
545
+ print(f"LoRA: {lora_path}")
546
+ print(f"{'-'*70}")
547
+ print(f"{'Method':<25} {'lm-eval acc':<15} {'LL acc':<15} {'Diff (lm-eval)':<15}")
548
+ print(f"{'-'*70}")
549
+ print(f"{'Base':<25} {base_acc_lm_eval:<15.4f} {base_acc_ll:<15.4f} {'':<15}")
550
+ print(f"{'LoRA':<25} {lora_acc_lm_eval:<15.4f} {lora_acc_ll:<15.4f} {lora_acc_lm_eval - base_acc_lm_eval:+.4f}")
551
+
552
+ if with_steering:
553
+ for scale, res in results["steering"]["scales"].items():
554
+ label = f"LoRA+{steering_method.upper()}@{scale}"
555
+ print(f"{label:<25} {res['accuracy_lm_eval']:<15.4f} {res['accuracy_ll']:<15.4f} {res['diff_from_base_lm_eval']:+.4f}")
556
+
557
+ print(f"{'='*70}")
558
+
559
+ # Save results
560
+ if output_dir:
561
+ output_dir = Path(output_dir)
562
+ model_dir_name = model_name.replace("/", "_")
563
+ output_dir = output_dir / model_dir_name
564
+ output_dir.mkdir(parents=True, exist_ok=True)
565
+ results_file = output_dir / f"{task}_lora_eval_results.json"
566
+ with open(results_file, "w") as f:
567
+ json.dump(results, f, indent=2)
568
+ print(f"\nResults saved to: {results_file}")
569
+
570
+ return results
571
+
572
+
573
+ def main():
574
+ import argparse
575
+
576
+ parser = argparse.ArgumentParser(description="Train and evaluate LoRA adapter on benchmark task")
577
+ parser.add_argument("--model", required=True, help="HuggingFace model name")
578
+ parser.add_argument("--task", default="boolq", help="lm-eval task name")
579
+ parser.add_argument("--output-dir", default="/home/ubuntu/output", help="Output directory")
580
+ parser.add_argument("--num-pairs", type=int, default=50, help="Number of training examples")
581
+ parser.add_argument("--device", default="cuda:0", help="Device")
582
+ parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
583
+ parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
584
+ parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
585
+ parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
586
+ parser.add_argument("--num-epochs", type=int, default=3, help="Number of epochs")
587
+ parser.add_argument("--batch-size", type=int, default=2, help="Training batch size")
588
+ parser.add_argument("--max-length", type=int, default=512, help="Max sequence length")
589
+ parser.add_argument("--keep-intermediate", action="store_true", help="Keep intermediate files")
590
+ # Eval args
591
+ parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio")
592
+ parser.add_argument("--eval-batch-size", default="auto", help="Eval batch size (int or 'auto')")
593
+ parser.add_argument("--eval-max-batch-size", type=int, default=64, help="Max eval batch size for auto")
594
+ parser.add_argument("--eval-limit", type=int, default=None, help="Limit eval examples")
595
+ parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation after training")
596
+ # LoRA + Steering args
597
+ parser.add_argument("--with-steering", action="store_true", help="Also evaluate LoRA + steering")
598
+ parser.add_argument("--steering-method", default="caa", choices=["caa", "fgaa"], help="Steering method")
599
+ parser.add_argument("--steering-layers", default="12", help="Layers for steering vector")
600
+ parser.add_argument("--steering-num-pairs", type=int, default=50, help="Number of pairs for steering")
601
+ parser.add_argument("--steering-scales", default="1.0,2.0,4.0", help="Comma-separated steering scales")
602
+ parser.add_argument("--extraction-strategy", default="mc_balanced", help="Extraction strategy for steering")
603
+
604
+ args = parser.parse_args()
605
+
606
+ output_path = Path(args.output_dir) / f"{args.task}_lora_adapter"
607
+
608
+ # Train
609
+ train_lora_adapter(
610
+ task=args.task,
611
+ model_name=args.model,
612
+ output_path=output_path,
613
+ num_pairs=args.num_pairs,
614
+ device=args.device,
615
+ keep_intermediate=args.keep_intermediate,
616
+ lora_r=args.lora_r,
617
+ lora_alpha=args.lora_alpha,
618
+ lora_dropout=args.lora_dropout,
619
+ learning_rate=args.learning_rate,
620
+ num_epochs=args.num_epochs,
621
+ batch_size=args.batch_size,
622
+ max_length=args.max_length,
623
+ )
624
+
625
+ # Evaluate base vs LoRA (and optionally LoRA + steering)
626
+ if not args.skip_eval:
627
+ # Parse eval batch size (can be "auto" or int)
628
+ eval_batch_size = args.eval_batch_size
629
+ if eval_batch_size != "auto":
630
+ eval_batch_size = int(eval_batch_size)
631
+
632
+ # Parse steering scales
633
+ steering_scales = [float(s.strip()) for s in args.steering_scales.split(",")]
634
+
635
+ evaluate_lora(
636
+ model_name=args.model,
637
+ lora_path=output_path,
638
+ task=args.task,
639
+ train_ratio=args.train_ratio,
640
+ device=args.device,
641
+ batch_size=eval_batch_size,
642
+ max_batch_size=args.eval_max_batch_size,
643
+ limit=args.eval_limit,
644
+ output_dir=args.output_dir,
645
+ # Training metadata
646
+ num_train_pairs=args.num_pairs,
647
+ num_epochs=args.num_epochs,
648
+ lora_r=args.lora_r,
649
+ lora_alpha=args.lora_alpha,
650
+ lora_dropout=args.lora_dropout,
651
+ learning_rate=args.learning_rate,
652
+ # Steering parameters
653
+ with_steering=args.with_steering,
654
+ steering_method=args.steering_method,
655
+ steering_layers=args.steering_layers,
656
+ steering_num_pairs=args.steering_num_pairs,
657
+ steering_scales=steering_scales,
658
+ extraction_strategy=args.extraction_strategy,
659
+ )
660
+
661
+
662
+ if __name__ == "__main__":
663
+ main()