wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
import inspect
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from wisent_guard.core.activations.core.atoms import LayerActivations, RawActivationMap, LayerName
|
|
11
|
+
from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"SteeringError",
|
|
15
|
+
"BaseSteeringMethod",
|
|
16
|
+
"PerLayerBaseSteeringMethod",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
class BaseSteeringError(RuntimeError):
|
|
20
|
+
"""Raised when a steering method fails or is misconfigured."""
|
|
21
|
+
|
|
22
|
+
class BaseSteeringMethod(ABC):
|
|
23
|
+
name: str = "base"
|
|
24
|
+
description: str = "Abstract steering method"
|
|
25
|
+
_REGISTRY: dict[str, type[BaseSteeringMethod]] = {}
|
|
26
|
+
|
|
27
|
+
def __init_subclass__(cls, **kwargs):
|
|
28
|
+
super().__init_subclass__(**kwargs)
|
|
29
|
+
if cls is BaseSteeringMethod:
|
|
30
|
+
return
|
|
31
|
+
if inspect.isabstract(cls):
|
|
32
|
+
return
|
|
33
|
+
if not getattr(cls, "name", None):
|
|
34
|
+
raise TypeError("BaseSteeringMethod subclasses must define `name`.")
|
|
35
|
+
if cls.name in BaseSteeringMethod._REGISTRY:
|
|
36
|
+
raise ValueError(f"Duplicate steering method: {cls.name!r}")
|
|
37
|
+
BaseSteeringMethod._REGISTRY[cls.name] = cls
|
|
38
|
+
|
|
39
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
40
|
+
self.kwargs: dict[str, Any] = dict(kwargs)
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def train(self, pair_set: ContrastivePairSet) -> LayerActivations:
|
|
44
|
+
"""
|
|
45
|
+
Produce per-layer vectors from the given contrastive set.
|
|
46
|
+
|
|
47
|
+
arguments:
|
|
48
|
+
pair_set: ContrastivePairSet with collected activations.
|
|
49
|
+
|
|
50
|
+
returns:
|
|
51
|
+
LayerActivations with one steering vector per layer.
|
|
52
|
+
"""
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def list_registered(cls) -> dict[str, type[BaseSteeringMethod]]:
|
|
57
|
+
"""
|
|
58
|
+
list all registered steering methods.
|
|
59
|
+
|
|
60
|
+
returns:
|
|
61
|
+
dict mapping method name to class.
|
|
62
|
+
"""
|
|
63
|
+
return dict(cls._REGISTRY)
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def get(cls, name: str) -> type[BaseSteeringMethod]:
|
|
67
|
+
"""
|
|
68
|
+
Get a registered steering method class by name.
|
|
69
|
+
|
|
70
|
+
arguments:
|
|
71
|
+
name: str name of the steering method.
|
|
72
|
+
|
|
73
|
+
returns:
|
|
74
|
+
BaseSteeringMethod subclass.
|
|
75
|
+
|
|
76
|
+
raises:
|
|
77
|
+
SteeringError if name is unknown.
|
|
78
|
+
"""
|
|
79
|
+
try:
|
|
80
|
+
return cls._REGISTRY[name]
|
|
81
|
+
except KeyError as exc:
|
|
82
|
+
raise BaseSteeringError(f"Unknown steering method: {name!r}") from exc
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class PerLayerBaseSteeringMethod(BaseSteeringMethod):
|
|
86
|
+
"""
|
|
87
|
+
Base for steering methods that compute one vector per layer independently.
|
|
88
|
+
Subclasses must implement 'train_for_layer'.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def train_for_layer(self, pos_list: list[torch.Tensor], neg_list: list[torch.Tensor]) -> torch.Tensor:
|
|
93
|
+
"""
|
|
94
|
+
Compute a vector for ONE layer from lists of positives/negatives.
|
|
95
|
+
|
|
96
|
+
arguments:
|
|
97
|
+
pos_list: list of tensors from positive examples.
|
|
98
|
+
neg_list: list of tensors from negative examples.
|
|
99
|
+
|
|
100
|
+
returns:
|
|
101
|
+
torch.Tensor steering vector for the layer.
|
|
102
|
+
"""
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
def _collect_from_set(self, pair_set: ContrastivePairSet) -> dict[LayerName, tuple[list[torch.Tensor], list[torch.Tensor]]]:
|
|
106
|
+
"""
|
|
107
|
+
Build {layer_name: ([pos tensors...], [neg tensors...])} by iterating pairs.
|
|
108
|
+
Skips entries where activations are missing/None.
|
|
109
|
+
|
|
110
|
+
arguments:
|
|
111
|
+
pair_set: ContrastivePairSet with collected activations.
|
|
112
|
+
|
|
113
|
+
returns:
|
|
114
|
+
dict mapping layer names to tuples of (list of pos tensors, list of neg tensors).
|
|
115
|
+
"""
|
|
116
|
+
buckets: dict[LayerName, tuple[list[torch.Tensor], list[torch.Tensor]]] = defaultdict(lambda: ([], []))
|
|
117
|
+
for pair in pair_set.pairs: # ContrastivePair
|
|
118
|
+
pos_la = getattr(pair.positive_response, "layers_activations", None)
|
|
119
|
+
neg_la = getattr(pair.negative_response, "layers_activations", None)
|
|
120
|
+
|
|
121
|
+
if pos_la is None or neg_la is None:
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
layer_names = set(pos_la.to_dict().keys()) | set(neg_la.to_dict().keys())
|
|
125
|
+
for layer in layer_names:
|
|
126
|
+
p = pos_la.to_dict().get(layer, None) if pos_la is not None else None
|
|
127
|
+
n = neg_la.to_dict().get(layer, None) if neg_la is not None else None
|
|
128
|
+
if isinstance(p, torch.Tensor) and isinstance(n, torch.Tensor):
|
|
129
|
+
buckets[layer][0].append(p)
|
|
130
|
+
buckets[layer][1].append(n)
|
|
131
|
+
return buckets
|
|
132
|
+
|
|
133
|
+
def train(self, pair_set: ContrastivePairSet) -> LayerActivations:
|
|
134
|
+
"""
|
|
135
|
+
Produce per-layer steering vectors from the given contrastive set.
|
|
136
|
+
|
|
137
|
+
arguments:
|
|
138
|
+
pair_set: ContrastivePairSet with collected activations.
|
|
139
|
+
|
|
140
|
+
returns:
|
|
141
|
+
LayerActivations with one steering vector per layer.
|
|
142
|
+
"""
|
|
143
|
+
buckets = self._collect_from_set(pair_set)
|
|
144
|
+
|
|
145
|
+
raw: RawActivationMap = {}
|
|
146
|
+
for layer, (pos_list, neg_list) in sorted(buckets.items(), key=lambda kv: (len(kv[0]), kv[0])):
|
|
147
|
+
if not pos_list or not neg_list:
|
|
148
|
+
continue
|
|
149
|
+
raw[layer] = self.train_for_layer(pos_list, neg_list)
|
|
150
|
+
|
|
151
|
+
dtype = self.kwargs.get("dtype", None)
|
|
152
|
+
agg = self.kwargs.get("activation_aggregation_strategy", None)
|
|
153
|
+
return LayerActivations(raw, activation_aggregation_strategy=agg, dtype=dtype)
|
|
File without changes
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from wisent_guard.core.steering_methods.core.atoms import PerLayerBaseSteeringMethod
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"CAAMethod",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
class CAAMethod(PerLayerBaseSteeringMethod):
|
|
13
|
+
"""
|
|
14
|
+
Contrastive Activation Additions (CAA).
|
|
15
|
+
For each layer: v = mean(positives) - mean(negatives),
|
|
16
|
+
optionally L2-normalized (kwargs: normalize=True, dtype=..., activation_aggregation_strategy=...).
|
|
17
|
+
"""
|
|
18
|
+
name = "caa"
|
|
19
|
+
description = "Per-layer mean(pos)-mean(neg) over ContrastivePairSet."
|
|
20
|
+
|
|
21
|
+
def train_for_layer(self, pos_list: List[torch.Tensor], neg_list: List[torch.Tensor]) -> torch.Tensor:
|
|
22
|
+
"""
|
|
23
|
+
Train CAA vector for a single layer.
|
|
24
|
+
|
|
25
|
+
arguments:
|
|
26
|
+
pos_list: List of positive activations (torch.Tensor) for this layer.
|
|
27
|
+
neg_list: List of negative activations (torch.Tensor) for this layer.
|
|
28
|
+
|
|
29
|
+
returns:
|
|
30
|
+
torch.Tensor steering vector for the layer.
|
|
31
|
+
"""
|
|
32
|
+
if not pos_list or not neg_list:
|
|
33
|
+
raise ValueError("Both positive and negative lists must be non-empty.")
|
|
34
|
+
pos = torch.stack([t.detach().to("cpu").float().reshape(-1) for t in pos_list], dim=0) # [N_pos, H]
|
|
35
|
+
neg = torch.stack([t.detach().to("cpu").float().reshape(-1) for t in neg_list], dim=0) # [N_neg, H]
|
|
36
|
+
v = pos.mean(dim=0) - neg.mean(dim=0)
|
|
37
|
+
if bool(self.kwargs.get("normalize", True)):
|
|
38
|
+
v = self._safe_l2_normalize(v)
|
|
39
|
+
return v
|
|
40
|
+
|
|
41
|
+
def _safe_l2_normalize(self, v: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
|
42
|
+
if v.ndim != 1:
|
|
43
|
+
v = v.reshape(-1)
|
|
44
|
+
return v / (torch.linalg.norm(v) + eps)
|