mi-crow 1.0.0__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mi_crow/datasets/base_dataset.py +71 -1
- mi_crow/datasets/classification_dataset.py +136 -30
- mi_crow/datasets/text_dataset.py +165 -24
- mi_crow/hooks/controller.py +12 -7
- mi_crow/hooks/implementations/layer_activation_detector.py +30 -34
- mi_crow/hooks/implementations/model_input_detector.py +87 -87
- mi_crow/hooks/implementations/model_output_detector.py +43 -42
- mi_crow/hooks/utils.py +74 -0
- mi_crow/language_model/activations.py +174 -77
- mi_crow/language_model/device_manager.py +119 -0
- mi_crow/language_model/inference.py +18 -5
- mi_crow/language_model/initialization.py +10 -6
- mi_crow/language_model/language_model.py +67 -97
- mi_crow/language_model/layers.py +16 -13
- mi_crow/language_model/persistence.py +4 -2
- mi_crow/language_model/utils.py +5 -5
- mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py +157 -95
- mi_crow/mechanistic/sae/concepts/concept_dictionary.py +12 -2
- mi_crow/mechanistic/sae/concepts/text_heap.py +161 -0
- mi_crow/mechanistic/sae/modules/topk_sae.py +29 -22
- mi_crow/mechanistic/sae/sae.py +3 -1
- mi_crow/mechanistic/sae/sae_trainer.py +362 -29
- mi_crow/store/local_store.py +11 -5
- mi_crow/store/store.py +34 -1
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/METADATA +2 -1
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/RECORD +28 -26
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/WHEEL +0 -0
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -1,68 +1,59 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
7
|
from mi_crow.hooks.detector import Detector
|
|
7
|
-
from mi_crow.hooks.hook import
|
|
8
|
+
from mi_crow.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT, HookType
|
|
8
9
|
from mi_crow.hooks.utils import extract_tensor_from_output
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
11
|
-
|
|
12
|
+
pass
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class LayerActivationDetector(Detector):
|
|
15
16
|
"""
|
|
16
17
|
Detector hook that captures and saves activations during inference.
|
|
17
|
-
|
|
18
|
+
|
|
18
19
|
This detector extracts activations from layer outputs and stores them
|
|
19
20
|
for later use (e.g., saving to disk, further analysis).
|
|
20
21
|
"""
|
|
21
22
|
|
|
22
|
-
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
layer_signature: str | int,
|
|
25
|
-
hook_id: str | None = None
|
|
26
|
-
):
|
|
23
|
+
def __init__(self, layer_signature: str | int, hook_id: str | None = None, target_dtype: torch.dtype | None = None):
|
|
27
24
|
"""
|
|
28
25
|
Initialize the activation saver detector.
|
|
29
|
-
|
|
26
|
+
|
|
30
27
|
Args:
|
|
31
28
|
layer_signature: Layer to capture activations from
|
|
32
29
|
hook_id: Unique identifier for this hook
|
|
33
|
-
|
|
30
|
+
target_dtype: Optional dtype to convert activations to before storing
|
|
31
|
+
|
|
34
32
|
Raises:
|
|
35
33
|
ValueError: If layer_signature is None
|
|
36
34
|
"""
|
|
37
35
|
if layer_signature is None:
|
|
38
36
|
raise ValueError("layer_signature cannot be None for LayerActivationDetector")
|
|
39
37
|
|
|
40
|
-
super().__init__(
|
|
41
|
-
|
|
42
|
-
hook_id=hook_id,
|
|
43
|
-
store=None,
|
|
44
|
-
layer_signature=layer_signature
|
|
45
|
-
)
|
|
38
|
+
super().__init__(hook_type=HookType.FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
|
|
39
|
+
self.target_dtype = target_dtype
|
|
46
40
|
|
|
47
41
|
def process_activations(
|
|
48
|
-
|
|
49
|
-
module: torch.nn.Module,
|
|
50
|
-
input: HOOK_FUNCTION_INPUT,
|
|
51
|
-
output: HOOK_FUNCTION_OUTPUT
|
|
42
|
+
self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
|
|
52
43
|
) -> None:
|
|
53
44
|
"""
|
|
54
45
|
Extract and store activations from output.
|
|
55
|
-
|
|
46
|
+
|
|
56
47
|
Handles various output types:
|
|
57
48
|
- Plain tensors
|
|
58
49
|
- Tuples/lists of tensors (takes first tensor)
|
|
59
50
|
- Objects with last_hidden_state attribute (e.g., HuggingFace outputs)
|
|
60
|
-
|
|
51
|
+
|
|
61
52
|
Args:
|
|
62
53
|
module: The PyTorch module being hooked
|
|
63
54
|
input: Tuple of input tensors to the module
|
|
64
55
|
output: Output tensor(s) from the module
|
|
65
|
-
|
|
56
|
+
|
|
66
57
|
Raises:
|
|
67
58
|
RuntimeError: If tensor extraction or storage fails
|
|
68
59
|
"""
|
|
@@ -70,27 +61,32 @@ class LayerActivationDetector(Detector):
|
|
|
70
61
|
tensor = extract_tensor_from_output(output)
|
|
71
62
|
|
|
72
63
|
if tensor is not None:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
64
|
+
if tensor.is_cuda:
|
|
65
|
+
tensor_cpu = tensor.detach().to("cpu", non_blocking=True)
|
|
66
|
+
else:
|
|
67
|
+
tensor_cpu = tensor.detach()
|
|
68
|
+
|
|
69
|
+
if self.target_dtype is not None:
|
|
70
|
+
tensor_cpu = tensor_cpu.to(self.target_dtype)
|
|
71
|
+
|
|
72
|
+
self.tensor_metadata["activations"] = tensor_cpu
|
|
73
|
+
self.metadata["activations_shape"] = tuple(tensor_cpu.shape)
|
|
78
74
|
except Exception as e:
|
|
75
|
+
layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
|
|
79
76
|
raise RuntimeError(
|
|
80
|
-
f"Error extracting activations in LayerActivationDetector {self.id}: {e}"
|
|
77
|
+
f"Error extracting activations in LayerActivationDetector {self.id} (layer={layer_sig}): {e}"
|
|
81
78
|
) from e
|
|
82
79
|
|
|
83
80
|
def get_captured(self) -> torch.Tensor | None:
|
|
84
81
|
"""
|
|
85
82
|
Get the captured activations from the current batch.
|
|
86
|
-
|
|
83
|
+
|
|
87
84
|
Returns:
|
|
88
85
|
The captured activation tensor from the current batch or None if no activations captured yet
|
|
89
86
|
"""
|
|
90
|
-
return self.tensor_metadata.get(
|
|
87
|
+
return self.tensor_metadata.get("activations")
|
|
91
88
|
|
|
92
89
|
def clear_captured(self) -> None:
|
|
93
90
|
"""Clear captured activations for current batch."""
|
|
94
|
-
self.tensor_metadata.pop(
|
|
95
|
-
self.metadata.pop(
|
|
96
|
-
|
|
91
|
+
self.tensor_metadata.pop("activations", None)
|
|
92
|
+
self.metadata.pop("activations_shape", None)
|
|
@@ -1,38 +1,39 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING, Dict,
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
|
4
|
+
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
7
|
from mi_crow.hooks.detector import Detector
|
|
7
|
-
from mi_crow.hooks.hook import
|
|
8
|
+
from mi_crow.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT, HookType
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
10
|
-
|
|
11
|
+
pass
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class ModelInputDetector(Detector):
|
|
14
15
|
"""
|
|
15
16
|
Detector hook that captures and saves tokenized inputs from model forward pass.
|
|
16
|
-
|
|
17
|
+
|
|
17
18
|
This detector is designed to be attached to the root model module and captures:
|
|
18
19
|
- Tokenized inputs (input_ids) from the model's forward pass
|
|
19
20
|
- Attention masks (optional) that exclude both padding and special tokens
|
|
20
|
-
|
|
21
|
+
|
|
21
22
|
Uses PRE_FORWARD hook to capture inputs before they are processed.
|
|
22
23
|
Useful for saving tokenized inputs for analysis or training.
|
|
23
24
|
"""
|
|
24
25
|
|
|
25
26
|
def __init__(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
self,
|
|
28
|
+
layer_signature: str | int | None = None,
|
|
29
|
+
hook_id: str | None = None,
|
|
30
|
+
save_input_ids: bool = True,
|
|
31
|
+
save_attention_mask: bool = False,
|
|
32
|
+
special_token_ids: Optional[List[int] | Set[int]] = None,
|
|
32
33
|
):
|
|
33
34
|
"""
|
|
34
35
|
Initialize the model input detector.
|
|
35
|
-
|
|
36
|
+
|
|
36
37
|
Args:
|
|
37
38
|
layer_signature: Layer to capture from (typically the root model, can be None)
|
|
38
39
|
hook_id: Unique identifier for this hook
|
|
@@ -40,12 +41,7 @@ class ModelInputDetector(Detector):
|
|
|
40
41
|
save_attention_mask: Whether to save attention_mask tensor (excludes padding and special tokens)
|
|
41
42
|
special_token_ids: Optional list/set of special token IDs. If None, will extract from LanguageModel context.
|
|
42
43
|
"""
|
|
43
|
-
super().__init__(
|
|
44
|
-
hook_type=HookType.PRE_FORWARD,
|
|
45
|
-
hook_id=hook_id,
|
|
46
|
-
store=None,
|
|
47
|
-
layer_signature=layer_signature
|
|
48
|
-
)
|
|
44
|
+
super().__init__(hook_type=HookType.PRE_FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
|
|
49
45
|
self.save_input_ids = save_input_ids
|
|
50
46
|
self.save_attention_mask = save_attention_mask
|
|
51
47
|
self.special_token_ids = set(special_token_ids) if special_token_ids is not None else None
|
|
@@ -53,90 +49,87 @@ class ModelInputDetector(Detector):
|
|
|
53
49
|
def _extract_input_ids(self, input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
|
|
54
50
|
"""
|
|
55
51
|
Extract input_ids from model input.
|
|
56
|
-
|
|
52
|
+
|
|
57
53
|
Handles various input formats:
|
|
58
54
|
- Dict with 'input_ids' key (most common for HuggingFace models)
|
|
59
55
|
- Tuple with dict as first element
|
|
60
56
|
- Tuple with tensor as first element
|
|
61
|
-
|
|
57
|
+
|
|
62
58
|
Args:
|
|
63
59
|
input: Input to the model forward pass
|
|
64
|
-
|
|
60
|
+
|
|
65
61
|
Returns:
|
|
66
62
|
input_ids tensor or None if not found
|
|
67
63
|
"""
|
|
68
64
|
if not input or len(input) == 0:
|
|
69
65
|
return None
|
|
70
|
-
|
|
66
|
+
|
|
71
67
|
first_item = input[0]
|
|
72
|
-
|
|
68
|
+
|
|
73
69
|
if isinstance(first_item, dict):
|
|
74
|
-
if
|
|
75
|
-
return first_item[
|
|
70
|
+
if "input_ids" in first_item:
|
|
71
|
+
return first_item["input_ids"]
|
|
76
72
|
return None
|
|
77
|
-
|
|
73
|
+
|
|
78
74
|
if isinstance(first_item, torch.Tensor):
|
|
79
75
|
return first_item
|
|
80
|
-
|
|
76
|
+
|
|
81
77
|
return None
|
|
82
78
|
|
|
83
79
|
def _extract_attention_mask(self, input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
|
|
84
80
|
"""
|
|
85
81
|
Extract attention_mask from model input.
|
|
86
|
-
|
|
82
|
+
|
|
87
83
|
Args:
|
|
88
84
|
input: Input to the model forward pass
|
|
89
|
-
|
|
85
|
+
|
|
90
86
|
Returns:
|
|
91
87
|
attention_mask tensor or None if not found
|
|
92
88
|
"""
|
|
93
89
|
if not input or len(input) == 0:
|
|
94
90
|
return None
|
|
95
|
-
|
|
91
|
+
|
|
96
92
|
first_item = input[0]
|
|
97
|
-
|
|
93
|
+
|
|
98
94
|
if isinstance(first_item, dict):
|
|
99
|
-
if
|
|
100
|
-
return first_item[
|
|
101
|
-
|
|
95
|
+
if "attention_mask" in first_item:
|
|
96
|
+
return first_item["attention_mask"]
|
|
97
|
+
|
|
102
98
|
return None
|
|
103
99
|
|
|
104
100
|
def _get_special_token_ids(self, module: torch.nn.Module) -> Set[int]:
|
|
105
101
|
"""
|
|
106
102
|
Get special token IDs from user-provided list or from LanguageModel context.
|
|
107
|
-
|
|
103
|
+
|
|
108
104
|
Priority order:
|
|
109
105
|
1. self.special_token_ids (user-provided during initialization)
|
|
110
106
|
2. self.context.special_token_ids (from LanguageModel initialization)
|
|
111
|
-
|
|
107
|
+
|
|
112
108
|
Args:
|
|
113
109
|
module: The PyTorch module being hooked (unused, kept for API compatibility)
|
|
114
|
-
|
|
110
|
+
|
|
115
111
|
Returns:
|
|
116
112
|
Set of special token IDs, or empty set if none available
|
|
117
113
|
"""
|
|
118
114
|
if self.special_token_ids is not None:
|
|
119
115
|
return self.special_token_ids
|
|
120
|
-
|
|
116
|
+
|
|
121
117
|
if self.context is not None and self.context.special_token_ids is not None:
|
|
122
118
|
return self.context.special_token_ids
|
|
123
|
-
|
|
119
|
+
|
|
124
120
|
return set()
|
|
125
121
|
|
|
126
122
|
def _create_combined_attention_mask(
|
|
127
|
-
self,
|
|
128
|
-
input_ids: torch.Tensor,
|
|
129
|
-
attention_mask: torch.Tensor | None,
|
|
130
|
-
module: torch.nn.Module
|
|
123
|
+
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None, module: torch.nn.Module
|
|
131
124
|
) -> torch.Tensor:
|
|
132
125
|
"""
|
|
133
126
|
Create a combined attention mask that excludes both padding and special tokens.
|
|
134
|
-
|
|
127
|
+
|
|
135
128
|
Args:
|
|
136
129
|
input_ids: Input token IDs tensor (batch_size × sequence_length)
|
|
137
130
|
attention_mask: Original attention mask from tokenizer (None if not provided)
|
|
138
131
|
module: The PyTorch module being hooked
|
|
139
|
-
|
|
132
|
+
|
|
140
133
|
Returns:
|
|
141
134
|
Binary mask tensor with same shape as input_ids (1 for regular tokens, 0 for padding/special tokens)
|
|
142
135
|
"""
|
|
@@ -144,72 +137,71 @@ class ModelInputDetector(Detector):
|
|
|
144
137
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
|
145
138
|
else:
|
|
146
139
|
attention_mask = attention_mask.bool()
|
|
147
|
-
|
|
140
|
+
|
|
148
141
|
special_token_ids = self._get_special_token_ids(module)
|
|
149
|
-
|
|
142
|
+
|
|
150
143
|
if special_token_ids:
|
|
151
144
|
special_ids_tensor = torch.tensor(list(special_token_ids), device=input_ids.device, dtype=input_ids.dtype)
|
|
152
145
|
expanded_input = input_ids.unsqueeze(-1)
|
|
153
146
|
expanded_special = special_ids_tensor.unsqueeze(0).unsqueeze(0)
|
|
154
147
|
is_special = (expanded_input == expanded_special).any(dim=-1)
|
|
155
148
|
attention_mask = attention_mask & ~is_special
|
|
156
|
-
|
|
149
|
+
|
|
157
150
|
return attention_mask.to(torch.bool)
|
|
158
151
|
|
|
159
|
-
def set_inputs_from_encodings(
|
|
152
|
+
def set_inputs_from_encodings(
|
|
153
|
+
self, encodings: Dict[str, torch.Tensor], module: Optional[torch.nn.Module] = None
|
|
154
|
+
) -> None:
|
|
160
155
|
"""
|
|
161
156
|
Manually set inputs from encodings dictionary.
|
|
162
|
-
|
|
157
|
+
|
|
163
158
|
This is useful when the model is called with keyword arguments,
|
|
164
159
|
as PyTorch's pre_forward hook doesn't receive kwargs.
|
|
165
|
-
|
|
160
|
+
|
|
166
161
|
Args:
|
|
167
|
-
encodings: Dictionary of encoded inputs (e.g., from lm.
|
|
162
|
+
encodings: Dictionary of encoded inputs (e.g., from lm.inference.execute_inference() or lm.tokenize())
|
|
168
163
|
module: Optional module for extracting special token IDs. If None, will use DummyModule.
|
|
169
|
-
|
|
164
|
+
|
|
170
165
|
Raises:
|
|
171
166
|
RuntimeError: If tensor extraction or storage fails
|
|
172
167
|
"""
|
|
173
168
|
try:
|
|
174
|
-
if self.save_input_ids and
|
|
175
|
-
input_ids = encodings[
|
|
176
|
-
self.tensor_metadata[
|
|
177
|
-
self.metadata[
|
|
178
|
-
|
|
179
|
-
if self.save_attention_mask and
|
|
180
|
-
input_ids = encodings[
|
|
169
|
+
if self.save_input_ids and "input_ids" in encodings:
|
|
170
|
+
input_ids = encodings["input_ids"]
|
|
171
|
+
self.tensor_metadata["input_ids"] = input_ids.detach().to("cpu")
|
|
172
|
+
self.metadata["input_ids_shape"] = tuple(input_ids.shape)
|
|
173
|
+
|
|
174
|
+
if self.save_attention_mask and "input_ids" in encodings:
|
|
175
|
+
input_ids = encodings["input_ids"]
|
|
181
176
|
if module is None:
|
|
177
|
+
|
|
182
178
|
class DummyModule:
|
|
183
179
|
pass
|
|
180
|
+
|
|
184
181
|
module = DummyModule()
|
|
185
|
-
|
|
186
|
-
original_attention_mask = encodings.get(
|
|
182
|
+
|
|
183
|
+
original_attention_mask = encodings.get("attention_mask")
|
|
187
184
|
combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
|
|
188
|
-
self.tensor_metadata[
|
|
189
|
-
self.metadata[
|
|
185
|
+
self.tensor_metadata["attention_mask"] = combined_mask.detach().to("cpu")
|
|
186
|
+
self.metadata["attention_mask_shape"] = tuple(combined_mask.shape)
|
|
190
187
|
except Exception as e:
|
|
191
|
-
raise RuntimeError(
|
|
192
|
-
f"Error setting inputs from encodings in ModelInputDetector {self.id}: {e}"
|
|
193
|
-
) from e
|
|
188
|
+
raise RuntimeError(f"Error setting inputs from encodings in ModelInputDetector {self.id}: {e}") from e
|
|
194
189
|
|
|
195
190
|
def process_activations(
|
|
196
|
-
|
|
197
|
-
module: torch.nn.Module,
|
|
198
|
-
input: HOOK_FUNCTION_INPUT,
|
|
199
|
-
output: HOOK_FUNCTION_OUTPUT
|
|
191
|
+
self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
|
|
200
192
|
) -> None:
|
|
201
193
|
"""
|
|
202
194
|
Extract and store tokenized inputs.
|
|
203
|
-
|
|
195
|
+
|
|
204
196
|
Note: For HuggingFace models called with **kwargs, the input tuple may be empty.
|
|
205
197
|
In such cases, use set_inputs_from_encodings() to manually set inputs from
|
|
206
|
-
the encodings dictionary returned by lm.
|
|
207
|
-
|
|
198
|
+
the encodings dictionary returned by lm.inference.execute_inference().
|
|
199
|
+
|
|
208
200
|
Args:
|
|
209
201
|
module: The PyTorch module being hooked (typically the root model)
|
|
210
202
|
input: Tuple of input tensors/dicts to the module
|
|
211
203
|
output: Output from the module (None for PRE_FORWARD hooks)
|
|
212
|
-
|
|
204
|
+
|
|
213
205
|
Raises:
|
|
214
206
|
RuntimeError: If tensor extraction or storage fails
|
|
215
207
|
"""
|
|
@@ -217,34 +209,42 @@ class ModelInputDetector(Detector):
|
|
|
217
209
|
if self.save_input_ids:
|
|
218
210
|
input_ids = self._extract_input_ids(input)
|
|
219
211
|
if input_ids is not None:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
212
|
+
if input_ids.is_cuda:
|
|
213
|
+
input_ids_cpu = input_ids.detach().to("cpu", non_blocking=True)
|
|
214
|
+
else:
|
|
215
|
+
input_ids_cpu = input_ids.detach()
|
|
216
|
+
self.tensor_metadata["input_ids"] = input_ids_cpu
|
|
217
|
+
self.metadata["input_ids_shape"] = tuple(input_ids_cpu.shape)
|
|
218
|
+
|
|
223
219
|
if self.save_attention_mask:
|
|
224
220
|
input_ids = self._extract_input_ids(input)
|
|
225
221
|
if input_ids is not None:
|
|
226
222
|
original_attention_mask = self._extract_attention_mask(input)
|
|
227
223
|
combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
224
|
+
if combined_mask.is_cuda:
|
|
225
|
+
combined_mask_cpu = combined_mask.detach().to("cpu", non_blocking=True)
|
|
226
|
+
else:
|
|
227
|
+
combined_mask_cpu = combined_mask.detach()
|
|
228
|
+
self.tensor_metadata["attention_mask"] = combined_mask_cpu
|
|
229
|
+
self.metadata["attention_mask_shape"] = tuple(combined_mask_cpu.shape)
|
|
230
|
+
|
|
231
231
|
except Exception as e:
|
|
232
|
+
layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
|
|
232
233
|
raise RuntimeError(
|
|
233
|
-
f"Error extracting inputs in ModelInputDetector {self.id}: {e}"
|
|
234
|
+
f"Error extracting inputs in ModelInputDetector {self.id} (layer={layer_sig}): {e}"
|
|
234
235
|
) from e
|
|
235
236
|
|
|
236
237
|
def get_captured_input_ids(self) -> torch.Tensor | None:
|
|
237
238
|
"""Get the captured input_ids from the current batch."""
|
|
238
|
-
return self.tensor_metadata.get(
|
|
239
|
+
return self.tensor_metadata.get("input_ids")
|
|
239
240
|
|
|
240
241
|
def get_captured_attention_mask(self) -> torch.Tensor | None:
|
|
241
242
|
"""Get the captured attention_mask from the current batch (excludes padding and special tokens)."""
|
|
242
|
-
return self.tensor_metadata.get(
|
|
243
|
+
return self.tensor_metadata.get("attention_mask")
|
|
243
244
|
|
|
244
245
|
def clear_captured(self) -> None:
|
|
245
246
|
"""Clear all captured inputs for current batch."""
|
|
246
|
-
keys_to_remove = [
|
|
247
|
+
keys_to_remove = ["input_ids", "attention_mask"]
|
|
247
248
|
for key in keys_to_remove:
|
|
248
249
|
self.tensor_metadata.pop(key, None)
|
|
249
|
-
self.metadata.pop(f
|
|
250
|
-
|
|
250
|
+
self.metadata.pop(f"{key}_shape", None)
|
|
@@ -1,132 +1,133 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
7
|
from mi_crow.hooks.detector import Detector
|
|
7
|
-
from mi_crow.hooks.hook import
|
|
8
|
+
from mi_crow.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT, HookType
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
10
|
-
|
|
11
|
+
pass
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class ModelOutputDetector(Detector):
|
|
14
15
|
"""
|
|
15
16
|
Detector hook that captures and saves model outputs.
|
|
16
|
-
|
|
17
|
+
|
|
17
18
|
This detector is designed to be attached to the root model module and captures:
|
|
18
19
|
- Model outputs (logits) from the model's forward pass
|
|
19
20
|
- Hidden states (optional) from the model's forward pass
|
|
20
|
-
|
|
21
|
+
|
|
21
22
|
Uses FORWARD hook to capture outputs after they are computed.
|
|
22
23
|
Useful for saving model outputs for analysis or training.
|
|
23
24
|
"""
|
|
24
25
|
|
|
25
26
|
def __init__(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
27
|
+
self,
|
|
28
|
+
layer_signature: str | int | None = None,
|
|
29
|
+
hook_id: str | None = None,
|
|
30
|
+
save_output_logits: bool = True,
|
|
31
|
+
save_output_hidden_state: bool = False,
|
|
31
32
|
):
|
|
32
33
|
"""
|
|
33
34
|
Initialize the model output detector.
|
|
34
|
-
|
|
35
|
+
|
|
35
36
|
Args:
|
|
36
37
|
layer_signature: Layer to capture from (typically the root model, can be None)
|
|
37
38
|
hook_id: Unique identifier for this hook
|
|
38
39
|
save_output_logits: Whether to save output logits (if available)
|
|
39
40
|
save_output_hidden_state: Whether to save last_hidden_state (if available)
|
|
40
41
|
"""
|
|
41
|
-
super().__init__(
|
|
42
|
-
hook_type=HookType.FORWARD,
|
|
43
|
-
hook_id=hook_id,
|
|
44
|
-
store=None,
|
|
45
|
-
layer_signature=layer_signature
|
|
46
|
-
)
|
|
42
|
+
super().__init__(hook_type=HookType.FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
|
|
47
43
|
self.save_output_logits = save_output_logits
|
|
48
44
|
self.save_output_hidden_state = save_output_hidden_state
|
|
49
45
|
|
|
50
46
|
def _extract_output_tensor(self, output: HOOK_FUNCTION_OUTPUT) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
51
47
|
"""
|
|
52
48
|
Extract logits and last_hidden_state from model output.
|
|
53
|
-
|
|
49
|
+
|
|
54
50
|
Args:
|
|
55
51
|
output: Output from the model forward pass
|
|
56
|
-
|
|
52
|
+
|
|
57
53
|
Returns:
|
|
58
54
|
Tuple of (logits, last_hidden_state), either can be None
|
|
59
55
|
"""
|
|
60
56
|
logits = None
|
|
61
57
|
hidden_state = None
|
|
62
|
-
|
|
58
|
+
|
|
63
59
|
if output is None:
|
|
64
60
|
return None, None
|
|
65
|
-
|
|
61
|
+
|
|
66
62
|
# Handle HuggingFace output objects
|
|
67
63
|
if hasattr(output, "logits"):
|
|
68
64
|
logits = output.logits
|
|
69
65
|
if hasattr(output, "last_hidden_state"):
|
|
70
66
|
hidden_state = output.last_hidden_state
|
|
71
|
-
|
|
67
|
+
|
|
72
68
|
# Handle tuple output (logits might be first element)
|
|
73
69
|
if isinstance(output, (tuple, list)) and len(output) > 0:
|
|
74
70
|
first_item = output[0]
|
|
75
71
|
if isinstance(first_item, torch.Tensor) and logits is None:
|
|
76
72
|
logits = first_item
|
|
77
|
-
|
|
73
|
+
|
|
78
74
|
# Handle direct tensor output
|
|
79
75
|
if isinstance(output, torch.Tensor) and logits is None:
|
|
80
76
|
logits = output
|
|
81
|
-
|
|
77
|
+
|
|
82
78
|
return logits, hidden_state
|
|
83
79
|
|
|
84
80
|
def process_activations(
|
|
85
|
-
|
|
86
|
-
module: torch.nn.Module,
|
|
87
|
-
input: HOOK_FUNCTION_INPUT,
|
|
88
|
-
output: HOOK_FUNCTION_OUTPUT
|
|
81
|
+
self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
|
|
89
82
|
) -> None:
|
|
90
83
|
"""
|
|
91
84
|
Extract and store model outputs.
|
|
92
|
-
|
|
85
|
+
|
|
93
86
|
Args:
|
|
94
87
|
module: The PyTorch module being hooked (typically the root model)
|
|
95
88
|
input: Tuple of input tensors/dicts to the module
|
|
96
89
|
output: Output from the module
|
|
97
|
-
|
|
90
|
+
|
|
98
91
|
Raises:
|
|
99
92
|
RuntimeError: If tensor extraction or storage fails
|
|
100
93
|
"""
|
|
101
94
|
try:
|
|
102
95
|
# Extract and save outputs
|
|
103
96
|
logits, hidden_state = self._extract_output_tensor(output)
|
|
104
|
-
|
|
97
|
+
|
|
105
98
|
if self.save_output_logits and logits is not None:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
99
|
+
if logits.is_cuda:
|
|
100
|
+
logits_cpu = logits.detach().to("cpu", non_blocking=True)
|
|
101
|
+
else:
|
|
102
|
+
logits_cpu = logits.detach()
|
|
103
|
+
self.tensor_metadata["output_logits"] = logits_cpu
|
|
104
|
+
self.metadata["output_logits_shape"] = tuple(logits_cpu.shape)
|
|
105
|
+
|
|
109
106
|
if self.save_output_hidden_state and hidden_state is not None:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
107
|
+
if hidden_state.is_cuda:
|
|
108
|
+
hidden_state_cpu = hidden_state.detach().to("cpu", non_blocking=True)
|
|
109
|
+
else:
|
|
110
|
+
hidden_state_cpu = hidden_state.detach()
|
|
111
|
+
self.tensor_metadata["output_hidden_state"] = hidden_state_cpu
|
|
112
|
+
self.metadata["output_hidden_state_shape"] = tuple(hidden_state_cpu.shape)
|
|
113
|
+
|
|
113
114
|
except Exception as e:
|
|
115
|
+
layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
|
|
114
116
|
raise RuntimeError(
|
|
115
|
-
f"Error extracting outputs in ModelOutputDetector {self.id}: {e}"
|
|
117
|
+
f"Error extracting outputs in ModelOutputDetector {self.id} (layer={layer_sig}): {e}"
|
|
116
118
|
) from e
|
|
117
119
|
|
|
118
120
|
def get_captured_output_logits(self) -> torch.Tensor | None:
|
|
119
121
|
"""Get the captured output logits from the current batch."""
|
|
120
|
-
return self.tensor_metadata.get(
|
|
122
|
+
return self.tensor_metadata.get("output_logits")
|
|
121
123
|
|
|
122
124
|
def get_captured_output_hidden_state(self) -> torch.Tensor | None:
|
|
123
125
|
"""Get the captured output hidden state from the current batch."""
|
|
124
|
-
return self.tensor_metadata.get(
|
|
126
|
+
return self.tensor_metadata.get("output_hidden_state")
|
|
125
127
|
|
|
126
128
|
def clear_captured(self) -> None:
|
|
127
129
|
"""Clear all captured outputs for current batch."""
|
|
128
|
-
keys_to_remove = [
|
|
130
|
+
keys_to_remove = ["output_logits", "output_hidden_state"]
|
|
129
131
|
for key in keys_to_remove:
|
|
130
132
|
self.tensor_metadata.pop(key, None)
|
|
131
|
-
self.metadata.pop(f
|
|
132
|
-
|
|
133
|
+
self.metadata.pop(f"{key}_shape", None)
|