wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
import torch
|
|
7
|
+
from typing import Mapping
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from wisent_guard.core.activations.core.atoms import RawActivationMap
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"SteeringVector",
|
|
15
|
+
"SteeringPlan",
|
|
16
|
+
"HookHandleGroup",
|
|
17
|
+
"TopLogits",
|
|
18
|
+
"GenerationStats",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(slots=True)
|
|
23
|
+
class SteeringVector:
|
|
24
|
+
"""
|
|
25
|
+
Single steering vector added to a layer's residual stream (output).
|
|
26
|
+
|
|
27
|
+
arguments:
|
|
28
|
+
vector: tensor whose last dim == hidden_size. Shape may be [H], [1,H], [1,1,H] or [B,T,H].
|
|
29
|
+
scale: scalar coefficient (alpha) multiplied before adding.
|
|
30
|
+
normalize: L2-normalize the vector (safe + epsilon) before applying 'scale'.
|
|
31
|
+
layer_description: human-readable description of the steering vector. Like "toxic", "biased", etc.
|
|
32
|
+
|
|
33
|
+
example:
|
|
34
|
+
>>> sv = SteeringVector(
|
|
35
|
+
... torch.randn(4096),
|
|
36
|
+
... scale=0.8,
|
|
37
|
+
... normalize=True,
|
|
38
|
+
... layer_description="toxic"
|
|
39
|
+
... )
|
|
40
|
+
"""
|
|
41
|
+
vector: torch.Tensor
|
|
42
|
+
scale: float = 1.0
|
|
43
|
+
normalize: bool = False
|
|
44
|
+
layer_description: str = ""
|
|
45
|
+
|
|
46
|
+
def materialize(self, like: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
"""
|
|
48
|
+
Broadcast + cast the vector so it's addable to 'like' ([B, T, H]).
|
|
49
|
+
Returns a tensor on like.device and like.dtype.
|
|
50
|
+
|
|
51
|
+
returns:
|
|
52
|
+
Broadcast + cast the vector so it's addable to 'like' ([B, T, H]).
|
|
53
|
+
|
|
54
|
+
raises:
|
|
55
|
+
ValueError: if the vector shape is incompatible.
|
|
56
|
+
"""
|
|
57
|
+
v = self.vector
|
|
58
|
+
if self.normalize and torch.is_floating_point(v):
|
|
59
|
+
denom = torch.linalg.vector_norm(v.float(), dim=-1, keepdim=True).clamp_min(1e-12)
|
|
60
|
+
v = v / denom
|
|
61
|
+
|
|
62
|
+
if v.dim() == 1: # [H] -> [1,1,H]
|
|
63
|
+
v = v.view(1, 1, -1)
|
|
64
|
+
elif v.dim() == 2: # [1,H] -> [1,1,H] or [B,H] -> [1,B,H] (still broadcastable)
|
|
65
|
+
v = v.view(1, *v.shape)
|
|
66
|
+
elif v.dim() == 3: # [B,T,H] fine
|
|
67
|
+
pass
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Unsupported steering vector shape {tuple(v.shape)}; "
|
|
71
|
+
f"expected [H], [1,H], [1,1,H], or [B,T,H]."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return v.to(dtype=like.dtype, device=like.device) * float(self.scale)
|
|
75
|
+
|
|
76
|
+
@dataclass(slots=True)
|
|
77
|
+
class SteeringPlan:
|
|
78
|
+
"""
|
|
79
|
+
Plan for applying steering vectors to multiple layers. It supports linear
|
|
80
|
+
combinations of multiple steering layers (for the same llm layer).
|
|
81
|
+
|
|
82
|
+
attributes:
|
|
83
|
+
layers:
|
|
84
|
+
dict of layer_name -> SteeringVector to apply at that layer.
|
|
85
|
+
layers_description:
|
|
86
|
+
descriptions corresponding to each RawActivationMap, for example
|
|
87
|
+
"toxic", "biased", etc. These are used to build combined
|
|
88
|
+
per-layer descriptions.
|
|
89
|
+
"""
|
|
90
|
+
layers: dict[str, SteeringVector] = field(default_factory=dict)
|
|
91
|
+
layers_description: list[str] = field(default_factory=list)
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_raw(
|
|
95
|
+
cls,
|
|
96
|
+
raw: Sequence[RawActivationMap] | RawActivationMap | None,
|
|
97
|
+
layers_description: list[str] | None = None,
|
|
98
|
+
scale: float = 1.0,
|
|
99
|
+
normalize: bool = False,
|
|
100
|
+
weights: Sequence[float] | None = None,
|
|
101
|
+
expected_hidden_size: int | None = None,
|
|
102
|
+
) -> SteeringPlan:
|
|
103
|
+
"""
|
|
104
|
+
Build a SteeringPlan by merging one or more RawActivationMap(s).
|
|
105
|
+
Each RawActivationMap is: layer_name (str) -> torch.Tensor (or None to skip).
|
|
106
|
+
Each RawActivationMap corresponds to one description in layers_description.
|
|
107
|
+
The final steering vector at each layer is a weighted sum of the
|
|
108
|
+
contributions from each RawActivationMap.
|
|
109
|
+
|
|
110
|
+
arguments:
|
|
111
|
+
raw:
|
|
112
|
+
One or more RawActivationMap(s) to combine
|
|
113
|
+
layers_description:
|
|
114
|
+
Descriptions corresponding to each RawActivationMap, for example
|
|
115
|
+
"toxic", "biased", etc. These are used to build combined
|
|
116
|
+
per-layer descriptions.
|
|
117
|
+
scale:
|
|
118
|
+
Scalar coefficient (alpha) applied to all steering vectors.
|
|
119
|
+
normalize:
|
|
120
|
+
Whether to L2-normalize each steering vector before applying 'scale'.
|
|
121
|
+
weights:
|
|
122
|
+
Optional weights for each RawActivationMap when combining.
|
|
123
|
+
If None, uniform weights are used. Length must match number of maps.
|
|
124
|
+
expected_hidden_size:
|
|
125
|
+
If provided, validate that all steering vectors have this hidden size.
|
|
126
|
+
"""
|
|
127
|
+
maps = cls._coerce_sequence(raw)
|
|
128
|
+
if layers_description is None:
|
|
129
|
+
layers_description = [f"steering_{i}" for i in range(len(maps))]
|
|
130
|
+
|
|
131
|
+
if len(layers_description) != len(maps):
|
|
132
|
+
raise ValueError("layers_description length must match number of maps.")
|
|
133
|
+
|
|
134
|
+
if not maps:
|
|
135
|
+
plan = cls(layers={}, layers_description=layers_description)
|
|
136
|
+
if expected_hidden_size is not None:
|
|
137
|
+
plan.validate_hidden_size(expected_hidden_size)
|
|
138
|
+
return plan
|
|
139
|
+
|
|
140
|
+
w = cls._normalize_weights(len(maps), weights)
|
|
141
|
+
conv = cls._convert_maps(maps)
|
|
142
|
+
order = cls._collect_layer_order(conv)
|
|
143
|
+
|
|
144
|
+
out_layers = cls._build_layers(
|
|
145
|
+
layer_order=order,
|
|
146
|
+
converted_maps=conv,
|
|
147
|
+
weights=w,
|
|
148
|
+
layers_description=layers_description,
|
|
149
|
+
scale=scale,
|
|
150
|
+
normalize=normalize,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
plan = cls(layers=out_layers, layers_description=list(layers_description))
|
|
154
|
+
|
|
155
|
+
if expected_hidden_size is not None:
|
|
156
|
+
plan.validate_hidden_size(expected_hidden_size)
|
|
157
|
+
return plan
|
|
158
|
+
|
|
159
|
+
def validate_hidden_size(self, hidden_size: int) -> None:
|
|
160
|
+
"""
|
|
161
|
+
Ensure all steering vectors have the specified hidden size.
|
|
162
|
+
|
|
163
|
+
arguments:
|
|
164
|
+
hidden_size: expected hidden size (last dim of steering vectors).
|
|
165
|
+
|
|
166
|
+
raises:
|
|
167
|
+
ValueError: if any steering vector has a mismatched hidden size.
|
|
168
|
+
"""
|
|
169
|
+
for layer, sv in self.layers.items():
|
|
170
|
+
if sv.vector.shape[-1] != hidden_size:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"Layer {layer} steering last dim {sv.vector.shape[-1]} "
|
|
173
|
+
f"!= hidden_size {hidden_size}"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def is_empty(self) -> bool:
|
|
177
|
+
"""True if there are no layers."""
|
|
178
|
+
return not self.layers
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def _as_tensor(x: torch.Tensor | float | int) -> torch.Tensor:
|
|
182
|
+
return x if isinstance(x, torch.Tensor) else torch.as_tensor(x)
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _normalize_weights(n: int, weights: Sequence[float] | None) -> torch.Tensor:
|
|
186
|
+
"""
|
|
187
|
+
Return a length-n float32 tensor of weights that sums to 1.
|
|
188
|
+
If weights is None, use uniform weights. Raises on length mismatch or zero-sum.
|
|
189
|
+
|
|
190
|
+
arguments:
|
|
191
|
+
n:
|
|
192
|
+
number of activation maps (must be non-negative).
|
|
193
|
+
weights:
|
|
194
|
+
optional sequence of weights (length n) to normalize. If None, uniform weights are used.
|
|
195
|
+
|
|
196
|
+
returns:
|
|
197
|
+
A torch.Tensor of shape (n,) with float32 weights summing to 1.
|
|
198
|
+
|
|
199
|
+
raises:
|
|
200
|
+
ValueError: if n < 0, or if weights length != n, or if weights sum to 0.
|
|
201
|
+
|
|
202
|
+
example:
|
|
203
|
+
>>> SteeringPlan._normalize_weights(3, [0.2, 0.3, 0.5])
|
|
204
|
+
tensor([0.2000, 0.3000, 0.5000])
|
|
205
|
+
>>> SteeringPlan._normalize_weights(2, None)
|
|
206
|
+
tensor([0.5000, 0.5000])
|
|
207
|
+
>>> SteeringPlan._normalize_weights(0, None)
|
|
208
|
+
tensor([])
|
|
209
|
+
"""
|
|
210
|
+
if n < 0:
|
|
211
|
+
raise ValueError("n must be non-negative.")
|
|
212
|
+
if n == 0:
|
|
213
|
+
return torch.empty(0, dtype=torch.float32)
|
|
214
|
+
if weights is None:
|
|
215
|
+
return torch.full((n,), 1.0 / n, dtype=torch.float32)
|
|
216
|
+
|
|
217
|
+
w = torch.as_tensor(weights, dtype=torch.float32)
|
|
218
|
+
if w.numel() != n:
|
|
219
|
+
raise ValueError(f"Length mismatch: {n} activation maps but {w.numel()} weights.")
|
|
220
|
+
s = float(w.sum())
|
|
221
|
+
if abs(s) < 1e-12:
|
|
222
|
+
raise ValueError("Weights sum to 0; cannot normalize.")
|
|
223
|
+
return w / s
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
def _coerce_sequence(
|
|
227
|
+
raw: Sequence[RawActivationMap] | RawActivationMap | None,
|
|
228
|
+
) -> list[RawActivationMap]:
|
|
229
|
+
"""
|
|
230
|
+
Normalize input into a list[RawActivationMap].
|
|
231
|
+
|
|
232
|
+
arguments:
|
|
233
|
+
raw: A raw activation map or a sequence of them.
|
|
234
|
+
|
|
235
|
+
returns:
|
|
236
|
+
A list of RawActivationMap.
|
|
237
|
+
|
|
238
|
+
raises:
|
|
239
|
+
TypeError: if raw is not a Mapping or a sequence of them.
|
|
240
|
+
|
|
241
|
+
"""
|
|
242
|
+
if raw is None:
|
|
243
|
+
return []
|
|
244
|
+
if isinstance(raw, Mapping):
|
|
245
|
+
return [raw]
|
|
246
|
+
if isinstance(raw, Sequence) and not isinstance(raw, (str, bytes)):
|
|
247
|
+
return [r or {} for r in raw]
|
|
248
|
+
raise TypeError(
|
|
249
|
+
"raw must be a Mapping[str, Tensor|None], a sequence of them, or None."
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def _convert_maps(cls, maps: list[RawActivationMap]) -> list[dict[str, torch.Tensor]]:
|
|
254
|
+
"""
|
|
255
|
+
Convert values to tensors and drop None entries early.
|
|
256
|
+
|
|
257
|
+
arguments:
|
|
258
|
+
maps: list of RawActivationMap to convert.
|
|
259
|
+
|
|
260
|
+
returns:
|
|
261
|
+
A list of dicts mapping layer names to torch.Tensors.
|
|
262
|
+
|
|
263
|
+
raises:
|
|
264
|
+
None
|
|
265
|
+
"""
|
|
266
|
+
out: list[dict[str, torch.Tensor]] = []
|
|
267
|
+
for mapping in maps:
|
|
268
|
+
conv: dict[str, torch.Tensor] = {}
|
|
269
|
+
for k, v in mapping.items():
|
|
270
|
+
if v is None:
|
|
271
|
+
continue
|
|
272
|
+
conv[str(k)] = cls._as_tensor(v)
|
|
273
|
+
out.append(conv)
|
|
274
|
+
return out
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def _collect_layer_order(converted_maps: list[dict[str, torch.Tensor]]) -> list[str]:
|
|
278
|
+
"""
|
|
279
|
+
First-seen layer order across all maps.
|
|
280
|
+
|
|
281
|
+
arguments:
|
|
282
|
+
converted_maps: list of dicts mapping layer names to torch.Tensors.
|
|
283
|
+
|
|
284
|
+
returns:
|
|
285
|
+
A list of layer names in first-seen order.
|
|
286
|
+
|
|
287
|
+
example:
|
|
288
|
+
>>> maps = [
|
|
289
|
+
... {"layer1": torch.randn(4), "layer2": torch.randn(4)},
|
|
290
|
+
... {"layer2": torch.randn(4), "layer3": torch.randn(4)},
|
|
291
|
+
... {"layer1": torch.randn(4), "layer4": torch.randn(4)},
|
|
292
|
+
... ]
|
|
293
|
+
>>> SteeringPlan._collect_layer_order(maps)
|
|
294
|
+
['layer1', 'layer2', 'layer3', 'layer4']
|
|
295
|
+
"""
|
|
296
|
+
return list(dict.fromkeys(k for m in converted_maps for k in m.keys()))
|
|
297
|
+
|
|
298
|
+
@staticmethod
|
|
299
|
+
def _combine_for_layer(
|
|
300
|
+
layer: str,
|
|
301
|
+
converted_maps: list[dict[str, torch.Tensor]],
|
|
302
|
+
weights: torch.Tensor,
|
|
303
|
+
layers_description: Sequence[str],
|
|
304
|
+
) -> tuple[torch.Tensor | None, str]:
|
|
305
|
+
"""
|
|
306
|
+
Combine weighted vectors for a single layer and build a combined description.
|
|
307
|
+
|
|
308
|
+
arguments:
|
|
309
|
+
layer:
|
|
310
|
+
the layer name to combine.
|
|
311
|
+
converted_maps:
|
|
312
|
+
list of dicts mapping layer names to torch.Tensors.
|
|
313
|
+
weights:
|
|
314
|
+
tensor of shape (len(converted_maps),) with float32 weights summing to 1.
|
|
315
|
+
layers_description:
|
|
316
|
+
descriptions corresponding to each converted_map.
|
|
317
|
+
|
|
318
|
+
returns:
|
|
319
|
+
A tuple containing the combined tensor (or None) and the description string.
|
|
320
|
+
|
|
321
|
+
raises:
|
|
322
|
+
ValueError: if hidden sizes mismatch across maps for this layer.
|
|
323
|
+
|
|
324
|
+
example:
|
|
325
|
+
>>> maps = [
|
|
326
|
+
... {"layer1": torch.tensor([1.0, 2.0]), "layer2": torch.tensor([3.0, 4.0])},
|
|
327
|
+
... {"layer2": torch.tensor([5.0, 6.0]), "layer3": torch.tensor([7.0, 8.0])},
|
|
328
|
+
... ]
|
|
329
|
+
>>> weights = torch.tensor([0.4, 0.6])
|
|
330
|
+
>>> descs = ["toxic", "biased"]
|
|
331
|
+
>>> combined, desc = SteeringPlan._combine_for_layer("layer2", maps, weights, descs)
|
|
332
|
+
>>> print(combined) # tensor([4.2, 5.2])
|
|
333
|
+
>>> print(desc) # "toxic + biased"
|
|
334
|
+
"""
|
|
335
|
+
combined: torch.Tensor | None = None
|
|
336
|
+
hidden_size: int | None = None
|
|
337
|
+
desc_parts: list[str] = []
|
|
338
|
+
|
|
339
|
+
for i, m in enumerate(converted_maps):
|
|
340
|
+
v = m.get(layer)
|
|
341
|
+
if v is None:
|
|
342
|
+
continue
|
|
343
|
+
|
|
344
|
+
last_dim = v.shape[-1]
|
|
345
|
+
if hidden_size is None:
|
|
346
|
+
hidden_size = last_dim
|
|
347
|
+
elif last_dim != hidden_size:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"Layer {layer} has mismatched hidden sizes across maps: "
|
|
350
|
+
f"{hidden_size} vs {last_dim}."
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
scaled_v = v * float(weights[i])
|
|
354
|
+
if combined is None:
|
|
355
|
+
combined = scaled_v.clone()
|
|
356
|
+
else:
|
|
357
|
+
combined.add_(scaled_v)
|
|
358
|
+
|
|
359
|
+
desc = layers_description[i]
|
|
360
|
+
if desc not in desc_parts:
|
|
361
|
+
desc_parts.append(desc)
|
|
362
|
+
|
|
363
|
+
return combined, " + ".join(desc_parts)
|
|
364
|
+
|
|
365
|
+
@classmethod
|
|
366
|
+
def _build_layers(
|
|
367
|
+
cls,
|
|
368
|
+
layer_order: list[str],
|
|
369
|
+
converted_maps: list[dict[str, torch.Tensor]],
|
|
370
|
+
weights: torch.Tensor,
|
|
371
|
+
layers_description: Sequence[str],
|
|
372
|
+
scale: float,
|
|
373
|
+
normalize: bool,
|
|
374
|
+
) -> dict[str, SteeringVector]:
|
|
375
|
+
"""
|
|
376
|
+
Iterate over layer_order, combine per-layer contributions, and
|
|
377
|
+
construct SteeringVector objects.
|
|
378
|
+
|
|
379
|
+
arguments:
|
|
380
|
+
layer_order:
|
|
381
|
+
list of layer names in first-seen order.
|
|
382
|
+
converted_maps:
|
|
383
|
+
list of dicts mapping layer names to torch.Tensors.
|
|
384
|
+
weights:
|
|
385
|
+
tensor of shape (len(converted_maps),) with float32 weights summing to 1.
|
|
386
|
+
layers_description:
|
|
387
|
+
descriptions corresponding to each converted_map.
|
|
388
|
+
"""
|
|
389
|
+
out: dict[str, SteeringVector] = {}
|
|
390
|
+
for layer in layer_order:
|
|
391
|
+
combined, desc = cls._combine_for_layer(
|
|
392
|
+
layer=layer,
|
|
393
|
+
converted_maps=converted_maps,
|
|
394
|
+
weights=weights,
|
|
395
|
+
layers_description=layers_description,
|
|
396
|
+
)
|
|
397
|
+
if combined is None:
|
|
398
|
+
continue
|
|
399
|
+
out[layer] = SteeringVector(
|
|
400
|
+
vector=combined,
|
|
401
|
+
scale=scale,
|
|
402
|
+
normalize=normalize,
|
|
403
|
+
layer_description=desc,
|
|
404
|
+
)
|
|
405
|
+
return out
|
|
406
|
+
|
|
407
|
+
class HookHandleGroup:
|
|
408
|
+
"""
|
|
409
|
+
Manage a set of torch hooks to ensure clean detach.
|
|
410
|
+
"""
|
|
411
|
+
def __init__(self) -> None:
|
|
412
|
+
self._handles: list[torch.utils.hooks.RemovableHandle] = []
|
|
413
|
+
|
|
414
|
+
def add(self, handle: torch.utils.hooks.RemovableHandle) -> None:
|
|
415
|
+
self._handles.append(handle)
|
|
416
|
+
|
|
417
|
+
def remove_all(self) -> None:
|
|
418
|
+
while self._handles:
|
|
419
|
+
h = self._handles.pop()
|
|
420
|
+
try:
|
|
421
|
+
h.remove()
|
|
422
|
+
except Exception:
|
|
423
|
+
pass
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
@dataclass(slots=True)
|
|
427
|
+
class TopLogits:
|
|
428
|
+
"""
|
|
429
|
+
Info for a generated step.
|
|
430
|
+
|
|
431
|
+
attributes:
|
|
432
|
+
token_id:
|
|
433
|
+
chosen token id at this step.
|
|
434
|
+
logit:
|
|
435
|
+
raw logit for that token.
|
|
436
|
+
prob:
|
|
437
|
+
softmax probability for that token.
|
|
438
|
+
topk_ids/topk_probs:
|
|
439
|
+
optional top-k for analysis/visualization.
|
|
440
|
+
"""
|
|
441
|
+
token_id: int
|
|
442
|
+
logit: float
|
|
443
|
+
prob: float
|
|
444
|
+
topk_ids: list[int] | None = None
|
|
445
|
+
topk_probs: list[float] | None = None
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
@dataclass(slots=True)
|
|
449
|
+
class GenerationStats:
|
|
450
|
+
"""
|
|
451
|
+
Per-sequence stats for a generation call.
|
|
452
|
+
|
|
453
|
+
attributes:
|
|
454
|
+
tokens:
|
|
455
|
+
the generated token ids (excluding the prompt).
|
|
456
|
+
per_step:
|
|
457
|
+
optional list of TopLogits, one per generated step.
|
|
458
|
+
"""
|
|
459
|
+
tokens: list[int]
|
|
460
|
+
per_step: list[TopLogits] | None = None
|