pysteer-adaptation 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. activation_manager/ActivationExtractor.py +407 -0
  2. activation_manager/SteeredModelWrapper.py +1327 -0
  3. activation_manager/VectorMediator.py +161 -0
  4. activation_manager/__init__.py +3 -0
  5. enums/ApplyFromModeEnum.py +9 -0
  6. enums/MediatorGroupRepresentativeEnum.py +9 -0
  7. enums/ModelFamilyEnum.py +12 -0
  8. enums/ModelTypeEnum.py +9 -0
  9. enums/TaskTypeEnum.py +8 -0
  10. enums/__init__.py +15 -0
  11. executor.py +775 -0
  12. prompt_generator/BasePromptGenerator.py +63 -0
  13. prompt_generator/Gemma3PromptGenerator.py +22 -0
  14. prompt_generator/Llama3Point1PromptGenerator.py +23 -0
  15. prompt_generator/MistralV0Point3PromptGenerator.py +21 -0
  16. prompt_generator/OLMo2PromptGenerator.py +23 -0
  17. prompt_generator/Qwen2Point5PromptGenerator.py +23 -0
  18. prompt_generator/__init__.py +17 -0
  19. pysteer/__init__.py +15 -0
  20. pysteer_adaptation-0.1.1.dist-info/METADATA +283 -0
  21. pysteer_adaptation-0.1.1.dist-info/RECORD +50 -0
  22. pysteer_adaptation-0.1.1.dist-info/WHEEL +5 -0
  23. pysteer_adaptation-0.1.1.dist-info/licenses/LICENSE.txt +373 -0
  24. pysteer_adaptation-0.1.1.dist-info/top_level.txt +9 -0
  25. steering_engine/__init__.py +65 -0
  26. steering_engine/components.py +191 -0
  27. steering_engine/defaults.py +401 -0
  28. steering_engine/domain.py +165 -0
  29. steering_engine/executor_services.py +1074 -0
  30. steering_engine/registry.py +151 -0
  31. steering_engine/runtime.py +47 -0
  32. steering_strategy/ActsSteeringStrategy.py +440 -0
  33. steering_strategy/AdaptiveActivationSteeringStrategy.py +222 -0
  34. steering_strategy/AngularSteeringStrategy.py +122 -0
  35. steering_strategy/BaseSteeringStrategy.py +37 -0
  36. steering_strategy/GeneralSteeringStrategy.py +146 -0
  37. steering_strategy/MbsSteeringStrategy.py +34 -0
  38. steering_strategy/__init__.py +3 -0
  39. utils/ModelUtils.py +374 -0
  40. utils/StringUtils.py +28 -0
  41. utils/__init__.py +17 -0
  42. vector_update_strategy/ActsVectorMediator.py +607 -0
  43. vector_update_strategy/AdaptiveActivationVectorMediator.py +223 -0
  44. vector_update_strategy/AngularVectorMediator.py +178 -0
  45. vector_update_strategy/BaseVectorUpdateStrategy.py +38 -0
  46. vector_update_strategy/CmdVectorMediator.py +294 -0
  47. vector_update_strategy/ColdKernelGradientMediator.py +96 -0
  48. vector_update_strategy/CpcaVectorMediator.py +214 -0
  49. vector_update_strategy/MbsCmdVectorMediator.py +118 -0
  50. vector_update_strategy/__init__.py +3 -0
