wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1738 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset-Agnostic Optimization Pipeline with Optuna
|
|
3
|
+
|
|
4
|
+
This script builds a reproducible pipeline that:
|
|
5
|
+
1. Trains probes and learns steering vectors on the training split
|
|
6
|
+
2. Selects the best layer, probe type, steering method, and hyperparameters on validation split via Optuna
|
|
7
|
+
3. Evaluates once on the test split with the single best configuration determined on validation
|
|
8
|
+
|
|
9
|
+
Key features:
|
|
10
|
+
- Optuna-based hyperparameter optimization with pruners
|
|
11
|
+
- Activation caching for efficiency
|
|
12
|
+
- Configurable datasets for train/val/test splits
|
|
13
|
+
- Steering evaluation with model re-forwarding
|
|
14
|
+
- Reproducibility bundle generation
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import hashlib
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import pickle
|
|
22
|
+
from dataclasses import asdict, dataclass, field
|
|
23
|
+
from datetime import datetime
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any, Optional
|
|
26
|
+
|
|
27
|
+
import numpy as np
|
|
28
|
+
import optuna
|
|
29
|
+
import torch
|
|
30
|
+
from optuna.pruners import MedianPruner, SuccessiveHalvingPruner
|
|
31
|
+
from optuna.samplers import TPESampler
|
|
32
|
+
from safetensors.torch import save_file as safetensors_save
|
|
33
|
+
from tqdm import tqdm
|
|
34
|
+
|
|
35
|
+
# Optional WandB integration
|
|
36
|
+
try:
|
|
37
|
+
import wandb
|
|
38
|
+
|
|
39
|
+
WANDB_AVAILABLE = True
|
|
40
|
+
except ImportError:
|
|
41
|
+
WANDB_AVAILABLE = False
|
|
42
|
+
from wisent_guard.core.contrastive_pairs.contrastive_pair import ContrastivePair
|
|
43
|
+
from wisent_guard.core.contrastive_pairs.contrastive_pair_set import ContrastivePairSet
|
|
44
|
+
from wisent_guard.core.optuna.steering import data_utils, metrics
|
|
45
|
+
from wisent_guard.core.response import Response
|
|
46
|
+
from wisent_guard.core.steering_methods.dac import DAC
|
|
47
|
+
from wisent_guard.core.task_interface import get_task
|
|
48
|
+
from wisent_guard.core.utils.device import empty_device_cache, preferred_dtype, resolve_default_device, resolve_device
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class OptimizationConfig:
|
|
55
|
+
"""Configuration for dataset-agnostic optimization pipeline."""
|
|
56
|
+
|
|
57
|
+
model_name: str = "realtreetune/rho-1b-sft-GSM8K"
|
|
58
|
+
device: str = field(default_factory=resolve_default_device)
|
|
59
|
+
|
|
60
|
+
train_dataset: str = "gsm8k"
|
|
61
|
+
val_dataset: str = "gsm8k"
|
|
62
|
+
test_dataset: str = "gsm8k"
|
|
63
|
+
|
|
64
|
+
# Training configuration
|
|
65
|
+
train_limit: int = 50 # How many training samples to load
|
|
66
|
+
contrastive_pairs_limit: int = 20 # How many contrastive pairs to extract for steering training
|
|
67
|
+
|
|
68
|
+
# Evaluation configuration
|
|
69
|
+
val_limit: int = 50 # How many validation samples to load
|
|
70
|
+
test_limit: int = 100 # How many test samples to load
|
|
71
|
+
|
|
72
|
+
layer_search_range: tuple[int, int] = (15, 20)
|
|
73
|
+
probe_type: str = "logistic_regression" # Fixed probe type
|
|
74
|
+
steering_methods: list[str] = field(default_factory=lambda: ["dac", "caa"]) # TODO add more
|
|
75
|
+
|
|
76
|
+
# Optuna study configuration
|
|
77
|
+
study_name: str = "optimization_pipeline"
|
|
78
|
+
db_url: str = field(
|
|
79
|
+
default_factory=lambda: f"sqlite:///{os.path.dirname(os.path.dirname(__file__))}/optuna_studies.db"
|
|
80
|
+
)
|
|
81
|
+
n_trials: int = 50
|
|
82
|
+
n_startup_trials: int = 10 # Random exploration before TPE kicks in
|
|
83
|
+
sampler: str = "TPE"
|
|
84
|
+
pruner: str = "MedianPruner"
|
|
85
|
+
|
|
86
|
+
# WandB configuration
|
|
87
|
+
wandb_project: str = "wisent-guard-optimization"
|
|
88
|
+
use_wandb: bool = False # TODO
|
|
89
|
+
|
|
90
|
+
batch_size: int = 8
|
|
91
|
+
max_length: int = 512
|
|
92
|
+
max_new_tokens: int = 256
|
|
93
|
+
seed: int = 42
|
|
94
|
+
|
|
95
|
+
temperature: float = 0.0
|
|
96
|
+
do_sample: bool = False
|
|
97
|
+
|
|
98
|
+
output_dir: str = "outputs/optimization_pipeline"
|
|
99
|
+
cache_dir: str = "cache/optimization_pipeline"
|
|
100
|
+
|
|
101
|
+
max_layers_to_search: int = 6
|
|
102
|
+
early_stopping_patience: int = 10
|
|
103
|
+
|
|
104
|
+
def to_dict(self) -> dict[str, Any]:
|
|
105
|
+
"""Convert to dictionary for serialization."""
|
|
106
|
+
return asdict(self)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ActivationCache:
|
|
110
|
+
"""Efficient activation caching system with proper cache keys."""
|
|
111
|
+
|
|
112
|
+
def __init__(self, cache_dir: str):
|
|
113
|
+
self.cache_dir = Path(cache_dir)
|
|
114
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
116
|
+
|
|
117
|
+
def _generate_cache_key(
|
|
118
|
+
self, split: str, layer_id: int, tokenization_config: dict[str, Any], prompt_variant: str = "default"
|
|
119
|
+
) -> str:
|
|
120
|
+
"""Generate unique cache key for activations."""
|
|
121
|
+
config_str = json.dumps(tokenization_config, sort_keys=True)
|
|
122
|
+
key_data = f"{split}_{layer_id}_{config_str}_{prompt_variant}"
|
|
123
|
+
return hashlib.md5(key_data.encode()).hexdigest()
|
|
124
|
+
|
|
125
|
+
def _get_cache_path(self, cache_key: str) -> Path:
|
|
126
|
+
"""Get cache file path for key."""
|
|
127
|
+
return self.cache_dir / f"activations_{cache_key}.pkl"
|
|
128
|
+
|
|
129
|
+
def has_cached_activations(
|
|
130
|
+
self, split: str, layer_id: int, tokenization_config: dict[str, Any], prompt_variant: str = "default"
|
|
131
|
+
) -> bool:
|
|
132
|
+
"""Check if activations are cached."""
|
|
133
|
+
cache_key = self._generate_cache_key(split, layer_id, tokenization_config, prompt_variant)
|
|
134
|
+
return self._get_cache_path(cache_key).exists()
|
|
135
|
+
|
|
136
|
+
def save_activations(
|
|
137
|
+
self,
|
|
138
|
+
activations: np.ndarray,
|
|
139
|
+
labels: np.ndarray,
|
|
140
|
+
split: str,
|
|
141
|
+
layer_id: int,
|
|
142
|
+
tokenization_config: dict[str, Any],
|
|
143
|
+
prompt_variant: str = "default",
|
|
144
|
+
):
|
|
145
|
+
"""Save activations to cache."""
|
|
146
|
+
cache_key = self._generate_cache_key(split, layer_id, tokenization_config, prompt_variant)
|
|
147
|
+
cache_path = self._get_cache_path(cache_key)
|
|
148
|
+
|
|
149
|
+
cache_data = {
|
|
150
|
+
"activations": activations,
|
|
151
|
+
"labels": labels,
|
|
152
|
+
"metadata": {
|
|
153
|
+
"split": split,
|
|
154
|
+
"layer_id": layer_id,
|
|
155
|
+
"tokenization_config": tokenization_config,
|
|
156
|
+
"prompt_variant": prompt_variant,
|
|
157
|
+
"timestamp": datetime.now().isoformat(),
|
|
158
|
+
"shape": activations.shape,
|
|
159
|
+
},
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
with open(cache_path, "wb") as f:
|
|
163
|
+
pickle.dump(cache_data, f)
|
|
164
|
+
|
|
165
|
+
self.logger.info(f"Cached activations for {split} layer {layer_id}: {activations.shape}")
|
|
166
|
+
|
|
167
|
+
def load_activations(
|
|
168
|
+
self, split: str, layer_id: int, tokenization_config: dict[str, Any], prompt_variant: str = "default"
|
|
169
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
170
|
+
"""Load activations from cache."""
|
|
171
|
+
cache_key = self._generate_cache_key(split, layer_id, tokenization_config, prompt_variant)
|
|
172
|
+
cache_path = self._get_cache_path(cache_key)
|
|
173
|
+
|
|
174
|
+
if not cache_path.exists():
|
|
175
|
+
raise FileNotFoundError(f"No cached activations found for key: {cache_key}")
|
|
176
|
+
|
|
177
|
+
with open(cache_path, "rb") as f:
|
|
178
|
+
cache_data = pickle.load(f)
|
|
179
|
+
|
|
180
|
+
self.logger.info(f"Loaded cached activations for {split} layer {layer_id}: {cache_data['activations'].shape}")
|
|
181
|
+
return cache_data["activations"], cache_data["labels"]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class OptimizationPipeline:
|
|
185
|
+
"""Main optimization pipeline using Optuna for hyperparameter search."""
|
|
186
|
+
|
|
187
|
+
def __init__(self, config: OptimizationConfig):
|
|
188
|
+
self.config = config
|
|
189
|
+
self.device = resolve_device(config.device)
|
|
190
|
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
191
|
+
|
|
192
|
+
# Setup output directories
|
|
193
|
+
self.output_dir = Path(config.output_dir)
|
|
194
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
195
|
+
|
|
196
|
+
# Initialize cache
|
|
197
|
+
self.cache = ActivationCache(config.cache_dir)
|
|
198
|
+
|
|
199
|
+
# Initialize WandB if configured
|
|
200
|
+
self.wandb_run = None
|
|
201
|
+
if config.use_wandb:
|
|
202
|
+
if not WANDB_AVAILABLE:
|
|
203
|
+
raise ImportError(
|
|
204
|
+
"WandB integration enabled but wandb is not installed. Install with: pip install wandb"
|
|
205
|
+
)
|
|
206
|
+
self._init_wandb()
|
|
207
|
+
|
|
208
|
+
self.model = None
|
|
209
|
+
self.tokenizer = None
|
|
210
|
+
self.train_samples = None
|
|
211
|
+
self.val_samples = None
|
|
212
|
+
self.test_samples = None
|
|
213
|
+
# Store task documents for BigCode evaluation
|
|
214
|
+
self.train_task_docs = None
|
|
215
|
+
self.val_task_docs = None
|
|
216
|
+
self.test_task_docs = None
|
|
217
|
+
self.tokenization_config = {
|
|
218
|
+
"max_length": config.max_length,
|
|
219
|
+
"padding": True,
|
|
220
|
+
"truncation": True,
|
|
221
|
+
"return_tensors": "pt",
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def is_coding_task(self) -> bool:
|
|
226
|
+
"""Check if the current task requires code execution evaluation."""
|
|
227
|
+
from ...parameters.task_config import CODING_TASKS
|
|
228
|
+
from ..bigcode_integration import is_bigcode_task
|
|
229
|
+
|
|
230
|
+
val_dataset = getattr(self.config, "val_dataset", None)
|
|
231
|
+
if not val_dataset:
|
|
232
|
+
return False
|
|
233
|
+
|
|
234
|
+
return val_dataset.lower() in CODING_TASKS or is_bigcode_task(val_dataset)
|
|
235
|
+
|
|
236
|
+
def run_optimization(self) -> dict[str, Any]:
|
|
237
|
+
"""Run the complete optimization pipeline."""
|
|
238
|
+
self.logger.info("=" * 80)
|
|
239
|
+
self.logger.info("🚀 STARTING OPTIMIZATION PIPELINE WITH OPTUNA")
|
|
240
|
+
self.logger.info("=" * 80)
|
|
241
|
+
|
|
242
|
+
# Create timestamped run directory
|
|
243
|
+
self.run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
244
|
+
self.run_dir = self.output_dir / f"run_{self.run_timestamp}"
|
|
245
|
+
self.run_dir.mkdir(parents=True, exist_ok=True)
|
|
246
|
+
self.logger.info(f"📁 Run directory: {self.run_dir}")
|
|
247
|
+
|
|
248
|
+
self._setup_experiment()
|
|
249
|
+
study = self._create_optuna_study()
|
|
250
|
+
study.optimize(self._objective_function, n_trials=self.config.n_trials)
|
|
251
|
+
best_trial = study.best_trial
|
|
252
|
+
final_results = self._final_evaluation(best_trial)
|
|
253
|
+
self._save_reproducibility_bundle(study, final_results)
|
|
254
|
+
|
|
255
|
+
# Log final results to WandB
|
|
256
|
+
self._log_final_results_to_wandb(study, final_results)
|
|
257
|
+
|
|
258
|
+
self.logger.info("✅ Optimization completed successfully!")
|
|
259
|
+
return final_results
|
|
260
|
+
|
|
261
|
+
def _setup_experiment(self):
|
|
262
|
+
"""Setup model, tokenizer, and load datasets."""
|
|
263
|
+
self.logger.info("📊 Setting up experiment...")
|
|
264
|
+
|
|
265
|
+
# Load model and tokenizer with memory optimizations
|
|
266
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
267
|
+
|
|
268
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
|
269
|
+
|
|
270
|
+
# Load model with memory optimizations (same as comprehensive evaluation)
|
|
271
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
272
|
+
self.config.model_name,
|
|
273
|
+
torch_dtype=preferred_dtype(self.device.type),
|
|
274
|
+
low_cpu_mem_usage=True,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self.model.to(self.device)
|
|
278
|
+
self.model.eval() # Set to evaluation mode
|
|
279
|
+
|
|
280
|
+
if self.tokenizer.pad_token is None:
|
|
281
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
282
|
+
|
|
283
|
+
# Set left padding for decoder-only models (same as comprehensive evaluation)
|
|
284
|
+
self.tokenizer.padding_side = "left"
|
|
285
|
+
|
|
286
|
+
# Load datasets
|
|
287
|
+
self.train_samples = data_utils.load_dataset_samples(self.config.train_dataset, self.config.train_limit)
|
|
288
|
+
self.val_samples = data_utils.load_dataset_samples(self.config.val_dataset, self.config.val_limit)
|
|
289
|
+
self.test_samples = data_utils.load_dataset_samples(self.config.test_dataset, self.config.test_limit)
|
|
290
|
+
|
|
291
|
+
# Store task documents for BigCode evaluation (coding tasks)
|
|
292
|
+
self.train_task_docs = self.train_samples
|
|
293
|
+
self.val_task_docs = self.val_samples
|
|
294
|
+
self.test_task_docs = self.test_samples
|
|
295
|
+
|
|
296
|
+
self.logger.info(
|
|
297
|
+
f"Loaded {len(self.train_samples)} train, {len(self.val_samples)} val, {len(self.test_samples)} test samples"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Pre-cache activations for all layers on all splits
|
|
301
|
+
self._precache_activations()
|
|
302
|
+
|
|
303
|
+
def _precache_activations(self):
|
|
304
|
+
"""Pre-cache activations for all layers and splits to improve efficiency."""
|
|
305
|
+
self.logger.info("🔄 Pre-caching activations for efficiency...")
|
|
306
|
+
|
|
307
|
+
layer_range = range(self.config.layer_search_range[0], self.config.layer_search_range[1] + 1)
|
|
308
|
+
|
|
309
|
+
splits_data = [("train", self.train_samples), ("val", self.val_samples), ("test", self.test_samples)]
|
|
310
|
+
|
|
311
|
+
for split_name, samples in splits_data:
|
|
312
|
+
for layer_id in layer_range:
|
|
313
|
+
if not self.cache.has_cached_activations(split_name, layer_id, self.tokenization_config):
|
|
314
|
+
self.logger.info(f"Caching activations for {split_name} split, layer {layer_id}")
|
|
315
|
+
|
|
316
|
+
dataset_name = {
|
|
317
|
+
"train": self.config.train_dataset,
|
|
318
|
+
"val": self.config.val_dataset,
|
|
319
|
+
"test": self.config.test_dataset,
|
|
320
|
+
}[split_name]
|
|
321
|
+
|
|
322
|
+
activations, labels = self._create_probe_data(samples, layer_id, dataset_name)
|
|
323
|
+
|
|
324
|
+
self.cache.save_activations(activations, labels, split_name, layer_id, self.tokenization_config)
|
|
325
|
+
else:
|
|
326
|
+
self.logger.info(f"Activations already cached for {split_name} split, layer {layer_id}")
|
|
327
|
+
|
|
328
|
+
def _create_probe_data(
|
|
329
|
+
self, samples: list[dict], layer_id: int, dataset_name: str
|
|
330
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
331
|
+
"""Create contrastive probe training data for a specific layer."""
|
|
332
|
+
self.logger.info(f"Creating probe data from {len(samples)} samples for {dataset_name} on layer {layer_id}")
|
|
333
|
+
|
|
334
|
+
# Get task for the specified dataset
|
|
335
|
+
task = get_task(dataset_name)
|
|
336
|
+
extractor = task.get_extractor()
|
|
337
|
+
self.logger.debug(f"Using task: {task.__class__.__name__}, extractor: {extractor.__class__.__name__}")
|
|
338
|
+
|
|
339
|
+
texts = []
|
|
340
|
+
labels = []
|
|
341
|
+
success_count = 0
|
|
342
|
+
fail_count = 0
|
|
343
|
+
|
|
344
|
+
for i, sample in enumerate(samples):
|
|
345
|
+
try:
|
|
346
|
+
# Extract QA pair
|
|
347
|
+
contrastive_pair = extractor.extract_contrastive_pair(sample, task)
|
|
348
|
+
|
|
349
|
+
# Skip samples where contrastive pair extraction failed
|
|
350
|
+
if not contrastive_pair:
|
|
351
|
+
self.logger.debug(f"Sample {i + 1}: No contrastive pair extracted from keys: {list(sample.keys())}")
|
|
352
|
+
fail_count += 1
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
success_count += 1
|
|
356
|
+
self.logger.debug(f"Sample {i + 1}: Successfully extracted contrastive pair")
|
|
357
|
+
|
|
358
|
+
except Exception as e:
|
|
359
|
+
self.logger.error(f"Sample {i + 1}: Exception during contrastive pair extraction: {e}")
|
|
360
|
+
fail_count += 1
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
question = contrastive_pair["question"]
|
|
364
|
+
correct_answer = contrastive_pair["correct_answer"]
|
|
365
|
+
incorrect_answer = contrastive_pair["incorrect_answer"]
|
|
366
|
+
|
|
367
|
+
# Log contrastive pair details
|
|
368
|
+
self.logger.debug(f"Contrastive pair - Question: ...{question[-50:]}")
|
|
369
|
+
self.logger.debug(f"Contrastive pair - Correct: {correct_answer}, Incorrect: {incorrect_answer}")
|
|
370
|
+
|
|
371
|
+
correct_text = f"{question} {correct_answer}"
|
|
372
|
+
texts.append(correct_text)
|
|
373
|
+
labels.append(1)
|
|
374
|
+
|
|
375
|
+
incorrect_text = f"{question} {incorrect_answer}"
|
|
376
|
+
texts.append(incorrect_text)
|
|
377
|
+
labels.append(0)
|
|
378
|
+
|
|
379
|
+
self.logger.info(
|
|
380
|
+
f"Probe data creation: {success_count} successful, {fail_count} failed. Generated {len(texts)} texts."
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if len(texts) == 0:
|
|
384
|
+
self.logger.error("No texts generated for activation extraction! All contrastive pair extractions failed.")
|
|
385
|
+
return np.array([]), np.array([])
|
|
386
|
+
|
|
387
|
+
activations = data_utils.extract_activations_with_hook(
|
|
388
|
+
self.model, self.tokenizer, texts, layer_id, self.config.batch_size, self.config.max_length, self.device
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
return activations, np.array(labels)
|
|
392
|
+
|
|
393
|
+
def _create_optuna_study(self) -> optuna.Study:
|
|
394
|
+
"""Create Optuna study with SQLite persistence and specified sampler/pruner."""
|
|
395
|
+
self.logger.info("📋 Creating Optuna study with SQLite persistence...")
|
|
396
|
+
self.logger.info(f"Database: {self.config.db_url}")
|
|
397
|
+
self.logger.info(f"Study name: {self.config.study_name}")
|
|
398
|
+
self.logger.info(f"🎲 Warmup: {self.config.n_startup_trials} random trials before TPE sampling")
|
|
399
|
+
|
|
400
|
+
# Setup sampler
|
|
401
|
+
if self.config.sampler == "TPE":
|
|
402
|
+
sampler = TPESampler(seed=self.config.seed, n_startup_trials=self.config.n_startup_trials)
|
|
403
|
+
elif self.config.sampler == "Random":
|
|
404
|
+
sampler = optuna.samplers.RandomSampler(seed=self.config.seed)
|
|
405
|
+
else:
|
|
406
|
+
sampler = TPESampler(seed=self.config.seed, n_startup_trials=self.config.n_startup_trials)
|
|
407
|
+
|
|
408
|
+
# Setup pruner
|
|
409
|
+
if self.config.pruner == "MedianPruner":
|
|
410
|
+
pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=10)
|
|
411
|
+
elif self.config.pruner == "SuccessiveHalvingPruner":
|
|
412
|
+
pruner = SuccessiveHalvingPruner()
|
|
413
|
+
else:
|
|
414
|
+
pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=10)
|
|
415
|
+
|
|
416
|
+
# Create study with SQLite storage
|
|
417
|
+
study = optuna.create_study(
|
|
418
|
+
study_name=self.config.study_name,
|
|
419
|
+
storage=self.config.db_url,
|
|
420
|
+
direction="maximize", # Maximize validation accuracy
|
|
421
|
+
sampler=sampler,
|
|
422
|
+
pruner=pruner,
|
|
423
|
+
load_if_exists=True, # Continue existing study if it exists
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
self.logger.info(f"Study created/loaded with {len(study.trials)} existing trials")
|
|
427
|
+
|
|
428
|
+
return study
|
|
429
|
+
|
|
430
|
+
def _save_steering_vector_dual_format(self, steering_instance, pt_path: Path, safetensors_path: Path) -> bool:
|
|
431
|
+
"""Save steering vector in both .pt and safetensors formats."""
|
|
432
|
+
# Save in original .pt format first (preserves all metadata)
|
|
433
|
+
if not steering_instance.save_steering_vector(str(pt_path)):
|
|
434
|
+
self.logger.warning("Failed to save steering vector - method may not be trained")
|
|
435
|
+
return False
|
|
436
|
+
|
|
437
|
+
self.logger.info(f"💾 Saved best steering vector to: {pt_path.name}")
|
|
438
|
+
|
|
439
|
+
# Also save in safetensors format for HuggingFace compatibility
|
|
440
|
+
try:
|
|
441
|
+
# Load the .pt file and extract steering vector
|
|
442
|
+
data = torch.load(str(pt_path), map_location="cpu", weights_only=False)
|
|
443
|
+
if isinstance(data, dict) and "steering_vector" in data:
|
|
444
|
+
# Save just the steering vector in safetensors format
|
|
445
|
+
safetensors_save({"steering_vector": data["steering_vector"]}, str(safetensors_path))
|
|
446
|
+
self.logger.info(f"💾 Also saved as safetensors: {safetensors_path.name}")
|
|
447
|
+
return True
|
|
448
|
+
self.logger.warning("Unexpected .pt file structure, safetensors conversion skipped")
|
|
449
|
+
return True # .pt save was successful
|
|
450
|
+
except Exception as e:
|
|
451
|
+
self.logger.warning(f"Could not create safetensors version: {e}")
|
|
452
|
+
return True # .pt save was successful
|
|
453
|
+
|
|
454
|
+
def _objective_function(self, trial: optuna.Trial) -> float:
|
|
455
|
+
"""Optuna objective function for hyperparameter optimization."""
|
|
456
|
+
try:
|
|
457
|
+
# Sample hyperparameters
|
|
458
|
+
layer_id = trial.suggest_int(
|
|
459
|
+
"layer_id", self.config.layer_search_range[0], self.config.layer_search_range[1]
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# Fixed probe type and regularization
|
|
463
|
+
probe_type = self.config.probe_type # Always logistic_regression
|
|
464
|
+
probe_c = 1.0 # Default regularization strength
|
|
465
|
+
|
|
466
|
+
steering_method = trial.suggest_categorical("steering_method", self.config.steering_methods)
|
|
467
|
+
|
|
468
|
+
if steering_method == "dac":
|
|
469
|
+
steering_alpha = trial.suggest_float("steering_alpha", 0.1, 5.0)
|
|
470
|
+
entropy_threshold = trial.suggest_float("entropy_threshold", 0.5, 2.0)
|
|
471
|
+
ptop = trial.suggest_float("ptop", 0.2, 0.8)
|
|
472
|
+
max_alpha = trial.suggest_float("max_alpha", 1.0, 5.0)
|
|
473
|
+
elif steering_method == "caa":
|
|
474
|
+
steering_alpha = trial.suggest_float("steering_alpha", 0.1, 5.0)
|
|
475
|
+
|
|
476
|
+
probe_score = self._train_and_evaluate_probe(trial, layer_id, probe_type, probe_c)
|
|
477
|
+
|
|
478
|
+
# Don't prune based on probe score - focus optimization on steering parameters
|
|
479
|
+
|
|
480
|
+
# Build clean hyperparameters dictionary
|
|
481
|
+
if steering_method == "dac":
|
|
482
|
+
hyperparams = {
|
|
483
|
+
"steering_alpha": steering_alpha,
|
|
484
|
+
"entropy_threshold": entropy_threshold,
|
|
485
|
+
"ptop": ptop,
|
|
486
|
+
"max_alpha": max_alpha,
|
|
487
|
+
}
|
|
488
|
+
elif steering_method == "caa":
|
|
489
|
+
hyperparams = {
|
|
490
|
+
"steering_alpha": steering_alpha,
|
|
491
|
+
}
|
|
492
|
+
else:
|
|
493
|
+
raise ValueError(f"Unsupported steering method: {steering_method}")
|
|
494
|
+
|
|
495
|
+
steering_method_instance = self._train_steering_method(trial, steering_method, layer_id, hyperparams)
|
|
496
|
+
|
|
497
|
+
validation_accuracy = self._evaluate_steering_on_validation(
|
|
498
|
+
steering_method_instance, steering_method, layer_id, hyperparams, trial.number, trial
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
trial.report(validation_accuracy, step=1)
|
|
502
|
+
|
|
503
|
+
# Log to WandB
|
|
504
|
+
metrics = {"validation_accuracy": validation_accuracy, "probe_score": probe_score}
|
|
505
|
+
self._log_trial_to_wandb(trial, metrics)
|
|
506
|
+
|
|
507
|
+
return validation_accuracy
|
|
508
|
+
|
|
509
|
+
except Exception as e:
|
|
510
|
+
self.logger.error(f"Trial failed: {e}")
|
|
511
|
+
return 0.0
|
|
512
|
+
|
|
513
|
+
def _train_and_evaluate_probe(self, trial: optuna.Trial, layer_id: int, probe_type: str, probe_c: float) -> float:
|
|
514
|
+
"""Train probe on training data and evaluate on validation data using cached activations."""
|
|
515
|
+
# Load cached training activations
|
|
516
|
+
X_train, y_train = self.cache.load_activations("train", layer_id, self.tokenization_config)
|
|
517
|
+
|
|
518
|
+
# Train probe
|
|
519
|
+
if probe_type == "logistic_regression":
|
|
520
|
+
from sklearn.linear_model import LogisticRegression
|
|
521
|
+
|
|
522
|
+
probe = LogisticRegression(C=probe_c, random_state=self.config.seed, max_iter=1000)
|
|
523
|
+
probe.fit(X_train, y_train)
|
|
524
|
+
else:
|
|
525
|
+
raise ValueError(f"Unsupported probe type: {probe_type}")
|
|
526
|
+
|
|
527
|
+
# Evaluate on validation data using cached activations
|
|
528
|
+
X_val, y_val = self.cache.load_activations("val", layer_id, self.tokenization_config)
|
|
529
|
+
|
|
530
|
+
from sklearn.metrics import roc_auc_score
|
|
531
|
+
|
|
532
|
+
y_pred_proba = probe.predict_proba(X_val)[:, 1]
|
|
533
|
+
return roc_auc_score(y_val, y_pred_proba) if len(np.unique(y_val)) > 1 else 0.5
|
|
534
|
+
|
|
535
|
+
# Don't store the probe object - it can't be JSON serialized
|
|
536
|
+
# The probe will be retrained in the final evaluation if needed
|
|
537
|
+
|
|
538
|
+
def _train_steering_method(
|
|
539
|
+
self, trial: optuna.Trial, method_name: str, layer_id: int, hyperparams: dict[str, Any]
|
|
540
|
+
) -> Any:
|
|
541
|
+
"""Train steering method on training data."""
|
|
542
|
+
# Use contrastive_pairs_limit with bounds checking
|
|
543
|
+
contrastive_limit = min(self.config.contrastive_pairs_limit, len(self.train_samples))
|
|
544
|
+
contrastive_pairs = self._create_contrastive_pairs(
|
|
545
|
+
self.train_samples, layer_id, self.config.train_dataset, limit=contrastive_limit
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
if method_name == "dac":
|
|
549
|
+
# Create DAC instance
|
|
550
|
+
dac = DAC(
|
|
551
|
+
entropy_threshold=hyperparams["entropy_threshold"],
|
|
552
|
+
ptop=hyperparams["ptop"],
|
|
553
|
+
max_alpha=hyperparams["max_alpha"],
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# Train DAC
|
|
557
|
+
dac.train(contrastive_pairs, layer_id)
|
|
558
|
+
return dac
|
|
559
|
+
|
|
560
|
+
if method_name == "caa":
|
|
561
|
+
# Create CAA instance
|
|
562
|
+
from wisent_guard.core.steering_methods.caa import CAA
|
|
563
|
+
|
|
564
|
+
caa = CAA(device=self.device)
|
|
565
|
+
|
|
566
|
+
# Train CAA
|
|
567
|
+
caa.train(contrastive_pairs, layer_id)
|
|
568
|
+
return caa
|
|
569
|
+
|
|
570
|
+
raise ValueError(f"Unsupported steering method: {method_name}")
|
|
571
|
+
|
|
572
|
+
def _create_contrastive_pairs(
|
|
573
|
+
self, samples: list[dict], layer_id: int, dataset_name: str, limit: Optional[int] = None
|
|
574
|
+
) -> ContrastivePairSet:
|
|
575
|
+
"""Create contrastive pairs with activations for steering training."""
|
|
576
|
+
contrastive_pairs = []
|
|
577
|
+
task = get_task(dataset_name)
|
|
578
|
+
extractor = task.get_extractor()
|
|
579
|
+
|
|
580
|
+
samples_to_use = samples[:limit] if limit else samples
|
|
581
|
+
|
|
582
|
+
for sample in samples_to_use:
|
|
583
|
+
contrastive_pair = extractor.extract_contrastive_pair(sample, task)
|
|
584
|
+
if contrastive_pair:
|
|
585
|
+
# Log contrastive pair details
|
|
586
|
+
self.logger.debug(f"Creating contrastive pair - Question: ...{contrastive_pair['question'][-50:]}")
|
|
587
|
+
self.logger.debug(
|
|
588
|
+
f"Creating contrastive pair - Correct: {contrastive_pair['correct_answer']}, Incorrect: {contrastive_pair['incorrect_answer']}"
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
positive_response = Response(text=contrastive_pair["correct_answer"], label=1)
|
|
592
|
+
negative_response = Response(text=contrastive_pair["incorrect_answer"], label=0)
|
|
593
|
+
|
|
594
|
+
pair = ContrastivePair(
|
|
595
|
+
prompt=contrastive_pair["question"],
|
|
596
|
+
positive_response=positive_response,
|
|
597
|
+
negative_response=negative_response,
|
|
598
|
+
)
|
|
599
|
+
contrastive_pairs.append(pair)
|
|
600
|
+
|
|
601
|
+
pair_set = ContrastivePairSet(name=f"{dataset_name}_training", pairs=contrastive_pairs)
|
|
602
|
+
|
|
603
|
+
# Extract activations for all pairs in batches
|
|
604
|
+
if pair_set.pairs:
|
|
605
|
+
all_texts = []
|
|
606
|
+
text_to_pair_mapping = []
|
|
607
|
+
|
|
608
|
+
for pair_idx, pair in enumerate(pair_set.pairs):
|
|
609
|
+
pos_text = f"{pair.prompt} {pair.positive_response.text}"
|
|
610
|
+
neg_text = f"{pair.prompt} {pair.negative_response.text}"
|
|
611
|
+
|
|
612
|
+
all_texts.extend([pos_text, neg_text])
|
|
613
|
+
text_to_pair_mapping.extend([(pair_idx, "positive"), (pair_idx, "negative")])
|
|
614
|
+
|
|
615
|
+
all_activations = self._extract_batch_activations(all_texts, layer_id)
|
|
616
|
+
|
|
617
|
+
for text_idx, (pair_idx, response_type) in enumerate(text_to_pair_mapping):
|
|
618
|
+
activation = all_activations[text_idx]
|
|
619
|
+
|
|
620
|
+
if response_type == "positive":
|
|
621
|
+
pair_set.pairs[pair_idx].positive_response.activations = activation
|
|
622
|
+
else:
|
|
623
|
+
pair_set.pairs[pair_idx].negative_response.activations = activation
|
|
624
|
+
|
|
625
|
+
return pair_set
|
|
626
|
+
|
|
627
|
+
def _extract_batch_activations(self, texts: list[str], layer_id: int) -> list[torch.Tensor]:
|
|
628
|
+
"""Extract activations for multiple texts in batches."""
|
|
629
|
+
if not texts:
|
|
630
|
+
return []
|
|
631
|
+
|
|
632
|
+
all_activations = []
|
|
633
|
+
batch_size = self.config.batch_size
|
|
634
|
+
|
|
635
|
+
for i in range(0, len(texts), batch_size):
|
|
636
|
+
batch_texts = texts[i : i + batch_size]
|
|
637
|
+
|
|
638
|
+
inputs = self.tokenizer(
|
|
639
|
+
batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=self.config.max_length
|
|
640
|
+
).to(self.device)
|
|
641
|
+
|
|
642
|
+
batch_activations = []
|
|
643
|
+
|
|
644
|
+
def batch_hook_fn(module, input, output):
|
|
645
|
+
with torch.no_grad():
|
|
646
|
+
hidden_states = output[0] if isinstance(output, tuple) else output
|
|
647
|
+
|
|
648
|
+
last_token_acts = hidden_states[:, -1, :].detach().clone()
|
|
649
|
+
batch_activations.append(last_token_acts)
|
|
650
|
+
|
|
651
|
+
if hasattr(self.model, "transformer"):
|
|
652
|
+
target_layer = self.model.transformer.h[layer_id]
|
|
653
|
+
elif hasattr(self.model, "model"):
|
|
654
|
+
target_layer = self.model.model.layers[layer_id]
|
|
655
|
+
else:
|
|
656
|
+
raise ValueError("Unknown model architecture")
|
|
657
|
+
|
|
658
|
+
handle = target_layer.register_forward_hook(batch_hook_fn)
|
|
659
|
+
|
|
660
|
+
try:
|
|
661
|
+
with torch.no_grad():
|
|
662
|
+
_ = self.model(**inputs)
|
|
663
|
+
finally:
|
|
664
|
+
handle.remove()
|
|
665
|
+
empty_device_cache(self.device.type)
|
|
666
|
+
|
|
667
|
+
if batch_activations:
|
|
668
|
+
batch_tensor = batch_activations[0]
|
|
669
|
+
for j in range(batch_tensor.shape[0]):
|
|
670
|
+
all_activations.append(batch_tensor[j].unsqueeze(0))
|
|
671
|
+
|
|
672
|
+
return all_activations
|
|
673
|
+
|
|
674
|
+
def _extract_single_activation(self, text: str, layer_id: int) -> torch.Tensor:
|
|
675
|
+
"""Extract activation for a single text."""
|
|
676
|
+
activations = self._extract_batch_activations([text], layer_id)
|
|
677
|
+
return activations[0] if activations else torch.zeros(1, self.model.config.hidden_size, device=self.device)
|
|
678
|
+
|
|
679
|
+
def _evaluate_steering_on_validation(
|
|
680
|
+
self,
|
|
681
|
+
steering_instance: Any,
|
|
682
|
+
method_name: str,
|
|
683
|
+
layer_id: int,
|
|
684
|
+
hyperparams: dict[str, Any],
|
|
685
|
+
trial_number: int = 0,
|
|
686
|
+
trial=None,
|
|
687
|
+
) -> float:
|
|
688
|
+
"""Evaluate steering method on validation data by re-running forward passes."""
|
|
689
|
+
if steering_instance is None:
|
|
690
|
+
return 0.0
|
|
691
|
+
|
|
692
|
+
# Generate predictions with steering applied
|
|
693
|
+
predictions = []
|
|
694
|
+
ground_truths = []
|
|
695
|
+
task_docs = [] # Preserve original task documents for BigCode evaluation
|
|
696
|
+
|
|
697
|
+
task = get_task(self.config.val_dataset)
|
|
698
|
+
extractor = task.get_extractor()
|
|
699
|
+
|
|
700
|
+
# Collect all questions for batched processing (use ALL validation samples)
|
|
701
|
+
questions = []
|
|
702
|
+
ground_truths = []
|
|
703
|
+
valid_samples = [] # Keep track of samples that produce valid QA pairs
|
|
704
|
+
|
|
705
|
+
for sample in tqdm(
|
|
706
|
+
self.val_samples, desc="Extracting validation QA pairs", leave=False
|
|
707
|
+
): # Use all validation samples for reliable evaluation
|
|
708
|
+
qa_pair = extractor.extract_qa_pair(sample, task)
|
|
709
|
+
if not qa_pair:
|
|
710
|
+
continue
|
|
711
|
+
|
|
712
|
+
question = qa_pair["formatted_question"]
|
|
713
|
+
ground_truth = qa_pair["correct_answer"]
|
|
714
|
+
questions.append(question)
|
|
715
|
+
ground_truths.append(ground_truth)
|
|
716
|
+
valid_samples.append(sample) # Store the original sample
|
|
717
|
+
|
|
718
|
+
# Generate predictions using batched approach
|
|
719
|
+
if questions:
|
|
720
|
+
if steering_instance is None:
|
|
721
|
+
predictions = self._generate_baseline_batched(questions)
|
|
722
|
+
else:
|
|
723
|
+
# Extract the appropriate strength parameter based on method
|
|
724
|
+
if method_name == "dac":
|
|
725
|
+
# DAC uses steering_alpha as base strength multiplier
|
|
726
|
+
strength = hyperparams.get("steering_alpha", 1.0)
|
|
727
|
+
else:
|
|
728
|
+
# CAA and other methods use steering_alpha directly
|
|
729
|
+
strength = hyperparams["steering_alpha"]
|
|
730
|
+
|
|
731
|
+
predictions = self._generate_with_steering_batched(steering_instance, questions, strength, layer_id)
|
|
732
|
+
|
|
733
|
+
# Log sample predictions for debugging
|
|
734
|
+
for i, (pred, gt) in enumerate(zip(predictions[:3], ground_truths[:3])):
|
|
735
|
+
self.logger.debug(f"{method_name.upper()} Sample {i} - Model: ...{pred[-50:] if pred else 'None'}")
|
|
736
|
+
self.logger.debug(f"{method_name.upper()} Sample {i} - Ground truth: {gt}")
|
|
737
|
+
else:
|
|
738
|
+
predictions = []
|
|
739
|
+
|
|
740
|
+
if not predictions:
|
|
741
|
+
return 0.0
|
|
742
|
+
|
|
743
|
+
# Save detailed validation results to JSON
|
|
744
|
+
self._save_detailed_validation_results(
|
|
745
|
+
questions,
|
|
746
|
+
ground_truths,
|
|
747
|
+
predictions,
|
|
748
|
+
trial_number,
|
|
749
|
+
trial=trial,
|
|
750
|
+
steering_method=method_name,
|
|
751
|
+
layer_id=layer_id,
|
|
752
|
+
hyperparams=hyperparams,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Prepare task docs for BigCode evaluation (if coding task)
|
|
756
|
+
task_docs = valid_samples[: len(predictions)] if valid_samples else []
|
|
757
|
+
|
|
758
|
+
# Evaluate benchmark performance (with task docs for coding tasks)
|
|
759
|
+
benchmark_metrics = metrics.evaluate_benchmark_performance(
|
|
760
|
+
predictions, ground_truths, self.config.val_dataset, task_docs=task_docs
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
return benchmark_metrics.get("accuracy", 0.0)
|
|
764
|
+
|
|
765
|
+
def _generate_with_dac_steering(self, dac: DAC, question: str, alpha: float, layer_id: int) -> str:
|
|
766
|
+
"""Generate response with DAC steering applied."""
|
|
767
|
+
# Use the general steering method which calls DAC's apply_steering
|
|
768
|
+
return self._generate_with_steering(dac, question, alpha, layer_id)
|
|
769
|
+
|
|
770
|
+
def _generate_with_caa_steering(self, caa, question: str, alpha: float, layer_id: int) -> str:
|
|
771
|
+
"""Generate response with CAA steering applied."""
|
|
772
|
+
if not hasattr(caa, "steering_vector") or caa.steering_vector is None:
|
|
773
|
+
return self._generate_baseline(question)
|
|
774
|
+
|
|
775
|
+
return self._generate_with_steering_hook(question, caa.steering_vector, layer_id, alpha)
|
|
776
|
+
|
|
777
|
+
def _generate_with_steering_hook(
|
|
778
|
+
self, question: str, steering_vector: torch.Tensor, layer_id: int, alpha: float
|
|
779
|
+
) -> str:
|
|
780
|
+
"""Generate response with steering vector applied via hook (re-runs forward pass)."""
|
|
781
|
+
inputs = self.tokenizer(question, return_tensors="pt").to(self.device)
|
|
782
|
+
|
|
783
|
+
def steering_hook(module, input, output):
|
|
784
|
+
"""Hook that applies steering vector during forward pass."""
|
|
785
|
+
if isinstance(output, tuple):
|
|
786
|
+
hidden_states = output[0]
|
|
787
|
+
# Apply steering to the last token
|
|
788
|
+
hidden_states[:, -1, :] += alpha * steering_vector.to(hidden_states.device)
|
|
789
|
+
return (hidden_states, *output[1:])
|
|
790
|
+
hidden_states = output
|
|
791
|
+
hidden_states[:, -1, :] += alpha * steering_vector.to(hidden_states.device)
|
|
792
|
+
return hidden_states
|
|
793
|
+
|
|
794
|
+
# Register hook on target layer
|
|
795
|
+
if hasattr(self.model, "transformer"):
|
|
796
|
+
target_layer = self.model.transformer.h[layer_id]
|
|
797
|
+
elif hasattr(self.model, "model"):
|
|
798
|
+
target_layer = self.model.model.layers[layer_id]
|
|
799
|
+
else:
|
|
800
|
+
raise ValueError("Unknown model architecture")
|
|
801
|
+
|
|
802
|
+
handle = target_layer.register_forward_hook(steering_hook)
|
|
803
|
+
|
|
804
|
+
try:
|
|
805
|
+
with torch.no_grad():
|
|
806
|
+
outputs = self.model.generate(
|
|
807
|
+
**inputs,
|
|
808
|
+
max_new_tokens=self.config.max_new_tokens,
|
|
809
|
+
do_sample=self.config.do_sample,
|
|
810
|
+
temperature=self.config.temperature if self.config.do_sample else 1.0,
|
|
811
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
812
|
+
eos_token_id=self.tokenizer.eos_token_id,
|
|
813
|
+
)
|
|
814
|
+
finally:
|
|
815
|
+
handle.remove()
|
|
816
|
+
|
|
817
|
+
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True)
|
|
818
|
+
return response.strip()
|
|
819
|
+
|
|
820
|
+
def _generate_baseline(self, question: str) -> str:
|
|
821
|
+
"""Generate baseline response without steering."""
|
|
822
|
+
inputs = self.tokenizer(question, return_tensors="pt").to(self.device)
|
|
823
|
+
|
|
824
|
+
with torch.no_grad():
|
|
825
|
+
outputs = self.model.generate(
|
|
826
|
+
**inputs,
|
|
827
|
+
max_new_tokens=self.config.max_new_tokens,
|
|
828
|
+
do_sample=self.config.do_sample,
|
|
829
|
+
temperature=self.config.temperature if self.config.do_sample else 1.0,
|
|
830
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
831
|
+
eos_token_id=self.tokenizer.eos_token_id,
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True)
|
|
835
|
+
return response.strip()
|
|
836
|
+
|
|
837
|
+
def _generate_baseline_batched(self, questions: list[str]) -> list[str]: # TODO
|
|
838
|
+
"""Generate baseline responses in batches without steering."""
|
|
839
|
+
if not questions:
|
|
840
|
+
return []
|
|
841
|
+
|
|
842
|
+
batch_size = self.config.batch_size
|
|
843
|
+
all_responses = []
|
|
844
|
+
|
|
845
|
+
# Process questions in batches
|
|
846
|
+
for i in tqdm(range(0, len(questions), batch_size), desc="Generating baseline predictions", leave=False):
|
|
847
|
+
batch_questions = questions[i : i + batch_size]
|
|
848
|
+
|
|
849
|
+
# Batch tokenization with padding
|
|
850
|
+
inputs = self.tokenizer(
|
|
851
|
+
batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=self.config.max_length
|
|
852
|
+
).to(self.device)
|
|
853
|
+
|
|
854
|
+
with torch.no_grad():
|
|
855
|
+
outputs = self.model.generate(
|
|
856
|
+
**inputs,
|
|
857
|
+
max_new_tokens=self.config.max_new_tokens,
|
|
858
|
+
do_sample=self.config.do_sample,
|
|
859
|
+
temperature=self.config.temperature if self.config.do_sample else 1.0,
|
|
860
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
861
|
+
eos_token_id=self.tokenizer.eos_token_id,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# Decode responses
|
|
865
|
+
batch_responses = []
|
|
866
|
+
for j, output in enumerate(outputs):
|
|
867
|
+
input_length = inputs.input_ids[j].shape[0]
|
|
868
|
+
response = self.tokenizer.decode(output[input_length:], skip_special_tokens=True)
|
|
869
|
+
batch_responses.append(response.strip())
|
|
870
|
+
|
|
871
|
+
all_responses.extend(batch_responses)
|
|
872
|
+
|
|
873
|
+
return all_responses
|
|
874
|
+
|
|
875
|
+
def _generate_with_steering_batched(
|
|
876
|
+
self, steering_instance: Any, questions: list[str], alpha: float, layer_id: int
|
|
877
|
+
) -> list[str]:
|
|
878
|
+
"""Generate responses with steering applied in batches using apply_steering()."""
|
|
879
|
+
if not questions:
|
|
880
|
+
return []
|
|
881
|
+
|
|
882
|
+
batch_size = self.config.batch_size
|
|
883
|
+
all_responses = []
|
|
884
|
+
|
|
885
|
+
# Process questions in batches
|
|
886
|
+
for i in tqdm(range(0, len(questions), batch_size), desc="Generating steered predictions", leave=False):
|
|
887
|
+
batch_questions = questions[i : i + batch_size]
|
|
888
|
+
|
|
889
|
+
# Batch tokenization with padding
|
|
890
|
+
inputs = self.tokenizer(
|
|
891
|
+
batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=self.config.max_length
|
|
892
|
+
).to(self.device)
|
|
893
|
+
|
|
894
|
+
def steering_hook(module, input, output):
|
|
895
|
+
"""Hook that applies steering using the steering method's apply_steering()."""
|
|
896
|
+
hidden_states = output[0] if isinstance(output, tuple) else output
|
|
897
|
+
|
|
898
|
+
# Apply steering using the method's apply_steering() function
|
|
899
|
+
steered = steering_instance.apply_steering(hidden_states, strength=alpha)
|
|
900
|
+
|
|
901
|
+
if isinstance(output, tuple):
|
|
902
|
+
return (steered, *output[1:])
|
|
903
|
+
return steered
|
|
904
|
+
|
|
905
|
+
# Register hook on target layer
|
|
906
|
+
if hasattr(self.model, "transformer"):
|
|
907
|
+
if layer_id >= len(self.model.transformer.h):
|
|
908
|
+
raise ValueError(f"layer_id {layer_id} exceeds model layers")
|
|
909
|
+
target_layer = self.model.transformer.h[layer_id]
|
|
910
|
+
elif hasattr(self.model, "model"):
|
|
911
|
+
if layer_id >= len(self.model.model.layers):
|
|
912
|
+
raise ValueError(f"layer_id {layer_id} exceeds model layers")
|
|
913
|
+
target_layer = self.model.model.layers[layer_id]
|
|
914
|
+
else:
|
|
915
|
+
raise ValueError("Unknown model architecture")
|
|
916
|
+
|
|
917
|
+
handle = target_layer.register_forward_hook(steering_hook)
|
|
918
|
+
|
|
919
|
+
try:
|
|
920
|
+
with torch.no_grad():
|
|
921
|
+
outputs = self.model.generate(
|
|
922
|
+
**inputs,
|
|
923
|
+
max_new_tokens=self.config.max_new_tokens,
|
|
924
|
+
do_sample=self.config.do_sample,
|
|
925
|
+
temperature=self.config.temperature if self.config.do_sample else 1.0,
|
|
926
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
927
|
+
eos_token_id=self.tokenizer.eos_token_id,
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
# Decode responses
|
|
931
|
+
batch_responses = []
|
|
932
|
+
for j, output in enumerate(outputs):
|
|
933
|
+
input_length = inputs.input_ids[j].shape[0]
|
|
934
|
+
response = self.tokenizer.decode(output[input_length:], skip_special_tokens=True)
|
|
935
|
+
batch_responses.append(response.strip())
|
|
936
|
+
|
|
937
|
+
all_responses.extend(batch_responses)
|
|
938
|
+
|
|
939
|
+
finally:
|
|
940
|
+
# Always remove the hook
|
|
941
|
+
handle.remove()
|
|
942
|
+
|
|
943
|
+
return all_responses
|
|
944
|
+
|
|
945
|
+
def _final_evaluation(self, best_trial: optuna.Trial) -> dict[str, Any]:
|
|
946
|
+
"""Run final evaluation on test split with best configuration."""
|
|
947
|
+
self.logger.info("🏆 Running final evaluation with best configuration...")
|
|
948
|
+
|
|
949
|
+
# Extract best hyperparameters
|
|
950
|
+
# Handle both real trials and FixedTrial objects
|
|
951
|
+
if hasattr(best_trial, "params") and best_trial.params:
|
|
952
|
+
best_params = best_trial.params
|
|
953
|
+
elif hasattr(best_trial, "_params"):
|
|
954
|
+
best_params = best_trial._params
|
|
955
|
+
else:
|
|
956
|
+
# Fallback - this shouldn't happen
|
|
957
|
+
raise ValueError("Cannot access trial parameters")
|
|
958
|
+
layer_id = best_params["layer_id"]
|
|
959
|
+
|
|
960
|
+
self.logger.info(f"Best configuration: {best_params}")
|
|
961
|
+
|
|
962
|
+
# Re-train best probe and steering method on training data
|
|
963
|
+
from sklearn.linear_model import LogisticRegression
|
|
964
|
+
|
|
965
|
+
# Train best probe with fixed probe_c
|
|
966
|
+
X_train, y_train = self.cache.load_activations("train", layer_id, self.tokenization_config)
|
|
967
|
+
probe = LogisticRegression(C=1.0, random_state=self.config.seed, max_iter=1000) # Fixed probe_c
|
|
968
|
+
probe.fit(X_train, y_train)
|
|
969
|
+
|
|
970
|
+
# Train best steering method
|
|
971
|
+
steering_method = best_params.get("steering_method", "caa") # Default to CAA if missing
|
|
972
|
+
steering_instance = self._train_steering_method(best_trial, steering_method, layer_id, best_params)
|
|
973
|
+
|
|
974
|
+
# Save the best steering vector in both formats
|
|
975
|
+
if steering_instance and hasattr(steering_instance, "save_steering_vector"):
|
|
976
|
+
pt_path = self.run_dir / "best_steering_vector.pt"
|
|
977
|
+
safetensors_path = self.run_dir / "best_steering_vector.safetensors"
|
|
978
|
+
self._save_steering_vector_dual_format(steering_instance, pt_path, safetensors_path)
|
|
979
|
+
|
|
980
|
+
# Generate baseline predictions (no steering)
|
|
981
|
+
self.logger.info("Generating baseline predictions...")
|
|
982
|
+
baseline_predictions, test_ground_truths, test_questions, test_task_docs = self._generate_test_predictions(
|
|
983
|
+
None, None, layer_id, 0.0
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
# Generate steered predictions
|
|
987
|
+
self.logger.info("Generating steered predictions...")
|
|
988
|
+
|
|
989
|
+
# Extract the appropriate strength parameter based on method and available parameters
|
|
990
|
+
method_name = best_params.get("steering_method", "caa") # Default to CAA if missing
|
|
991
|
+
if method_name == "dac":
|
|
992
|
+
# DAC can use base_strength or steering_alpha, with fallback to 1.0
|
|
993
|
+
strength = best_params.get("base_strength", best_params.get("steering_alpha", 1.0))
|
|
994
|
+
else:
|
|
995
|
+
# CAA and other methods use steering_alpha
|
|
996
|
+
strength = best_params["steering_alpha"]
|
|
997
|
+
|
|
998
|
+
steered_predictions, _, _, _ = self._generate_test_predictions(
|
|
999
|
+
steering_instance, method_name, layer_id, strength
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
# Save detailed test results to JSON
|
|
1003
|
+
if test_questions and test_ground_truths and baseline_predictions and steered_predictions:
|
|
1004
|
+
self._save_detailed_test_results(
|
|
1005
|
+
test_questions,
|
|
1006
|
+
test_ground_truths,
|
|
1007
|
+
baseline_predictions,
|
|
1008
|
+
steered_predictions,
|
|
1009
|
+
best_trial=best_trial,
|
|
1010
|
+
best_params=best_params,
|
|
1011
|
+
layer_id=layer_id,
|
|
1012
|
+
steering_method=method_name,
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
# Calculate benchmark metrics (with real task docs for coding tasks)
|
|
1016
|
+
baseline_benchmark_metrics = metrics.evaluate_benchmark_performance(
|
|
1017
|
+
baseline_predictions, test_ground_truths, self.config.test_dataset, task_docs=test_task_docs
|
|
1018
|
+
)
|
|
1019
|
+
steered_benchmark_metrics = metrics.evaluate_benchmark_performance(
|
|
1020
|
+
steered_predictions, test_ground_truths, self.config.test_dataset, task_docs=test_task_docs
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
# Evaluate probe on test data
|
|
1024
|
+
X_test, y_test = self.cache.load_activations("test", layer_id, self.tokenization_config)
|
|
1025
|
+
test_probe_metrics = self._evaluate_probe_metrics(probe, X_test, y_test)
|
|
1026
|
+
|
|
1027
|
+
# Calculate improvement
|
|
1028
|
+
accuracy_improvement = steered_benchmark_metrics.get("accuracy", 0.0) - baseline_benchmark_metrics.get(
|
|
1029
|
+
"accuracy", 0.0
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
final_results = {
|
|
1033
|
+
"best_trial_params": best_params,
|
|
1034
|
+
"best_validation_score": getattr(best_trial, "value", None),
|
|
1035
|
+
"baseline_benchmark_metrics": baseline_benchmark_metrics,
|
|
1036
|
+
"steered_benchmark_metrics": steered_benchmark_metrics,
|
|
1037
|
+
"accuracy_improvement": accuracy_improvement,
|
|
1038
|
+
"test_probe_metrics": test_probe_metrics,
|
|
1039
|
+
"config": self.config.to_dict(),
|
|
1040
|
+
"num_test_samples": len(test_ground_truths),
|
|
1041
|
+
}
|
|
1042
|
+
|
|
1043
|
+
# Log final results
|
|
1044
|
+
self.logger.info("=" * 60)
|
|
1045
|
+
self.logger.info("🏆 FINAL TEST RESULTS")
|
|
1046
|
+
self.logger.info("=" * 60)
|
|
1047
|
+
self.logger.info(f"Baseline accuracy: {baseline_benchmark_metrics.get('accuracy', 0.0):.4f}")
|
|
1048
|
+
self.logger.info(f"Steered accuracy: {steered_benchmark_metrics.get('accuracy', 0.0):.4f}")
|
|
1049
|
+
self.logger.info(f"Improvement: {accuracy_improvement:+.4f}")
|
|
1050
|
+
self.logger.info(f"Probe AUC: {test_probe_metrics.get('auc', 0.5):.4f}")
|
|
1051
|
+
self.logger.info(f"Test samples: {len(test_ground_truths)}")
|
|
1052
|
+
self.logger.info("=" * 60)
|
|
1053
|
+
|
|
1054
|
+
return final_results
|
|
1055
|
+
|
|
1056
|
+
def _generate_test_predictions(
|
|
1057
|
+
self, steering_instance: Any, method_name: str, layer_id: int, alpha: float
|
|
1058
|
+
) -> tuple[list[str], list[str], list[str], list[dict]]:
|
|
1059
|
+
"""Generate predictions on test data using batched generation."""
|
|
1060
|
+
# Collect all questions and ground truths for batching
|
|
1061
|
+
questions = []
|
|
1062
|
+
ground_truths = []
|
|
1063
|
+
valid_samples = [] # Keep track of samples that produce valid QA pairs
|
|
1064
|
+
|
|
1065
|
+
task = get_task(self.config.test_dataset)
|
|
1066
|
+
extractor = task.get_extractor()
|
|
1067
|
+
|
|
1068
|
+
for sample in self.test_samples:
|
|
1069
|
+
qa_pair = extractor.extract_qa_pair(sample, task)
|
|
1070
|
+
if not qa_pair:
|
|
1071
|
+
continue
|
|
1072
|
+
|
|
1073
|
+
question = qa_pair["formatted_question"]
|
|
1074
|
+
ground_truth = qa_pair["correct_answer"]
|
|
1075
|
+
questions.append(question)
|
|
1076
|
+
ground_truths.append(ground_truth)
|
|
1077
|
+
valid_samples.append(sample) # Store the original sample
|
|
1078
|
+
|
|
1079
|
+
# Process all questions with appropriate batched method
|
|
1080
|
+
if questions:
|
|
1081
|
+
try:
|
|
1082
|
+
if steering_instance is None:
|
|
1083
|
+
# Baseline generation - use batched method
|
|
1084
|
+
predictions = self._generate_baseline_batched(questions)
|
|
1085
|
+
else:
|
|
1086
|
+
# Use unified batched generation with apply_steering()
|
|
1087
|
+
predictions = self._generate_with_steering_batched(steering_instance, questions, alpha, layer_id)
|
|
1088
|
+
|
|
1089
|
+
# Log sample predictions for debugging
|
|
1090
|
+
for i, (pred, gt) in enumerate(zip(predictions[:3], ground_truths[:3])):
|
|
1091
|
+
self.logger.debug(f"Test Sample {i} - Model: ...{pred[-50:] if pred else 'None'}")
|
|
1092
|
+
self.logger.debug(f"Test Sample {i} - Ground truth: {gt}")
|
|
1093
|
+
|
|
1094
|
+
except Exception as e:
|
|
1095
|
+
self.logger.warning(f"Batched generation failed for test: {e}")
|
|
1096
|
+
predictions = ["Error"] * len(questions)
|
|
1097
|
+
else:
|
|
1098
|
+
predictions = []
|
|
1099
|
+
|
|
1100
|
+
return predictions, ground_truths, questions, valid_samples
|
|
1101
|
+
|
|
1102
|
+
def _evaluate_probe_metrics(self, probe, X_test: np.ndarray, y_test: np.ndarray) -> dict[str, float]:
|
|
1103
|
+
"""Evaluate probe metrics."""
|
|
1104
|
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
|
|
1105
|
+
|
|
1106
|
+
y_pred = probe.predict(X_test)
|
|
1107
|
+
y_pred_proba = probe.predict_proba(X_test)[:, 1]
|
|
1108
|
+
|
|
1109
|
+
return {
|
|
1110
|
+
"accuracy": accuracy_score(y_test, y_pred),
|
|
1111
|
+
"precision": precision_score(y_test, y_pred, zero_division=0),
|
|
1112
|
+
"recall": recall_score(y_test, y_pred, zero_division=0),
|
|
1113
|
+
"f1": f1_score(y_test, y_pred, zero_division=0),
|
|
1114
|
+
"auc": roc_auc_score(y_test, y_pred_proba) if len(np.unique(y_test)) > 1 else 0.5,
|
|
1115
|
+
}
|
|
1116
|
+
|
|
1117
|
+
def _create_experiment_metadata(
|
|
1118
|
+
self, trial=None, steering_method: str = None, layer_id: int = None, hyperparams: dict = None
|
|
1119
|
+
):
|
|
1120
|
+
"""Create comprehensive experiment metadata for detailed results."""
|
|
1121
|
+
import platform
|
|
1122
|
+
from datetime import datetime
|
|
1123
|
+
|
|
1124
|
+
metadata = {
|
|
1125
|
+
"trial_info": {
|
|
1126
|
+
"trial_number": trial.number if trial else None,
|
|
1127
|
+
"trial_params": dict(trial.params) if trial else {},
|
|
1128
|
+
"trial_state": str(getattr(trial, "state", "RUNNING")) if trial else None,
|
|
1129
|
+
},
|
|
1130
|
+
"model_config": {
|
|
1131
|
+
"model_name": self.config.model_name,
|
|
1132
|
+
"device": self.config.device,
|
|
1133
|
+
"is_coding_task": self.is_coding_task,
|
|
1134
|
+
},
|
|
1135
|
+
"dataset_config": {
|
|
1136
|
+
"train_dataset": self.config.train_dataset,
|
|
1137
|
+
"val_dataset": self.config.val_dataset,
|
|
1138
|
+
"test_dataset": self.config.test_dataset,
|
|
1139
|
+
"train_limit": self.config.train_limit,
|
|
1140
|
+
"val_limit": self.config.val_limit,
|
|
1141
|
+
"test_limit": self.config.test_limit,
|
|
1142
|
+
"contrastive_pairs_limit": self.config.contrastive_pairs_limit,
|
|
1143
|
+
},
|
|
1144
|
+
"steering_config": {
|
|
1145
|
+
"steering_method": steering_method,
|
|
1146
|
+
"layer_id": layer_id,
|
|
1147
|
+
"hyperparams": hyperparams or {},
|
|
1148
|
+
"layer_search_range": self.config.layer_search_range,
|
|
1149
|
+
"probe_type": self.config.probe_type,
|
|
1150
|
+
"available_steering_methods": self.config.steering_methods,
|
|
1151
|
+
},
|
|
1152
|
+
"optimization_config": {
|
|
1153
|
+
"study_name": self.config.study_name,
|
|
1154
|
+
"sampler": self.config.sampler,
|
|
1155
|
+
"pruner": self.config.pruner,
|
|
1156
|
+
"n_trials": self.config.n_trials,
|
|
1157
|
+
"n_startup_trials": self.config.n_startup_trials,
|
|
1158
|
+
},
|
|
1159
|
+
"generation_config": {
|
|
1160
|
+
"batch_size": self.config.batch_size,
|
|
1161
|
+
"max_length": self.config.max_length,
|
|
1162
|
+
"max_new_tokens": self.config.max_new_tokens,
|
|
1163
|
+
"temperature": self.config.temperature,
|
|
1164
|
+
"do_sample": self.config.do_sample,
|
|
1165
|
+
},
|
|
1166
|
+
"run_info": {
|
|
1167
|
+
"timestamp": datetime.now().isoformat(),
|
|
1168
|
+
"run_dir": str(self.run_dir),
|
|
1169
|
+
"output_dir": self.config.output_dir,
|
|
1170
|
+
"cache_dir": self.config.cache_dir,
|
|
1171
|
+
"platform": platform.platform(),
|
|
1172
|
+
"python_version": platform.python_version(),
|
|
1173
|
+
},
|
|
1174
|
+
"wandb_config": {
|
|
1175
|
+
"use_wandb": self.config.use_wandb,
|
|
1176
|
+
"wandb_project": self.config.wandb_project,
|
|
1177
|
+
}
|
|
1178
|
+
if hasattr(self.config, "use_wandb")
|
|
1179
|
+
else {},
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1182
|
+
return metadata
|
|
1183
|
+
|
|
1184
|
+
def _save_detailed_validation_results(
|
|
1185
|
+
self,
|
|
1186
|
+
questions: list[str],
|
|
1187
|
+
ground_truths: list[str],
|
|
1188
|
+
predictions: list[str],
|
|
1189
|
+
trial_number: int,
|
|
1190
|
+
trial=None,
|
|
1191
|
+
steering_method: str = None,
|
|
1192
|
+
layer_id: int = None,
|
|
1193
|
+
hyperparams: dict = None,
|
|
1194
|
+
):
|
|
1195
|
+
"""Save detailed validation results to JSON file with experiment metadata."""
|
|
1196
|
+
detailed_results = []
|
|
1197
|
+
|
|
1198
|
+
# For coding tasks, use the same BigCode evaluation as accuracy calculation
|
|
1199
|
+
if self.is_coding_task:
|
|
1200
|
+
# Use evaluate_benchmark_performance to get consistent BigCode evaluation
|
|
1201
|
+
eval_results = metrics.evaluate_benchmark_performance(
|
|
1202
|
+
predictions, ground_truths, task_name=self.config.val_dataset, task_docs=self.val_task_docs
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
# Extract individual correctness from evaluation details
|
|
1206
|
+
eval_details = eval_results.get("evaluation_details", [])
|
|
1207
|
+
|
|
1208
|
+
for i, (question, correct_answer, model_answer) in enumerate(zip(questions, ground_truths, predictions)):
|
|
1209
|
+
# Get correctness from BigCode evaluation if available
|
|
1210
|
+
is_correct = eval_details[i]["is_correct"] if i < len(eval_details) else False
|
|
1211
|
+
|
|
1212
|
+
detailed_results.append(
|
|
1213
|
+
{
|
|
1214
|
+
"row": i,
|
|
1215
|
+
"question": question,
|
|
1216
|
+
"correct_answer": correct_answer,
|
|
1217
|
+
"model_answer": model_answer,
|
|
1218
|
+
"is_correct": is_correct,
|
|
1219
|
+
"evaluation_method": eval_results.get("evaluation_method", "unknown"),
|
|
1220
|
+
"extracted_code": eval_details[i].get("prediction", model_answer)
|
|
1221
|
+
if i < len(eval_details)
|
|
1222
|
+
else model_answer,
|
|
1223
|
+
"execution_error": eval_details[i].get("execution_error") if i < len(eval_details) else None,
|
|
1224
|
+
}
|
|
1225
|
+
)
|
|
1226
|
+
else:
|
|
1227
|
+
# For non-coding tasks, process each result
|
|
1228
|
+
for i, (question, correct_answer, model_answer) in enumerate(zip(questions, ground_truths, predictions)):
|
|
1229
|
+
# Use standard evaluation via metrics module
|
|
1230
|
+
is_correct = metrics.evaluate_response_correctness(
|
|
1231
|
+
model_answer, correct_answer, self.config.val_dataset
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
result_entry = {
|
|
1235
|
+
"row": i,
|
|
1236
|
+
"question": question,
|
|
1237
|
+
"correct_answer": correct_answer,
|
|
1238
|
+
"model_answer": model_answer,
|
|
1239
|
+
"is_correct": is_correct,
|
|
1240
|
+
"evaluation_method": "string_comparison",
|
|
1241
|
+
}
|
|
1242
|
+
|
|
1243
|
+
# Add MC-specific fields if this is a multiple choice task
|
|
1244
|
+
if self._should_use_multiple_choice_evaluation():
|
|
1245
|
+
# Extract MC diagnostics directly without custom evaluation
|
|
1246
|
+
import re
|
|
1247
|
+
|
|
1248
|
+
# Extract available answers from question (A. choice, B. choice, etc.)
|
|
1249
|
+
available_answers = []
|
|
1250
|
+
choice_pattern = r"([A-E])\.\s+(.+?)(?=\n[A-E]\.|$)"
|
|
1251
|
+
matches = re.findall(choice_pattern, question, re.MULTILINE | re.DOTALL)
|
|
1252
|
+
for letter, choice_text in matches:
|
|
1253
|
+
available_answers.append(f"{letter}. {choice_text.strip()}")
|
|
1254
|
+
|
|
1255
|
+
# Extract model's selected letter from model answer
|
|
1256
|
+
model_selected_letter = "?"
|
|
1257
|
+
model_letter_match = re.search(r"\b([A-E])\b", model_answer.upper())
|
|
1258
|
+
if model_letter_match:
|
|
1259
|
+
model_selected_letter = model_letter_match.group(1)
|
|
1260
|
+
|
|
1261
|
+
result_entry["available_answers"] = available_answers
|
|
1262
|
+
result_entry["correct_choice_letter"] = correct_answer
|
|
1263
|
+
result_entry["model_selected_letter"] = model_selected_letter
|
|
1264
|
+
|
|
1265
|
+
detailed_results.append(result_entry)
|
|
1266
|
+
|
|
1267
|
+
# Create experiment metadata
|
|
1268
|
+
experiment_metadata = self._create_experiment_metadata(trial, steering_method, layer_id, hyperparams)
|
|
1269
|
+
|
|
1270
|
+
# Create final results structure with metadata
|
|
1271
|
+
final_results = {"experiment_metadata": experiment_metadata, "results": detailed_results}
|
|
1272
|
+
|
|
1273
|
+
# Save to JSON file
|
|
1274
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1275
|
+
filename = f"validation_detailed_results_trial_{trial_number:03d}_{timestamp}.json"
|
|
1276
|
+
filepath = self.run_dir / filename
|
|
1277
|
+
|
|
1278
|
+
with open(filepath, "w", encoding="utf-8") as f:
|
|
1279
|
+
json.dump(final_results, f, indent=2, ensure_ascii=False)
|
|
1280
|
+
|
|
1281
|
+
self.logger.info(f"💾 Saved detailed validation results to: {filename}")
|
|
1282
|
+
|
|
1283
|
+
def _save_detailed_test_results(
|
|
1284
|
+
self,
|
|
1285
|
+
questions: list[str],
|
|
1286
|
+
ground_truths: list[str],
|
|
1287
|
+
baseline_predictions: list[str],
|
|
1288
|
+
steered_predictions: list[str],
|
|
1289
|
+
best_trial=None,
|
|
1290
|
+
best_params: dict = None,
|
|
1291
|
+
layer_id: int = None,
|
|
1292
|
+
steering_method: str = None,
|
|
1293
|
+
):
|
|
1294
|
+
"""Save detailed test results to JSON file with both baseline and steered answers and experiment metadata."""
|
|
1295
|
+
detailed_results = []
|
|
1296
|
+
|
|
1297
|
+
# For coding tasks, use BigCode evaluation consistently
|
|
1298
|
+
if self.is_coding_task:
|
|
1299
|
+
# Evaluate baseline predictions with BigCode
|
|
1300
|
+
baseline_eval_results = metrics.evaluate_benchmark_performance(
|
|
1301
|
+
baseline_predictions, ground_truths, task_name=self.config.test_dataset, task_docs=self.test_task_docs
|
|
1302
|
+
)
|
|
1303
|
+
|
|
1304
|
+
# Evaluate steered predictions with BigCode
|
|
1305
|
+
steered_eval_results = metrics.evaluate_benchmark_performance(
|
|
1306
|
+
steered_predictions, ground_truths, task_name=self.config.test_dataset, task_docs=self.test_task_docs
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1309
|
+
baseline_details = baseline_eval_results.get("evaluation_details", [])
|
|
1310
|
+
steered_details = steered_eval_results.get("evaluation_details", [])
|
|
1311
|
+
|
|
1312
|
+
for i, (question, correct_answer, baseline_answer, steered_answer) in enumerate(
|
|
1313
|
+
zip(questions, ground_truths, baseline_predictions, steered_predictions)
|
|
1314
|
+
):
|
|
1315
|
+
# Get correctness from BigCode evaluation
|
|
1316
|
+
is_baseline_correct = baseline_details[i]["is_correct"] if i < len(baseline_details) else False
|
|
1317
|
+
is_correct = steered_details[i]["is_correct"] if i < len(steered_details) else False
|
|
1318
|
+
|
|
1319
|
+
detailed_results.append(
|
|
1320
|
+
{
|
|
1321
|
+
"row": i,
|
|
1322
|
+
"question": question,
|
|
1323
|
+
"correct_answer": correct_answer,
|
|
1324
|
+
"baseline_model_answer": baseline_answer,
|
|
1325
|
+
"model_answer": steered_answer,
|
|
1326
|
+
"is_baseline_correct": is_baseline_correct,
|
|
1327
|
+
"is_correct": is_correct,
|
|
1328
|
+
"evaluation_method": steered_eval_results.get("evaluation_method", "bigcode_execution"),
|
|
1329
|
+
"baseline_extracted_code": baseline_details[i].get("prediction", baseline_answer)
|
|
1330
|
+
if i < len(baseline_details)
|
|
1331
|
+
else baseline_answer,
|
|
1332
|
+
"steered_extracted_code": steered_details[i].get("prediction", steered_answer)
|
|
1333
|
+
if i < len(steered_details)
|
|
1334
|
+
else steered_answer,
|
|
1335
|
+
"baseline_execution_error": baseline_details[i].get("execution_error")
|
|
1336
|
+
if i < len(baseline_details)
|
|
1337
|
+
else None,
|
|
1338
|
+
"steered_execution_error": steered_details[i].get("execution_error")
|
|
1339
|
+
if i < len(steered_details)
|
|
1340
|
+
else None,
|
|
1341
|
+
}
|
|
1342
|
+
)
|
|
1343
|
+
else:
|
|
1344
|
+
# For non-coding tasks, process each result
|
|
1345
|
+
for i, (question, correct_answer, baseline_answer, steered_answer) in enumerate(
|
|
1346
|
+
zip(questions, ground_truths, baseline_predictions, steered_predictions)
|
|
1347
|
+
):
|
|
1348
|
+
# Use standard evaluation for both baseline and steered answers
|
|
1349
|
+
is_baseline_correct = metrics.evaluate_response_correctness(
|
|
1350
|
+
baseline_answer, correct_answer, self.config.test_dataset
|
|
1351
|
+
)
|
|
1352
|
+
is_correct = metrics.evaluate_response_correctness(
|
|
1353
|
+
steered_answer, correct_answer, self.config.test_dataset
|
|
1354
|
+
)
|
|
1355
|
+
|
|
1356
|
+
result_entry = {
|
|
1357
|
+
"row": i,
|
|
1358
|
+
"question": question,
|
|
1359
|
+
"correct_answer": correct_answer,
|
|
1360
|
+
"baseline_model_answer": baseline_answer,
|
|
1361
|
+
"model_answer": steered_answer,
|
|
1362
|
+
"is_baseline_correct": is_baseline_correct,
|
|
1363
|
+
"is_correct": is_correct,
|
|
1364
|
+
"evaluation_method": "string_comparison",
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
# Add MC-specific fields if this is a multiple choice task
|
|
1368
|
+
if self._should_use_multiple_choice_evaluation():
|
|
1369
|
+
# Extract MC diagnostics directly without custom evaluation
|
|
1370
|
+
import re
|
|
1371
|
+
|
|
1372
|
+
# Extract available answers from question (A. choice, B. choice, etc.)
|
|
1373
|
+
available_answers = []
|
|
1374
|
+
choice_pattern = r"([A-E])\.\s+(.+?)(?=\n[A-E]\.|$)"
|
|
1375
|
+
matches = re.findall(choice_pattern, question, re.MULTILINE | re.DOTALL)
|
|
1376
|
+
for letter, choice_text in matches:
|
|
1377
|
+
available_answers.append(f"{letter}. {choice_text.strip()}")
|
|
1378
|
+
|
|
1379
|
+
# Extract steered model's selected letter
|
|
1380
|
+
steered_selected_letter = "?"
|
|
1381
|
+
steered_letter_match = re.search(r"\b([A-E])\b", steered_answer.upper())
|
|
1382
|
+
if steered_letter_match:
|
|
1383
|
+
steered_selected_letter = steered_letter_match.group(1)
|
|
1384
|
+
|
|
1385
|
+
# Extract baseline model's selected letter
|
|
1386
|
+
baseline_selected_letter = "?"
|
|
1387
|
+
baseline_letter_match = re.search(r"\b([A-E])\b", baseline_answer.upper())
|
|
1388
|
+
if baseline_letter_match:
|
|
1389
|
+
baseline_selected_letter = baseline_letter_match.group(1)
|
|
1390
|
+
|
|
1391
|
+
result_entry["available_answers"] = available_answers
|
|
1392
|
+
result_entry["correct_choice_letter"] = correct_answer
|
|
1393
|
+
result_entry["model_selected_letter"] = steered_selected_letter
|
|
1394
|
+
result_entry["baseline_model_selected_letter"] = baseline_selected_letter
|
|
1395
|
+
|
|
1396
|
+
detailed_results.append(result_entry)
|
|
1397
|
+
|
|
1398
|
+
# Create experiment metadata for test results
|
|
1399
|
+
experiment_metadata = self._create_experiment_metadata(
|
|
1400
|
+
trial=best_trial,
|
|
1401
|
+
steering_method=steering_method or best_params.get("steering_method") if best_params else None,
|
|
1402
|
+
layer_id=layer_id,
|
|
1403
|
+
hyperparams=best_params,
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1406
|
+
# Create final results structure with metadata
|
|
1407
|
+
final_results = {"experiment_metadata": experiment_metadata, "results": detailed_results}
|
|
1408
|
+
|
|
1409
|
+
# Save to JSON file
|
|
1410
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1411
|
+
filename = f"test_detailed_results_{timestamp}.json"
|
|
1412
|
+
filepath = self.run_dir / filename
|
|
1413
|
+
|
|
1414
|
+
with open(filepath, "w", encoding="utf-8") as f:
|
|
1415
|
+
json.dump(final_results, f, indent=2, ensure_ascii=False)
|
|
1416
|
+
|
|
1417
|
+
self.logger.info(f"💾 Saved detailed test results to: {filename}")
|
|
1418
|
+
return filename
|
|
1419
|
+
|
|
1420
|
+
def _save_reproducibility_bundle(self, study: optuna.Study, final_results: dict[str, Any]):
|
|
1421
|
+
"""Save complete reproducibility bundle."""
|
|
1422
|
+
|
|
1423
|
+
# Save Optuna study
|
|
1424
|
+
study_path = self.run_dir / f"optuna_study_{self.run_timestamp}.db"
|
|
1425
|
+
study.study_name = str(study_path)
|
|
1426
|
+
|
|
1427
|
+
# Save configuration
|
|
1428
|
+
config_path = self.run_dir / f"config_{self.run_timestamp}.json"
|
|
1429
|
+
with open(config_path, "w") as f:
|
|
1430
|
+
json.dump(self.config.to_dict(), f, indent=2)
|
|
1431
|
+
|
|
1432
|
+
# Save final results
|
|
1433
|
+
results_path = self.run_dir / f"final_results_{self.run_timestamp}.json"
|
|
1434
|
+
with open(results_path, "w") as f:
|
|
1435
|
+
json.dump(final_results, f, indent=2, default=str)
|
|
1436
|
+
|
|
1437
|
+
# Save best configuration
|
|
1438
|
+
best_config = {
|
|
1439
|
+
"best_params": study.best_trial.params,
|
|
1440
|
+
"best_value": study.best_trial.value,
|
|
1441
|
+
"model_name": self.config.model_name,
|
|
1442
|
+
"random_seed": self.config.seed,
|
|
1443
|
+
"commit_hash": self._get_git_commit_hash(),
|
|
1444
|
+
"timestamp": self.run_timestamp,
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
best_config_path = self.run_dir / f"best_configuration_{self.run_timestamp}.json"
|
|
1448
|
+
with open(best_config_path, "w") as f:
|
|
1449
|
+
json.dump(best_config, f, indent=2)
|
|
1450
|
+
|
|
1451
|
+
# Save study trials summary
|
|
1452
|
+
trials_df = study.trials_dataframe()
|
|
1453
|
+
trials_path = self.run_dir / f"study_trials_{self.run_timestamp}.csv"
|
|
1454
|
+
trials_df.to_csv(trials_path, index=False)
|
|
1455
|
+
|
|
1456
|
+
self.logger.info(f"💾 Reproducibility bundle saved to: {self.run_dir}")
|
|
1457
|
+
self.logger.info(f"📊 Study database: {study_path}")
|
|
1458
|
+
self.logger.info(f"⚙️ Configuration: {config_path}")
|
|
1459
|
+
self.logger.info(f"🏆 Results: {results_path}")
|
|
1460
|
+
self.logger.info(f"🎯 Best config: {best_config_path}")
|
|
1461
|
+
|
|
1462
|
+
# Log steering vector if it exists (prefer safetensors format)
|
|
1463
|
+
safetensors_path = self.run_dir / "best_steering_vector.safetensors"
|
|
1464
|
+
pt_path = self.run_dir / "best_steering_vector.pt"
|
|
1465
|
+
|
|
1466
|
+
if safetensors_path.exists():
|
|
1467
|
+
self.logger.info(f"🧭 Steering vector: {safetensors_path.name}")
|
|
1468
|
+
elif pt_path.exists():
|
|
1469
|
+
self.logger.info(f"🧭 Steering vector: {pt_path.name}")
|
|
1470
|
+
|
|
1471
|
+
def _get_git_commit_hash(self) -> Optional[str]:
|
|
1472
|
+
"""Get current git commit hash for reproducibility."""
|
|
1473
|
+
try:
|
|
1474
|
+
import subprocess
|
|
1475
|
+
|
|
1476
|
+
result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True)
|
|
1477
|
+
if result.returncode == 0:
|
|
1478
|
+
return result.stdout.strip()
|
|
1479
|
+
except:
|
|
1480
|
+
pass
|
|
1481
|
+
return None
|
|
1482
|
+
|
|
1483
|
+
def evaluate_only(self, best_params: dict[str, Any]) -> dict[str, Any]:
|
|
1484
|
+
"""Run evaluation only with provided parameters.
|
|
1485
|
+
|
|
1486
|
+
Args:
|
|
1487
|
+
best_params: Dictionary of hyperparameters to use for evaluation
|
|
1488
|
+
|
|
1489
|
+
Returns:
|
|
1490
|
+
Dictionary containing evaluation results
|
|
1491
|
+
"""
|
|
1492
|
+
self.logger.info("🔬 Running evaluation-only mode with provided parameters")
|
|
1493
|
+
self.logger.info(f"Parameters: {best_params}")
|
|
1494
|
+
|
|
1495
|
+
# Setup experiment if not already done
|
|
1496
|
+
if self.model is None:
|
|
1497
|
+
self._setup_experiment()
|
|
1498
|
+
|
|
1499
|
+
# Create timestamped run directory for evaluation-only mode
|
|
1500
|
+
if not hasattr(self, "run_dir"):
|
|
1501
|
+
self.run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1502
|
+
self.run_dir = self.output_dir / f"evaluate_only_{self.run_timestamp}"
|
|
1503
|
+
self.run_dir.mkdir(parents=True, exist_ok=True)
|
|
1504
|
+
self.logger.info(f"📁 Evaluation directory: {self.run_dir}")
|
|
1505
|
+
|
|
1506
|
+
# Create a complete mock trial with all expected parameters
|
|
1507
|
+
from optuna.trial import FixedTrial
|
|
1508
|
+
|
|
1509
|
+
# Ensure we have all required parameters for _final_evaluation
|
|
1510
|
+
complete_params = {
|
|
1511
|
+
"layer_id": best_params.get("layer_id", 15),
|
|
1512
|
+
"probe_type": best_params.get("probe_type", "logistic_regression"),
|
|
1513
|
+
"probe_c": best_params.get("probe_c", 1.0),
|
|
1514
|
+
"steering_method": best_params.get("steering_method", "caa"),
|
|
1515
|
+
"steering_alpha": best_params.get("steering_alpha", 0.5),
|
|
1516
|
+
}
|
|
1517
|
+
|
|
1518
|
+
# Add method-specific parameters if needed
|
|
1519
|
+
if complete_params["steering_method"] == "dac":
|
|
1520
|
+
complete_params.update(
|
|
1521
|
+
{
|
|
1522
|
+
"entropy_threshold": best_params.get("entropy_threshold", 1.5),
|
|
1523
|
+
"ptop": best_params.get("ptop", 0.5),
|
|
1524
|
+
"max_alpha": best_params.get("max_alpha", 2.0),
|
|
1525
|
+
}
|
|
1526
|
+
)
|
|
1527
|
+
|
|
1528
|
+
fixed_trial = FixedTrial(complete_params)
|
|
1529
|
+
|
|
1530
|
+
# Fix FixedTrial params access issue
|
|
1531
|
+
if not hasattr(fixed_trial, "params"):
|
|
1532
|
+
fixed_trial.params = complete_params
|
|
1533
|
+
|
|
1534
|
+
# Run final evaluation
|
|
1535
|
+
return self._final_evaluation(fixed_trial)
|
|
1536
|
+
|
|
1537
|
+
@classmethod
|
|
1538
|
+
def from_saved_study(
|
|
1539
|
+
cls, study_path: str, config_path: Optional[str] = None, override_config: Optional[dict[str, Any]] = None
|
|
1540
|
+
):
|
|
1541
|
+
"""Create pipeline from saved study and optionally saved config.
|
|
1542
|
+
|
|
1543
|
+
Args:
|
|
1544
|
+
study_path: Path to the SQLite study database
|
|
1545
|
+
config_path: Optional path to saved configuration JSON
|
|
1546
|
+
override_config: Optional dict of config values to override
|
|
1547
|
+
|
|
1548
|
+
Returns:
|
|
1549
|
+
Tuple of (pipeline, study) ready for evaluation
|
|
1550
|
+
"""
|
|
1551
|
+
# Load config if provided
|
|
1552
|
+
if config_path:
|
|
1553
|
+
with open(config_path) as f:
|
|
1554
|
+
config_dict = json.load(f)
|
|
1555
|
+
# Apply any overrides
|
|
1556
|
+
if override_config:
|
|
1557
|
+
config_dict.update(override_config)
|
|
1558
|
+
config = OptimizationConfig(**config_dict)
|
|
1559
|
+
else:
|
|
1560
|
+
# Create minimal config with overrides
|
|
1561
|
+
config = OptimizationConfig(**(override_config or {}))
|
|
1562
|
+
|
|
1563
|
+
# Load study
|
|
1564
|
+
from pathlib import Path
|
|
1565
|
+
|
|
1566
|
+
study_name = Path(study_path).stem
|
|
1567
|
+
study = optuna.load_study(study_name=study_name, storage=f"sqlite:///{study_path}")
|
|
1568
|
+
|
|
1569
|
+
pipeline = cls(config)
|
|
1570
|
+
return pipeline, study
|
|
1571
|
+
|
|
1572
|
+
def evaluate_on_dataset(
|
|
1573
|
+
self, best_params: dict[str, Any], dataset_name: str, dataset_limit: Optional[int] = None
|
|
1574
|
+
) -> dict[str, Any]:
|
|
1575
|
+
"""Evaluate best parameters on a different dataset.
|
|
1576
|
+
|
|
1577
|
+
Args:
|
|
1578
|
+
best_params: Dictionary of hyperparameters to use
|
|
1579
|
+
dataset_name: Name of dataset to evaluate on
|
|
1580
|
+
dataset_limit: Optional limit on number of samples
|
|
1581
|
+
|
|
1582
|
+
Returns:
|
|
1583
|
+
Dictionary containing evaluation results on the new dataset
|
|
1584
|
+
"""
|
|
1585
|
+
# Temporarily override dataset configuration
|
|
1586
|
+
original_test_dataset = self.config.test_dataset
|
|
1587
|
+
original_test_limit = self.config.test_limit
|
|
1588
|
+
|
|
1589
|
+
self.config.test_dataset = dataset_name
|
|
1590
|
+
self.config.test_limit = dataset_limit or self.config.test_limit
|
|
1591
|
+
|
|
1592
|
+
self.logger.info(f"📊 Evaluating on {dataset_name} with {self.config.test_limit} samples")
|
|
1593
|
+
|
|
1594
|
+
# Reload test samples for new dataset
|
|
1595
|
+
from . import data_utils
|
|
1596
|
+
|
|
1597
|
+
self.test_samples = data_utils.load_dataset_samples(self.config.test_dataset, self.config.test_limit)
|
|
1598
|
+
|
|
1599
|
+
# Run evaluation
|
|
1600
|
+
results = self.evaluate_only(best_params)
|
|
1601
|
+
|
|
1602
|
+
# Restore original config
|
|
1603
|
+
self.config.test_dataset = original_test_dataset
|
|
1604
|
+
self.config.test_limit = original_test_limit
|
|
1605
|
+
|
|
1606
|
+
return results
|
|
1607
|
+
|
|
1608
|
+
def cleanup_memory(self):
|
|
1609
|
+
"""Clean up GPU/MPS memory."""
|
|
1610
|
+
if hasattr(self, "model") and self.model is not None:
|
|
1611
|
+
del self.model
|
|
1612
|
+
self.model = None
|
|
1613
|
+
if hasattr(self, "tokenizer") and self.tokenizer is not None:
|
|
1614
|
+
del self.tokenizer
|
|
1615
|
+
self.tokenizer = None
|
|
1616
|
+
|
|
1617
|
+
# Finish WandB run
|
|
1618
|
+
if self.wandb_run is not None:
|
|
1619
|
+
wandb.finish()
|
|
1620
|
+
self.wandb_run = None
|
|
1621
|
+
|
|
1622
|
+
# Clean up device memory
|
|
1623
|
+
empty_device_cache(self.device.type)
|
|
1624
|
+
|
|
1625
|
+
import gc
|
|
1626
|
+
|
|
1627
|
+
gc.collect()
|
|
1628
|
+
|
|
1629
|
+
def _init_wandb(self):
|
|
1630
|
+
"""Initialize WandB for experiment tracking."""
|
|
1631
|
+
try:
|
|
1632
|
+
self.wandb_run = wandb.init(
|
|
1633
|
+
project=self.config.wandb_project,
|
|
1634
|
+
name=f"{self.config.study_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
1635
|
+
config=self.config.to_dict(),
|
|
1636
|
+
tags=["optuna", "steering", "optimization"],
|
|
1637
|
+
reinit=True,
|
|
1638
|
+
)
|
|
1639
|
+
self.logger.info(f"WandB initialized: {wandb.run.url}")
|
|
1640
|
+
except Exception as e:
|
|
1641
|
+
# Don't silently disable - user explicitly requested WandB
|
|
1642
|
+
raise RuntimeError(
|
|
1643
|
+
f"Failed to initialize WandB: {e}\n"
|
|
1644
|
+
f"Possible solutions:\n"
|
|
1645
|
+
f"1. Run 'wandb login' to authenticate\n"
|
|
1646
|
+
f"2. Check your internet connection\n"
|
|
1647
|
+
f"3. Verify project name: {self.config.wandb_project}\n"
|
|
1648
|
+
f"4. Set use_wandb=False to disable WandB"
|
|
1649
|
+
) from e
|
|
1650
|
+
|
|
1651
|
+
def _log_trial_to_wandb(self, trial: optuna.Trial, metrics: dict[str, float]):
|
|
1652
|
+
"""Log trial results to WandB."""
|
|
1653
|
+
if not self.config.use_wandb or self.wandb_run is None:
|
|
1654
|
+
return
|
|
1655
|
+
|
|
1656
|
+
try:
|
|
1657
|
+
# Log trial parameters and metrics
|
|
1658
|
+
log_data = {f"trial/{k}": v for k, v in trial.params.items()}
|
|
1659
|
+
log_data.update({f"metrics/{k}": v for k, v in metrics.items()})
|
|
1660
|
+
log_data["trial/number"] = trial.number
|
|
1661
|
+
|
|
1662
|
+
wandb.log(log_data)
|
|
1663
|
+
except Exception as e:
|
|
1664
|
+
self.logger.warning(f"Failed to log trial to WandB: {e}")
|
|
1665
|
+
|
|
1666
|
+
def _log_final_results_to_wandb(self, study: optuna.Study, final_results: dict[str, Any]):
|
|
1667
|
+
"""Log final optimization results to WandB."""
|
|
1668
|
+
if not self.config.use_wandb or self.wandb_run is None:
|
|
1669
|
+
return
|
|
1670
|
+
|
|
1671
|
+
try:
|
|
1672
|
+
# Log best trial results
|
|
1673
|
+
best_params = {f"best/{k}": v for k, v in study.best_params.items()}
|
|
1674
|
+
best_metrics = {
|
|
1675
|
+
"best/validation_accuracy": study.best_value,
|
|
1676
|
+
"best/baseline_accuracy": final_results["baseline_benchmark_metrics"]["accuracy"],
|
|
1677
|
+
"best/steered_accuracy": final_results["steered_benchmark_metrics"]["accuracy"],
|
|
1678
|
+
"best/accuracy_improvement": final_results["accuracy_improvement"],
|
|
1679
|
+
"study/n_trials": len(study.trials),
|
|
1680
|
+
"study/n_complete_trials": len(
|
|
1681
|
+
[t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
|
|
1682
|
+
),
|
|
1683
|
+
}
|
|
1684
|
+
|
|
1685
|
+
wandb.log({**best_params, **best_metrics})
|
|
1686
|
+
|
|
1687
|
+
# Log optimization history
|
|
1688
|
+
trial_values = [t.value for t in study.trials if t.value is not None]
|
|
1689
|
+
if trial_values:
|
|
1690
|
+
wandb.log(
|
|
1691
|
+
{
|
|
1692
|
+
"optimization/best_value_so_far": max(trial_values),
|
|
1693
|
+
"optimization/mean_trial_value": np.mean(trial_values),
|
|
1694
|
+
"optimization/std_trial_value": np.std(trial_values),
|
|
1695
|
+
}
|
|
1696
|
+
)
|
|
1697
|
+
|
|
1698
|
+
except Exception as e:
|
|
1699
|
+
self.logger.warning(f"Failed to log final results to WandB: {e}")
|
|
1700
|
+
|
|
1701
|
+
def _should_use_multiple_choice_evaluation(self) -> bool:
|
|
1702
|
+
"""Determine if we should use multiple choice evaluation for this dataset."""
|
|
1703
|
+
# Use multiple choice evaluation for TruthfulQA and other MC tasks
|
|
1704
|
+
return self.config.test_dataset.lower() in ["truthfulqa_mc1", "truthfulqa", "mmlu"]
|
|
1705
|
+
|
|
1706
|
+
|
|
1707
|
+
def main():
|
|
1708
|
+
"""Main entry point for optimization pipeline."""
|
|
1709
|
+
# Setup logging
|
|
1710
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
1711
|
+
|
|
1712
|
+
# Create configuration
|
|
1713
|
+
config = OptimizationConfig(
|
|
1714
|
+
train_limit=100,
|
|
1715
|
+
contrastive_pairs_limit=30, # Bounded by train_limit
|
|
1716
|
+
val_limit=50,
|
|
1717
|
+
test_limit=50,
|
|
1718
|
+
n_trials=20,
|
|
1719
|
+
layer_search_range=(10, 15),
|
|
1720
|
+
)
|
|
1721
|
+
|
|
1722
|
+
# Run optimization
|
|
1723
|
+
pipeline = OptimizationPipeline(config)
|
|
1724
|
+
try:
|
|
1725
|
+
results = pipeline.run_optimization()
|
|
1726
|
+
|
|
1727
|
+
print("🎉 Optimization completed!")
|
|
1728
|
+
print(f"Best validation score: {results['best_validation_score']:.4f}")
|
|
1729
|
+
print(f"Test accuracy: {results['steered_benchmark_metrics']['accuracy']:.4f}")
|
|
1730
|
+
print(f"Accuracy improvement: {results['accuracy_improvement']:+.4f}")
|
|
1731
|
+
|
|
1732
|
+
finally:
|
|
1733
|
+
# Clean up memory
|
|
1734
|
+
pipeline.cleanup_memory()
|
|
1735
|
+
|
|
1736
|
+
|
|
1737
|
+
if __name__ == "__main__":
|
|
1738
|
+
main()
|