wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,460 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING
6
+ import torch
7
+ from typing import Mapping
8
+
9
+ if TYPE_CHECKING:
10
+ from wisent.core.activations.core.atoms import RawActivationMap
11
+
12
+
13
+ __all__ = [
14
+ "SteeringVector",
15
+ "SteeringPlan",
16
+ "HookHandleGroup",
17
+ "TopLogits",
18
+ "GenerationStats",
19
+ ]
20
+
21
+
22
+ @dataclass(slots=True)
23
+ class SteeringVector:
24
+ """
25
+ Single steering vector added to a layer's residual stream (output).
26
+
27
+ arguments:
28
+ vector: tensor whose last dim == hidden_size. Shape may be [H], [1,H], [1,1,H] or [B,T,H].
29
+ scale: scalar coefficient (alpha) multiplied before adding.
30
+ normalize: L2-normalize the vector (safe + epsilon) before applying 'scale'.
31
+ layer_description: human-readable description of the steering vector. Like "toxic", "biased", etc.
32
+
33
+ example:
34
+ >>> sv = SteeringVector(
35
+ ... torch.randn(4096),
36
+ ... scale=0.8,
37
+ ... normalize=True,
38
+ ... layer_description="toxic"
39
+ ... )
40
+ """
41
+ vector: torch.Tensor
42
+ scale: float = 1.0
43
+ normalize: bool = False
44
+ layer_description: str = ""
45
+
46
+ def materialize(self, like: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Broadcast + cast the vector so it's addable to 'like' ([B, T, H]).
49
+ Returns a tensor on like.device and like.dtype.
50
+
51
+ returns:
52
+ Broadcast + cast the vector so it's addable to 'like' ([B, T, H]).
53
+
54
+ raises:
55
+ ValueError: if the vector shape is incompatible.
56
+ """
57
+ v = self.vector
58
+ if self.normalize and torch.is_floating_point(v):
59
+ denom = torch.linalg.vector_norm(v.float(), dim=-1, keepdim=True).clamp_min(1e-12)
60
+ v = v / denom
61
+
62
+ if v.dim() == 1: # [H] -> [1,1,H]
63
+ v = v.view(1, 1, -1)
64
+ elif v.dim() == 2: # [1,H] -> [1,1,H] or [B,H] -> [1,B,H] (still broadcastable)
65
+ v = v.view(1, *v.shape)
66
+ elif v.dim() == 3: # [B,T,H] fine
67
+ pass
68
+ else:
69
+ raise ValueError(
70
+ f"Unsupported steering vector shape {tuple(v.shape)}; "
71
+ f"expected [H], [1,H], [1,1,H], or [B,T,H]."
72
+ )
73
+
74
+ return v.to(dtype=like.dtype, device=like.device) * float(self.scale)
75
+
76
+ @dataclass(slots=True)
77
+ class SteeringPlan:
78
+ """
79
+ Plan for applying steering vectors to multiple layers. It supports linear
80
+ combinations of multiple steering layers (for the same llm layer).
81
+
82
+ attributes:
83
+ layers:
84
+ dict of layer_name -> SteeringVector to apply at that layer.
85
+ layers_description:
86
+ descriptions corresponding to each RawActivationMap, for example
87
+ "toxic", "biased", etc. These are used to build combined
88
+ per-layer descriptions.
89
+ """
90
+ layers: dict[str, SteeringVector] = field(default_factory=dict)
91
+ layers_description: list[str] = field(default_factory=list)
92
+
93
+ @classmethod
94
+ def from_raw(
95
+ cls,
96
+ raw: Sequence[RawActivationMap] | RawActivationMap | None,
97
+ layers_description: list[str] | None = None,
98
+ scale: float = 1.0,
99
+ normalize: bool = False,
100
+ weights: Sequence[float] | None = None,
101
+ expected_hidden_size: int | None = None,
102
+ ) -> SteeringPlan:
103
+ """
104
+ Build a SteeringPlan by merging one or more RawActivationMap(s).
105
+ Each RawActivationMap is: layer_name (str) -> torch.Tensor (or None to skip).
106
+ Each RawActivationMap corresponds to one description in layers_description.
107
+ The final steering vector at each layer is a weighted sum of the
108
+ contributions from each RawActivationMap.
109
+
110
+ arguments:
111
+ raw:
112
+ One or more RawActivationMap(s) to combine
113
+ layers_description:
114
+ Descriptions corresponding to each RawActivationMap, for example
115
+ "toxic", "biased", etc. These are used to build combined
116
+ per-layer descriptions.
117
+ scale:
118
+ Scalar coefficient (alpha) applied to all steering vectors.
119
+ normalize:
120
+ Whether to L2-normalize each steering vector before applying 'scale'.
121
+ weights:
122
+ Optional weights for each RawActivationMap when combining.
123
+ If None, uniform weights are used. Length must match number of maps.
124
+ expected_hidden_size:
125
+ If provided, validate that all steering vectors have this hidden size.
126
+ """
127
+ maps = cls._coerce_sequence(raw)
128
+ if layers_description is None:
129
+ layers_description = [f"steering_{i}" for i in range(len(maps))]
130
+
131
+ if len(layers_description) != len(maps):
132
+ raise ValueError("layers_description length must match number of maps.")
133
+
134
+ if not maps:
135
+ plan = cls(layers={}, layers_description=layers_description)
136
+ if expected_hidden_size is not None:
137
+ plan.validate_hidden_size(expected_hidden_size)
138
+ return plan
139
+
140
+ w = cls._normalize_weights(len(maps), weights)
141
+ conv = cls._convert_maps(maps)
142
+ order = cls._collect_layer_order(conv)
143
+
144
+ out_layers = cls._build_layers(
145
+ layer_order=order,
146
+ converted_maps=conv,
147
+ weights=w,
148
+ layers_description=layers_description,
149
+ scale=scale,
150
+ normalize=normalize,
151
+ )
152
+
153
+ plan = cls(layers=out_layers, layers_description=list(layers_description))
154
+
155
+ if expected_hidden_size is not None:
156
+ plan.validate_hidden_size(expected_hidden_size)
157
+ return plan
158
+
159
+ def validate_hidden_size(self, hidden_size: int) -> None:
160
+ """
161
+ Ensure all steering vectors have the specified hidden size.
162
+
163
+ arguments:
164
+ hidden_size: expected hidden size (last dim of steering vectors).
165
+
166
+ raises:
167
+ ValueError: if any steering vector has a mismatched hidden size.
168
+ """
169
+ for layer, sv in self.layers.items():
170
+ if sv.vector.shape[-1] != hidden_size:
171
+ raise ValueError(
172
+ f"Layer {layer} steering last dim {sv.vector.shape[-1]} "
173
+ f"!= hidden_size {hidden_size}"
174
+ )
175
+
176
+ def is_empty(self) -> bool:
177
+ """True if there are no layers."""
178
+ return not self.layers
179
+
180
+ @staticmethod
181
+ def _as_tensor(x: torch.Tensor | float | int) -> torch.Tensor:
182
+ return x if isinstance(x, torch.Tensor) else torch.as_tensor(x)
183
+
184
+ @staticmethod
185
+ def _normalize_weights(n: int, weights: Sequence[float] | None) -> torch.Tensor:
186
+ """
187
+ Return a length-n float32 tensor of weights that sums to 1.
188
+ If weights is None, use uniform weights. Raises on length mismatch or zero-sum.
189
+
190
+ arguments:
191
+ n:
192
+ number of activation maps (must be non-negative).
193
+ weights:
194
+ optional sequence of weights (length n) to normalize. If None, uniform weights are used.
195
+
196
+ returns:
197
+ A torch.Tensor of shape (n,) with float32 weights summing to 1.
198
+
199
+ raises:
200
+ ValueError: if n < 0, or if weights length != n, or if weights sum to 0.
201
+
202
+ example:
203
+ >>> SteeringPlan._normalize_weights(3, [0.2, 0.3, 0.5])
204
+ tensor([0.2000, 0.3000, 0.5000])
205
+ >>> SteeringPlan._normalize_weights(2, None)
206
+ tensor([0.5000, 0.5000])
207
+ >>> SteeringPlan._normalize_weights(0, None)
208
+ tensor([])
209
+ """
210
+ if n < 0:
211
+ raise ValueError("n must be non-negative.")
212
+ if n == 0:
213
+ return torch.empty(0, dtype=torch.float32)
214
+ if weights is None:
215
+ return torch.full((n,), 1.0 / n, dtype=torch.float32)
216
+
217
+ w = torch.as_tensor(weights, dtype=torch.float32)
218
+ if w.numel() != n:
219
+ raise ValueError(f"Length mismatch: {n} activation maps but {w.numel()} weights.")
220
+ s = float(w.sum())
221
+ if abs(s) < 1e-12:
222
+ raise ValueError("Weights sum to 0; cannot normalize.")
223
+ return w / s
224
+
225
+ @staticmethod
226
+ def _coerce_sequence(
227
+ raw: Sequence[RawActivationMap] | RawActivationMap | None,
228
+ ) -> list[RawActivationMap]:
229
+ """
230
+ Normalize input into a list[RawActivationMap].
231
+
232
+ arguments:
233
+ raw: A raw activation map or a sequence of them.
234
+
235
+ returns:
236
+ A list of RawActivationMap.
237
+
238
+ raises:
239
+ TypeError: if raw is not a Mapping or a sequence of them.
240
+
241
+ """
242
+ if raw is None:
243
+ return []
244
+ if isinstance(raw, Mapping):
245
+ return [raw]
246
+ if isinstance(raw, Sequence) and not isinstance(raw, (str, bytes)):
247
+ return [r or {} for r in raw]
248
+ raise TypeError(
249
+ "raw must be a Mapping[str, Tensor|None], a sequence of them, or None."
250
+ )
251
+
252
+ @classmethod
253
+ def _convert_maps(cls, maps: list[RawActivationMap]) -> list[dict[str, torch.Tensor]]:
254
+ """
255
+ Convert values to tensors and drop None entries early.
256
+
257
+ arguments:
258
+ maps: list of RawActivationMap to convert.
259
+
260
+ returns:
261
+ A list of dicts mapping layer names to torch.Tensors.
262
+
263
+ raises:
264
+ None
265
+ """
266
+ out: list[dict[str, torch.Tensor]] = []
267
+ for mapping in maps:
268
+ conv: dict[str, torch.Tensor] = {}
269
+ for k, v in mapping.items():
270
+ if v is None:
271
+ continue
272
+ conv[str(k)] = cls._as_tensor(v)
273
+ out.append(conv)
274
+ return out
275
+
276
+ @staticmethod
277
+ def _collect_layer_order(converted_maps: list[dict[str, torch.Tensor]]) -> list[str]:
278
+ """
279
+ First-seen layer order across all maps.
280
+
281
+ arguments:
282
+ converted_maps: list of dicts mapping layer names to torch.Tensors.
283
+
284
+ returns:
285
+ A list of layer names in first-seen order.
286
+
287
+ example:
288
+ >>> maps = [
289
+ ... {"layer1": torch.randn(4), "layer2": torch.randn(4)},
290
+ ... {"layer2": torch.randn(4), "layer3": torch.randn(4)},
291
+ ... {"layer1": torch.randn(4), "layer4": torch.randn(4)},
292
+ ... ]
293
+ >>> SteeringPlan._collect_layer_order(maps)
294
+ ['layer1', 'layer2', 'layer3', 'layer4']
295
+ """
296
+ return list(dict.fromkeys(k for m in converted_maps for k in m.keys()))
297
+
298
+ @staticmethod
299
+ def _combine_for_layer(
300
+ layer: str,
301
+ converted_maps: list[dict[str, torch.Tensor]],
302
+ weights: torch.Tensor,
303
+ layers_description: Sequence[str],
304
+ ) -> tuple[torch.Tensor | None, str]:
305
+ """
306
+ Combine weighted vectors for a single layer and build a combined description.
307
+
308
+ arguments:
309
+ layer:
310
+ the layer name to combine.
311
+ converted_maps:
312
+ list of dicts mapping layer names to torch.Tensors.
313
+ weights:
314
+ tensor of shape (len(converted_maps),) with float32 weights summing to 1.
315
+ layers_description:
316
+ descriptions corresponding to each converted_map.
317
+
318
+ returns:
319
+ A tuple containing the combined tensor (or None) and the description string.
320
+
321
+ raises:
322
+ ValueError: if hidden sizes mismatch across maps for this layer.
323
+
324
+ example:
325
+ >>> maps = [
326
+ ... {"layer1": torch.tensor([1.0, 2.0]), "layer2": torch.tensor([3.0, 4.0])},
327
+ ... {"layer2": torch.tensor([5.0, 6.0]), "layer3": torch.tensor([7.0, 8.0])},
328
+ ... ]
329
+ >>> weights = torch.tensor([0.4, 0.6])
330
+ >>> descs = ["toxic", "biased"]
331
+ >>> combined, desc = SteeringPlan._combine_for_layer("layer2", maps, weights, descs)
332
+ >>> print(combined) # tensor([4.2, 5.2])
333
+ >>> print(desc) # "toxic + biased"
334
+ """
335
+ combined: torch.Tensor | None = None
336
+ hidden_size: int | None = None
337
+ desc_parts: list[str] = []
338
+
339
+ for i, m in enumerate(converted_maps):
340
+ v = m.get(layer)
341
+ if v is None:
342
+ continue
343
+
344
+ last_dim = v.shape[-1]
345
+ if hidden_size is None:
346
+ hidden_size = last_dim
347
+ elif last_dim != hidden_size:
348
+ raise ValueError(
349
+ f"Layer {layer} has mismatched hidden sizes across maps: "
350
+ f"{hidden_size} vs {last_dim}."
351
+ )
352
+
353
+ scaled_v = v * float(weights[i])
354
+ if combined is None:
355
+ combined = scaled_v.clone()
356
+ else:
357
+ combined.add_(scaled_v)
358
+
359
+ desc = layers_description[i]
360
+ if desc not in desc_parts:
361
+ desc_parts.append(desc)
362
+
363
+ return combined, " + ".join(desc_parts)
364
+
365
+ @classmethod
366
+ def _build_layers(
367
+ cls,
368
+ layer_order: list[str],
369
+ converted_maps: list[dict[str, torch.Tensor]],
370
+ weights: torch.Tensor,
371
+ layers_description: Sequence[str],
372
+ scale: float,
373
+ normalize: bool,
374
+ ) -> dict[str, SteeringVector]:
375
+ """
376
+ Iterate over layer_order, combine per-layer contributions, and
377
+ construct SteeringVector objects.
378
+
379
+ arguments:
380
+ layer_order:
381
+ list of layer names in first-seen order.
382
+ converted_maps:
383
+ list of dicts mapping layer names to torch.Tensors.
384
+ weights:
385
+ tensor of shape (len(converted_maps),) with float32 weights summing to 1.
386
+ layers_description:
387
+ descriptions corresponding to each converted_map.
388
+ """
389
+ out: dict[str, SteeringVector] = {}
390
+ for layer in layer_order:
391
+ combined, desc = cls._combine_for_layer(
392
+ layer=layer,
393
+ converted_maps=converted_maps,
394
+ weights=weights,
395
+ layers_description=layers_description,
396
+ )
397
+ if combined is None:
398
+ continue
399
+ out[layer] = SteeringVector(
400
+ vector=combined,
401
+ scale=scale,
402
+ normalize=normalize,
403
+ layer_description=desc,
404
+ )
405
+ return out
406
+
407
+ class HookHandleGroup:
408
+ """
409
+ Manage a set of torch hooks to ensure clean detach.
410
+ """
411
+ def __init__(self) -> None:
412
+ self._handles: list[torch.utils.hooks.RemovableHandle] = []
413
+
414
+ def add(self, handle: torch.utils.hooks.RemovableHandle) -> None:
415
+ self._handles.append(handle)
416
+
417
+ def remove_all(self) -> None:
418
+ while self._handles:
419
+ h = self._handles.pop()
420
+ try:
421
+ h.remove()
422
+ except Exception:
423
+ pass
424
+
425
+
426
+ @dataclass(slots=True)
427
+ class TopLogits:
428
+ """
429
+ Info for a generated step.
430
+
431
+ attributes:
432
+ token_id:
433
+ chosen token id at this step.
434
+ logit:
435
+ raw logit for that token.
436
+ prob:
437
+ softmax probability for that token.
438
+ topk_ids/topk_probs:
439
+ optional top-k for analysis/visualization.
440
+ """
441
+ token_id: int
442
+ logit: float
443
+ prob: float
444
+ topk_ids: list[int] | None = None
445
+ topk_probs: list[float] | None = None
446
+
447
+
448
+ @dataclass(slots=True)
449
+ class GenerationStats:
450
+ """
451
+ Per-sequence stats for a generation call.
452
+
453
+ attributes:
454
+ tokens:
455
+ the generated token ids (excluding the prompt).
456
+ per_step:
457
+ optional list of TopLogits, one per generated step.
458
+ """
459
+ tokens: list[int]
460
+ per_step: list[TopLogits] | None = None