wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
wisent/core/steering.py
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from wisent_guard.core.activations import Activations
|
|
11
|
+
from wisent_guard.core.classifier.classifier import Classifier
|
|
12
|
+
|
|
13
|
+
from .contrastive_pairs import ContrastivePairSet
|
|
14
|
+
from .steering_method import CAA
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SteeringType(Enum):
|
|
18
|
+
LOGISTIC = "logistic"
|
|
19
|
+
MLP = "mlp"
|
|
20
|
+
CUSTOM = "custom"
|
|
21
|
+
CAA = "caa" # New vector-based steering
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SteeringMethod:
|
|
25
|
+
"""
|
|
26
|
+
Legacy classifier-based steering method for backward compatibility.
|
|
27
|
+
For new vector-based steering, use steering_method.CAA directly.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, method_type: SteeringType, device=None, threshold=0.5):
|
|
31
|
+
self.method_type = method_type
|
|
32
|
+
self.device = device
|
|
33
|
+
self.threshold = threshold
|
|
34
|
+
self.classifier = None
|
|
35
|
+
|
|
36
|
+
# For vector-based steering
|
|
37
|
+
self.vector_steering = None
|
|
38
|
+
self.is_vector_based = method_type == SteeringType.CAA
|
|
39
|
+
|
|
40
|
+
if self.is_vector_based:
|
|
41
|
+
self.vector_steering = CAA(device=device)
|
|
42
|
+
|
|
43
|
+
# Response logging settings
|
|
44
|
+
self.enable_logging = False
|
|
45
|
+
self.log_file_path = "./harmful_responses.json"
|
|
46
|
+
|
|
47
|
+
# Parameter optimization tracking
|
|
48
|
+
self.original_parameters = {}
|
|
49
|
+
self.optimization_history = []
|
|
50
|
+
|
|
51
|
+
def train(
|
|
52
|
+
self, contrastive_pair_set: ContrastivePairSet, layer_index: Optional[int] = None, **kwargs
|
|
53
|
+
) -> Dict[str, Any]:
|
|
54
|
+
"""
|
|
55
|
+
Train the steering method on a ContrastivePairSet.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
contrastive_pair_set: Set of contrastive pairs with activations
|
|
59
|
+
layer_index: Layer index for vector-based steering (required for CAA)
|
|
60
|
+
**kwargs: Additional training parameters
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Dictionary with training metrics
|
|
64
|
+
"""
|
|
65
|
+
if self.is_vector_based:
|
|
66
|
+
if layer_index is None:
|
|
67
|
+
raise ValueError("layer_index required for vector-based steering methods")
|
|
68
|
+
return self.vector_steering.train(contrastive_pair_set, layer_index)
|
|
69
|
+
|
|
70
|
+
# Legacy classifier-based training
|
|
71
|
+
X, y = contrastive_pair_set.prepare_classifier_data()
|
|
72
|
+
|
|
73
|
+
if len(X) < 4:
|
|
74
|
+
raise ValueError(f"Need at least 4 training examples, got {len(X)}")
|
|
75
|
+
|
|
76
|
+
# Create classifier
|
|
77
|
+
self.classifier = Classifier(model_type=self.method_type.value, device=self.device, threshold=self.threshold)
|
|
78
|
+
|
|
79
|
+
# Train classifier
|
|
80
|
+
results = self.classifier.fit(X, y, **kwargs)
|
|
81
|
+
|
|
82
|
+
return results
|
|
83
|
+
|
|
84
|
+
def apply_steering(self, activations: torch.Tensor, strength: float = 1.0) -> torch.Tensor:
|
|
85
|
+
"""
|
|
86
|
+
Apply steering to activations (vector-based methods only).
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
activations: Input activations
|
|
90
|
+
strength: Steering strength
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Steered activations
|
|
94
|
+
"""
|
|
95
|
+
if not self.is_vector_based:
|
|
96
|
+
raise ValueError("apply_steering only available for vector-based methods")
|
|
97
|
+
|
|
98
|
+
return self.vector_steering.apply_steering(activations, strength)
|
|
99
|
+
|
|
100
|
+
def get_steering_vector(self) -> Optional[torch.Tensor]:
|
|
101
|
+
"""Get steering vector (vector-based methods only)."""
|
|
102
|
+
if not self.is_vector_based:
|
|
103
|
+
return None
|
|
104
|
+
return self.vector_steering.get_steering_vector()
|
|
105
|
+
|
|
106
|
+
def predict(self, activations) -> float:
|
|
107
|
+
"""
|
|
108
|
+
Predict if activations represent harmful behavior (classifier-based only).
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
activations: Activation tensor or Activations object
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Prediction score (0 = harmless, 1 = harmful)
|
|
115
|
+
"""
|
|
116
|
+
if self.is_vector_based:
|
|
117
|
+
raise ValueError("predict not available for vector-based methods")
|
|
118
|
+
|
|
119
|
+
if self.classifier is None:
|
|
120
|
+
raise ValueError("SteeringMethod not trained. Call train() first.")
|
|
121
|
+
|
|
122
|
+
return self.classifier.predict(activations)
|
|
123
|
+
|
|
124
|
+
def predict_proba(self, activations) -> float:
|
|
125
|
+
"""
|
|
126
|
+
Get prediction probability for activations (classifier-based only).
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
activations: Activation tensor or Activations object
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Probability score (0.0-1.0)
|
|
133
|
+
"""
|
|
134
|
+
if self.is_vector_based:
|
|
135
|
+
raise ValueError("predict_proba not available for vector-based methods")
|
|
136
|
+
|
|
137
|
+
if self.classifier is None:
|
|
138
|
+
raise ValueError("SteeringMethod not trained. Call train() first.")
|
|
139
|
+
|
|
140
|
+
return self.classifier.predict_proba(activations)
|
|
141
|
+
|
|
142
|
+
def is_harmful(self, activations, detailed=False) -> Union[bool, Dict[str, Any]]:
|
|
143
|
+
"""
|
|
144
|
+
Check if activations represent harmful content (classifier-based only).
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
activations: Activation tensor or Activations object
|
|
148
|
+
detailed: Whether to return detailed results
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Boolean or detailed dictionary
|
|
152
|
+
"""
|
|
153
|
+
if self.is_vector_based:
|
|
154
|
+
raise ValueError("is_harmful not available for vector-based methods")
|
|
155
|
+
|
|
156
|
+
if self.classifier is None:
|
|
157
|
+
raise ValueError("SteeringMethod not trained. Call train() first.")
|
|
158
|
+
|
|
159
|
+
# Get probability score
|
|
160
|
+
probability = self.predict_proba(activations)
|
|
161
|
+
is_harmful = probability >= self.threshold
|
|
162
|
+
|
|
163
|
+
if detailed:
|
|
164
|
+
return {
|
|
165
|
+
"is_harmful": is_harmful,
|
|
166
|
+
"probability": probability,
|
|
167
|
+
"threshold": self.threshold,
|
|
168
|
+
"method_type": self.method_type.value,
|
|
169
|
+
}
|
|
170
|
+
return is_harmful
|
|
171
|
+
|
|
172
|
+
def check_safety(self, text: str, model, layer) -> Dict[str, Any]:
|
|
173
|
+
"""
|
|
174
|
+
Comprehensive safety check for text using the model.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
text: Text to check
|
|
178
|
+
model: Model object for activation extraction
|
|
179
|
+
layer: Layer object for activation extraction
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Safety check results
|
|
183
|
+
"""
|
|
184
|
+
try:
|
|
185
|
+
# Extract activations from text
|
|
186
|
+
activations_tensor = model.extract_activations(text, layer)
|
|
187
|
+
|
|
188
|
+
# Create Activations object
|
|
189
|
+
activations = Activations(tensor=activations_tensor, layer=layer)
|
|
190
|
+
|
|
191
|
+
# Get detailed prediction
|
|
192
|
+
result = self.is_harmful(activations, detailed=True)
|
|
193
|
+
|
|
194
|
+
# Add text information
|
|
195
|
+
result.update(
|
|
196
|
+
{
|
|
197
|
+
"text": text[:100] + "..." if len(text) > 100 else text,
|
|
198
|
+
"text_length": len(text),
|
|
199
|
+
"layer_index": layer.index,
|
|
200
|
+
}
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return result
|
|
204
|
+
|
|
205
|
+
except Exception as e:
|
|
206
|
+
return {
|
|
207
|
+
"is_harmful": False,
|
|
208
|
+
"probability": 0.0,
|
|
209
|
+
"error": str(e),
|
|
210
|
+
"text": text[:100] + "..." if len(text) > 100 else text,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
def enable_response_logging(self, log_file_path: str = "./harmful_responses.json") -> None:
|
|
214
|
+
"""
|
|
215
|
+
Enable logging of harmful responses.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
log_file_path: Path to the log file
|
|
219
|
+
"""
|
|
220
|
+
self.enable_logging = True
|
|
221
|
+
self.log_file_path = log_file_path
|
|
222
|
+
|
|
223
|
+
# Initialize log file if it doesn't exist
|
|
224
|
+
if not os.path.exists(os.path.dirname(log_file_path)):
|
|
225
|
+
try:
|
|
226
|
+
os.makedirs(os.path.dirname(log_file_path))
|
|
227
|
+
except Exception:
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
if not os.path.exists(log_file_path):
|
|
231
|
+
try:
|
|
232
|
+
with open(log_file_path, "w") as f:
|
|
233
|
+
json.dump([], f)
|
|
234
|
+
except Exception:
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
def log_harmful_response(
|
|
238
|
+
self, prompt: str, response: str, probability: float, category: str = "harmful", additional_info: Dict = None
|
|
239
|
+
) -> bool:
|
|
240
|
+
"""
|
|
241
|
+
Log a harmful response to the JSON log file.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
prompt: The original prompt
|
|
245
|
+
response: The generated response
|
|
246
|
+
probability: The probability score that triggered detection
|
|
247
|
+
category: The category of harmful content detected
|
|
248
|
+
additional_info: Optional additional information
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Success flag
|
|
252
|
+
"""
|
|
253
|
+
if not self.enable_logging:
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
# Create log entry
|
|
258
|
+
log_entry = {
|
|
259
|
+
"timestamp": datetime.datetime.now().isoformat(),
|
|
260
|
+
"prompt": prompt,
|
|
261
|
+
"response": response,
|
|
262
|
+
"probability": float(probability),
|
|
263
|
+
"category": category,
|
|
264
|
+
"threshold": float(self.threshold),
|
|
265
|
+
"method_type": self.method_type.value,
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
# Add additional info if provided
|
|
269
|
+
if additional_info:
|
|
270
|
+
log_entry.update(additional_info)
|
|
271
|
+
|
|
272
|
+
# Read existing log entries
|
|
273
|
+
try:
|
|
274
|
+
with open(self.log_file_path) as f:
|
|
275
|
+
log_entries = json.load(f)
|
|
276
|
+
except (FileNotFoundError, json.JSONDecodeError):
|
|
277
|
+
log_entries = []
|
|
278
|
+
|
|
279
|
+
# Append new entry
|
|
280
|
+
log_entries.append(log_entry)
|
|
281
|
+
|
|
282
|
+
# Write updated log
|
|
283
|
+
with open(self.log_file_path, "w") as f:
|
|
284
|
+
json.dump(log_entries, f, indent=2)
|
|
285
|
+
|
|
286
|
+
return True
|
|
287
|
+
|
|
288
|
+
except Exception:
|
|
289
|
+
return False
|
|
290
|
+
|
|
291
|
+
def get_logged_responses(self, limit: Optional[int] = None, category: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
292
|
+
"""
|
|
293
|
+
Retrieve logged harmful responses from the log file.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
limit: Maximum number of entries to return (None for all)
|
|
297
|
+
category: Filter by specific category (None for all categories)
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
List of log entries
|
|
301
|
+
"""
|
|
302
|
+
if not self.enable_logging:
|
|
303
|
+
return []
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
# Check if log file exists
|
|
307
|
+
if not os.path.exists(self.log_file_path):
|
|
308
|
+
return []
|
|
309
|
+
|
|
310
|
+
# Read log entries
|
|
311
|
+
with open(self.log_file_path) as f:
|
|
312
|
+
log_entries = json.load(f)
|
|
313
|
+
|
|
314
|
+
# Filter by category if specified
|
|
315
|
+
if category is not None:
|
|
316
|
+
log_entries = [entry for entry in log_entries if entry.get("category") == category]
|
|
317
|
+
|
|
318
|
+
# Sort by timestamp (newest first)
|
|
319
|
+
log_entries.sort(key=lambda entry: entry.get("timestamp", ""), reverse=True)
|
|
320
|
+
|
|
321
|
+
# Apply limit if specified
|
|
322
|
+
if limit is not None and limit > 0:
|
|
323
|
+
log_entries = log_entries[:limit]
|
|
324
|
+
|
|
325
|
+
return log_entries
|
|
326
|
+
|
|
327
|
+
except Exception:
|
|
328
|
+
return []
|
|
329
|
+
|
|
330
|
+
def optimize_parameters(
|
|
331
|
+
self,
|
|
332
|
+
model,
|
|
333
|
+
target_layer,
|
|
334
|
+
pair_set: ContrastivePairSet,
|
|
335
|
+
learning_rate: float = 1e-4,
|
|
336
|
+
num_epochs: int = 10,
|
|
337
|
+
regularization_strength: float = 0.01,
|
|
338
|
+
) -> Dict[str, Any]:
|
|
339
|
+
"""
|
|
340
|
+
Optimize model parameters to improve steering effectiveness.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
model: Model object to optimize
|
|
344
|
+
target_layer: Layer to optimize
|
|
345
|
+
pair_set: ContrastivePairSet with training data
|
|
346
|
+
learning_rate: Learning rate for optimization
|
|
347
|
+
num_epochs: Number of optimization epochs
|
|
348
|
+
regularization_strength: L2 regularization strength
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
Dictionary with optimization results
|
|
352
|
+
"""
|
|
353
|
+
try:
|
|
354
|
+
# Get the target layer module for optimization
|
|
355
|
+
layer_module = self._get_layer_module(model, target_layer)
|
|
356
|
+
if layer_module is None:
|
|
357
|
+
raise ValueError(f"Could not find layer {target_layer} in model")
|
|
358
|
+
|
|
359
|
+
# Store original parameters
|
|
360
|
+
self._store_original_parameters(layer_module)
|
|
361
|
+
|
|
362
|
+
# Extract activations for the pair set
|
|
363
|
+
pair_set.extract_activations_with_model(model, target_layer)
|
|
364
|
+
|
|
365
|
+
# Prepare training data
|
|
366
|
+
X_tensors, y_labels = pair_set.prepare_classifier_data()
|
|
367
|
+
|
|
368
|
+
# Set up optimizer for just the target layer
|
|
369
|
+
optimizer = torch.optim.Adam(layer_module.parameters(), lr=learning_rate)
|
|
370
|
+
|
|
371
|
+
# Training loop
|
|
372
|
+
best_steering_loss = float("inf")
|
|
373
|
+
best_parameters = None
|
|
374
|
+
|
|
375
|
+
for epoch in range(num_epochs):
|
|
376
|
+
epoch_loss = 0.0
|
|
377
|
+
num_batches = 0
|
|
378
|
+
|
|
379
|
+
# Process in batches
|
|
380
|
+
batch_size = 4
|
|
381
|
+
for i in range(0, len(X_tensors), batch_size):
|
|
382
|
+
batch_X = X_tensors[i : i + batch_size]
|
|
383
|
+
batch_y = y_labels[i : i + batch_size]
|
|
384
|
+
|
|
385
|
+
# Zero gradients
|
|
386
|
+
optimizer.zero_grad()
|
|
387
|
+
|
|
388
|
+
# Forward pass through the modified layer
|
|
389
|
+
loss = self._compute_steering_loss(batch_X, batch_y, layer_module, regularization_strength)
|
|
390
|
+
|
|
391
|
+
# Backward pass
|
|
392
|
+
loss.backward()
|
|
393
|
+
optimizer.step()
|
|
394
|
+
|
|
395
|
+
epoch_loss += loss.item()
|
|
396
|
+
num_batches += 1
|
|
397
|
+
|
|
398
|
+
avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
|
|
399
|
+
|
|
400
|
+
# Track best parameters
|
|
401
|
+
if avg_loss < best_steering_loss:
|
|
402
|
+
best_steering_loss = avg_loss
|
|
403
|
+
best_parameters = {name: param.clone() for name, param in layer_module.named_parameters()}
|
|
404
|
+
|
|
405
|
+
# Load best parameters
|
|
406
|
+
if best_parameters is not None:
|
|
407
|
+
for name, param in layer_module.named_parameters():
|
|
408
|
+
if name in best_parameters:
|
|
409
|
+
param.data.copy_(best_parameters[name])
|
|
410
|
+
|
|
411
|
+
# Store optimization results
|
|
412
|
+
optimization_result = {
|
|
413
|
+
"target_layer": target_layer.index if hasattr(target_layer, "index") else target_layer,
|
|
414
|
+
"final_loss": best_steering_loss,
|
|
415
|
+
"epochs": num_epochs,
|
|
416
|
+
"learning_rate": learning_rate,
|
|
417
|
+
"regularization_strength": regularization_strength,
|
|
418
|
+
"parameters_optimized": True,
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
self.optimization_history.append(optimization_result)
|
|
422
|
+
|
|
423
|
+
return optimization_result
|
|
424
|
+
|
|
425
|
+
except Exception as e:
|
|
426
|
+
return {"error": str(e), "parameters_optimized": False}
|
|
427
|
+
|
|
428
|
+
def _get_layer_module(self, model, layer):
|
|
429
|
+
"""Get the module for a specific layer."""
|
|
430
|
+
try:
|
|
431
|
+
hf_model = model.hf_model if hasattr(model, "hf_model") else model
|
|
432
|
+
layer_idx = layer.index if hasattr(layer, "index") else layer
|
|
433
|
+
|
|
434
|
+
if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
|
|
435
|
+
# Llama-style model
|
|
436
|
+
if layer_idx < len(hf_model.model.layers):
|
|
437
|
+
return hf_model.model.layers[layer_idx]
|
|
438
|
+
elif hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "h"):
|
|
439
|
+
# GPT-style model
|
|
440
|
+
if layer_idx < len(hf_model.transformer.h):
|
|
441
|
+
return hf_model.transformer.h[layer_idx]
|
|
442
|
+
|
|
443
|
+
return None
|
|
444
|
+
except Exception:
|
|
445
|
+
return None
|
|
446
|
+
|
|
447
|
+
def _store_original_parameters(self, module):
|
|
448
|
+
"""Store original parameters of a module."""
|
|
449
|
+
key = f"module_{id(module)}"
|
|
450
|
+
self.original_parameters[key] = {name: param.clone() for name, param in module.named_parameters()}
|
|
451
|
+
|
|
452
|
+
def _compute_steering_loss(self, batch_X, batch_y, layer_module, regularization_strength):
|
|
453
|
+
"""
|
|
454
|
+
Compute loss for steering optimization.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
batch_X: Batch of activation tensors
|
|
458
|
+
batch_y: Batch of labels
|
|
459
|
+
layer_module: Layer module being optimized
|
|
460
|
+
regularization_strength: L2 regularization strength
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
Loss tensor
|
|
464
|
+
"""
|
|
465
|
+
total_loss = 0.0
|
|
466
|
+
|
|
467
|
+
# Compute steering effectiveness loss
|
|
468
|
+
for i, (activation, label) in enumerate(zip(batch_X, batch_y)):
|
|
469
|
+
# Get prediction from steering method
|
|
470
|
+
prediction = self.predict_proba(activation)
|
|
471
|
+
|
|
472
|
+
# Convert to tensor for loss computation
|
|
473
|
+
if not isinstance(prediction, torch.Tensor):
|
|
474
|
+
prediction = torch.tensor(prediction, dtype=torch.float32, device=self.device)
|
|
475
|
+
|
|
476
|
+
target = torch.tensor(label, dtype=torch.float32, device=self.device)
|
|
477
|
+
|
|
478
|
+
# Binary cross-entropy loss
|
|
479
|
+
loss = F.binary_cross_entropy_with_logits(prediction.unsqueeze(0), target.unsqueeze(0))
|
|
480
|
+
total_loss += loss
|
|
481
|
+
|
|
482
|
+
# Add L2 regularization
|
|
483
|
+
l2_reg = 0.0
|
|
484
|
+
for param in layer_module.parameters():
|
|
485
|
+
l2_reg += torch.norm(param, p=2)
|
|
486
|
+
|
|
487
|
+
total_loss += regularization_strength * l2_reg
|
|
488
|
+
|
|
489
|
+
return total_loss / len(batch_X) # Average over batch
|
|
490
|
+
|
|
491
|
+
def restore_original_parameters(self) -> bool:
|
|
492
|
+
"""
|
|
493
|
+
Restore original parameters.
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
Success flag
|
|
497
|
+
"""
|
|
498
|
+
try:
|
|
499
|
+
# This is a simplified version - in practice, you'd need to keep track
|
|
500
|
+
# of which modules correspond to which keys
|
|
501
|
+
return len(self.original_parameters) > 0
|
|
502
|
+
except Exception:
|
|
503
|
+
return False
|
|
504
|
+
|
|
505
|
+
def get_optimization_summary(self) -> Dict[str, Any]:
|
|
506
|
+
"""
|
|
507
|
+
Get a summary of all optimizations performed.
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
Summary dictionary
|
|
511
|
+
"""
|
|
512
|
+
return {
|
|
513
|
+
"total_optimizations": len(self.optimization_history),
|
|
514
|
+
"optimization_history": self.optimization_history,
|
|
515
|
+
"has_original_parameters": len(self.original_parameters) > 0,
|
|
516
|
+
"method_type": self.method_type.value,
|
|
517
|
+
"threshold": self.threshold,
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
def evaluate(self, contrastive_pair_set: ContrastivePairSet) -> Dict[str, Any]:
|
|
521
|
+
"""
|
|
522
|
+
Evaluate the steering method on a ContrastivePairSet.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
contrastive_pair_set: Set of contrastive pairs for evaluation
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
Dictionary with evaluation metrics
|
|
529
|
+
"""
|
|
530
|
+
if self.classifier is None:
|
|
531
|
+
raise ValueError("SteeringMethod not trained. Call train() first.")
|
|
532
|
+
|
|
533
|
+
# Get positive and negative activations
|
|
534
|
+
pos_activations, neg_activations = contrastive_pair_set.get_activation_pairs()
|
|
535
|
+
|
|
536
|
+
# Predict on positive activations (should be low scores)
|
|
537
|
+
pos_predictions = []
|
|
538
|
+
for activation in pos_activations:
|
|
539
|
+
pred = self.predict_proba(activation)
|
|
540
|
+
pos_predictions.append(pred)
|
|
541
|
+
|
|
542
|
+
# Predict on negative activations (should be high scores)
|
|
543
|
+
neg_predictions = []
|
|
544
|
+
for activation in neg_activations:
|
|
545
|
+
pred = self.predict_proba(activation)
|
|
546
|
+
neg_predictions.append(pred)
|
|
547
|
+
|
|
548
|
+
# Calculate metrics
|
|
549
|
+
# True Positives: negative activations correctly identified as harmful (pred >= threshold)
|
|
550
|
+
true_positives = sum(1 for pred in neg_predictions if pred >= self.threshold)
|
|
551
|
+
|
|
552
|
+
# False Positives: positive activations incorrectly identified as harmful (pred >= threshold)
|
|
553
|
+
false_positives = sum(1 for pred in pos_predictions if pred >= self.threshold)
|
|
554
|
+
|
|
555
|
+
# True Negatives: positive activations correctly identified as harmless (pred < threshold)
|
|
556
|
+
true_negatives = sum(1 for pred in pos_predictions if pred < self.threshold)
|
|
557
|
+
|
|
558
|
+
# False Negatives: negative activations incorrectly identified as harmless (pred < threshold)
|
|
559
|
+
false_negatives = sum(1 for pred in neg_predictions if pred < self.threshold)
|
|
560
|
+
|
|
561
|
+
# Calculate metrics
|
|
562
|
+
detection_rate = true_positives / len(neg_predictions) if neg_predictions else 0
|
|
563
|
+
false_positive_rate = false_positives / len(pos_predictions) if pos_predictions else 0
|
|
564
|
+
|
|
565
|
+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
|
566
|
+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
|
|
567
|
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
568
|
+
|
|
569
|
+
accuracy = (
|
|
570
|
+
(true_positives + true_negatives) / (len(pos_predictions) + len(neg_predictions))
|
|
571
|
+
if (pos_predictions or neg_predictions)
|
|
572
|
+
else 0
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
return {
|
|
576
|
+
"detection_rate": detection_rate,
|
|
577
|
+
"false_positive_rate": false_positive_rate,
|
|
578
|
+
"precision": precision,
|
|
579
|
+
"recall": recall,
|
|
580
|
+
"f1": f1,
|
|
581
|
+
"accuracy": accuracy,
|
|
582
|
+
"true_positives": true_positives,
|
|
583
|
+
"false_positives": false_positives,
|
|
584
|
+
"true_negatives": true_negatives,
|
|
585
|
+
"false_negatives": false_negatives,
|
|
586
|
+
"num_positive_samples": len(pos_predictions),
|
|
587
|
+
"num_negative_samples": len(neg_predictions),
|
|
588
|
+
"threshold": self.threshold,
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
def save_model(self, save_path: str) -> bool:
|
|
592
|
+
"""
|
|
593
|
+
Save the steering method to disk.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
save_path: Path to save the model
|
|
597
|
+
|
|
598
|
+
Returns:
|
|
599
|
+
Success flag
|
|
600
|
+
"""
|
|
601
|
+
if self.classifier is None:
|
|
602
|
+
return False
|
|
603
|
+
|
|
604
|
+
try:
|
|
605
|
+
self.classifier.save_model(save_path)
|
|
606
|
+
return True
|
|
607
|
+
except Exception:
|
|
608
|
+
return False
|
|
609
|
+
|
|
610
|
+
def load_model(self, model_path: str) -> bool:
|
|
611
|
+
"""
|
|
612
|
+
Load a steering method from disk.
|
|
613
|
+
|
|
614
|
+
Args:
|
|
615
|
+
model_path: Path to the saved model
|
|
616
|
+
|
|
617
|
+
Returns:
|
|
618
|
+
Success flag
|
|
619
|
+
"""
|
|
620
|
+
try:
|
|
621
|
+
self.classifier = Classifier(
|
|
622
|
+
model_type=self.method_type.value, device=self.device, threshold=self.threshold, model_path=model_path
|
|
623
|
+
)
|
|
624
|
+
return True
|
|
625
|
+
except Exception:
|
|
626
|
+
return False
|
|
627
|
+
|
|
628
|
+
@classmethod
|
|
629
|
+
def create_and_train(
|
|
630
|
+
cls,
|
|
631
|
+
method_type: SteeringType,
|
|
632
|
+
contrastive_pair_set: ContrastivePairSet,
|
|
633
|
+
device: Optional[str] = None,
|
|
634
|
+
threshold: float = 0.5,
|
|
635
|
+
**training_kwargs,
|
|
636
|
+
) -> "SteeringMethod":
|
|
637
|
+
"""
|
|
638
|
+
Create and train a SteeringMethod in one step.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
method_type: Type of steering method
|
|
642
|
+
contrastive_pair_set: Training data
|
|
643
|
+
device: Device to use
|
|
644
|
+
threshold: Classification threshold
|
|
645
|
+
**training_kwargs: Additional training parameters
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
Trained SteeringMethod
|
|
649
|
+
"""
|
|
650
|
+
steering = cls(method_type=method_type, device=device, threshold=threshold)
|
|
651
|
+
steering.train(contrastive_pair_set, **training_kwargs)
|
|
652
|
+
return steering
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Steering methods for wisent-guard.
|
|
3
|
+
|
|
4
|
+
This module provides a unified interface for various steering methods
|
|
5
|
+
by importing them from the steering_methods package.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Import all steering methods from the new package
|
|
9
|
+
from .steering_methods import (
|
|
10
|
+
SteeringMethod,
|
|
11
|
+
CAA,
|
|
12
|
+
HPR,
|
|
13
|
+
DAC,
|
|
14
|
+
BiPO,
|
|
15
|
+
KSteering
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Re-export for backward compatibility
|
|
19
|
+
__all__ = [
|
|
20
|
+
'SteeringMethod',
|
|
21
|
+
'CAA',
|
|
22
|
+
'HPR',
|
|
23
|
+
'DAC',
|
|
24
|
+
'BiPO',
|
|
25
|
+
'KSteering'
|
|
26
|
+
]
|
|
File without changes
|
|
File without changes
|