@@ -0,0 +1,407 @@
1
+ """Forward-hook based activation collection for transformer residual streams."""
2
+
3
+ import contextlib
4
+ import logging
5
+ import warnings
6
+ from typing import Any, Dict, Iterable, List, Optional, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+ from torch.utils.hooks import RemovableHandle
12
+ from utils.ModelUtils import ModelUtils
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ActivationExtractor:
18
+ """Collect residual-stream activations from selected transformer layers.
19
+
20
+ The extractor registers forward hooks only while attached. Captured tensors
21
+ are normalized to ``[batch, sequence, hidden]`` layout, optionally offloaded
22
+ to CPU, and concatenated across prefill/decode chunks by :meth:`finalize`.
23
+ """
24
+ def __init__(
25
+ self,
26
+ model: Any,
27
+ layers_to_extract: Union[Iterable[int], int],
28
+ *,
29
+ offload_to_cpu: bool = True,
30
+ decode_chunk_max: int = 1,
31
+ ):
32
+ """Initialize the extractor.
33
+
34
+ Args:
35
+ model: PyTorch model whose transformer blocks should be hooked.
36
+ layers_to_extract: Layer index or indexes to collect.
37
+ offload_to_cpu: Move captured activations to CPU before storing.
38
+ decode_chunk_max: Maximum sequence length treated as one decode
39
+ chunk after prefill.
40
+
41
+ Raises:
42
+ TypeError: If ``model`` is not a ``torch.nn.Module``.
43
+ IndexError: If a requested layer is outside the discovered stack.
44
+ """
45
+ if not isinstance(model, nn.Module):
46
+ raise TypeError("ActivationExtractor expects 'model' to be a torch.nn.Module")
47
+
48
+ self._model = model
49
+ self.layers_to_extract = layers_to_extract
50
+ self.handles: List[RemovableHandle] = []
51
+ self._offload_to_cpu = bool(offload_to_cpu)
52
+
53
+ base = ModelUtils.unwrap_model(self._model)
54
+ self._hidden_size = ModelUtils.get_hidden_size(base.config)
55
+ self._layers = list(ModelUtils.find_transformer_layers(self._model))
56
+
57
+ self._hook_errors: List[tuple[int, Exception]] = []
58
+ self._chunks: Dict[int, List[Tensor]] = {}
59
+ self._activations: Dict[int, Tensor] = {}
60
+
61
+ self._seen_seq_len: Dict[int, int] = {}
62
+ self._phase: Dict[int, str] = {}
63
+ self._prefill_len: Dict[int, int] = {}
64
+ self._decode_chunk_max: int = max(1, int(decode_chunk_max))
65
+
66
+ for idx in self.layers_to_extract:
67
+ if idx < 0 or idx >= len(self._layers):
68
+ raise IndexError(
69
+ f"Layer index {idx} out of range for model with {len(self._layers)} layers."
70
+ )
71
+
72
+ def _reset_layer_state(self, idx: int) -> None:
73
+ """Reset layer state."""
74
+ self._chunks.pop(idx, None)
75
+ self._activations.pop(idx, None)
76
+ self._seen_seq_len.pop(idx, None)
77
+ self._phase.pop(idx, None)
78
+ self._prefill_len.pop(idx, None)
79
+
80
+ @property
81
+ def model(self) -> Any:
82
+ """Model."""
83
+ return self._model
84
+
85
+ @model.setter
86
+ def model(self, value: Any) -> None:
87
+ """Set the model value after validation."""
88
+ raise AttributeError("ActivationExtractor.model is immutable after initialization.")
89
+
90
+ @property
91
+ def layers_to_extract(self) -> List[int]:
92
+ """Layers to extract."""
93
+ return self._layers_to_extract
94
+
95
+ @layers_to_extract.setter
96
+ def layers_to_extract(self, value: Union[Iterable[int], int]) -> None:
97
+ """Set the layers_to_extract value after validation."""
98
+ if isinstance(value, int):
99
+ value = [value]
100
+ self._layers_to_extract = sorted({int(i) for i in value})
101
+ if not self._layers_to_extract:
102
+ raise ValueError("At least one layer must be specified for extraction")
103
+
104
+ @property
105
+ def handles(self) -> List[RemovableHandle]:
106
+ """Handles."""
107
+ return self._handles
108
+
109
+ @handles.setter
110
+ def handles(self, value: List[RemovableHandle]) -> None:
111
+ """Set the handles value after validation."""
112
+ self._handles = value
113
+
114
+ @property
115
+ def activations(self) -> Dict[int, Tensor]:
116
+ """Activations."""
117
+ return self._activations
118
+
119
+ def _make_hook(self, idx: int):
120
+ """Create hook helper data."""
121
+ @torch.inference_mode()
122
+ def hook(_module, inputs, output):
123
+ """Forward hook used to capture or replace hidden states."""
124
+ try:
125
+ if hasattr(output, "last_hidden_state"):
126
+ t = output.last_hidden_state
127
+ elif isinstance(output, (tuple, list)) and len(output) > 0:
128
+ t = output[0]
129
+ else:
130
+ t = output
131
+
132
+ if not torch.is_tensor(t):
133
+ raise RuntimeError(
134
+ f"ActivationExtractor: layer {idx} hook output is not a tensor "
135
+ f"(type={type(t).__name__})."
136
+ )
137
+
138
+ t = ModelUtils.ensure_bsh(
139
+ t,
140
+ self._hidden_size,
141
+ from_layout="BSH",
142
+ )
143
+
144
+ if not (torch.is_tensor(t) and t.dim() == 3):
145
+ raise RuntimeError(
146
+ f"ActivationExtractor: layer {idx} expected 3D tensor (B,S,H), "
147
+ f"got {None if not torch.is_tensor(t) else tuple(t.shape)}."
148
+ )
149
+
150
+ s = int(t.shape[1])
151
+
152
+ already_have = bool(self._chunks.get(idx)) or (idx in self._activations)
153
+ if already_have and s > self._decode_chunk_max:
154
+ msg = (
155
+ f"ActivationExtractor: detected seq_len={s} (> decode_chunk_max={self._decode_chunk_max}) "
156
+ f"on layer {idx} while previous activations are still accumulated. "
157
+ "This often means a new prefill happened without clear(). "
158
+ "Resetting accumulation for this layer to avoid mixing."
159
+ )
160
+ warnings.warn(msg, RuntimeWarning, stacklevel=2)
161
+ logger.warning("%s", msg)
162
+ self._reset_layer_state(idx)
163
+
164
+ phase = self._phase.get(idx)
165
+ if phase is None:
166
+ if s > 1:
167
+ self._phase[idx] = "prefill"
168
+ self._prefill_len[idx] = s
169
+ else:
170
+ self._phase[idx] = "decode"
171
+ self._seen_seq_len[idx] = s
172
+ else:
173
+ if phase == "prefill":
174
+ prefill_s = self._prefill_len.get(idx, self._seen_seq_len.get(idx, s))
175
+
176
+ if s == prefill_s:
177
+ self._seen_seq_len[idx] = s
178
+ elif s <= self._decode_chunk_max:
179
+ self._phase[idx] = "decode"
180
+ self._seen_seq_len[idx] = s
181
+ else:
182
+ raise RuntimeError(
183
+ f"ActivationExtractor: layer {idx} saw unexpected seq_len change during prefill: "
184
+ f"prefill_seq_len={prefill_s} new_seq_len={s}. "
185
+ "Call clear() between independent forwards/generations."
186
+ )
187
+ else:
188
+ if s <= self._decode_chunk_max:
189
+ self._seen_seq_len[idx] = s
190
+ else:
191
+ msg = (
192
+ f"ActivationExtractor: layer {idx} saw seq_len={s} during decode phase; "
193
+ f"expected <= {self._decode_chunk_max}. "
194
+ "This likely indicates an accidental new prefill without clear(). "
195
+ "Resetting this layer state and treating this as a new prefill."
196
+ )
197
+ warnings.warn(msg, RuntimeWarning, stacklevel=2)
198
+ logger.warning("%s", msg)
199
+ self._reset_layer_state(idx)
200
+ self._phase[idx] = "prefill" if s > 1 else "decode"
201
+ if s > 1:
202
+ self._prefill_len[idx] = s
203
+ self._seen_seq_len[idx] = s
204
+
205
+ t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0).detach()
206
+
207
+ if self._offload_to_cpu:
208
+ if t.device.type == "cuda":
209
+ t = t.to("cpu", non_blocking=True)
210
+ else:
211
+ t = t.to("cpu")
212
+
213
+ chunks = self._chunks.setdefault(idx, [])
214
+ prev = chunks[-1] if chunks else self._activations.get(idx, None)
215
+
216
+ if (
217
+ prev is not None
218
+ and (
219
+ (not torch.is_tensor(prev))
220
+ or prev.dim() != 3
221
+ or t.shape[0] != prev.shape[0]
222
+ or t.shape[2] != prev.shape[2]
223
+ )
224
+ ):
225
+ msg = (
226
+ f"ActivationExtractor: shape mismatch on layer {idx}: "
227
+ f"prev={None if prev is None else tuple(prev.shape)} new={tuple(t.shape)}."
228
+ )
229
+ logger.warning("%s Resetting accumulation.", msg)
230
+ self._reset_layer_state(idx)
231
+ raise RuntimeError(msg)
232
+
233
+ chunks.append(t)
234
+
235
+ except Exception as e:
236
+ self._reset_layer_state(idx)
237
+ self._hook_errors.append((idx, e))
238
+ raise
239
+ return hook
240
+
241
+ def assert_ok(self) -> None:
242
+ """Raise the first stored hook or update error, if any."""
243
+ if self._hook_errors:
244
+ idx, e = self._hook_errors[0]
245
+ raise RuntimeError(f"Activation hook failed on layer {idx}: {e}") from e
246
+
247
+ def finalize(self, *, clear_chunks: bool = True) -> Dict[int, Tensor]:
248
+ """Concatenate captured chunks into one tensor per layer.
249
+
250
+ Args:
251
+ clear_chunks: Clear chunk buffers after finalization.
252
+
253
+ Returns:
254
+ Mapping from layer index to captured activation tensor.
255
+
256
+ Raises:
257
+ RuntimeError: If chunks for a layer have incompatible shapes.
258
+ """
259
+ for idx, chunks in self._chunks.copy().items():
260
+ if not chunks:
261
+ continue
262
+
263
+ if idx in self._activations:
264
+ pieces: List[Tensor] = [self._activations.pop(idx)] + chunks
265
+ else:
266
+ pieces = chunks
267
+
268
+ try:
269
+ b0 = int(pieces[0].shape[0])
270
+ h0 = int(pieces[0].shape[2])
271
+ except Exception:
272
+ if clear_chunks:
273
+ self._chunks.pop(idx, None)
274
+ raise RuntimeError(f"ActivationExtractor.finalize: bad tensor shape at layer {idx}.")
275
+
276
+ ok_mask: List[bool] = []
277
+ for x in pieces:
278
+ ok_mask.append(
279
+ torch.is_tensor(x)
280
+ and x.dim() == 3
281
+ and int(x.shape[0]) == b0
282
+ and int(x.shape[2]) == h0
283
+ )
284
+
285
+ ok = [pieces[i] for i, good in enumerate(ok_mask) if good]
286
+ if len(ok) != len(pieces):
287
+ bad_shapes = [
288
+ None if not torch.is_tensor(pieces[i]) else tuple(pieces[i].shape)
289
+ for i, good in enumerate(ok_mask) if not good
290
+ ]
291
+ raise RuntimeError(
292
+ f"ActivationExtractor.finalize: incompatible chunks at layer {idx}: {bad_shapes}"
293
+ )
294
+
295
+ if not ok:
296
+ if clear_chunks:
297
+ self._chunks.pop(idx, None)
298
+ continue
299
+
300
+ self._activations[idx] = torch.cat(ok, dim=1) if len(ok) > 1 else ok[0]
301
+
302
+ if clear_chunks:
303
+ self._chunks.pop(idx, None)
304
+ else:
305
+ self._chunks[idx] = []
306
+
307
+ if clear_chunks:
308
+ self._chunks.clear()
309
+
310
+ return self._activations
311
+
312
+ def attach(self, *, clear_activations: bool = True) -> None:
313
+ """Attach hooks to the configured model layers.
314
+
315
+ Args:
316
+ clear_activations: Clear previous activations before attaching.
317
+
318
+ Raises:
319
+ IndexError: If layer discovery changed and a requested layer is now
320
+ out of range.
321
+ RuntimeError: If any hook cannot be registered.
322
+ """
323
+ self._detach_hooks_only()
324
+ if clear_activations:
325
+ self.clear()
326
+
327
+ self._layers = list(ModelUtils.find_transformer_layers(self._model))
328
+ try:
329
+ base = ModelUtils.unwrap_model(self._model)
330
+ cfg = getattr(base, "config", None)
331
+ expected = None
332
+ for k in ("num_hidden_layers", "n_layer", "num_layers", "n_layers"):
333
+ if cfg is not None and hasattr(cfg, k):
334
+ expected = int(getattr(cfg, k))
335
+ break
336
+ if expected is not None and 0 < expected != len(self._layers):
337
+ logger.warning(
338
+ "ActivationExtractor: find_transformer_layers() found %d layers, "
339
+ "but config expects %d. Hooks may target wrong modules (pre/post norm mismatch risk).",
340
+ len(self._layers), expected
341
+ )
342
+ except Exception:
343
+ pass
344
+ n = len(self._layers)
345
+ bad = [i for i in self.layers_to_extract if i < 0 or i >= n]
346
+ if bad:
347
+ raise IndexError(f"Layer index(es) {bad} out of range for model with {n} layers.")
348
+
349
+ new_handles: List[RemovableHandle] = []
350
+ try:
351
+ for idx in self.layers_to_extract:
352
+ block = self._layers[idx]
353
+ h = block.register_forward_hook(self._make_hook(idx))
354
+ new_handles.append(h)
355
+ self.handles = new_handles
356
+ except Exception as e:
357
+ self.remove()
358
+ raise RuntimeError(f"Failed to attach activation hooks: {e}") from e
359
+
360
+ def clear(self) -> None:
361
+ """Clear accumulated tensors and error state."""
362
+ self._chunks.clear()
363
+ self._activations.clear()
364
+ self._hook_errors.clear()
365
+ self._seen_seq_len.clear()
366
+ self._phase.clear()
367
+ self._prefill_len.clear()
368
+
369
+ def _detach_hooks_only(self) -> None:
370
+ """Detach hooks only."""
371
+ for h in self.handles.copy():
372
+ try:
373
+ h.remove()
374
+ except Exception:
375
+ pass
376
+ self.handles.clear()
377
+
378
+ def remove(self) -> None:
379
+ """Remove hooks and clear accumulated state."""
380
+ self._detach_hooks_only()
381
+ self.clear()
382
+
383
+ def close(self) -> None:
384
+ """Close the object by releasing managed hooks."""
385
+ self.remove()
386
+
387
+ def __enter__(self):
388
+ """Enter the context manager and activate managed resources."""
389
+ self.attach(clear_activations=True)
390
+ return self
391
+
392
+ def __exit__(self, exc_type, exc_val, exc_tb):
393
+ """Exit the context manager and release managed resources."""
394
+ self.remove()
395
+ return False
396
+
397
+ @contextlib.contextmanager
398
+ def temporarily_disabled(self):
399
+ """Temporarily detach hooks and restore them afterward."""
400
+ was_on = bool(self.handles)
401
+ if was_on:
402
+ self._detach_hooks_only()
403
+ try:
404
+ yield
405
+ finally:
406
+ if was_on:
407
+ self.attach(clear_activations=False)