mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Set, List, Optional
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from amber.hooks.detector import Detector
|
|
7
|
+
from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModelInputDetector(Detector):
|
|
14
|
+
"""
|
|
15
|
+
Detector hook that captures and saves tokenized inputs from model forward pass.
|
|
16
|
+
|
|
17
|
+
This detector is designed to be attached to the root model module and captures:
|
|
18
|
+
- Tokenized inputs (input_ids) from the model's forward pass
|
|
19
|
+
- Attention masks (optional) that exclude both padding and special tokens
|
|
20
|
+
|
|
21
|
+
Uses PRE_FORWARD hook to capture inputs before they are processed.
|
|
22
|
+
Useful for saving tokenized inputs for analysis or training.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
layer_signature: str | int | None = None,
|
|
28
|
+
hook_id: str | None = None,
|
|
29
|
+
save_input_ids: bool = True,
|
|
30
|
+
save_attention_mask: bool = False,
|
|
31
|
+
special_token_ids: Optional[List[int] | Set[int]] = None
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Initialize the model input detector.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
layer_signature: Layer to capture from (typically the root model, can be None)
|
|
38
|
+
hook_id: Unique identifier for this hook
|
|
39
|
+
save_input_ids: Whether to save input_ids tensor
|
|
40
|
+
save_attention_mask: Whether to save attention_mask tensor (excludes padding and special tokens)
|
|
41
|
+
special_token_ids: Optional list/set of special token IDs. If None, will extract from LanguageModel context.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(
|
|
44
|
+
hook_type=HookType.PRE_FORWARD,
|
|
45
|
+
hook_id=hook_id,
|
|
46
|
+
store=None,
|
|
47
|
+
layer_signature=layer_signature
|
|
48
|
+
)
|
|
49
|
+
self.save_input_ids = save_input_ids
|
|
50
|
+
self.save_attention_mask = save_attention_mask
|
|
51
|
+
self.special_token_ids = set(special_token_ids) if special_token_ids is not None else None
|
|
52
|
+
|
|
53
|
+
def _extract_input_ids(self, input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
|
|
54
|
+
"""
|
|
55
|
+
Extract input_ids from model input.
|
|
56
|
+
|
|
57
|
+
Handles various input formats:
|
|
58
|
+
- Dict with 'input_ids' key (most common for HuggingFace models)
|
|
59
|
+
- Tuple with dict as first element
|
|
60
|
+
- Tuple with tensor as first element
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
input: Input to the model forward pass
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
input_ids tensor or None if not found
|
|
67
|
+
"""
|
|
68
|
+
if not input or len(input) == 0:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
first_item = input[0]
|
|
72
|
+
|
|
73
|
+
if isinstance(first_item, dict):
|
|
74
|
+
if 'input_ids' in first_item:
|
|
75
|
+
return first_item['input_ids']
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
if isinstance(first_item, torch.Tensor):
|
|
79
|
+
return first_item
|
|
80
|
+
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
def _extract_attention_mask(self, input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
|
|
84
|
+
"""
|
|
85
|
+
Extract attention_mask from model input.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
input: Input to the model forward pass
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
attention_mask tensor or None if not found
|
|
92
|
+
"""
|
|
93
|
+
if not input or len(input) == 0:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
first_item = input[0]
|
|
97
|
+
|
|
98
|
+
if isinstance(first_item, dict):
|
|
99
|
+
if 'attention_mask' in first_item:
|
|
100
|
+
return first_item['attention_mask']
|
|
101
|
+
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
def _get_special_token_ids(self, module: torch.nn.Module) -> Set[int]:
|
|
105
|
+
"""
|
|
106
|
+
Get special token IDs from user-provided list or from LanguageModel context.
|
|
107
|
+
|
|
108
|
+
Priority order:
|
|
109
|
+
1. self.special_token_ids (user-provided during initialization)
|
|
110
|
+
2. self.context.special_token_ids (from LanguageModel initialization)
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
module: The PyTorch module being hooked (unused, kept for API compatibility)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Set of special token IDs, or empty set if none available
|
|
117
|
+
"""
|
|
118
|
+
if self.special_token_ids is not None:
|
|
119
|
+
return self.special_token_ids
|
|
120
|
+
|
|
121
|
+
if self.context is not None and self.context.special_token_ids is not None:
|
|
122
|
+
return self.context.special_token_ids
|
|
123
|
+
|
|
124
|
+
return set()
|
|
125
|
+
|
|
126
|
+
def _create_combined_attention_mask(
|
|
127
|
+
self,
|
|
128
|
+
input_ids: torch.Tensor,
|
|
129
|
+
attention_mask: torch.Tensor | None,
|
|
130
|
+
module: torch.nn.Module
|
|
131
|
+
) -> torch.Tensor:
|
|
132
|
+
"""
|
|
133
|
+
Create a combined attention mask that excludes both padding and special tokens.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
input_ids: Input token IDs tensor (batch_size × sequence_length)
|
|
137
|
+
attention_mask: Original attention mask from tokenizer (None if not provided)
|
|
138
|
+
module: The PyTorch module being hooked
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Binary mask tensor with same shape as input_ids (1 for regular tokens, 0 for padding/special tokens)
|
|
142
|
+
"""
|
|
143
|
+
if attention_mask is None:
|
|
144
|
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
|
145
|
+
else:
|
|
146
|
+
attention_mask = attention_mask.bool()
|
|
147
|
+
|
|
148
|
+
special_token_ids = self._get_special_token_ids(module)
|
|
149
|
+
|
|
150
|
+
if special_token_ids:
|
|
151
|
+
special_ids_tensor = torch.tensor(list(special_token_ids), device=input_ids.device, dtype=input_ids.dtype)
|
|
152
|
+
expanded_input = input_ids.unsqueeze(-1)
|
|
153
|
+
expanded_special = special_ids_tensor.unsqueeze(0).unsqueeze(0)
|
|
154
|
+
is_special = (expanded_input == expanded_special).any(dim=-1)
|
|
155
|
+
attention_mask = attention_mask & ~is_special
|
|
156
|
+
|
|
157
|
+
return attention_mask.to(torch.bool)
|
|
158
|
+
|
|
159
|
+
def set_inputs_from_encodings(self, encodings: Dict[str, torch.Tensor], module: Optional[torch.nn.Module] = None) -> None:
|
|
160
|
+
"""
|
|
161
|
+
Manually set inputs from encodings dictionary.
|
|
162
|
+
|
|
163
|
+
This is useful when the model is called with keyword arguments,
|
|
164
|
+
as PyTorch's pre_forward hook doesn't receive kwargs.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
encodings: Dictionary of encoded inputs (e.g., from lm.forwards() or lm.tokenize())
|
|
168
|
+
module: Optional module for extracting special token IDs. If None, will use DummyModule.
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
RuntimeError: If tensor extraction or storage fails
|
|
172
|
+
"""
|
|
173
|
+
try:
|
|
174
|
+
if self.save_input_ids and 'input_ids' in encodings:
|
|
175
|
+
input_ids = encodings['input_ids']
|
|
176
|
+
self.tensor_metadata['input_ids'] = input_ids.detach().to("cpu")
|
|
177
|
+
self.metadata['input_ids_shape'] = tuple(input_ids.shape)
|
|
178
|
+
|
|
179
|
+
if self.save_attention_mask and 'input_ids' in encodings:
|
|
180
|
+
input_ids = encodings['input_ids']
|
|
181
|
+
if module is None:
|
|
182
|
+
class DummyModule:
|
|
183
|
+
pass
|
|
184
|
+
module = DummyModule()
|
|
185
|
+
|
|
186
|
+
original_attention_mask = encodings.get('attention_mask')
|
|
187
|
+
combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
|
|
188
|
+
self.tensor_metadata['attention_mask'] = combined_mask.detach().to("cpu")
|
|
189
|
+
self.metadata['attention_mask_shape'] = tuple(combined_mask.shape)
|
|
190
|
+
except Exception as e:
|
|
191
|
+
raise RuntimeError(
|
|
192
|
+
f"Error setting inputs from encodings in ModelInputDetector {self.id}: {e}"
|
|
193
|
+
) from e
|
|
194
|
+
|
|
195
|
+
def process_activations(
|
|
196
|
+
self,
|
|
197
|
+
module: torch.nn.Module,
|
|
198
|
+
input: HOOK_FUNCTION_INPUT,
|
|
199
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
200
|
+
) -> None:
|
|
201
|
+
"""
|
|
202
|
+
Extract and store tokenized inputs.
|
|
203
|
+
|
|
204
|
+
Note: For HuggingFace models called with **kwargs, the input tuple may be empty.
|
|
205
|
+
In such cases, use set_inputs_from_encodings() to manually set inputs from
|
|
206
|
+
the encodings dictionary returned by lm.forwards().
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
module: The PyTorch module being hooked (typically the root model)
|
|
210
|
+
input: Tuple of input tensors/dicts to the module
|
|
211
|
+
output: Output from the module (None for PRE_FORWARD hooks)
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
RuntimeError: If tensor extraction or storage fails
|
|
215
|
+
"""
|
|
216
|
+
try:
|
|
217
|
+
if self.save_input_ids:
|
|
218
|
+
input_ids = self._extract_input_ids(input)
|
|
219
|
+
if input_ids is not None:
|
|
220
|
+
self.tensor_metadata['input_ids'] = input_ids.detach().to("cpu")
|
|
221
|
+
self.metadata['input_ids_shape'] = tuple(input_ids.shape)
|
|
222
|
+
|
|
223
|
+
if self.save_attention_mask:
|
|
224
|
+
input_ids = self._extract_input_ids(input)
|
|
225
|
+
if input_ids is not None:
|
|
226
|
+
original_attention_mask = self._extract_attention_mask(input)
|
|
227
|
+
combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
|
|
228
|
+
self.tensor_metadata['attention_mask'] = combined_mask.detach().to("cpu")
|
|
229
|
+
self.metadata['attention_mask_shape'] = tuple(combined_mask.shape)
|
|
230
|
+
|
|
231
|
+
except Exception as e:
|
|
232
|
+
raise RuntimeError(
|
|
233
|
+
f"Error extracting inputs in ModelInputDetector {self.id}: {e}"
|
|
234
|
+
) from e
|
|
235
|
+
|
|
236
|
+
def get_captured_input_ids(self) -> torch.Tensor | None:
|
|
237
|
+
"""Get the captured input_ids from the current batch."""
|
|
238
|
+
return self.tensor_metadata.get('input_ids')
|
|
239
|
+
|
|
240
|
+
def get_captured_attention_mask(self) -> torch.Tensor | None:
|
|
241
|
+
"""Get the captured attention_mask from the current batch (excludes padding and special tokens)."""
|
|
242
|
+
return self.tensor_metadata.get('attention_mask')
|
|
243
|
+
|
|
244
|
+
def clear_captured(self) -> None:
|
|
245
|
+
"""Clear all captured inputs for current batch."""
|
|
246
|
+
keys_to_remove = ['input_ids', 'attention_mask']
|
|
247
|
+
for key in keys_to_remove:
|
|
248
|
+
self.tensor_metadata.pop(key, None)
|
|
249
|
+
self.metadata.pop(f'{key}_shape', None)
|
|
250
|
+
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from amber.hooks.detector import Detector
|
|
7
|
+
from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModelOutputDetector(Detector):
|
|
14
|
+
"""
|
|
15
|
+
Detector hook that captures and saves model outputs.
|
|
16
|
+
|
|
17
|
+
This detector is designed to be attached to the root model module and captures:
|
|
18
|
+
- Model outputs (logits) from the model's forward pass
|
|
19
|
+
- Hidden states (optional) from the model's forward pass
|
|
20
|
+
|
|
21
|
+
Uses FORWARD hook to capture outputs after they are computed.
|
|
22
|
+
Useful for saving model outputs for analysis or training.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
layer_signature: str | int | None = None,
|
|
28
|
+
hook_id: str | None = None,
|
|
29
|
+
save_output_logits: bool = True,
|
|
30
|
+
save_output_hidden_state: bool = False
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the model output detector.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
layer_signature: Layer to capture from (typically the root model, can be None)
|
|
37
|
+
hook_id: Unique identifier for this hook
|
|
38
|
+
save_output_logits: Whether to save output logits (if available)
|
|
39
|
+
save_output_hidden_state: Whether to save last_hidden_state (if available)
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(
|
|
42
|
+
hook_type=HookType.FORWARD,
|
|
43
|
+
hook_id=hook_id,
|
|
44
|
+
store=None,
|
|
45
|
+
layer_signature=layer_signature
|
|
46
|
+
)
|
|
47
|
+
self.save_output_logits = save_output_logits
|
|
48
|
+
self.save_output_hidden_state = save_output_hidden_state
|
|
49
|
+
|
|
50
|
+
def _extract_output_tensor(self, output: HOOK_FUNCTION_OUTPUT) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
51
|
+
"""
|
|
52
|
+
Extract logits and last_hidden_state from model output.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
output: Output from the model forward pass
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Tuple of (logits, last_hidden_state), either can be None
|
|
59
|
+
"""
|
|
60
|
+
logits = None
|
|
61
|
+
hidden_state = None
|
|
62
|
+
|
|
63
|
+
if output is None:
|
|
64
|
+
return None, None
|
|
65
|
+
|
|
66
|
+
# Handle HuggingFace output objects
|
|
67
|
+
if hasattr(output, "logits"):
|
|
68
|
+
logits = output.logits
|
|
69
|
+
if hasattr(output, "last_hidden_state"):
|
|
70
|
+
hidden_state = output.last_hidden_state
|
|
71
|
+
|
|
72
|
+
# Handle tuple output (logits might be first element)
|
|
73
|
+
if isinstance(output, (tuple, list)) and len(output) > 0:
|
|
74
|
+
first_item = output[0]
|
|
75
|
+
if isinstance(first_item, torch.Tensor) and logits is None:
|
|
76
|
+
logits = first_item
|
|
77
|
+
|
|
78
|
+
# Handle direct tensor output
|
|
79
|
+
if isinstance(output, torch.Tensor) and logits is None:
|
|
80
|
+
logits = output
|
|
81
|
+
|
|
82
|
+
return logits, hidden_state
|
|
83
|
+
|
|
84
|
+
def process_activations(
|
|
85
|
+
self,
|
|
86
|
+
module: torch.nn.Module,
|
|
87
|
+
input: HOOK_FUNCTION_INPUT,
|
|
88
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
89
|
+
) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Extract and store model outputs.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
module: The PyTorch module being hooked (typically the root model)
|
|
95
|
+
input: Tuple of input tensors/dicts to the module
|
|
96
|
+
output: Output from the module
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
RuntimeError: If tensor extraction or storage fails
|
|
100
|
+
"""
|
|
101
|
+
try:
|
|
102
|
+
# Extract and save outputs
|
|
103
|
+
logits, hidden_state = self._extract_output_tensor(output)
|
|
104
|
+
|
|
105
|
+
if self.save_output_logits and logits is not None:
|
|
106
|
+
self.tensor_metadata['output_logits'] = logits.detach().to("cpu")
|
|
107
|
+
self.metadata['output_logits_shape'] = tuple(logits.shape)
|
|
108
|
+
|
|
109
|
+
if self.save_output_hidden_state and hidden_state is not None:
|
|
110
|
+
self.tensor_metadata['output_hidden_state'] = hidden_state.detach().to("cpu")
|
|
111
|
+
self.metadata['output_hidden_state_shape'] = tuple(hidden_state.shape)
|
|
112
|
+
|
|
113
|
+
except Exception as e:
|
|
114
|
+
raise RuntimeError(
|
|
115
|
+
f"Error extracting outputs in ModelOutputDetector {self.id}: {e}"
|
|
116
|
+
) from e
|
|
117
|
+
|
|
118
|
+
def get_captured_output_logits(self) -> torch.Tensor | None:
|
|
119
|
+
"""Get the captured output logits from the current batch."""
|
|
120
|
+
return self.tensor_metadata.get('output_logits')
|
|
121
|
+
|
|
122
|
+
def get_captured_output_hidden_state(self) -> torch.Tensor | None:
|
|
123
|
+
"""Get the captured output hidden state from the current batch."""
|
|
124
|
+
return self.tensor_metadata.get('output_hidden_state')
|
|
125
|
+
|
|
126
|
+
def clear_captured(self) -> None:
|
|
127
|
+
"""Clear all captured outputs for current batch."""
|
|
128
|
+
keys_to_remove = ['output_logits', 'output_hidden_state']
|
|
129
|
+
for key in keys_to_remove:
|
|
130
|
+
self.tensor_metadata.pop(key, None)
|
|
131
|
+
self.metadata.pop(f'{key}_shape', None)
|
|
132
|
+
|
amber/hooks/utils.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Utility functions for hook implementations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from amber.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def extract_tensor_from_input(input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
|
|
13
|
+
"""
|
|
14
|
+
Extract the first tensor from input sequence.
|
|
15
|
+
|
|
16
|
+
Handles various input formats:
|
|
17
|
+
- Direct tensor in first position
|
|
18
|
+
- Tuple/list of tensors in first position
|
|
19
|
+
- Empty or None inputs
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
input: Input sequence (tuple/list of tensors)
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
First tensor found, or None if no tensor found
|
|
26
|
+
"""
|
|
27
|
+
if not input or len(input) == 0:
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
first_item = input[0]
|
|
31
|
+
if isinstance(first_item, torch.Tensor):
|
|
32
|
+
return first_item
|
|
33
|
+
|
|
34
|
+
if isinstance(first_item, (tuple, list)):
|
|
35
|
+
for item in first_item:
|
|
36
|
+
if isinstance(item, torch.Tensor):
|
|
37
|
+
return item
|
|
38
|
+
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def extract_tensor_from_output(output: HOOK_FUNCTION_OUTPUT) -> torch.Tensor | None:
|
|
43
|
+
"""
|
|
44
|
+
Extract tensor from output (handles various output types).
|
|
45
|
+
|
|
46
|
+
Handles various output formats:
|
|
47
|
+
- Plain tensors
|
|
48
|
+
- Tuples/lists of tensors (takes first tensor)
|
|
49
|
+
- Objects with last_hidden_state attribute (e.g., HuggingFace outputs)
|
|
50
|
+
- None outputs
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
output: Output from module (tensor, tuple, or object with attributes)
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
First tensor found, or None if no tensor found
|
|
57
|
+
"""
|
|
58
|
+
if output is None:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
if isinstance(output, torch.Tensor):
|
|
62
|
+
return output
|
|
63
|
+
|
|
64
|
+
if isinstance(output, (tuple, list)):
|
|
65
|
+
for item in output:
|
|
66
|
+
if isinstance(item, torch.Tensor):
|
|
67
|
+
return item
|
|
68
|
+
|
|
69
|
+
# Try common HuggingFace output objects
|
|
70
|
+
if hasattr(output, "last_hidden_state"):
|
|
71
|
+
maybe = getattr(output, "last_hidden_state")
|
|
72
|
+
if isinstance(maybe, torch.Tensor):
|
|
73
|
+
return maybe
|
|
74
|
+
|
|
75
|
+
return None
|
|
76
|
+
|
|
File without changes
|