wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1111 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Steering optimization module for improving benchmark performance.
|
|
3
|
+
|
|
4
|
+
This module handles training and optimizing different steering methods that can
|
|
5
|
+
improve model performance on benchmarks by steering internal activations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import traceback
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
17
|
+
from wisent.core.activations.core import ActivationAggregationStrategy
|
|
18
|
+
from wisent.core.classifier.classifier import Classifier
|
|
19
|
+
from wisent.core.contrastive_pairs.contrastive_pair import ContrastivePair
|
|
20
|
+
from wisent.core.contrastive_pairs.contrastive_pair_set import ContrastivePairSet
|
|
21
|
+
from wisent.core.optuna.classifier import (
|
|
22
|
+
CacheConfig,
|
|
23
|
+
ClassifierCache,
|
|
24
|
+
ClassifierOptimizationConfig,
|
|
25
|
+
GenerationConfig,
|
|
26
|
+
OptunaClassifierOptimizer,
|
|
27
|
+
)
|
|
28
|
+
from wisent.core.optuna.steering import data_utils, metrics
|
|
29
|
+
from wisent.core.response import Response
|
|
30
|
+
from wisent.core.steering_methods.dac import DAC
|
|
31
|
+
from wisent.core.task_interface import get_task
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class SteeringMethodConfig(ABC):
|
|
38
|
+
"""Base configuration for steering methods."""
|
|
39
|
+
|
|
40
|
+
method_name: str = "base"
|
|
41
|
+
layers: List[int] = None
|
|
42
|
+
strengths: List[float] = None
|
|
43
|
+
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
if self.layers is None:
|
|
46
|
+
self.layers = []
|
|
47
|
+
if self.strengths is None:
|
|
48
|
+
self.strengths = [1.0]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class DACConfig(SteeringMethodConfig):
|
|
53
|
+
"""Configuration for DAC (Dynamic Activation Composition) steering method."""
|
|
54
|
+
|
|
55
|
+
method_name: str = "dac"
|
|
56
|
+
entropy_thresholds: List[float] = None
|
|
57
|
+
ptop_values: List[float] = None
|
|
58
|
+
max_alpha_values: List[float] = None
|
|
59
|
+
|
|
60
|
+
def __post_init__(self):
|
|
61
|
+
super().__post_init__()
|
|
62
|
+
if self.entropy_thresholds is None:
|
|
63
|
+
self.entropy_thresholds = [1.0]
|
|
64
|
+
if self.ptop_values is None:
|
|
65
|
+
self.ptop_values = [0.4]
|
|
66
|
+
if self.max_alpha_values is None:
|
|
67
|
+
self.max_alpha_values = [2.0]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class SteeringResult:
|
|
72
|
+
"""Results from training and evaluating a steering method configuration."""
|
|
73
|
+
|
|
74
|
+
method_name: str
|
|
75
|
+
layer: int
|
|
76
|
+
hyperparameters: Dict[str, Any]
|
|
77
|
+
benchmark_metrics: Dict[str, float]
|
|
78
|
+
training_success: bool
|
|
79
|
+
training_stats: Dict[str, Any] = None
|
|
80
|
+
baseline_metrics: Dict[str, float] = None
|
|
81
|
+
comparative_metrics: Dict[str, Any] = None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class SteeringMethodTrainer(ABC):
|
|
85
|
+
"""Abstract base class for training different steering methods."""
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def create_method_instance(self, hyperparams: Dict[str, Any], device: str) -> Any:
|
|
89
|
+
"""Create an instance of the steering method with given hyperparameters."""
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def train_method(
|
|
93
|
+
self,
|
|
94
|
+
method_instance: Any,
|
|
95
|
+
train_samples: List[Dict],
|
|
96
|
+
layer: int,
|
|
97
|
+
model,
|
|
98
|
+
tokenizer,
|
|
99
|
+
device: str,
|
|
100
|
+
task_name: str = "gsm8k",
|
|
101
|
+
max_new_tokens: int = 200,
|
|
102
|
+
) -> Tuple[bool, Dict[str, Any]]:
|
|
103
|
+
"""Train the steering method on training data."""
|
|
104
|
+
|
|
105
|
+
@abstractmethod
|
|
106
|
+
def apply_steering_and_evaluate(
|
|
107
|
+
self,
|
|
108
|
+
method_instance: Any,
|
|
109
|
+
evaluation_samples: List[Dict],
|
|
110
|
+
layer: int,
|
|
111
|
+
strength: float,
|
|
112
|
+
model,
|
|
113
|
+
tokenizer,
|
|
114
|
+
device: str,
|
|
115
|
+
batch_size: int,
|
|
116
|
+
max_length: int,
|
|
117
|
+
task_name: str = "gsm8k",
|
|
118
|
+
max_new_tokens: int = 200,
|
|
119
|
+
) -> Tuple[List[str], List[str]]:
|
|
120
|
+
"""Apply steering and generate predictions for evaluation."""
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class DACTrainer(SteeringMethodTrainer):
|
|
124
|
+
"""Trainer for DAC (Dynamic Activation Composition) steering method."""
|
|
125
|
+
|
|
126
|
+
def create_method_instance(self, hyperparams: Dict[str, Any], device: str) -> DAC:
|
|
127
|
+
"""Create DAC instance with specified hyperparameters."""
|
|
128
|
+
return DAC(
|
|
129
|
+
device=device,
|
|
130
|
+
dynamic_control=True,
|
|
131
|
+
entropy_threshold=hyperparams.get("entropy_threshold", 1.0),
|
|
132
|
+
ptop=hyperparams.get("ptop", 0.4),
|
|
133
|
+
max_alpha=hyperparams.get("max_alpha", 2.0),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def train_method(
|
|
137
|
+
self,
|
|
138
|
+
dac_instance: DAC,
|
|
139
|
+
train_samples: List[Dict],
|
|
140
|
+
layer: int,
|
|
141
|
+
model,
|
|
142
|
+
tokenizer,
|
|
143
|
+
device: str,
|
|
144
|
+
task_name: str = "gsm8k",
|
|
145
|
+
max_new_tokens: int = 200,
|
|
146
|
+
) -> Tuple[bool, Dict[str, Any]]:
|
|
147
|
+
"""Train DAC on training data to create steering vectors."""
|
|
148
|
+
try:
|
|
149
|
+
# Set model reference for KL computation
|
|
150
|
+
dac_instance.set_model_reference(model)
|
|
151
|
+
|
|
152
|
+
# Extract contrastive pairs from training data using task's extractor
|
|
153
|
+
contrastive_pairs = data_utils.get_task_contrastive_pairs(train_samples, task_name)
|
|
154
|
+
|
|
155
|
+
if not contrastive_pairs:
|
|
156
|
+
logger.warning(f"No contrastive pairs extracted from {task_name} training data")
|
|
157
|
+
return False, {"error": "No contrastive pairs"}
|
|
158
|
+
|
|
159
|
+
# Convert to ContrastivePairSet format
|
|
160
|
+
pair_set = self._create_pair_set_from_extracted_pairs(contrastive_pairs, layer, model, tokenizer, device)
|
|
161
|
+
|
|
162
|
+
# Train DAC
|
|
163
|
+
training_result = dac_instance.train(pair_set, layer)
|
|
164
|
+
|
|
165
|
+
success = training_result.get("success", False)
|
|
166
|
+
logger.debug(f"DAC training on layer {layer}: {'Success' if success else 'Failed'}")
|
|
167
|
+
|
|
168
|
+
return success, training_result
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"DAC training failed on layer {layer}: {e}")
|
|
172
|
+
return False, {"error": str(e)}
|
|
173
|
+
|
|
174
|
+
def apply_steering_and_evaluate(
|
|
175
|
+
self,
|
|
176
|
+
dac_instance: DAC,
|
|
177
|
+
evaluation_samples: List[Dict],
|
|
178
|
+
layer: int,
|
|
179
|
+
strength: float,
|
|
180
|
+
model,
|
|
181
|
+
tokenizer,
|
|
182
|
+
device: str,
|
|
183
|
+
batch_size: int,
|
|
184
|
+
max_length: int,
|
|
185
|
+
task_name: str = "gsm8k",
|
|
186
|
+
max_new_tokens: int = 200,
|
|
187
|
+
) -> Tuple[List[str], List[str]]:
|
|
188
|
+
"""Apply DAC steering and generate predictions using task extractor."""
|
|
189
|
+
|
|
190
|
+
predictions = []
|
|
191
|
+
ground_truths = []
|
|
192
|
+
|
|
193
|
+
# Get the task and its extractor
|
|
194
|
+
task = get_task(task_name)
|
|
195
|
+
extractor = task.get_extractor()
|
|
196
|
+
|
|
197
|
+
# Pre-extract all questions and answers (optimization)
|
|
198
|
+
questions = []
|
|
199
|
+
answers = []
|
|
200
|
+
|
|
201
|
+
for sample in evaluation_samples:
|
|
202
|
+
qa_pair = extractor.extract_qa_pair(sample, task)
|
|
203
|
+
if not qa_pair:
|
|
204
|
+
logger.warning(f"Skipping sample - extractor couldn't extract QA pair: {sample.keys()}")
|
|
205
|
+
continue
|
|
206
|
+
questions.append(qa_pair["formatted_question"])
|
|
207
|
+
answers.append(qa_pair["correct_answer"])
|
|
208
|
+
|
|
209
|
+
# Process questions with steering in batches (optimized approach)
|
|
210
|
+
ground_truths.extend(answers)
|
|
211
|
+
|
|
212
|
+
# Handle different model architectures
|
|
213
|
+
if hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
214
|
+
# LLaMA-style models
|
|
215
|
+
layer_module = model.model.layers[layer]
|
|
216
|
+
elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
|
|
217
|
+
# GPT2-style models
|
|
218
|
+
layer_module = model.transformer.h[layer]
|
|
219
|
+
else:
|
|
220
|
+
raise ValueError("Unsupported model architecture for DAC steering")
|
|
221
|
+
|
|
222
|
+
# Process in batches with steering
|
|
223
|
+
for i in tqdm(range(0, len(questions), batch_size), desc="Generating predictions with steering"):
|
|
224
|
+
batch_questions = questions[i : i + batch_size]
|
|
225
|
+
|
|
226
|
+
# First, get actual lengths (before padding) for proper steering
|
|
227
|
+
actual_lengths = []
|
|
228
|
+
for question in batch_questions:
|
|
229
|
+
tokens = tokenizer(question, return_tensors="pt")
|
|
230
|
+
actual_lengths.append(tokens["input_ids"].shape[1])
|
|
231
|
+
|
|
232
|
+
# Create batched steering hook that handles variable lengths
|
|
233
|
+
def create_batched_steering_hook(actual_lengths):
|
|
234
|
+
def steering_hook(module, input, output):
|
|
235
|
+
hidden_states = output[0] # [batch_size, seq_len, hidden_dim]
|
|
236
|
+
|
|
237
|
+
# Apply steering to each sample's actual last token
|
|
238
|
+
for j, actual_length in enumerate(actual_lengths):
|
|
239
|
+
if j < hidden_states.shape[0]: # Safety check for batch size
|
|
240
|
+
# Get the actual last token (before padding)
|
|
241
|
+
last_token = hidden_states[j : j + 1, actual_length - 1 : actual_length, :]
|
|
242
|
+
steered = dac_instance.apply_steering(last_token, strength=strength)
|
|
243
|
+
hidden_states[j : j + 1, actual_length - 1 : actual_length, :] = steered
|
|
244
|
+
|
|
245
|
+
return (hidden_states,) + output[1:]
|
|
246
|
+
|
|
247
|
+
return steering_hook
|
|
248
|
+
|
|
249
|
+
# Register the batched hook
|
|
250
|
+
batched_hook = create_batched_steering_hook(actual_lengths)
|
|
251
|
+
handle = layer_module.register_forward_hook(batched_hook)
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
# Tokenize batch with padding for generation
|
|
255
|
+
inputs = tokenizer(
|
|
256
|
+
batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=max_length
|
|
257
|
+
).to(device)
|
|
258
|
+
|
|
259
|
+
with torch.no_grad():
|
|
260
|
+
outputs = model.generate(
|
|
261
|
+
**inputs,
|
|
262
|
+
max_new_tokens=max_new_tokens,
|
|
263
|
+
do_sample=True,
|
|
264
|
+
temperature=0.7,
|
|
265
|
+
pad_token_id=tokenizer.eos_token_id,
|
|
266
|
+
use_cache=False, # Disable cache to avoid cache_position errors
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Decode responses for each item in batch
|
|
270
|
+
for j, (output, question) in enumerate(zip(outputs, batch_questions)):
|
|
271
|
+
response = tokenizer.decode(output, skip_special_tokens=True)
|
|
272
|
+
prediction = response[len(question) :].strip()
|
|
273
|
+
predictions.append(prediction)
|
|
274
|
+
|
|
275
|
+
finally:
|
|
276
|
+
handle.remove()
|
|
277
|
+
|
|
278
|
+
return predictions, ground_truths
|
|
279
|
+
|
|
280
|
+
def _create_pair_set_from_extracted_pairs(
|
|
281
|
+
self, extracted_pairs: List[Dict], layer_index: int, model, tokenizer, device: str
|
|
282
|
+
) -> ContrastivePairSet:
|
|
283
|
+
"""Convert extracted pairs to ContrastivePairSet format with proper activation extraction."""
|
|
284
|
+
pair_set = ContrastivePairSet(name="dac_training", task_type="mathematical_reasoning")
|
|
285
|
+
|
|
286
|
+
logger.info(f"Creating {len(extracted_pairs)} contrastive pairs for layer {layer_index}")
|
|
287
|
+
|
|
288
|
+
for pair_data in tqdm(extracted_pairs, desc="Creating contrastive pairs"):
|
|
289
|
+
# Extract data from GSM8K format
|
|
290
|
+
try:
|
|
291
|
+
question = pair_data["question"]
|
|
292
|
+
correct_answer = pair_data["correct_answer"]
|
|
293
|
+
incorrect_answer = pair_data["incorrect_answer"]
|
|
294
|
+
|
|
295
|
+
# Extract activations for correct and incorrect responses
|
|
296
|
+
correct_activations = self._extract_activations_for_text(
|
|
297
|
+
f"{question} {correct_answer}", layer_index, model, tokenizer, device
|
|
298
|
+
)
|
|
299
|
+
incorrect_activations = self._extract_activations_for_text(
|
|
300
|
+
f"{question} {incorrect_answer}", layer_index, model, tokenizer, device
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Create Response objects
|
|
304
|
+
positive_response = Response(text=correct_answer, activations=correct_activations)
|
|
305
|
+
negative_response = Response(text=incorrect_answer, activations=incorrect_activations)
|
|
306
|
+
|
|
307
|
+
# Create ContrastivePair
|
|
308
|
+
contrastive_pair = ContrastivePair(
|
|
309
|
+
prompt=question, positive_response=positive_response, negative_response=negative_response
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
pair_set.pairs.append(contrastive_pair)
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.warning(f"Failed to create contrastive pair: {e}")
|
|
316
|
+
continue
|
|
317
|
+
|
|
318
|
+
logger.info(f"Successfully created ContrastivePairSet with {len(pair_set.pairs)} pairs")
|
|
319
|
+
return pair_set
|
|
320
|
+
|
|
321
|
+
def _extract_activations_for_text(self, text: str, layer_index: int, model, tokenizer, device: str) -> torch.Tensor:
|
|
322
|
+
"""Extract activations from a specific layer for given text."""
|
|
323
|
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128).to(device)
|
|
324
|
+
|
|
325
|
+
activations = []
|
|
326
|
+
|
|
327
|
+
def hook(module, input, output):
|
|
328
|
+
# Extract the last token's activations
|
|
329
|
+
hidden_states = output[0]
|
|
330
|
+
last_token_activations = hidden_states[:, -1, :]
|
|
331
|
+
activations.append(last_token_activations.detach().cpu())
|
|
332
|
+
|
|
333
|
+
# Handle different model architectures
|
|
334
|
+
if hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
335
|
+
# LLaMA-style models
|
|
336
|
+
layer_module = model.model.layers[layer_index]
|
|
337
|
+
elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
|
|
338
|
+
# GPT2-style models
|
|
339
|
+
layer_module = model.transformer.h[layer_index]
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError("Unsupported model architecture for activation extraction")
|
|
342
|
+
|
|
343
|
+
handle = layer_module.register_forward_hook(hook)
|
|
344
|
+
|
|
345
|
+
with torch.no_grad():
|
|
346
|
+
model(**inputs)
|
|
347
|
+
|
|
348
|
+
handle.remove()
|
|
349
|
+
return activations[0].squeeze(0)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class SteeringOptimizer:
|
|
353
|
+
"""
|
|
354
|
+
Optimizes steering methods for improving benchmark performance.
|
|
355
|
+
|
|
356
|
+
The steering optimization process:
|
|
357
|
+
1. Train steering methods on training data
|
|
358
|
+
2. Evaluate steering performance on validation data using benchmark metrics
|
|
359
|
+
3. Select best configuration based on benchmark performance
|
|
360
|
+
4. Test final steering method on test data
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
def __init__(self, cache_config: Optional[CacheConfig] = None):
|
|
364
|
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
365
|
+
self.trainers = {"dac": DACTrainer()}
|
|
366
|
+
|
|
367
|
+
# Initialize classifier cache for reusing trained classifiers
|
|
368
|
+
if cache_config is None:
|
|
369
|
+
cache_config = CacheConfig(cache_dir="./steering_classifier_cache")
|
|
370
|
+
self.classifier_cache = ClassifierCache(cache_config)
|
|
371
|
+
|
|
372
|
+
# Session-level classifier caching for current optimization run
|
|
373
|
+
self._session_classifier = None # Best classifier for current session
|
|
374
|
+
self._session_classifier_metadata = {} # Layer, model_type, performance, etc.
|
|
375
|
+
self._session_cache_key = None # Track current session
|
|
376
|
+
|
|
377
|
+
def register_trainer(self, method_name: str, trainer: SteeringMethodTrainer):
|
|
378
|
+
"""Register a new steering method trainer."""
|
|
379
|
+
self.trainers[method_name] = trainer
|
|
380
|
+
self.logger.info(f"Registered trainer for steering method: {method_name}")
|
|
381
|
+
|
|
382
|
+
def optimize_steering_hyperparameters(
|
|
383
|
+
self,
|
|
384
|
+
config: SteeringMethodConfig,
|
|
385
|
+
classifier_optimization_config: ClassifierOptimizationConfig,
|
|
386
|
+
train_samples: List[Dict],
|
|
387
|
+
validation_samples: List[Dict],
|
|
388
|
+
model,
|
|
389
|
+
tokenizer,
|
|
390
|
+
device: str,
|
|
391
|
+
batch_size: int = 32,
|
|
392
|
+
max_length: int = 512,
|
|
393
|
+
task_name: str = "gsm8k",
|
|
394
|
+
max_new_tokens: int = 200,
|
|
395
|
+
) -> Tuple[Dict[str, Any], List[SteeringResult]]:
|
|
396
|
+
"""
|
|
397
|
+
Optimize hyperparameters for a steering method using grid search.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
config: Steering method configuration with hyperparameter ranges
|
|
401
|
+
classifier_optimization_config: Configuration for classifier optimization
|
|
402
|
+
train_samples: Training samples for method training
|
|
403
|
+
validation_samples: Validation samples for evaluation
|
|
404
|
+
model: Language model
|
|
405
|
+
tokenizer: Model tokenizer
|
|
406
|
+
device: Device to run on
|
|
407
|
+
batch_size: Batch size for processing
|
|
408
|
+
max_length: Maximum sequence length
|
|
409
|
+
task_name: Task name for evaluation
|
|
410
|
+
max_new_tokens: Maximum tokens to generate
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
Tuple of (best_config, all_results)
|
|
414
|
+
"""
|
|
415
|
+
method_name = config.method_name
|
|
416
|
+
|
|
417
|
+
if method_name not in self.trainers:
|
|
418
|
+
raise ValueError(f"No trainer registered for method: {method_name}")
|
|
419
|
+
|
|
420
|
+
trainer = self.trainers[method_name]
|
|
421
|
+
|
|
422
|
+
# Load best classifier once at the start of optimization
|
|
423
|
+
self.logger.info("Loading/training classifier for evaluation...")
|
|
424
|
+
contrastive_pairs = data_utils.get_task_contrastive_pairs(train_samples, task_name)
|
|
425
|
+
|
|
426
|
+
classifier = self.load_or_find_best_classifier(
|
|
427
|
+
model=model, optimization_config=classifier_optimization_config, contrastive_pairs=contrastive_pairs
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
if classifier is None:
|
|
431
|
+
raise ValueError(
|
|
432
|
+
f"Could not load or train classifier for {classifier_optimization_config.model_name}/{task_name}"
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
self.logger.info(f"Using classifier: {self._session_classifier_metadata}")
|
|
436
|
+
|
|
437
|
+
# Collect baseline predictions once for all trials
|
|
438
|
+
self.logger.info("Collecting baseline predictions for comparison...")
|
|
439
|
+
baseline_predictions, ground_truths = self.collect_baseline_predictions(
|
|
440
|
+
validation_samples, model, tokenizer, classifier, device, batch_size, max_length, task_name, max_new_tokens
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Calculate baseline metrics with integrated classifier scoring
|
|
444
|
+
classifier_scorer = lambda predictions, description: self.score_predictions_with_classifier(
|
|
445
|
+
predictions, model, tokenizer, device, max_length, description
|
|
446
|
+
)
|
|
447
|
+
baseline_benchmark_metrics = metrics.evaluate_benchmark_performance(
|
|
448
|
+
baseline_predictions, ground_truths, task_name, classifier_scorer=classifier_scorer
|
|
449
|
+
)
|
|
450
|
+
self.logger.info(f"Baseline performance: {baseline_benchmark_metrics}")
|
|
451
|
+
|
|
452
|
+
# Generate all hyperparameter combinations
|
|
453
|
+
hyperparameter_combinations = self._generate_hyperparameter_combinations(config)
|
|
454
|
+
|
|
455
|
+
self.logger.info(f"Starting {method_name} optimization with {len(hyperparameter_combinations)} configurations")
|
|
456
|
+
|
|
457
|
+
best_config = None
|
|
458
|
+
best_score = -1
|
|
459
|
+
all_results = []
|
|
460
|
+
|
|
461
|
+
for i, (layer, strength, hyperparams) in enumerate(
|
|
462
|
+
tqdm(hyperparameter_combinations, desc="Optimizing steering hyperparameters")
|
|
463
|
+
):
|
|
464
|
+
self.logger.debug(
|
|
465
|
+
f"Testing {method_name} config {i + 1}/{len(hyperparameter_combinations)}: "
|
|
466
|
+
f"layer={layer}, strength={strength}, hyperparams={hyperparams}"
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
try:
|
|
470
|
+
# Create method instance
|
|
471
|
+
method_instance = trainer.create_method_instance(hyperparams, device)
|
|
472
|
+
|
|
473
|
+
# Train the method
|
|
474
|
+
training_success, training_stats = trainer.train_method(
|
|
475
|
+
method_instance, train_samples, layer, model, tokenizer, device, task_name, max_new_tokens
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
if not training_success:
|
|
479
|
+
self.logger.warning(f"Training failed for config {i + 1}")
|
|
480
|
+
result = SteeringResult(
|
|
481
|
+
method_name=method_name,
|
|
482
|
+
layer=layer,
|
|
483
|
+
hyperparameters={**hyperparams, "strength": strength},
|
|
484
|
+
benchmark_metrics={"accuracy": 0.0},
|
|
485
|
+
training_success=False,
|
|
486
|
+
training_stats=training_stats,
|
|
487
|
+
)
|
|
488
|
+
all_results.append(result)
|
|
489
|
+
continue
|
|
490
|
+
|
|
491
|
+
# Evaluate on validation data with steering
|
|
492
|
+
steered_predictions, steered_ground_truths = trainer.apply_steering_and_evaluate(
|
|
493
|
+
method_instance,
|
|
494
|
+
validation_samples,
|
|
495
|
+
layer,
|
|
496
|
+
strength,
|
|
497
|
+
model,
|
|
498
|
+
tokenizer,
|
|
499
|
+
device,
|
|
500
|
+
batch_size,
|
|
501
|
+
max_length,
|
|
502
|
+
task_name,
|
|
503
|
+
max_new_tokens,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Compare baseline vs steered predictions using enhanced metrics
|
|
507
|
+
enhanced_metrics = self.compare_predictions(
|
|
508
|
+
baseline_predictions,
|
|
509
|
+
steered_predictions,
|
|
510
|
+
ground_truths,
|
|
511
|
+
model,
|
|
512
|
+
tokenizer,
|
|
513
|
+
device,
|
|
514
|
+
max_length,
|
|
515
|
+
task_name,
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Extract steered metrics for compatibility
|
|
519
|
+
benchmark_metrics = enhanced_metrics["steered"]
|
|
520
|
+
baseline_metrics_for_result = enhanced_metrics["baseline"]
|
|
521
|
+
comparative_metrics = enhanced_metrics["improvement"]
|
|
522
|
+
|
|
523
|
+
result = SteeringResult(
|
|
524
|
+
method_name=method_name,
|
|
525
|
+
layer=layer,
|
|
526
|
+
hyperparameters={**hyperparams, "strength": strength},
|
|
527
|
+
benchmark_metrics=benchmark_metrics,
|
|
528
|
+
baseline_metrics=baseline_metrics_for_result,
|
|
529
|
+
comparative_metrics=comparative_metrics,
|
|
530
|
+
training_success=True,
|
|
531
|
+
training_stats=training_stats,
|
|
532
|
+
)
|
|
533
|
+
all_results.append(result)
|
|
534
|
+
|
|
535
|
+
# Standard Optuna practice: optimize steered accuracy directly
|
|
536
|
+
steered_accuracy = benchmark_metrics.get("accuracy", 0.0)
|
|
537
|
+
baseline_accuracy = baseline_metrics_for_result.get("accuracy", 0.0)
|
|
538
|
+
improvement_delta = steered_accuracy - baseline_accuracy
|
|
539
|
+
|
|
540
|
+
if steered_accuracy > best_score:
|
|
541
|
+
best_score = steered_accuracy
|
|
542
|
+
best_config = {
|
|
543
|
+
"method": method_name,
|
|
544
|
+
"layer": layer,
|
|
545
|
+
"strength": strength,
|
|
546
|
+
**hyperparams,
|
|
547
|
+
"benchmark_metrics": benchmark_metrics,
|
|
548
|
+
"baseline_metrics": baseline_metrics_for_result,
|
|
549
|
+
"method_instance": method_instance,
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
self.logger.debug(
|
|
553
|
+
f"Config {i + 1} - Baseline: {baseline_accuracy:.3f}, "
|
|
554
|
+
f"Steered: {steered_accuracy:.3f}, Delta: {improvement_delta:+.3f}"
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
except Exception as e:
|
|
558
|
+
self.logger.error(f"Failed to evaluate config {i + 1}: {e}")
|
|
559
|
+
result = SteeringResult(
|
|
560
|
+
method_name=method_name,
|
|
561
|
+
layer=layer,
|
|
562
|
+
hyperparameters={**hyperparams, "strength": strength},
|
|
563
|
+
benchmark_metrics={"accuracy": 0.0},
|
|
564
|
+
baseline_metrics=baseline_benchmark_metrics,
|
|
565
|
+
comparative_metrics={"accuracy_delta": 0.0, "improvement_rate": 0.0},
|
|
566
|
+
training_success=False,
|
|
567
|
+
training_stats={"error": str(e)},
|
|
568
|
+
)
|
|
569
|
+
all_results.append(result)
|
|
570
|
+
continue
|
|
571
|
+
|
|
572
|
+
if best_config is None:
|
|
573
|
+
self.logger.warning("No successful steering configuration found")
|
|
574
|
+
# Return a default configuration
|
|
575
|
+
best_config = {
|
|
576
|
+
"method": method_name,
|
|
577
|
+
"layer": config.layers[0] if config.layers else 0,
|
|
578
|
+
"strength": config.strengths[0] if config.strengths else 1.0,
|
|
579
|
+
"benchmark_metrics": {"accuracy": 0.0},
|
|
580
|
+
"method_instance": None,
|
|
581
|
+
}
|
|
582
|
+
else:
|
|
583
|
+
steered_acc = best_config["benchmark_metrics"]["accuracy"]
|
|
584
|
+
baseline_acc = best_config.get("baseline_metrics", {}).get("accuracy", 0.0)
|
|
585
|
+
improvement = steered_acc - baseline_acc
|
|
586
|
+
|
|
587
|
+
self.logger.info(
|
|
588
|
+
f"Best {method_name} config (optimized for steered accuracy): "
|
|
589
|
+
f"layer={best_config['layer']}, steered={steered_acc:.3f} "
|
|
590
|
+
f"(baseline={baseline_acc:.3f}, Δ={improvement:+.3f})"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
return best_config, all_results
|
|
594
|
+
|
|
595
|
+
def _generate_hyperparameter_combinations(
|
|
596
|
+
self, config: SteeringMethodConfig
|
|
597
|
+
) -> List[Tuple[int, float, Dict[str, Any]]]:
|
|
598
|
+
"""Generate all combinations of hyperparameters for grid search."""
|
|
599
|
+
combinations = []
|
|
600
|
+
|
|
601
|
+
if isinstance(config, DACConfig):
|
|
602
|
+
# Generate DAC hyperparameter combinations
|
|
603
|
+
for layer in config.layers:
|
|
604
|
+
for strength in config.strengths:
|
|
605
|
+
for entropy_threshold in config.entropy_thresholds:
|
|
606
|
+
for ptop in config.ptop_values:
|
|
607
|
+
for max_alpha in config.max_alpha_values:
|
|
608
|
+
hyperparams = {
|
|
609
|
+
"entropy_threshold": entropy_threshold,
|
|
610
|
+
"ptop": ptop,
|
|
611
|
+
"max_alpha": max_alpha,
|
|
612
|
+
}
|
|
613
|
+
combinations.append((layer, strength, hyperparams))
|
|
614
|
+
else:
|
|
615
|
+
# Generic handling for other steering methods
|
|
616
|
+
for layer in config.layers:
|
|
617
|
+
for strength in config.strengths:
|
|
618
|
+
combinations.append((layer, strength, {}))
|
|
619
|
+
|
|
620
|
+
return combinations
|
|
621
|
+
|
|
622
|
+
def collect_baseline_predictions(
|
|
623
|
+
self,
|
|
624
|
+
evaluation_samples: List[Dict],
|
|
625
|
+
model,
|
|
626
|
+
tokenizer,
|
|
627
|
+
classifier: Classifier,
|
|
628
|
+
device: str,
|
|
629
|
+
batch_size: int,
|
|
630
|
+
max_length: int,
|
|
631
|
+
task_name: str,
|
|
632
|
+
max_new_tokens: int = 200,
|
|
633
|
+
) -> Tuple[List[str], List[str]]:
|
|
634
|
+
"""
|
|
635
|
+
Collect unsteered model predictions for baseline comparison.
|
|
636
|
+
Uses the same evaluation logic as steered evaluation but without steering hooks.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
evaluation_samples: Samples to evaluate
|
|
640
|
+
model: Language model
|
|
641
|
+
tokenizer: Model tokenizer
|
|
642
|
+
classifier: Trained classifier for evaluation
|
|
643
|
+
device: Device to run on
|
|
644
|
+
batch_size: Batch size for processing
|
|
645
|
+
max_length: Maximum sequence length
|
|
646
|
+
task_name: Task name for evaluation
|
|
647
|
+
max_new_tokens: Maximum tokens to generate
|
|
648
|
+
|
|
649
|
+
Returns:
|
|
650
|
+
Tuple of (predictions, ground_truths)
|
|
651
|
+
"""
|
|
652
|
+
predictions = []
|
|
653
|
+
ground_truths = []
|
|
654
|
+
|
|
655
|
+
# Get the task and its extractor
|
|
656
|
+
task = get_task(task_name)
|
|
657
|
+
extractor = task.get_extractor()
|
|
658
|
+
|
|
659
|
+
# Pre-extract all questions and answers (optimization)
|
|
660
|
+
questions = []
|
|
661
|
+
answers = []
|
|
662
|
+
|
|
663
|
+
for sample in evaluation_samples:
|
|
664
|
+
qa_pair = extractor.extract_qa_pair(sample, task)
|
|
665
|
+
if not qa_pair:
|
|
666
|
+
self.logger.warning(f"Skipping sample - extractor couldn't extract QA pair: {sample.keys()}")
|
|
667
|
+
continue
|
|
668
|
+
questions.append(qa_pair["formatted_question"])
|
|
669
|
+
answers.append(qa_pair["correct_answer"])
|
|
670
|
+
|
|
671
|
+
# Process questions WITHOUT steering in batches
|
|
672
|
+
ground_truths.extend(answers)
|
|
673
|
+
|
|
674
|
+
# Process in batches without steering
|
|
675
|
+
for i in tqdm(range(0, len(questions), batch_size), desc="Generating baseline predictions"):
|
|
676
|
+
batch_questions = questions[i : i + batch_size]
|
|
677
|
+
|
|
678
|
+
# Tokenize batch with padding for generation
|
|
679
|
+
inputs = tokenizer(
|
|
680
|
+
batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=max_length
|
|
681
|
+
).to(device)
|
|
682
|
+
|
|
683
|
+
with torch.no_grad():
|
|
684
|
+
outputs = model.generate(
|
|
685
|
+
**inputs,
|
|
686
|
+
max_new_tokens=max_new_tokens,
|
|
687
|
+
do_sample=True,
|
|
688
|
+
temperature=0.7,
|
|
689
|
+
pad_token_id=tokenizer.eos_token_id,
|
|
690
|
+
use_cache=False, # Disable cache to avoid cache_position errors
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
# Decode responses for each item in batch
|
|
694
|
+
for j, (output, question) in enumerate(zip(outputs, batch_questions)):
|
|
695
|
+
response = tokenizer.decode(output, skip_special_tokens=True)
|
|
696
|
+
prediction = response[len(question) :].strip()
|
|
697
|
+
predictions.append(prediction)
|
|
698
|
+
|
|
699
|
+
return predictions, ground_truths
|
|
700
|
+
|
|
701
|
+
def _extract_activation_for_text(
|
|
702
|
+
self,
|
|
703
|
+
text: str,
|
|
704
|
+
layer_index: int,
|
|
705
|
+
aggregation_strategy: str,
|
|
706
|
+
model,
|
|
707
|
+
tokenizer,
|
|
708
|
+
device: str,
|
|
709
|
+
max_length: int = 512,
|
|
710
|
+
) -> torch.Tensor:
|
|
711
|
+
"""
|
|
712
|
+
Extract activation from text at specified layer with aggregation.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
text: Input text to extract activation from
|
|
716
|
+
layer_index: Layer index to extract from
|
|
717
|
+
aggregation_strategy: Aggregation strategy string (e.g., "mean_pooling")
|
|
718
|
+
model: Language model
|
|
719
|
+
tokenizer: Model tokenizer
|
|
720
|
+
device: Device to run on
|
|
721
|
+
max_length: Maximum sequence length
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
Aggregated activation tensor
|
|
725
|
+
"""
|
|
726
|
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
|
727
|
+
activations = []
|
|
728
|
+
|
|
729
|
+
def hook(module, input, output):
|
|
730
|
+
# Extract hidden states from the layer
|
|
731
|
+
hidden_states = output[0] if isinstance(output, tuple) else output
|
|
732
|
+
activations.append(hidden_states.detach().cpu())
|
|
733
|
+
|
|
734
|
+
# Handle different model architectures
|
|
735
|
+
if hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
736
|
+
# LLaMA-style models
|
|
737
|
+
layer_module = model.model.layers[layer_index]
|
|
738
|
+
elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
|
|
739
|
+
# GPT2-style models
|
|
740
|
+
layer_module = model.transformer.h[layer_index]
|
|
741
|
+
else:
|
|
742
|
+
raise ValueError("Unsupported model architecture for activation extraction")
|
|
743
|
+
|
|
744
|
+
# Register hook and run forward pass
|
|
745
|
+
handle = layer_module.register_forward_hook(hook)
|
|
746
|
+
try:
|
|
747
|
+
with torch.no_grad():
|
|
748
|
+
_ = model(**inputs)
|
|
749
|
+
finally:
|
|
750
|
+
handle.remove()
|
|
751
|
+
|
|
752
|
+
if not activations:
|
|
753
|
+
raise ValueError("No activations extracted")
|
|
754
|
+
|
|
755
|
+
# Get the activation tensor [1, seq_len, hidden_dim]
|
|
756
|
+
activation_tensor = activations[0]
|
|
757
|
+
|
|
758
|
+
# Apply aggregation strategy
|
|
759
|
+
if (
|
|
760
|
+
aggregation_strategy == "mean_pooling"
|
|
761
|
+
or aggregation_strategy == ActivationAggregationStrategy.MEAN_POOLING.value
|
|
762
|
+
):
|
|
763
|
+
aggregated = torch.mean(activation_tensor, dim=1) # [1, hidden_dim]
|
|
764
|
+
elif (
|
|
765
|
+
aggregation_strategy == "last_token"
|
|
766
|
+
or aggregation_strategy == ActivationAggregationStrategy.LAST_TOKEN.value
|
|
767
|
+
):
|
|
768
|
+
aggregated = activation_tensor[:, -1, :] # [1, hidden_dim]
|
|
769
|
+
elif (
|
|
770
|
+
aggregation_strategy == "first_token"
|
|
771
|
+
or aggregation_strategy == ActivationAggregationStrategy.FIRST_TOKEN.value
|
|
772
|
+
):
|
|
773
|
+
aggregated = activation_tensor[:, 0, :] # [1, hidden_dim]
|
|
774
|
+
elif (
|
|
775
|
+
aggregation_strategy == "max_pooling"
|
|
776
|
+
or aggregation_strategy == ActivationAggregationStrategy.MAX_POOLING.value
|
|
777
|
+
):
|
|
778
|
+
aggregated = torch.max(activation_tensor, dim=1)[0] # [1, hidden_dim]
|
|
779
|
+
else:
|
|
780
|
+
# Default to mean pooling if unknown
|
|
781
|
+
self.logger.warning(f"Unknown aggregation strategy {aggregation_strategy}, using mean pooling")
|
|
782
|
+
aggregated = torch.mean(activation_tensor, dim=1)
|
|
783
|
+
|
|
784
|
+
return aggregated.squeeze(0) # Return [hidden_dim] tensor
|
|
785
|
+
|
|
786
|
+
def score_predictions_with_classifier(
|
|
787
|
+
self,
|
|
788
|
+
predictions: List[str],
|
|
789
|
+
model,
|
|
790
|
+
tokenizer,
|
|
791
|
+
device: str,
|
|
792
|
+
max_length: int = 512,
|
|
793
|
+
description: str = "predictions",
|
|
794
|
+
) -> List[float]:
|
|
795
|
+
"""
|
|
796
|
+
Score predictions using the cached classifier.
|
|
797
|
+
|
|
798
|
+
This is the core feature that was requested - using the optimized classifier
|
|
799
|
+
to score unsteered vs steered generations.
|
|
800
|
+
|
|
801
|
+
Args:
|
|
802
|
+
predictions: Text predictions to score
|
|
803
|
+
model: Language model for activation extraction
|
|
804
|
+
tokenizer: Model tokenizer
|
|
805
|
+
device: Device to run on
|
|
806
|
+
max_length: Maximum sequence length
|
|
807
|
+
description: Description for logging
|
|
808
|
+
|
|
809
|
+
Returns:
|
|
810
|
+
List of classifier scores/probabilities for each prediction
|
|
811
|
+
"""
|
|
812
|
+
if self._session_classifier is None:
|
|
813
|
+
self.logger.warning("No cached classifier available for scoring")
|
|
814
|
+
return [0.5] * len(predictions) # Return neutral scores
|
|
815
|
+
|
|
816
|
+
if not predictions:
|
|
817
|
+
self.logger.debug("No predictions to score")
|
|
818
|
+
return []
|
|
819
|
+
|
|
820
|
+
# Get classifier metadata
|
|
821
|
+
layer = self._session_classifier_metadata.get("layer", 12)
|
|
822
|
+
aggregation = self._session_classifier_metadata.get("aggregation", "mean_pooling")
|
|
823
|
+
|
|
824
|
+
self.logger.info(
|
|
825
|
+
f"Scoring {len(predictions)} {description} with cached classifier (layer={layer}, aggregation={aggregation})"
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
confidence_scores = []
|
|
829
|
+
|
|
830
|
+
# Process predictions in batches for efficiency
|
|
831
|
+
batch_size = 8 # Smaller batch size to avoid OOM
|
|
832
|
+
for i in range(0, len(predictions), batch_size):
|
|
833
|
+
batch_predictions = predictions[i : i + batch_size]
|
|
834
|
+
batch_activations = []
|
|
835
|
+
|
|
836
|
+
# Extract activations for each prediction in the batch
|
|
837
|
+
for pred_text in batch_predictions:
|
|
838
|
+
try:
|
|
839
|
+
# Extract activation for this prediction text
|
|
840
|
+
activation = self._extract_activation_for_text(
|
|
841
|
+
text=pred_text,
|
|
842
|
+
layer_index=layer,
|
|
843
|
+
aggregation_strategy=aggregation,
|
|
844
|
+
model=model,
|
|
845
|
+
tokenizer=tokenizer,
|
|
846
|
+
device=device,
|
|
847
|
+
max_length=max_length,
|
|
848
|
+
)
|
|
849
|
+
batch_activations.append(activation)
|
|
850
|
+
|
|
851
|
+
except Exception as e:
|
|
852
|
+
self.logger.debug(f"Failed to extract activation for prediction: {e}")
|
|
853
|
+
# Use neutral score for failed extractions
|
|
854
|
+
confidence_scores.append(0.5)
|
|
855
|
+
continue
|
|
856
|
+
|
|
857
|
+
if batch_activations:
|
|
858
|
+
try:
|
|
859
|
+
# Stack activations into batch tensor
|
|
860
|
+
batch_tensor = torch.stack(batch_activations)
|
|
861
|
+
|
|
862
|
+
# Convert to numpy for sklearn classifier
|
|
863
|
+
batch_numpy = batch_tensor.detach().cpu().numpy()
|
|
864
|
+
|
|
865
|
+
# Get prediction probabilities from classifier
|
|
866
|
+
probabilities = self._session_classifier.predict_proba(batch_numpy)
|
|
867
|
+
|
|
868
|
+
# Extract confidence scores (probability for positive class)
|
|
869
|
+
# Assuming binary classification with class 1 as positive
|
|
870
|
+
if probabilities.shape[1] > 1:
|
|
871
|
+
batch_scores = probabilities[:, 1].tolist() # Probability of positive class
|
|
872
|
+
else:
|
|
873
|
+
batch_scores = probabilities[:, 0].tolist() # Single class probability
|
|
874
|
+
|
|
875
|
+
confidence_scores.extend(batch_scores)
|
|
876
|
+
|
|
877
|
+
except Exception as e:
|
|
878
|
+
self.logger.warning(f"Failed to score batch of activations: {e}")
|
|
879
|
+
# Add neutral scores for failed batch
|
|
880
|
+
confidence_scores.extend([0.5] * len(batch_activations))
|
|
881
|
+
|
|
882
|
+
# Ensure we have scores for all predictions
|
|
883
|
+
while len(confidence_scores) < len(predictions):
|
|
884
|
+
confidence_scores.append(0.5) # Pad with neutral scores if needed
|
|
885
|
+
|
|
886
|
+
# Truncate if we have too many scores (shouldn't happen)
|
|
887
|
+
confidence_scores = confidence_scores[: len(predictions)]
|
|
888
|
+
|
|
889
|
+
# Log statistics
|
|
890
|
+
avg_score = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0.5
|
|
891
|
+
self.logger.debug(
|
|
892
|
+
f"Generated {len(confidence_scores)} classifier confidence scores for {description} (avg={avg_score:.3f})"
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
return confidence_scores
|
|
896
|
+
|
|
897
|
+
def compare_predictions(
|
|
898
|
+
self,
|
|
899
|
+
baseline_predictions: List[str],
|
|
900
|
+
steered_predictions: List[str],
|
|
901
|
+
ground_truths: List[str],
|
|
902
|
+
model,
|
|
903
|
+
tokenizer,
|
|
904
|
+
device: str,
|
|
905
|
+
max_length: int = 512,
|
|
906
|
+
task_name: str = "gsm8k",
|
|
907
|
+
) -> Dict[str, Any]:
|
|
908
|
+
"""
|
|
909
|
+
Compare baseline vs steered predictions using benchmark metrics and classifier scores.
|
|
910
|
+
|
|
911
|
+
Args:
|
|
912
|
+
baseline_predictions: Unsteered model predictions
|
|
913
|
+
steered_predictions: Steered model predictions
|
|
914
|
+
ground_truths: Ground truth answers
|
|
915
|
+
model: Language model for classifier scoring
|
|
916
|
+
tokenizer: Model tokenizer
|
|
917
|
+
device: Device to run on
|
|
918
|
+
max_length: Maximum sequence length
|
|
919
|
+
task_name: Task name for evaluation metrics
|
|
920
|
+
|
|
921
|
+
Returns:
|
|
922
|
+
Enhanced metrics with baseline vs steered comparison including classifier scores
|
|
923
|
+
"""
|
|
924
|
+
# Create classifier scorer function for metrics integration
|
|
925
|
+
classifier_scorer = lambda predictions, description: self.score_predictions_with_classifier(
|
|
926
|
+
predictions, model, tokenizer, device, max_length, description
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
# Calculate standard benchmark metrics with integrated classifier confidence scores
|
|
930
|
+
baseline_metrics = metrics.evaluate_benchmark_performance(
|
|
931
|
+
baseline_predictions, ground_truths, task_name, classifier_scorer=classifier_scorer
|
|
932
|
+
)
|
|
933
|
+
steered_metrics = metrics.evaluate_benchmark_performance(
|
|
934
|
+
steered_predictions, ground_truths, task_name, classifier_scorer=classifier_scorer
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
# Extract classifier scores from integrated metrics
|
|
938
|
+
baseline_scores = [
|
|
939
|
+
detail.get("classifier_confidence", 0.5) for detail in baseline_metrics.get("evaluation_details", [])
|
|
940
|
+
]
|
|
941
|
+
steered_scores = [
|
|
942
|
+
detail.get("classifier_confidence", 0.5) for detail in steered_metrics.get("evaluation_details", [])
|
|
943
|
+
]
|
|
944
|
+
|
|
945
|
+
# Calculate improvement metrics
|
|
946
|
+
accuracy_delta = steered_metrics.get("accuracy", 0) - baseline_metrics.get("accuracy", 0)
|
|
947
|
+
f1_delta = steered_metrics.get("f1", 0) - baseline_metrics.get("f1", 0)
|
|
948
|
+
|
|
949
|
+
# Calculate classifier score improvements
|
|
950
|
+
avg_baseline_score = sum(baseline_scores) / len(baseline_scores) if baseline_scores else 0.0
|
|
951
|
+
avg_steered_score = sum(steered_scores) / len(steered_scores) if steered_scores else 0.0
|
|
952
|
+
classifier_score_delta = avg_steered_score - avg_baseline_score
|
|
953
|
+
|
|
954
|
+
return {
|
|
955
|
+
"baseline": {
|
|
956
|
+
"accuracy": baseline_metrics.get("accuracy", 0.0),
|
|
957
|
+
"f1": baseline_metrics.get("f1", 0.0),
|
|
958
|
+
"classifier_scores": baseline_scores,
|
|
959
|
+
"avg_classifier_score": avg_baseline_score,
|
|
960
|
+
"predictions": baseline_predictions,
|
|
961
|
+
},
|
|
962
|
+
"steered": {
|
|
963
|
+
"accuracy": steered_metrics.get("accuracy", 0.0),
|
|
964
|
+
"f1": steered_metrics.get("f1", 0.0),
|
|
965
|
+
"classifier_scores": steered_scores,
|
|
966
|
+
"avg_classifier_score": avg_steered_score,
|
|
967
|
+
"predictions": steered_predictions,
|
|
968
|
+
},
|
|
969
|
+
"improvement": {
|
|
970
|
+
"accuracy_delta": accuracy_delta,
|
|
971
|
+
"f1_delta": f1_delta,
|
|
972
|
+
"classifier_score_delta": classifier_score_delta,
|
|
973
|
+
},
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
def load_or_find_best_classifier(
|
|
977
|
+
self,
|
|
978
|
+
model,
|
|
979
|
+
optimization_config: Optional[ClassifierOptimizationConfig] = None,
|
|
980
|
+
model_name: Optional[str] = None,
|
|
981
|
+
task_name: Optional[str] = None,
|
|
982
|
+
contrastive_pairs: Optional[List] = None,
|
|
983
|
+
force_reoptimize: bool = False,
|
|
984
|
+
) -> Optional[Classifier]:
|
|
985
|
+
"""
|
|
986
|
+
Load or train the best classifier for current steering session.
|
|
987
|
+
|
|
988
|
+
On first call: Run full classifier optimization and cache result for session
|
|
989
|
+
On subsequent calls: Return cached classifier from current session
|
|
990
|
+
|
|
991
|
+
Args:
|
|
992
|
+
model: Language model (wisent_guard Model wrapper)
|
|
993
|
+
optimization_config: Primary configuration source
|
|
994
|
+
model_name: Fallback model name if optimization_config not provided
|
|
995
|
+
task_name: Fallback task name if optimization_config not provided
|
|
996
|
+
contrastive_pairs: Training data for classifier optimization
|
|
997
|
+
force_reoptimize: Force reoptimization even if session classifier exists
|
|
998
|
+
|
|
999
|
+
Returns:
|
|
1000
|
+
Best trained classifier or None if optimization failed
|
|
1001
|
+
"""
|
|
1002
|
+
# Extract configuration
|
|
1003
|
+
if optimization_config is not None:
|
|
1004
|
+
model_name = optimization_config.model_name
|
|
1005
|
+
task_name = getattr(optimization_config, "task_name", task_name)
|
|
1006
|
+
limit = getattr(optimization_config, "data_limit", 100)
|
|
1007
|
+
else:
|
|
1008
|
+
limit = 100 # Default data limit
|
|
1009
|
+
|
|
1010
|
+
if not model_name or not task_name:
|
|
1011
|
+
raise ValueError("model_name and task_name must be provided either via optimization_config or directly")
|
|
1012
|
+
|
|
1013
|
+
# Create session cache key
|
|
1014
|
+
session_cache_key = f"{model_name}_{task_name}"
|
|
1015
|
+
|
|
1016
|
+
# Check if we already have a classifier for this session
|
|
1017
|
+
if (
|
|
1018
|
+
not force_reoptimize
|
|
1019
|
+
and self._session_classifier is not None
|
|
1020
|
+
and self._session_cache_key == session_cache_key
|
|
1021
|
+
):
|
|
1022
|
+
self.logger.info("Using cached classifier from current session")
|
|
1023
|
+
return self._session_classifier
|
|
1024
|
+
|
|
1025
|
+
# First call or forced reoptimization - run classifier optimization
|
|
1026
|
+
self.logger.info("Running classifier optimization (first trial in session)")
|
|
1027
|
+
|
|
1028
|
+
if not contrastive_pairs:
|
|
1029
|
+
self.logger.error("contrastive_pairs required for classifier optimization")
|
|
1030
|
+
return None
|
|
1031
|
+
|
|
1032
|
+
try:
|
|
1033
|
+
# Create configuration for classifier optimization if not provided
|
|
1034
|
+
if optimization_config is None:
|
|
1035
|
+
optimization_config = ClassifierOptimizationConfig(
|
|
1036
|
+
model_name=model_name,
|
|
1037
|
+
device="auto",
|
|
1038
|
+
n_trials=20, # Reasonable number for steering optimization
|
|
1039
|
+
model_types=["logistic", "mlp"],
|
|
1040
|
+
primary_metric="f1",
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
# Create generation config for activation pre-generation
|
|
1044
|
+
generation_config = GenerationConfig(
|
|
1045
|
+
layer_search_range=(0, 23), # Will be auto-detected from model
|
|
1046
|
+
aggregation_methods=[
|
|
1047
|
+
ActivationAggregationStrategy.MEAN_POOLING,
|
|
1048
|
+
ActivationAggregationStrategy.LAST_TOKEN,
|
|
1049
|
+
ActivationAggregationStrategy.FIRST_TOKEN,
|
|
1050
|
+
ActivationAggregationStrategy.MAX_POOLING,
|
|
1051
|
+
],
|
|
1052
|
+
cache_dir="./cache/steering_activations",
|
|
1053
|
+
device=optimization_config.device,
|
|
1054
|
+
batch_size=32,
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
# Create classifier optimizer
|
|
1058
|
+
classifier_optimizer = OptunaClassifierOptimizer(
|
|
1059
|
+
optimization_config=optimization_config,
|
|
1060
|
+
generation_config=generation_config,
|
|
1061
|
+
cache_config=self.classifier_cache.config,
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
# Run classifier optimization
|
|
1065
|
+
self.logger.info(f"Optimizing classifier for {model_name}/{task_name} with {len(contrastive_pairs)} pairs")
|
|
1066
|
+
result = classifier_optimizer.optimize(
|
|
1067
|
+
model=model,
|
|
1068
|
+
contrastive_pairs=contrastive_pairs,
|
|
1069
|
+
task_name=task_name,
|
|
1070
|
+
model_name=model_name,
|
|
1071
|
+
limit=limit,
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
if result.best_value > 0:
|
|
1075
|
+
# Get the best configuration and classifier
|
|
1076
|
+
best_config = result.get_best_config()
|
|
1077
|
+
best_classifier = result.best_classifier
|
|
1078
|
+
|
|
1079
|
+
# Cache for current session
|
|
1080
|
+
self._session_classifier = best_classifier
|
|
1081
|
+
self._session_classifier_metadata = {
|
|
1082
|
+
"layer": best_config["layer"],
|
|
1083
|
+
"aggregation": best_config["aggregation"],
|
|
1084
|
+
"model_type": best_config["model_type"],
|
|
1085
|
+
"threshold": best_config["threshold"],
|
|
1086
|
+
"f1_score": result.best_value,
|
|
1087
|
+
"hyperparameters": best_config.get("hyperparameters", {}),
|
|
1088
|
+
}
|
|
1089
|
+
self._session_cache_key = session_cache_key
|
|
1090
|
+
|
|
1091
|
+
self.logger.info(
|
|
1092
|
+
f"Cached best classifier for session: layer_{best_config['layer']} "
|
|
1093
|
+
f"{best_config['model_type']} (F1: {result.best_value:.3f})"
|
|
1094
|
+
)
|
|
1095
|
+
|
|
1096
|
+
return best_classifier
|
|
1097
|
+
self.logger.warning("Classifier optimization failed - no successful trials")
|
|
1098
|
+
return None
|
|
1099
|
+
|
|
1100
|
+
except Exception as e:
|
|
1101
|
+
self.logger.error(f"Failed to run classifier optimization: {e}")
|
|
1102
|
+
traceback.print_exc()
|
|
1103
|
+
return None
|
|
1104
|
+
|
|
1105
|
+
def get_cache_info(self) -> Dict[str, Any]:
|
|
1106
|
+
"""Get information about cached classifiers."""
|
|
1107
|
+
return self.classifier_cache.get_cache_info()
|
|
1108
|
+
|
|
1109
|
+
def clear_classifier_cache(self, keep_recent_hours: float = 24.0) -> int:
|
|
1110
|
+
"""Clear old cached classifiers."""
|
|
1111
|
+
return self.classifier_cache.clear_cache(keep_recent_hours=keep_recent_hours)
|