mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
amber/hooks/detector.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, TYPE_CHECKING, Dict
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from amber.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
9
|
+
from amber.store.store import Store
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Detector(Hook):
|
|
16
|
+
"""
|
|
17
|
+
Abstract base class for detector hooks that collect metadata during inference.
|
|
18
|
+
|
|
19
|
+
Detectors can accumulate data across batches and optionally save it to a Store.
|
|
20
|
+
They are designed to observe and record information without modifying activations.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
hook_type: HookType | str = HookType.FORWARD,
|
|
26
|
+
hook_id: str | None = None,
|
|
27
|
+
store: Store | None = None,
|
|
28
|
+
layer_signature: str | int | None = None
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Initialize a detector hook.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
|
|
35
|
+
hook_id: Unique identifier
|
|
36
|
+
store: Optional Store for saving metadata
|
|
37
|
+
layer_signature: Layer to attach to (optional, for compatibility)
|
|
38
|
+
"""
|
|
39
|
+
super().__init__(layer_signature=layer_signature, hook_type=hook_type, hook_id=hook_id)
|
|
40
|
+
self.store = store
|
|
41
|
+
self.metadata: Dict[str, Any] = {}
|
|
42
|
+
self.tensor_metadata: Dict[str, torch.Tensor] = {}
|
|
43
|
+
|
|
44
|
+
def _hook_fn(
|
|
45
|
+
self,
|
|
46
|
+
module: torch.nn.Module,
|
|
47
|
+
input: HOOK_FUNCTION_INPUT,
|
|
48
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
49
|
+
) -> None:
|
|
50
|
+
"""
|
|
51
|
+
Internal hook function that collects metadata.
|
|
52
|
+
|
|
53
|
+
This calls process_activations to allow subclasses to implement
|
|
54
|
+
their specific detection logic.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
module: The PyTorch module being hooked
|
|
58
|
+
input: Tuple of input tensors to the module
|
|
59
|
+
output: Output tensor(s) from the module
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
Exception: If process_activations raises an exception
|
|
63
|
+
"""
|
|
64
|
+
if not self._enabled:
|
|
65
|
+
return None
|
|
66
|
+
try:
|
|
67
|
+
self.process_activations(module, input, output)
|
|
68
|
+
except Exception as e:
|
|
69
|
+
raise RuntimeError(
|
|
70
|
+
f"Error in detector {self.id} process_activations: {e}"
|
|
71
|
+
) from e
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
@abc.abstractmethod
|
|
75
|
+
def process_activations(
|
|
76
|
+
self,
|
|
77
|
+
module: torch.nn.Module,
|
|
78
|
+
input: HOOK_FUNCTION_INPUT,
|
|
79
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
80
|
+
) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Process activations from the hooked layer.
|
|
83
|
+
|
|
84
|
+
This is where detector-specific logic goes (e.g., tracking top activations,
|
|
85
|
+
computing statistics, etc.).
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
module: The PyTorch module being hooked
|
|
89
|
+
input: Tuple of input tensors to the module
|
|
90
|
+
output: Output tensor(s) from the module
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
Exception: Subclasses may raise exceptions for invalid inputs or processing errors
|
|
94
|
+
"""
|
|
95
|
+
raise NotImplementedError("process_activations must be implemented by subclasses")
|
amber/hooks/hook.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import uuid
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Callable, TypeAlias, Sequence, Optional, TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn, Tensor
|
|
10
|
+
from torch.types import _TensorOrTensors
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from amber.language_model.context import LanguageModelContext
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HookType(str, Enum):
|
|
17
|
+
"""Type of hook to register on a layer."""
|
|
18
|
+
FORWARD = "forward"
|
|
19
|
+
PRE_FORWARD = "pre_forward"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
HOOK_FUNCTION_INPUT: TypeAlias = Sequence[Tensor]
|
|
23
|
+
HOOK_FUNCTION_OUTPUT: TypeAlias = _TensorOrTensors | None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class HookError(Exception):
|
|
27
|
+
"""Exception raised when a hook encounters an error during execution."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, hook_id: str, hook_type: str, original_error: Exception):
|
|
30
|
+
"""
|
|
31
|
+
Initialize HookError.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
hook_id: Unique identifier of the hook that raised the error
|
|
35
|
+
hook_type: Type of hook (e.g., "forward", "pre_forward")
|
|
36
|
+
original_error: The original exception that was raised
|
|
37
|
+
"""
|
|
38
|
+
self.hook_id = hook_id
|
|
39
|
+
self.hook_type = hook_type
|
|
40
|
+
self.original_error = original_error
|
|
41
|
+
message = f"Hook {hook_id} (type={hook_type}) raised exception: {original_error}"
|
|
42
|
+
super().__init__(message)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Hook(abc.ABC):
|
|
46
|
+
"""
|
|
47
|
+
Abstract base class for hooks that can be registered on language model layers.
|
|
48
|
+
|
|
49
|
+
Hooks provide a way to intercept and process activations during model inference.
|
|
50
|
+
They expose PyTorch-compatible callables via get_torch_hook() while providing
|
|
51
|
+
additional functionality like enable/disable and unique identification.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
layer_signature: str | int | None = None,
|
|
57
|
+
hook_type: HookType | str = HookType.FORWARD,
|
|
58
|
+
hook_id: str | None = None
|
|
59
|
+
):
|
|
60
|
+
"""
|
|
61
|
+
Initialize a hook.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
layer_signature: Layer name or index to attach hook to
|
|
65
|
+
hook_type: Type of hook - HookType.FORWARD or HookType.PRE_FORWARD
|
|
66
|
+
hook_id: Unique identifier (auto-generated if not provided)
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If hook_type string is invalid
|
|
70
|
+
"""
|
|
71
|
+
self.layer_signature = layer_signature
|
|
72
|
+
self.hook_type = self._normalize_hook_type(hook_type)
|
|
73
|
+
self.id = hook_id if hook_id is not None else str(uuid.uuid4())
|
|
74
|
+
self._enabled = True
|
|
75
|
+
self._torch_hook_handle = None
|
|
76
|
+
self._context: Optional["LanguageModelContext"] = None
|
|
77
|
+
|
|
78
|
+
def _normalize_hook_type(self, hook_type: HookType | str) -> HookType:
|
|
79
|
+
"""Normalize hook_type to HookType enum.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
hook_type: HookType enum or string value
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
HookType enum value
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
ValueError: If hook_type string is not a valid HookType value
|
|
89
|
+
"""
|
|
90
|
+
if isinstance(hook_type, HookType):
|
|
91
|
+
return hook_type
|
|
92
|
+
|
|
93
|
+
if isinstance(hook_type, str):
|
|
94
|
+
try:
|
|
95
|
+
return HookType(hook_type)
|
|
96
|
+
except ValueError:
|
|
97
|
+
valid_values = [ht.value for ht in HookType]
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"Invalid hook_type string '{hook_type}'. "
|
|
100
|
+
f"Must be one of: {valid_values}"
|
|
101
|
+
) from None
|
|
102
|
+
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"hook_type must be HookType enum or string, got: {type(hook_type)}"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def _create_pre_forward_wrapper(self) -> Callable:
|
|
108
|
+
"""Create a pre-forward hook wrapper function.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Wrapper function for pre-forward hooks
|
|
112
|
+
"""
|
|
113
|
+
def pre_forward_wrapper(module: nn.Module, input: HOOK_FUNCTION_INPUT) -> None | HOOK_FUNCTION_INPUT:
|
|
114
|
+
if not self._enabled:
|
|
115
|
+
return None
|
|
116
|
+
try:
|
|
117
|
+
result = self._hook_fn(module, input, None)
|
|
118
|
+
return result if result is not None else None
|
|
119
|
+
except Exception as e:
|
|
120
|
+
raise HookError(self.id, self.hook_type.value, e) from e
|
|
121
|
+
|
|
122
|
+
return pre_forward_wrapper
|
|
123
|
+
|
|
124
|
+
def _create_forward_wrapper(self) -> Callable:
|
|
125
|
+
"""Create a forward hook wrapper function.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Wrapper function for forward hooks
|
|
129
|
+
"""
|
|
130
|
+
def forward_wrapper(module: nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT) -> None:
|
|
131
|
+
if not self._enabled:
|
|
132
|
+
return None
|
|
133
|
+
try:
|
|
134
|
+
self._hook_fn(module, input, output)
|
|
135
|
+
return None
|
|
136
|
+
except Exception as e:
|
|
137
|
+
raise HookError(self.id, self.hook_type.value, e) from e
|
|
138
|
+
|
|
139
|
+
return forward_wrapper
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def enabled(self) -> bool:
|
|
143
|
+
"""Whether this hook is currently enabled."""
|
|
144
|
+
return self._enabled
|
|
145
|
+
|
|
146
|
+
def enable(self) -> None:
|
|
147
|
+
"""Enable this hook."""
|
|
148
|
+
self._enabled = True
|
|
149
|
+
|
|
150
|
+
def disable(self) -> None:
|
|
151
|
+
"""Disable this hook."""
|
|
152
|
+
self._enabled = False
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def context(self) -> Optional["LanguageModelContext"]:
|
|
156
|
+
"""Get the LanguageModelContext associated with this hook."""
|
|
157
|
+
return self._context
|
|
158
|
+
|
|
159
|
+
def set_context(self, context: "LanguageModelContext") -> None:
|
|
160
|
+
"""Set the LanguageModelContext for this hook.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
context: The LanguageModelContext instance
|
|
164
|
+
"""
|
|
165
|
+
self._context = context
|
|
166
|
+
|
|
167
|
+
def _is_both_controller_and_detector(self) -> bool:
|
|
168
|
+
"""
|
|
169
|
+
Check if this hook instance inherits from both Controller and Detector.
|
|
170
|
+
|
|
171
|
+
Uses MRO (Method Resolution Order) to check for both class names
|
|
172
|
+
without requiring imports, avoiding circular dependencies.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
True if the instance inherits from both Controller and Detector, False otherwise
|
|
176
|
+
"""
|
|
177
|
+
mro_class_names = [cls.__name__ for cls in type(self).__mro__]
|
|
178
|
+
return 'Controller' in mro_class_names and 'Detector' in mro_class_names
|
|
179
|
+
|
|
180
|
+
def get_torch_hook(self) -> Callable:
|
|
181
|
+
"""
|
|
182
|
+
Return a PyTorch-compatible hook function.
|
|
183
|
+
|
|
184
|
+
The returned callable will check the enabled flag before executing
|
|
185
|
+
and call the abstract _hook_fn method.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
A callable compatible with PyTorch's register_forward_hook or
|
|
189
|
+
register_forward_pre_hook APIs.
|
|
190
|
+
"""
|
|
191
|
+
if self.hook_type == HookType.PRE_FORWARD:
|
|
192
|
+
return self._create_pre_forward_wrapper()
|
|
193
|
+
else:
|
|
194
|
+
return self._create_forward_wrapper()
|
|
195
|
+
|
|
196
|
+
@abc.abstractmethod
|
|
197
|
+
def _hook_fn(
|
|
198
|
+
self,
|
|
199
|
+
module: torch.nn.Module,
|
|
200
|
+
input: HOOK_FUNCTION_INPUT,
|
|
201
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
202
|
+
) -> None | HOOK_FUNCTION_INPUT:
|
|
203
|
+
"""
|
|
204
|
+
Internal hook function to be implemented by subclasses.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
module: The PyTorch module being hooked
|
|
208
|
+
input: Tuple of input tensors to the module
|
|
209
|
+
output: Output tensor(s) from the module (None for pre_forward hooks)
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
For pre_forward hooks: modified inputs (tuple) or None to keep original
|
|
213
|
+
For forward hooks: None (forward hooks cannot modify output in PyTorch)
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
Exception: Subclasses may raise exceptions which will be caught by the wrapper
|
|
217
|
+
"""
|
|
218
|
+
raise NotImplementedError("_hook_fn must be implemented by subclasses")
|
|
File without changes
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable, TYPE_CHECKING
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from amber.hooks.controller import Controller
|
|
7
|
+
from amber.hooks.hook import HookType
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FunctionController(Controller):
|
|
14
|
+
"""
|
|
15
|
+
A controller that applies a user-provided function to tensors during inference.
|
|
16
|
+
|
|
17
|
+
This controller allows users to pass any function and apply it to activations.
|
|
18
|
+
The function will be applied to:
|
|
19
|
+
- Single tensors directly
|
|
20
|
+
- All tensors in tuples/lists (default behavior)
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> # Scale activations by 2
|
|
24
|
+
>>> controller = FunctionController(
|
|
25
|
+
... layer_signature="layer_0",
|
|
26
|
+
... function=lambda x: x * 2.0
|
|
27
|
+
... )
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
layer_signature: str | int,
|
|
33
|
+
function: Callable[[torch.Tensor], torch.Tensor],
|
|
34
|
+
hook_type: HookType | str = HookType.FORWARD,
|
|
35
|
+
hook_id: str | None = None,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize a function controller.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
layer_signature: Layer to attach to
|
|
42
|
+
function: Function to apply to tensors. Must take a torch.Tensor and return a torch.Tensor
|
|
43
|
+
hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
|
|
44
|
+
hook_id: Unique identifier
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: If function is None or not callable
|
|
48
|
+
"""
|
|
49
|
+
if function is None:
|
|
50
|
+
raise ValueError("function cannot be None")
|
|
51
|
+
|
|
52
|
+
if not callable(function):
|
|
53
|
+
raise ValueError(f"function must be callable, got: {type(function)}")
|
|
54
|
+
|
|
55
|
+
super().__init__(hook_type=hook_type, hook_id=hook_id, layer_signature=layer_signature)
|
|
56
|
+
self.function = function
|
|
57
|
+
|
|
58
|
+
def modify_activations(
|
|
59
|
+
self,
|
|
60
|
+
module: "nn.Module",
|
|
61
|
+
inputs: torch.Tensor | None,
|
|
62
|
+
output: torch.Tensor | None
|
|
63
|
+
) -> torch.Tensor | None:
|
|
64
|
+
"""
|
|
65
|
+
Apply the user-provided function to activations.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
module: The PyTorch module being hooked
|
|
69
|
+
inputs: Input tensor (None for forward hooks)
|
|
70
|
+
output: Output tensor (None for pre_forward hooks)
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Modified tensor with function applied, or None if target tensor is None
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
RuntimeError: If function raises an exception when applied to tensor
|
|
77
|
+
"""
|
|
78
|
+
target = output if self.hook_type == HookType.FORWARD else inputs
|
|
79
|
+
|
|
80
|
+
if target is None or not isinstance(target, torch.Tensor):
|
|
81
|
+
return target
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
result = self.function(target)
|
|
85
|
+
if not isinstance(result, torch.Tensor):
|
|
86
|
+
raise TypeError(
|
|
87
|
+
f"Function must return a torch.Tensor, got: {type(result)}"
|
|
88
|
+
)
|
|
89
|
+
return result
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise RuntimeError(
|
|
92
|
+
f"Error applying function in FunctionController {self.id}: {e}"
|
|
93
|
+
) from e
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from amber.hooks.detector import Detector
|
|
7
|
+
from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
8
|
+
from amber.hooks.utils import extract_tensor_from_output
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LayerActivationDetector(Detector):
|
|
15
|
+
"""
|
|
16
|
+
Detector hook that captures and saves activations during inference.
|
|
17
|
+
|
|
18
|
+
This detector extracts activations from layer outputs and stores them
|
|
19
|
+
for later use (e.g., saving to disk, further analysis).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
layer_signature: str | int,
|
|
25
|
+
hook_id: str | None = None
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize the activation saver detector.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
layer_signature: Layer to capture activations from
|
|
32
|
+
hook_id: Unique identifier for this hook
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If layer_signature is None
|
|
36
|
+
"""
|
|
37
|
+
if layer_signature is None:
|
|
38
|
+
raise ValueError("layer_signature cannot be None for LayerActivationDetector")
|
|
39
|
+
|
|
40
|
+
super().__init__(
|
|
41
|
+
hook_type=HookType.FORWARD,
|
|
42
|
+
hook_id=hook_id,
|
|
43
|
+
store=None,
|
|
44
|
+
layer_signature=layer_signature
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def process_activations(
|
|
48
|
+
self,
|
|
49
|
+
module: torch.nn.Module,
|
|
50
|
+
input: HOOK_FUNCTION_INPUT,
|
|
51
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
52
|
+
) -> None:
|
|
53
|
+
"""
|
|
54
|
+
Extract and store activations from output.
|
|
55
|
+
|
|
56
|
+
Handles various output types:
|
|
57
|
+
- Plain tensors
|
|
58
|
+
- Tuples/lists of tensors (takes first tensor)
|
|
59
|
+
- Objects with last_hidden_state attribute (e.g., HuggingFace outputs)
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
module: The PyTorch module being hooked
|
|
63
|
+
input: Tuple of input tensors to the module
|
|
64
|
+
output: Output tensor(s) from the module
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
RuntimeError: If tensor extraction or storage fails
|
|
68
|
+
"""
|
|
69
|
+
try:
|
|
70
|
+
tensor = extract_tensor_from_output(output)
|
|
71
|
+
|
|
72
|
+
if tensor is not None:
|
|
73
|
+
tensor_cpu = tensor.detach().to("cpu")
|
|
74
|
+
# Store current batch's tensor (overwrites previous)
|
|
75
|
+
self.tensor_metadata['activations'] = tensor_cpu
|
|
76
|
+
# Store activations shape to metadata
|
|
77
|
+
self.metadata['activations_shape'] = tuple(tensor_cpu.shape)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
raise RuntimeError(
|
|
80
|
+
f"Error extracting activations in LayerActivationDetector {self.id}: {e}"
|
|
81
|
+
) from e
|
|
82
|
+
|
|
83
|
+
def get_captured(self) -> torch.Tensor | None:
|
|
84
|
+
"""
|
|
85
|
+
Get the captured activations from the current batch.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
The captured activation tensor from the current batch or None if no activations captured yet
|
|
89
|
+
"""
|
|
90
|
+
return self.tensor_metadata.get('activations')
|
|
91
|
+
|
|
92
|
+
def clear_captured(self) -> None:
|
|
93
|
+
"""Clear captured activations for current batch."""
|
|
94
|
+
self.tensor_metadata.pop('activations', None)
|
|
95
|
+
self.metadata.pop('activations_shape', None)
|
|
96
|
+
|