pysteer-adaptation 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- activation_manager/ActivationExtractor.py +407 -0
- activation_manager/SteeredModelWrapper.py +1327 -0
- activation_manager/VectorMediator.py +161 -0
- activation_manager/__init__.py +3 -0
- enums/ApplyFromModeEnum.py +9 -0
- enums/MediatorGroupRepresentativeEnum.py +9 -0
- enums/ModelFamilyEnum.py +12 -0
- enums/ModelTypeEnum.py +9 -0
- enums/TaskTypeEnum.py +8 -0
- enums/__init__.py +15 -0
- executor.py +775 -0
- prompt_generator/BasePromptGenerator.py +63 -0
- prompt_generator/Gemma3PromptGenerator.py +22 -0
- prompt_generator/Llama3Point1PromptGenerator.py +23 -0
- prompt_generator/MistralV0Point3PromptGenerator.py +21 -0
- prompt_generator/OLMo2PromptGenerator.py +23 -0
- prompt_generator/Qwen2Point5PromptGenerator.py +23 -0
- prompt_generator/__init__.py +17 -0
- pysteer/__init__.py +15 -0
- pysteer_adaptation-0.1.1.dist-info/METADATA +283 -0
- pysteer_adaptation-0.1.1.dist-info/RECORD +50 -0
- pysteer_adaptation-0.1.1.dist-info/WHEEL +5 -0
- pysteer_adaptation-0.1.1.dist-info/licenses/LICENSE.txt +373 -0
- pysteer_adaptation-0.1.1.dist-info/top_level.txt +9 -0
- steering_engine/__init__.py +65 -0
- steering_engine/components.py +191 -0
- steering_engine/defaults.py +401 -0
- steering_engine/domain.py +165 -0
- steering_engine/executor_services.py +1074 -0
- steering_engine/registry.py +151 -0
- steering_engine/runtime.py +47 -0
- steering_strategy/ActsSteeringStrategy.py +440 -0
- steering_strategy/AdaptiveActivationSteeringStrategy.py +222 -0
- steering_strategy/AngularSteeringStrategy.py +122 -0
- steering_strategy/BaseSteeringStrategy.py +37 -0
- steering_strategy/GeneralSteeringStrategy.py +146 -0
- steering_strategy/MbsSteeringStrategy.py +34 -0
- steering_strategy/__init__.py +3 -0
- utils/ModelUtils.py +374 -0
- utils/StringUtils.py +28 -0
- utils/__init__.py +17 -0
- vector_update_strategy/ActsVectorMediator.py +607 -0
- vector_update_strategy/AdaptiveActivationVectorMediator.py +223 -0
- vector_update_strategy/AngularVectorMediator.py +178 -0
- vector_update_strategy/BaseVectorUpdateStrategy.py +38 -0
- vector_update_strategy/CmdVectorMediator.py +294 -0
- vector_update_strategy/ColdKernelGradientMediator.py +96 -0
- vector_update_strategy/CpcaVectorMediator.py +214 -0
- vector_update_strategy/MbsCmdVectorMediator.py +118 -0
- vector_update_strategy/__init__.py +3 -0
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
"""Forward-hook based activation collection for transformer residual streams."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import logging
|
|
5
|
+
import warnings
|
|
6
|
+
from typing import Any, Dict, Iterable, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
from torch.utils.hooks import RemovableHandle
|
|
12
|
+
from utils.ModelUtils import ModelUtils
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ActivationExtractor:
|
|
18
|
+
"""Collect residual-stream activations from selected transformer layers.
|
|
19
|
+
|
|
20
|
+
The extractor registers forward hooks only while attached. Captured tensors
|
|
21
|
+
are normalized to ``[batch, sequence, hidden]`` layout, optionally offloaded
|
|
22
|
+
to CPU, and concatenated across prefill/decode chunks by :meth:`finalize`.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model: Any,
|
|
27
|
+
layers_to_extract: Union[Iterable[int], int],
|
|
28
|
+
*,
|
|
29
|
+
offload_to_cpu: bool = True,
|
|
30
|
+
decode_chunk_max: int = 1,
|
|
31
|
+
):
|
|
32
|
+
"""Initialize the extractor.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model: PyTorch model whose transformer blocks should be hooked.
|
|
36
|
+
layers_to_extract: Layer index or indexes to collect.
|
|
37
|
+
offload_to_cpu: Move captured activations to CPU before storing.
|
|
38
|
+
decode_chunk_max: Maximum sequence length treated as one decode
|
|
39
|
+
chunk after prefill.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
TypeError: If ``model`` is not a ``torch.nn.Module``.
|
|
43
|
+
IndexError: If a requested layer is outside the discovered stack.
|
|
44
|
+
"""
|
|
45
|
+
if not isinstance(model, nn.Module):
|
|
46
|
+
raise TypeError("ActivationExtractor expects 'model' to be a torch.nn.Module")
|
|
47
|
+
|
|
48
|
+
self._model = model
|
|
49
|
+
self.layers_to_extract = layers_to_extract
|
|
50
|
+
self.handles: List[RemovableHandle] = []
|
|
51
|
+
self._offload_to_cpu = bool(offload_to_cpu)
|
|
52
|
+
|
|
53
|
+
base = ModelUtils.unwrap_model(self._model)
|
|
54
|
+
self._hidden_size = ModelUtils.get_hidden_size(base.config)
|
|
55
|
+
self._layers = list(ModelUtils.find_transformer_layers(self._model))
|
|
56
|
+
|
|
57
|
+
self._hook_errors: List[tuple[int, Exception]] = []
|
|
58
|
+
self._chunks: Dict[int, List[Tensor]] = {}
|
|
59
|
+
self._activations: Dict[int, Tensor] = {}
|
|
60
|
+
|
|
61
|
+
self._seen_seq_len: Dict[int, int] = {}
|
|
62
|
+
self._phase: Dict[int, str] = {}
|
|
63
|
+
self._prefill_len: Dict[int, int] = {}
|
|
64
|
+
self._decode_chunk_max: int = max(1, int(decode_chunk_max))
|
|
65
|
+
|
|
66
|
+
for idx in self.layers_to_extract:
|
|
67
|
+
if idx < 0 or idx >= len(self._layers):
|
|
68
|
+
raise IndexError(
|
|
69
|
+
f"Layer index {idx} out of range for model with {len(self._layers)} layers."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def _reset_layer_state(self, idx: int) -> None:
|
|
73
|
+
"""Reset layer state."""
|
|
74
|
+
self._chunks.pop(idx, None)
|
|
75
|
+
self._activations.pop(idx, None)
|
|
76
|
+
self._seen_seq_len.pop(idx, None)
|
|
77
|
+
self._phase.pop(idx, None)
|
|
78
|
+
self._prefill_len.pop(idx, None)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def model(self) -> Any:
|
|
82
|
+
"""Model."""
|
|
83
|
+
return self._model
|
|
84
|
+
|
|
85
|
+
@model.setter
|
|
86
|
+
def model(self, value: Any) -> None:
|
|
87
|
+
"""Set the model value after validation."""
|
|
88
|
+
raise AttributeError("ActivationExtractor.model is immutable after initialization.")
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def layers_to_extract(self) -> List[int]:
|
|
92
|
+
"""Layers to extract."""
|
|
93
|
+
return self._layers_to_extract
|
|
94
|
+
|
|
95
|
+
@layers_to_extract.setter
|
|
96
|
+
def layers_to_extract(self, value: Union[Iterable[int], int]) -> None:
|
|
97
|
+
"""Set the layers_to_extract value after validation."""
|
|
98
|
+
if isinstance(value, int):
|
|
99
|
+
value = [value]
|
|
100
|
+
self._layers_to_extract = sorted({int(i) for i in value})
|
|
101
|
+
if not self._layers_to_extract:
|
|
102
|
+
raise ValueError("At least one layer must be specified for extraction")
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def handles(self) -> List[RemovableHandle]:
|
|
106
|
+
"""Handles."""
|
|
107
|
+
return self._handles
|
|
108
|
+
|
|
109
|
+
@handles.setter
|
|
110
|
+
def handles(self, value: List[RemovableHandle]) -> None:
|
|
111
|
+
"""Set the handles value after validation."""
|
|
112
|
+
self._handles = value
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def activations(self) -> Dict[int, Tensor]:
|
|
116
|
+
"""Activations."""
|
|
117
|
+
return self._activations
|
|
118
|
+
|
|
119
|
+
def _make_hook(self, idx: int):
|
|
120
|
+
"""Create hook helper data."""
|
|
121
|
+
@torch.inference_mode()
|
|
122
|
+
def hook(_module, inputs, output):
|
|
123
|
+
"""Forward hook used to capture or replace hidden states."""
|
|
124
|
+
try:
|
|
125
|
+
if hasattr(output, "last_hidden_state"):
|
|
126
|
+
t = output.last_hidden_state
|
|
127
|
+
elif isinstance(output, (tuple, list)) and len(output) > 0:
|
|
128
|
+
t = output[0]
|
|
129
|
+
else:
|
|
130
|
+
t = output
|
|
131
|
+
|
|
132
|
+
if not torch.is_tensor(t):
|
|
133
|
+
raise RuntimeError(
|
|
134
|
+
f"ActivationExtractor: layer {idx} hook output is not a tensor "
|
|
135
|
+
f"(type={type(t).__name__})."
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
t = ModelUtils.ensure_bsh(
|
|
139
|
+
t,
|
|
140
|
+
self._hidden_size,
|
|
141
|
+
from_layout="BSH",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if not (torch.is_tensor(t) and t.dim() == 3):
|
|
145
|
+
raise RuntimeError(
|
|
146
|
+
f"ActivationExtractor: layer {idx} expected 3D tensor (B,S,H), "
|
|
147
|
+
f"got {None if not torch.is_tensor(t) else tuple(t.shape)}."
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
s = int(t.shape[1])
|
|
151
|
+
|
|
152
|
+
already_have = bool(self._chunks.get(idx)) or (idx in self._activations)
|
|
153
|
+
if already_have and s > self._decode_chunk_max:
|
|
154
|
+
msg = (
|
|
155
|
+
f"ActivationExtractor: detected seq_len={s} (> decode_chunk_max={self._decode_chunk_max}) "
|
|
156
|
+
f"on layer {idx} while previous activations are still accumulated. "
|
|
157
|
+
"This often means a new prefill happened without clear(). "
|
|
158
|
+
"Resetting accumulation for this layer to avoid mixing."
|
|
159
|
+
)
|
|
160
|
+
warnings.warn(msg, RuntimeWarning, stacklevel=2)
|
|
161
|
+
logger.warning("%s", msg)
|
|
162
|
+
self._reset_layer_state(idx)
|
|
163
|
+
|
|
164
|
+
phase = self._phase.get(idx)
|
|
165
|
+
if phase is None:
|
|
166
|
+
if s > 1:
|
|
167
|
+
self._phase[idx] = "prefill"
|
|
168
|
+
self._prefill_len[idx] = s
|
|
169
|
+
else:
|
|
170
|
+
self._phase[idx] = "decode"
|
|
171
|
+
self._seen_seq_len[idx] = s
|
|
172
|
+
else:
|
|
173
|
+
if phase == "prefill":
|
|
174
|
+
prefill_s = self._prefill_len.get(idx, self._seen_seq_len.get(idx, s))
|
|
175
|
+
|
|
176
|
+
if s == prefill_s:
|
|
177
|
+
self._seen_seq_len[idx] = s
|
|
178
|
+
elif s <= self._decode_chunk_max:
|
|
179
|
+
self._phase[idx] = "decode"
|
|
180
|
+
self._seen_seq_len[idx] = s
|
|
181
|
+
else:
|
|
182
|
+
raise RuntimeError(
|
|
183
|
+
f"ActivationExtractor: layer {idx} saw unexpected seq_len change during prefill: "
|
|
184
|
+
f"prefill_seq_len={prefill_s} new_seq_len={s}. "
|
|
185
|
+
"Call clear() between independent forwards/generations."
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
if s <= self._decode_chunk_max:
|
|
189
|
+
self._seen_seq_len[idx] = s
|
|
190
|
+
else:
|
|
191
|
+
msg = (
|
|
192
|
+
f"ActivationExtractor: layer {idx} saw seq_len={s} during decode phase; "
|
|
193
|
+
f"expected <= {self._decode_chunk_max}. "
|
|
194
|
+
"This likely indicates an accidental new prefill without clear(). "
|
|
195
|
+
"Resetting this layer state and treating this as a new prefill."
|
|
196
|
+
)
|
|
197
|
+
warnings.warn(msg, RuntimeWarning, stacklevel=2)
|
|
198
|
+
logger.warning("%s", msg)
|
|
199
|
+
self._reset_layer_state(idx)
|
|
200
|
+
self._phase[idx] = "prefill" if s > 1 else "decode"
|
|
201
|
+
if s > 1:
|
|
202
|
+
self._prefill_len[idx] = s
|
|
203
|
+
self._seen_seq_len[idx] = s
|
|
204
|
+
|
|
205
|
+
t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0).detach()
|
|
206
|
+
|
|
207
|
+
if self._offload_to_cpu:
|
|
208
|
+
if t.device.type == "cuda":
|
|
209
|
+
t = t.to("cpu", non_blocking=True)
|
|
210
|
+
else:
|
|
211
|
+
t = t.to("cpu")
|
|
212
|
+
|
|
213
|
+
chunks = self._chunks.setdefault(idx, [])
|
|
214
|
+
prev = chunks[-1] if chunks else self._activations.get(idx, None)
|
|
215
|
+
|
|
216
|
+
if (
|
|
217
|
+
prev is not None
|
|
218
|
+
and (
|
|
219
|
+
(not torch.is_tensor(prev))
|
|
220
|
+
or prev.dim() != 3
|
|
221
|
+
or t.shape[0] != prev.shape[0]
|
|
222
|
+
or t.shape[2] != prev.shape[2]
|
|
223
|
+
)
|
|
224
|
+
):
|
|
225
|
+
msg = (
|
|
226
|
+
f"ActivationExtractor: shape mismatch on layer {idx}: "
|
|
227
|
+
f"prev={None if prev is None else tuple(prev.shape)} new={tuple(t.shape)}."
|
|
228
|
+
)
|
|
229
|
+
logger.warning("%s Resetting accumulation.", msg)
|
|
230
|
+
self._reset_layer_state(idx)
|
|
231
|
+
raise RuntimeError(msg)
|
|
232
|
+
|
|
233
|
+
chunks.append(t)
|
|
234
|
+
|
|
235
|
+
except Exception as e:
|
|
236
|
+
self._reset_layer_state(idx)
|
|
237
|
+
self._hook_errors.append((idx, e))
|
|
238
|
+
raise
|
|
239
|
+
return hook
|
|
240
|
+
|
|
241
|
+
def assert_ok(self) -> None:
|
|
242
|
+
"""Raise the first stored hook or update error, if any."""
|
|
243
|
+
if self._hook_errors:
|
|
244
|
+
idx, e = self._hook_errors[0]
|
|
245
|
+
raise RuntimeError(f"Activation hook failed on layer {idx}: {e}") from e
|
|
246
|
+
|
|
247
|
+
def finalize(self, *, clear_chunks: bool = True) -> Dict[int, Tensor]:
|
|
248
|
+
"""Concatenate captured chunks into one tensor per layer.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
clear_chunks: Clear chunk buffers after finalization.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Mapping from layer index to captured activation tensor.
|
|
255
|
+
|
|
256
|
+
Raises:
|
|
257
|
+
RuntimeError: If chunks for a layer have incompatible shapes.
|
|
258
|
+
"""
|
|
259
|
+
for idx, chunks in self._chunks.copy().items():
|
|
260
|
+
if not chunks:
|
|
261
|
+
continue
|
|
262
|
+
|
|
263
|
+
if idx in self._activations:
|
|
264
|
+
pieces: List[Tensor] = [self._activations.pop(idx)] + chunks
|
|
265
|
+
else:
|
|
266
|
+
pieces = chunks
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
b0 = int(pieces[0].shape[0])
|
|
270
|
+
h0 = int(pieces[0].shape[2])
|
|
271
|
+
except Exception:
|
|
272
|
+
if clear_chunks:
|
|
273
|
+
self._chunks.pop(idx, None)
|
|
274
|
+
raise RuntimeError(f"ActivationExtractor.finalize: bad tensor shape at layer {idx}.")
|
|
275
|
+
|
|
276
|
+
ok_mask: List[bool] = []
|
|
277
|
+
for x in pieces:
|
|
278
|
+
ok_mask.append(
|
|
279
|
+
torch.is_tensor(x)
|
|
280
|
+
and x.dim() == 3
|
|
281
|
+
and int(x.shape[0]) == b0
|
|
282
|
+
and int(x.shape[2]) == h0
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
ok = [pieces[i] for i, good in enumerate(ok_mask) if good]
|
|
286
|
+
if len(ok) != len(pieces):
|
|
287
|
+
bad_shapes = [
|
|
288
|
+
None if not torch.is_tensor(pieces[i]) else tuple(pieces[i].shape)
|
|
289
|
+
for i, good in enumerate(ok_mask) if not good
|
|
290
|
+
]
|
|
291
|
+
raise RuntimeError(
|
|
292
|
+
f"ActivationExtractor.finalize: incompatible chunks at layer {idx}: {bad_shapes}"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if not ok:
|
|
296
|
+
if clear_chunks:
|
|
297
|
+
self._chunks.pop(idx, None)
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
self._activations[idx] = torch.cat(ok, dim=1) if len(ok) > 1 else ok[0]
|
|
301
|
+
|
|
302
|
+
if clear_chunks:
|
|
303
|
+
self._chunks.pop(idx, None)
|
|
304
|
+
else:
|
|
305
|
+
self._chunks[idx] = []
|
|
306
|
+
|
|
307
|
+
if clear_chunks:
|
|
308
|
+
self._chunks.clear()
|
|
309
|
+
|
|
310
|
+
return self._activations
|
|
311
|
+
|
|
312
|
+
def attach(self, *, clear_activations: bool = True) -> None:
|
|
313
|
+
"""Attach hooks to the configured model layers.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
clear_activations: Clear previous activations before attaching.
|
|
317
|
+
|
|
318
|
+
Raises:
|
|
319
|
+
IndexError: If layer discovery changed and a requested layer is now
|
|
320
|
+
out of range.
|
|
321
|
+
RuntimeError: If any hook cannot be registered.
|
|
322
|
+
"""
|
|
323
|
+
self._detach_hooks_only()
|
|
324
|
+
if clear_activations:
|
|
325
|
+
self.clear()
|
|
326
|
+
|
|
327
|
+
self._layers = list(ModelUtils.find_transformer_layers(self._model))
|
|
328
|
+
try:
|
|
329
|
+
base = ModelUtils.unwrap_model(self._model)
|
|
330
|
+
cfg = getattr(base, "config", None)
|
|
331
|
+
expected = None
|
|
332
|
+
for k in ("num_hidden_layers", "n_layer", "num_layers", "n_layers"):
|
|
333
|
+
if cfg is not None and hasattr(cfg, k):
|
|
334
|
+
expected = int(getattr(cfg, k))
|
|
335
|
+
break
|
|
336
|
+
if expected is not None and 0 < expected != len(self._layers):
|
|
337
|
+
logger.warning(
|
|
338
|
+
"ActivationExtractor: find_transformer_layers() found %d layers, "
|
|
339
|
+
"but config expects %d. Hooks may target wrong modules (pre/post norm mismatch risk).",
|
|
340
|
+
len(self._layers), expected
|
|
341
|
+
)
|
|
342
|
+
except Exception:
|
|
343
|
+
pass
|
|
344
|
+
n = len(self._layers)
|
|
345
|
+
bad = [i for i in self.layers_to_extract if i < 0 or i >= n]
|
|
346
|
+
if bad:
|
|
347
|
+
raise IndexError(f"Layer index(es) {bad} out of range for model with {n} layers.")
|
|
348
|
+
|
|
349
|
+
new_handles: List[RemovableHandle] = []
|
|
350
|
+
try:
|
|
351
|
+
for idx in self.layers_to_extract:
|
|
352
|
+
block = self._layers[idx]
|
|
353
|
+
h = block.register_forward_hook(self._make_hook(idx))
|
|
354
|
+
new_handles.append(h)
|
|
355
|
+
self.handles = new_handles
|
|
356
|
+
except Exception as e:
|
|
357
|
+
self.remove()
|
|
358
|
+
raise RuntimeError(f"Failed to attach activation hooks: {e}") from e
|
|
359
|
+
|
|
360
|
+
def clear(self) -> None:
|
|
361
|
+
"""Clear accumulated tensors and error state."""
|
|
362
|
+
self._chunks.clear()
|
|
363
|
+
self._activations.clear()
|
|
364
|
+
self._hook_errors.clear()
|
|
365
|
+
self._seen_seq_len.clear()
|
|
366
|
+
self._phase.clear()
|
|
367
|
+
self._prefill_len.clear()
|
|
368
|
+
|
|
369
|
+
def _detach_hooks_only(self) -> None:
|
|
370
|
+
"""Detach hooks only."""
|
|
371
|
+
for h in self.handles.copy():
|
|
372
|
+
try:
|
|
373
|
+
h.remove()
|
|
374
|
+
except Exception:
|
|
375
|
+
pass
|
|
376
|
+
self.handles.clear()
|
|
377
|
+
|
|
378
|
+
def remove(self) -> None:
|
|
379
|
+
"""Remove hooks and clear accumulated state."""
|
|
380
|
+
self._detach_hooks_only()
|
|
381
|
+
self.clear()
|
|
382
|
+
|
|
383
|
+
def close(self) -> None:
|
|
384
|
+
"""Close the object by releasing managed hooks."""
|
|
385
|
+
self.remove()
|
|
386
|
+
|
|
387
|
+
def __enter__(self):
|
|
388
|
+
"""Enter the context manager and activate managed resources."""
|
|
389
|
+
self.attach(clear_activations=True)
|
|
390
|
+
return self
|
|
391
|
+
|
|
392
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
393
|
+
"""Exit the context manager and release managed resources."""
|
|
394
|
+
self.remove()
|
|
395
|
+
return False
|
|
396
|
+
|
|
397
|
+
@contextlib.contextmanager
|
|
398
|
+
def temporarily_disabled(self):
|
|
399
|
+
"""Temporarily detach hooks and restore them afterward."""
|
|
400
|
+
was_on = bool(self.handles)
|
|
401
|
+
if was_on:
|
|
402
|
+
self._detach_hooks_only()
|
|
403
|
+
try:
|
|
404
|
+
yield
|
|
405
|
+
finally:
|
|
406
|
+
if was_on:
|
|
407
|
+
self.attach(clear_activations=False)
|