wisent 0.1.1__py3-none-any.whl ā 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info ā wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info ā wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info ā wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""Multi-steering functionality for combining multiple steering vectors."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import torch
|
|
5
|
+
from typing import List, Tuple, Optional, Dict, Any
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .layer import Layer
|
|
9
|
+
from .model import Model
|
|
10
|
+
from .steering_methods.caa import CAA
|
|
11
|
+
from .steering_methods.dac import DAC
|
|
12
|
+
from .utils.device import resolve_default_device
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MultiSteeringError(Exception):
|
|
16
|
+
"""Exception raised for multi-steering errors."""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MultiSteering:
|
|
21
|
+
"""Handles multi-steering vector combination and application."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, device: str | None = None, method: str = "CAA"):
|
|
24
|
+
"""Initialize multi-steering handler.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
device: Device to use for computations (cpu/cuda/mps)
|
|
28
|
+
method: Steering method to use for combination ("CAA" or "DAC")
|
|
29
|
+
"""
|
|
30
|
+
self.device = device or resolve_default_device()
|
|
31
|
+
self.method = method
|
|
32
|
+
self.loaded_vectors = []
|
|
33
|
+
self.weights = []
|
|
34
|
+
self.combined_vector = None
|
|
35
|
+
self.layer = None
|
|
36
|
+
|
|
37
|
+
def load_vectors(self, vector_specs: List[str]) -> None:
|
|
38
|
+
"""Load and validate steering vectors from file paths.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
vector_specs: List of "path:weight" specifications
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
MultiSteeringError: If vectors cannot be loaded or are incompatible
|
|
45
|
+
"""
|
|
46
|
+
if not vector_specs:
|
|
47
|
+
raise MultiSteeringError("No vectors specified")
|
|
48
|
+
|
|
49
|
+
self.loaded_vectors = []
|
|
50
|
+
self.weights = []
|
|
51
|
+
layers_found = set()
|
|
52
|
+
|
|
53
|
+
for spec in vector_specs:
|
|
54
|
+
parts = spec.split(":")
|
|
55
|
+
if len(parts) != 2:
|
|
56
|
+
raise MultiSteeringError(f"Invalid vector specification: {spec}. Expected format: path:weight")
|
|
57
|
+
|
|
58
|
+
vector_path = parts[0]
|
|
59
|
+
try:
|
|
60
|
+
weight = float(parts[1])
|
|
61
|
+
except ValueError:
|
|
62
|
+
raise MultiSteeringError(f"Invalid weight in {spec}. Must be a number.")
|
|
63
|
+
|
|
64
|
+
if not Path(vector_path).exists():
|
|
65
|
+
raise MultiSteeringError(f"Vector file not found: {vector_path}")
|
|
66
|
+
|
|
67
|
+
print(f"Loading vector from {vector_path} with weight {weight}")
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
vector_data = torch.load(vector_path, map_location=self.device)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
raise MultiSteeringError(f"Failed to load vector from {vector_path}: {e}")
|
|
73
|
+
|
|
74
|
+
# Extract metadata from loaded vector
|
|
75
|
+
if isinstance(vector_data, dict):
|
|
76
|
+
layer = vector_data.get("layer_index", None)
|
|
77
|
+
steering_vector = vector_data.get("steering_vector", None)
|
|
78
|
+
|
|
79
|
+
if steering_vector is None:
|
|
80
|
+
raise MultiSteeringError(f"No steering vector found in {vector_path}")
|
|
81
|
+
|
|
82
|
+
if layer is not None:
|
|
83
|
+
layers_found.add(layer)
|
|
84
|
+
|
|
85
|
+
self.loaded_vectors.append(vector_data)
|
|
86
|
+
self.weights.append(weight)
|
|
87
|
+
|
|
88
|
+
print(f" ā Loaded vector from layer {layer}")
|
|
89
|
+
else:
|
|
90
|
+
raise MultiSteeringError(f"Invalid vector format in {vector_path}")
|
|
91
|
+
|
|
92
|
+
# Validate compatibility
|
|
93
|
+
if len(layers_found) > 1:
|
|
94
|
+
raise MultiSteeringError(f"Vectors from different layers cannot be combined: {layers_found}")
|
|
95
|
+
|
|
96
|
+
if not layers_found:
|
|
97
|
+
raise MultiSteeringError("No layer information found in vectors")
|
|
98
|
+
|
|
99
|
+
self.layer = Layer(list(layers_found)[0])
|
|
100
|
+
|
|
101
|
+
print(f"\nUsing {self.method} method for vector combination")
|
|
102
|
+
print(f"Target layer: {self.layer.index}")
|
|
103
|
+
|
|
104
|
+
def combine_vectors(self, normalize: bool = True) -> torch.Tensor:
|
|
105
|
+
"""Combine loaded vectors using appropriate method.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
normalize: Whether to normalize the combined vector
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Combined steering vector as tensor
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
MultiSteeringError: If combination fails
|
|
115
|
+
"""
|
|
116
|
+
if not self.loaded_vectors:
|
|
117
|
+
raise MultiSteeringError("No vectors loaded")
|
|
118
|
+
|
|
119
|
+
print(f"\nš Combining {len(self.loaded_vectors)} vectors using {self.method}")
|
|
120
|
+
|
|
121
|
+
if self.method == "CAA":
|
|
122
|
+
# Create a CAA instance and use its proper combination method
|
|
123
|
+
caa = CAA(device=self.device)
|
|
124
|
+
|
|
125
|
+
# Set up behavior vectors dictionary
|
|
126
|
+
caa.behavior_vectors = {}
|
|
127
|
+
for i, (vector_data, weight) in enumerate(zip(self.loaded_vectors, self.weights)):
|
|
128
|
+
steering_vector = vector_data["steering_vector"]
|
|
129
|
+
|
|
130
|
+
if not isinstance(steering_vector, torch.Tensor):
|
|
131
|
+
steering_vector = torch.tensor(steering_vector, device=self.device)
|
|
132
|
+
else:
|
|
133
|
+
steering_vector = steering_vector.to(self.device)
|
|
134
|
+
|
|
135
|
+
# Store with unique names
|
|
136
|
+
behavior_name = f"vector_{i}"
|
|
137
|
+
caa.behavior_vectors[behavior_name] = steering_vector
|
|
138
|
+
|
|
139
|
+
# Create weights dictionary
|
|
140
|
+
behavior_weights = {f"vector_{i}": weight for i, weight in enumerate(self.weights)}
|
|
141
|
+
|
|
142
|
+
# Use CAA's combine_behaviors method with normalization
|
|
143
|
+
self.combined_vector = caa.combine_behaviors(behavior_weights, normalize_result=normalize)
|
|
144
|
+
|
|
145
|
+
else: # DAC or mixed methods
|
|
146
|
+
# For DAC, use its combine_steering_vectors method
|
|
147
|
+
vectors = []
|
|
148
|
+
for vector_data in self.loaded_vectors:
|
|
149
|
+
steering_vector = vector_data["steering_vector"]
|
|
150
|
+
|
|
151
|
+
if not isinstance(steering_vector, torch.Tensor):
|
|
152
|
+
steering_vector = torch.tensor(steering_vector, device=self.device)
|
|
153
|
+
else:
|
|
154
|
+
steering_vector = steering_vector.to(self.device)
|
|
155
|
+
|
|
156
|
+
vectors.append(steering_vector)
|
|
157
|
+
|
|
158
|
+
# Use DAC's static method for combination
|
|
159
|
+
self.combined_vector = DAC.combine_steering_vectors(
|
|
160
|
+
vectors, self.weights, normalize_weights=normalize
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
print(f" ā Combined vector shape: {self.combined_vector.shape}")
|
|
164
|
+
print(f" ā Combined vector norm: {torch.norm(self.combined_vector).item():.4f}")
|
|
165
|
+
|
|
166
|
+
return self.combined_vector
|
|
167
|
+
|
|
168
|
+
def apply_steering(self, model: Model, prompt: str, max_new_tokens: int = 100,
|
|
169
|
+
temperature: float = 0.7, top_p: float = 0.9) -> str:
|
|
170
|
+
"""Apply the combined steering vector to generate text.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
model: Model to use for generation
|
|
174
|
+
prompt: Input prompt
|
|
175
|
+
max_new_tokens: Maximum tokens to generate
|
|
176
|
+
temperature: Sampling temperature
|
|
177
|
+
top_p: Top-p sampling parameter
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Generated text
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
MultiSteeringError: If steering fails
|
|
184
|
+
"""
|
|
185
|
+
if self.combined_vector is None:
|
|
186
|
+
raise MultiSteeringError("No combined vector available. Call combine_vectors() first.")
|
|
187
|
+
|
|
188
|
+
if self.layer is None:
|
|
189
|
+
raise MultiSteeringError("No layer information available")
|
|
190
|
+
|
|
191
|
+
print(f"\nšÆ Applying combined steering vector at layer {self.layer.index}")
|
|
192
|
+
print(f"Prompt: {prompt}")
|
|
193
|
+
print("=" * 50)
|
|
194
|
+
|
|
195
|
+
# Create appropriate steering method instance
|
|
196
|
+
if self.method == "CAA":
|
|
197
|
+
steering_method = CAA(device=self.device)
|
|
198
|
+
steering_method.steering_vector = self.combined_vector
|
|
199
|
+
steering_method.layer_index = self.layer.index
|
|
200
|
+
steering_method.is_trained = True
|
|
201
|
+
else:
|
|
202
|
+
# Use DAC for other methods
|
|
203
|
+
steering_method = DAC(device=self.device)
|
|
204
|
+
steering_method.steering_vector = self.combined_vector
|
|
205
|
+
steering_method.layer_index = self.layer.index
|
|
206
|
+
steering_method.is_trained = True
|
|
207
|
+
|
|
208
|
+
# Set up steering hook
|
|
209
|
+
hooks = []
|
|
210
|
+
|
|
211
|
+
def steering_hook(module, input, output):
|
|
212
|
+
if isinstance(output, tuple):
|
|
213
|
+
hidden_states = output[0]
|
|
214
|
+
else:
|
|
215
|
+
hidden_states = output
|
|
216
|
+
|
|
217
|
+
# Apply steering using the method's apply_steering
|
|
218
|
+
steered = steering_method.apply_steering(hidden_states, strength=1.0)
|
|
219
|
+
|
|
220
|
+
if isinstance(output, tuple):
|
|
221
|
+
return (steered,) + output[1:]
|
|
222
|
+
return steered
|
|
223
|
+
|
|
224
|
+
# Find the target layer module
|
|
225
|
+
if hasattr(model.hf_model, "model") and hasattr(model.hf_model.model, "layers"):
|
|
226
|
+
layer_module = model.hf_model.model.layers[self.layer.index]
|
|
227
|
+
elif hasattr(model.hf_model, "transformer") and hasattr(model.hf_model.transformer, "h"):
|
|
228
|
+
layer_module = model.hf_model.transformer.h[self.layer.index]
|
|
229
|
+
else:
|
|
230
|
+
raise MultiSteeringError("Could not find model layers")
|
|
231
|
+
|
|
232
|
+
# Register hook
|
|
233
|
+
handle = layer_module.register_forward_hook(steering_hook)
|
|
234
|
+
hooks.append(handle)
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
# Generate with steering
|
|
238
|
+
output, _ = model.generate(
|
|
239
|
+
prompt=prompt,
|
|
240
|
+
layer_index=self.layer.index,
|
|
241
|
+
max_new_tokens=max_new_tokens,
|
|
242
|
+
temperature=temperature,
|
|
243
|
+
top_p=top_p,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
return output
|
|
247
|
+
|
|
248
|
+
except Exception as e:
|
|
249
|
+
raise MultiSteeringError(f"Failed to apply steering: {e}")
|
|
250
|
+
finally:
|
|
251
|
+
# Clean up hooks
|
|
252
|
+
for hook in hooks:
|
|
253
|
+
hook.remove()
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def run_multi_steer(
|
|
257
|
+
vector_specs: List[str],
|
|
258
|
+
model_name: str,
|
|
259
|
+
prompt: str,
|
|
260
|
+
method: str = "CAA",
|
|
261
|
+
layer: Optional[int] = None,
|
|
262
|
+
max_new_tokens: int = 100,
|
|
263
|
+
temperature: float = 0.7,
|
|
264
|
+
top_p: float = 0.9,
|
|
265
|
+
device: str | None = None,
|
|
266
|
+
verbose: bool = True,
|
|
267
|
+
) -> str:
|
|
268
|
+
"""Convenience function to run multi-steering.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
vector_specs: List of "path:weight" specifications
|
|
272
|
+
model_name: Name of model to load
|
|
273
|
+
prompt: Input prompt
|
|
274
|
+
method: Steering method to use ("CAA" or "DAC")
|
|
275
|
+
layer: Target layer (will be inferred from vectors if not specified)
|
|
276
|
+
max_new_tokens: Maximum tokens to generate
|
|
277
|
+
temperature: Sampling temperature
|
|
278
|
+
top_p: Top-p sampling parameter
|
|
279
|
+
device: Device to use
|
|
280
|
+
verbose: Whether to print progress
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Generated text
|
|
284
|
+
"""
|
|
285
|
+
# Initialize model
|
|
286
|
+
if verbose:
|
|
287
|
+
print(f"\nš Loading model: {model_name}")
|
|
288
|
+
|
|
289
|
+
chosen_device = device or resolve_default_device()
|
|
290
|
+
model = Model(model_name, device=chosen_device)
|
|
291
|
+
|
|
292
|
+
# Initialize multi-steering with specified method
|
|
293
|
+
multi_steer = MultiSteering(device=chosen_device, method=method)
|
|
294
|
+
|
|
295
|
+
# Load vectors
|
|
296
|
+
multi_steer.load_vectors(vector_specs)
|
|
297
|
+
|
|
298
|
+
# Override layer if specified
|
|
299
|
+
if layer is not None:
|
|
300
|
+
multi_steer.layer = Layer(layer)
|
|
301
|
+
if verbose:
|
|
302
|
+
print(f"Overriding layer to: {layer}")
|
|
303
|
+
|
|
304
|
+
# Combine vectors with normalization
|
|
305
|
+
multi_steer.combine_vectors(normalize=True)
|
|
306
|
+
|
|
307
|
+
# Apply steering
|
|
308
|
+
output = multi_steer.apply_steering(
|
|
309
|
+
model=model,
|
|
310
|
+
prompt=prompt,
|
|
311
|
+
max_new_tokens=max_new_tokens,
|
|
312
|
+
temperature=temperature,
|
|
313
|
+
top_p=top_p
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return output
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Optuna-based Optimization Framework for Wisent Guard
|
|
3
|
+
|
|
4
|
+
This module provides Optuna-based hyperparameter optimization for both steering and classifier systems:
|
|
5
|
+
|
|
6
|
+
STEERING OPTIMIZATION:
|
|
7
|
+
1. Hyperparameter Optimization: Optuna-driven search for best steering parameters
|
|
8
|
+
2. Evaluation Pipeline: Comprehensive evaluation on multiple datasets
|
|
9
|
+
3. Reproducibility: Complete experiment tracking and reproduction
|
|
10
|
+
|
|
11
|
+
CLASSIFIER OPTIMIZATION:
|
|
12
|
+
1. Activation Pre-generation: Efficient caching of model activations
|
|
13
|
+
2. Model Training: Optimized logistic regression and MLP classifiers
|
|
14
|
+
3. Intelligent Caching: Avoid retraining identical configurations
|
|
15
|
+
4. Cross-validation: Robust performance evaluation
|
|
16
|
+
|
|
17
|
+
Key components:
|
|
18
|
+
- Steering: OptimizationPipeline, OptimizationConfig, metrics
|
|
19
|
+
- Classifier: OptunaClassifierOptimizer, GenerationConfig, CacheConfig
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# Steering optimization components
|
|
23
|
+
# Classifier optimization components
|
|
24
|
+
from wisent_guard.core.optuna.classifier import (
|
|
25
|
+
ActivationGenerator,
|
|
26
|
+
CacheConfig,
|
|
27
|
+
ClassifierCache,
|
|
28
|
+
ClassifierOptimizationConfig as ClassifierOptimizationConfig,
|
|
29
|
+
GenerationConfig,
|
|
30
|
+
OptimizationResult,
|
|
31
|
+
OptunaClassifierOptimizer,
|
|
32
|
+
)
|
|
33
|
+
from wisent_guard.core.optuna.steering.metrics import (
|
|
34
|
+
calculate_comprehensive_metrics,
|
|
35
|
+
evaluate_benchmark_performance,
|
|
36
|
+
evaluate_probe_performance,
|
|
37
|
+
generate_performance_summary,
|
|
38
|
+
)
|
|
39
|
+
from wisent_guard.core.optuna.steering.optuna_pipeline import OptimizationConfig, OptimizationPipeline
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
# Steering optimization
|
|
43
|
+
"OptimizationConfig",
|
|
44
|
+
"OptimizationPipeline",
|
|
45
|
+
"calculate_comprehensive_metrics",
|
|
46
|
+
"evaluate_benchmark_performance",
|
|
47
|
+
"evaluate_probe_performance",
|
|
48
|
+
"generate_performance_summary",
|
|
49
|
+
# Classifier optimization
|
|
50
|
+
"OptunaClassifierOptimizer",
|
|
51
|
+
"ClassifierOptimizationConfig",
|
|
52
|
+
"GenerationConfig",
|
|
53
|
+
"CacheConfig",
|
|
54
|
+
"ActivationGenerator",
|
|
55
|
+
"ClassifierCache",
|
|
56
|
+
"OptimizationResult",
|
|
57
|
+
]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Optuna-based classifier optimization module.
|
|
3
|
+
|
|
4
|
+
This module provides modern, efficient classifier optimization using Optuna with
|
|
5
|
+
intelligent caching and pre-generation of activations for maximum performance.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .activation_generator import ActivationData, ActivationGenerator, GenerationConfig
|
|
9
|
+
from .classifier_cache import CacheConfig, CacheMetadata, ClassifierCache
|
|
10
|
+
from .optuna_classifier_optimizer import ClassifierOptimizationConfig, OptimizationResult, OptunaClassifierOptimizer
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
# Activation generation
|
|
14
|
+
"ActivationGenerator",
|
|
15
|
+
"GenerationConfig",
|
|
16
|
+
"ActivationData",
|
|
17
|
+
# Classifier caching
|
|
18
|
+
"ClassifierCache",
|
|
19
|
+
"CacheConfig",
|
|
20
|
+
"CacheMetadata",
|
|
21
|
+
# Optuna optimization
|
|
22
|
+
"OptunaClassifierOptimizer",
|
|
23
|
+
"ClassifierOptimizationConfig",
|
|
24
|
+
"OptimizationResult",
|
|
25
|
+
]
|