mi-crow 0.1.2__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,68 +1,59 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING
4
+
4
5
  import torch
5
6
 
6
7
  from mi_crow.hooks.detector import Detector
7
- from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
+ from mi_crow.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT, HookType
8
9
  from mi_crow.hooks.utils import extract_tensor_from_output
9
10
 
10
11
  if TYPE_CHECKING:
11
- from torch import nn
12
+ pass
12
13
 
13
14
 
14
15
  class LayerActivationDetector(Detector):
15
16
  """
16
17
  Detector hook that captures and saves activations during inference.
17
-
18
+
18
19
  This detector extracts activations from layer outputs and stores them
19
20
  for later use (e.g., saving to disk, further analysis).
20
21
  """
21
22
 
22
- def __init__(
23
- self,
24
- layer_signature: str | int,
25
- hook_id: str | None = None
26
- ):
23
+ def __init__(self, layer_signature: str | int, hook_id: str | None = None, target_dtype: torch.dtype | None = None):
27
24
  """
28
25
  Initialize the activation saver detector.
29
-
26
+
30
27
  Args:
31
28
  layer_signature: Layer to capture activations from
32
29
  hook_id: Unique identifier for this hook
33
-
30
+ target_dtype: Optional dtype to convert activations to before storing
31
+
34
32
  Raises:
35
33
  ValueError: If layer_signature is None
36
34
  """
37
35
  if layer_signature is None:
38
36
  raise ValueError("layer_signature cannot be None for LayerActivationDetector")
39
37
 
40
- super().__init__(
41
- hook_type=HookType.FORWARD,
42
- hook_id=hook_id,
43
- store=None,
44
- layer_signature=layer_signature
45
- )
38
+ super().__init__(hook_type=HookType.FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
39
+ self.target_dtype = target_dtype
46
40
 
47
41
  def process_activations(
48
- self,
49
- module: torch.nn.Module,
50
- input: HOOK_FUNCTION_INPUT,
51
- output: HOOK_FUNCTION_OUTPUT
42
+ self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
52
43
  ) -> None:
53
44
  """
54
45
  Extract and store activations from output.
55
-
46
+
56
47
  Handles various output types:
57
48
  - Plain tensors
58
49
  - Tuples/lists of tensors (takes first tensor)
59
50
  - Objects with last_hidden_state attribute (e.g., HuggingFace outputs)
60
-
51
+
61
52
  Args:
62
53
  module: The PyTorch module being hooked
63
54
  input: Tuple of input tensors to the module
64
55
  output: Output tensor(s) from the module
65
-
56
+
66
57
  Raises:
67
58
  RuntimeError: If tensor extraction or storage fails
68
59
  """
@@ -70,27 +61,32 @@ class LayerActivationDetector(Detector):
70
61
  tensor = extract_tensor_from_output(output)
71
62
 
72
63
  if tensor is not None:
73
- tensor_cpu = tensor.detach().to("cpu")
74
- # Store current batch's tensor (overwrites previous)
75
- self.tensor_metadata['activations'] = tensor_cpu
76
- # Store activations shape to metadata
77
- self.metadata['activations_shape'] = tuple(tensor_cpu.shape)
64
+ if tensor.is_cuda:
65
+ tensor_cpu = tensor.detach().to("cpu", non_blocking=True)
66
+ else:
67
+ tensor_cpu = tensor.detach()
68
+
69
+ if self.target_dtype is not None:
70
+ tensor_cpu = tensor_cpu.to(self.target_dtype)
71
+
72
+ self.tensor_metadata["activations"] = tensor_cpu
73
+ self.metadata["activations_shape"] = tuple(tensor_cpu.shape)
78
74
  except Exception as e:
75
+ layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
79
76
  raise RuntimeError(
80
- f"Error extracting activations in LayerActivationDetector {self.id}: {e}"
77
+ f"Error extracting activations in LayerActivationDetector {self.id} (layer={layer_sig}): {e}"
81
78
  ) from e
82
79
 
83
80
  def get_captured(self) -> torch.Tensor | None:
84
81
  """
85
82
  Get the captured activations from the current batch.
86
-
83
+
87
84
  Returns:
88
85
  The captured activation tensor from the current batch or None if no activations captured yet
89
86
  """
90
- return self.tensor_metadata.get('activations')
87
+ return self.tensor_metadata.get("activations")
91
88
 
92
89
  def clear_captured(self) -> None:
93
90
  """Clear captured activations for current batch."""
94
- self.tensor_metadata.pop('activations', None)
95
- self.metadata.pop('activations_shape', None)
96
-
91
+ self.tensor_metadata.pop("activations", None)
92
+ self.metadata.pop("activations_shape", None)
@@ -1,38 +1,39 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Dict, Set, List, Optional
3
+ from typing import TYPE_CHECKING, Dict, List, Optional, Set
4
+
4
5
  import torch
5
6
 
6
7
  from mi_crow.hooks.detector import Detector
7
- from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
+ from mi_crow.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT, HookType
8
9
 
9
10
  if TYPE_CHECKING:
10
- from torch import nn
11
+ pass
11
12
 
12
13
 
13
14
  class ModelInputDetector(Detector):
14
15
  """
15
16
  Detector hook that captures and saves tokenized inputs from model forward pass.
16
-
17
+
17
18
  This detector is designed to be attached to the root model module and captures:
18
19
  - Tokenized inputs (input_ids) from the model's forward pass
19
20
  - Attention masks (optional) that exclude both padding and special tokens
20
-
21
+
21
22
  Uses PRE_FORWARD hook to capture inputs before they are processed.
22
23
  Useful for saving tokenized inputs for analysis or training.
23
24
  """
24
25
 
25
26
  def __init__(
26
- self,
27
- layer_signature: str | int | None = None,
28
- hook_id: str | None = None,
29
- save_input_ids: bool = True,
30
- save_attention_mask: bool = False,
31
- special_token_ids: Optional[List[int] | Set[int]] = None
27
+ self,
28
+ layer_signature: str | int | None = None,
29
+ hook_id: str | None = None,
30
+ save_input_ids: bool = True,
31
+ save_attention_mask: bool = False,
32
+ special_token_ids: Optional[List[int] | Set[int]] = None,
32
33
  ):
33
34
  """
34
35
  Initialize the model input detector.
35
-
36
+
36
37
  Args:
37
38
  layer_signature: Layer to capture from (typically the root model, can be None)
38
39
  hook_id: Unique identifier for this hook
@@ -40,12 +41,7 @@ class ModelInputDetector(Detector):
40
41
  save_attention_mask: Whether to save attention_mask tensor (excludes padding and special tokens)
41
42
  special_token_ids: Optional list/set of special token IDs. If None, will extract from LanguageModel context.
42
43
  """
43
- super().__init__(
44
- hook_type=HookType.PRE_FORWARD,
45
- hook_id=hook_id,
46
- store=None,
47
- layer_signature=layer_signature
48
- )
44
+ super().__init__(hook_type=HookType.PRE_FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
49
45
  self.save_input_ids = save_input_ids
50
46
  self.save_attention_mask = save_attention_mask
51
47
  self.special_token_ids = set(special_token_ids) if special_token_ids is not None else None
@@ -53,90 +49,87 @@ class ModelInputDetector(Detector):
53
49
  def _extract_input_ids(self, input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
54
50
  """
55
51
  Extract input_ids from model input.
56
-
52
+
57
53
  Handles various input formats:
58
54
  - Dict with 'input_ids' key (most common for HuggingFace models)
59
55
  - Tuple with dict as first element
60
56
  - Tuple with tensor as first element
61
-
57
+
62
58
  Args:
63
59
  input: Input to the model forward pass
64
-
60
+
65
61
  Returns:
66
62
  input_ids tensor or None if not found
67
63
  """
68
64
  if not input or len(input) == 0:
69
65
  return None
70
-
66
+
71
67
  first_item = input[0]
72
-
68
+
73
69
  if isinstance(first_item, dict):
74
- if 'input_ids' in first_item:
75
- return first_item['input_ids']
70
+ if "input_ids" in first_item:
71
+ return first_item["input_ids"]
76
72
  return None
77
-
73
+
78
74
  if isinstance(first_item, torch.Tensor):
79
75
  return first_item
80
-
76
+
81
77
  return None
82
78
 
83
79
  def _extract_attention_mask(self, input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
84
80
  """
85
81
  Extract attention_mask from model input.
86
-
82
+
87
83
  Args:
88
84
  input: Input to the model forward pass
89
-
85
+
90
86
  Returns:
91
87
  attention_mask tensor or None if not found
92
88
  """
93
89
  if not input or len(input) == 0:
94
90
  return None
95
-
91
+
96
92
  first_item = input[0]
97
-
93
+
98
94
  if isinstance(first_item, dict):
99
- if 'attention_mask' in first_item:
100
- return first_item['attention_mask']
101
-
95
+ if "attention_mask" in first_item:
96
+ return first_item["attention_mask"]
97
+
102
98
  return None
103
99
 
104
100
  def _get_special_token_ids(self, module: torch.nn.Module) -> Set[int]:
105
101
  """
106
102
  Get special token IDs from user-provided list or from LanguageModel context.
107
-
103
+
108
104
  Priority order:
109
105
  1. self.special_token_ids (user-provided during initialization)
110
106
  2. self.context.special_token_ids (from LanguageModel initialization)
111
-
107
+
112
108
  Args:
113
109
  module: The PyTorch module being hooked (unused, kept for API compatibility)
114
-
110
+
115
111
  Returns:
116
112
  Set of special token IDs, or empty set if none available
117
113
  """
118
114
  if self.special_token_ids is not None:
119
115
  return self.special_token_ids
120
-
116
+
121
117
  if self.context is not None and self.context.special_token_ids is not None:
122
118
  return self.context.special_token_ids
123
-
119
+
124
120
  return set()
125
121
 
126
122
  def _create_combined_attention_mask(
127
- self,
128
- input_ids: torch.Tensor,
129
- attention_mask: torch.Tensor | None,
130
- module: torch.nn.Module
123
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None, module: torch.nn.Module
131
124
  ) -> torch.Tensor:
132
125
  """
133
126
  Create a combined attention mask that excludes both padding and special tokens.
134
-
127
+
135
128
  Args:
136
129
  input_ids: Input token IDs tensor (batch_size × sequence_length)
137
130
  attention_mask: Original attention mask from tokenizer (None if not provided)
138
131
  module: The PyTorch module being hooked
139
-
132
+
140
133
  Returns:
141
134
  Binary mask tensor with same shape as input_ids (1 for regular tokens, 0 for padding/special tokens)
142
135
  """
@@ -144,72 +137,71 @@ class ModelInputDetector(Detector):
144
137
  attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
145
138
  else:
146
139
  attention_mask = attention_mask.bool()
147
-
140
+
148
141
  special_token_ids = self._get_special_token_ids(module)
149
-
142
+
150
143
  if special_token_ids:
151
144
  special_ids_tensor = torch.tensor(list(special_token_ids), device=input_ids.device, dtype=input_ids.dtype)
152
145
  expanded_input = input_ids.unsqueeze(-1)
153
146
  expanded_special = special_ids_tensor.unsqueeze(0).unsqueeze(0)
154
147
  is_special = (expanded_input == expanded_special).any(dim=-1)
155
148
  attention_mask = attention_mask & ~is_special
156
-
149
+
157
150
  return attention_mask.to(torch.bool)
158
151
 
159
- def set_inputs_from_encodings(self, encodings: Dict[str, torch.Tensor], module: Optional[torch.nn.Module] = None) -> None:
152
+ def set_inputs_from_encodings(
153
+ self, encodings: Dict[str, torch.Tensor], module: Optional[torch.nn.Module] = None
154
+ ) -> None:
160
155
  """
161
156
  Manually set inputs from encodings dictionary.
162
-
157
+
163
158
  This is useful when the model is called with keyword arguments,
164
159
  as PyTorch's pre_forward hook doesn't receive kwargs.
165
-
160
+
166
161
  Args:
167
- encodings: Dictionary of encoded inputs (e.g., from lm.forwards() or lm.tokenize())
162
+ encodings: Dictionary of encoded inputs (e.g., from lm.inference.execute_inference() or lm.tokenize())
168
163
  module: Optional module for extracting special token IDs. If None, will use DummyModule.
169
-
164
+
170
165
  Raises:
171
166
  RuntimeError: If tensor extraction or storage fails
172
167
  """
173
168
  try:
174
- if self.save_input_ids and 'input_ids' in encodings:
175
- input_ids = encodings['input_ids']
176
- self.tensor_metadata['input_ids'] = input_ids.detach().to("cpu")
177
- self.metadata['input_ids_shape'] = tuple(input_ids.shape)
178
-
179
- if self.save_attention_mask and 'input_ids' in encodings:
180
- input_ids = encodings['input_ids']
169
+ if self.save_input_ids and "input_ids" in encodings:
170
+ input_ids = encodings["input_ids"]
171
+ self.tensor_metadata["input_ids"] = input_ids.detach().to("cpu")
172
+ self.metadata["input_ids_shape"] = tuple(input_ids.shape)
173
+
174
+ if self.save_attention_mask and "input_ids" in encodings:
175
+ input_ids = encodings["input_ids"]
181
176
  if module is None:
177
+
182
178
  class DummyModule:
183
179
  pass
180
+
184
181
  module = DummyModule()
185
-
186
- original_attention_mask = encodings.get('attention_mask')
182
+
183
+ original_attention_mask = encodings.get("attention_mask")
187
184
  combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
188
- self.tensor_metadata['attention_mask'] = combined_mask.detach().to("cpu")
189
- self.metadata['attention_mask_shape'] = tuple(combined_mask.shape)
185
+ self.tensor_metadata["attention_mask"] = combined_mask.detach().to("cpu")
186
+ self.metadata["attention_mask_shape"] = tuple(combined_mask.shape)
190
187
  except Exception as e:
191
- raise RuntimeError(
192
- f"Error setting inputs from encodings in ModelInputDetector {self.id}: {e}"
193
- ) from e
188
+ raise RuntimeError(f"Error setting inputs from encodings in ModelInputDetector {self.id}: {e}") from e
194
189
 
195
190
  def process_activations(
196
- self,
197
- module: torch.nn.Module,
198
- input: HOOK_FUNCTION_INPUT,
199
- output: HOOK_FUNCTION_OUTPUT
191
+ self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
200
192
  ) -> None:
201
193
  """
202
194
  Extract and store tokenized inputs.
203
-
195
+
204
196
  Note: For HuggingFace models called with **kwargs, the input tuple may be empty.
205
197
  In such cases, use set_inputs_from_encodings() to manually set inputs from
206
- the encodings dictionary returned by lm.forwards().
207
-
198
+ the encodings dictionary returned by lm.inference.execute_inference().
199
+
208
200
  Args:
209
201
  module: The PyTorch module being hooked (typically the root model)
210
202
  input: Tuple of input tensors/dicts to the module
211
203
  output: Output from the module (None for PRE_FORWARD hooks)
212
-
204
+
213
205
  Raises:
214
206
  RuntimeError: If tensor extraction or storage fails
215
207
  """
@@ -217,34 +209,42 @@ class ModelInputDetector(Detector):
217
209
  if self.save_input_ids:
218
210
  input_ids = self._extract_input_ids(input)
219
211
  if input_ids is not None:
220
- self.tensor_metadata['input_ids'] = input_ids.detach().to("cpu")
221
- self.metadata['input_ids_shape'] = tuple(input_ids.shape)
222
-
212
+ if input_ids.is_cuda:
213
+ input_ids_cpu = input_ids.detach().to("cpu", non_blocking=True)
214
+ else:
215
+ input_ids_cpu = input_ids.detach()
216
+ self.tensor_metadata["input_ids"] = input_ids_cpu
217
+ self.metadata["input_ids_shape"] = tuple(input_ids_cpu.shape)
218
+
223
219
  if self.save_attention_mask:
224
220
  input_ids = self._extract_input_ids(input)
225
221
  if input_ids is not None:
226
222
  original_attention_mask = self._extract_attention_mask(input)
227
223
  combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
228
- self.tensor_metadata['attention_mask'] = combined_mask.detach().to("cpu")
229
- self.metadata['attention_mask_shape'] = tuple(combined_mask.shape)
230
-
224
+ if combined_mask.is_cuda:
225
+ combined_mask_cpu = combined_mask.detach().to("cpu", non_blocking=True)
226
+ else:
227
+ combined_mask_cpu = combined_mask.detach()
228
+ self.tensor_metadata["attention_mask"] = combined_mask_cpu
229
+ self.metadata["attention_mask_shape"] = tuple(combined_mask_cpu.shape)
230
+
231
231
  except Exception as e:
232
+ layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
232
233
  raise RuntimeError(
233
- f"Error extracting inputs in ModelInputDetector {self.id}: {e}"
234
+ f"Error extracting inputs in ModelInputDetector {self.id} (layer={layer_sig}): {e}"
234
235
  ) from e
235
236
 
236
237
  def get_captured_input_ids(self) -> torch.Tensor | None:
237
238
  """Get the captured input_ids from the current batch."""
238
- return self.tensor_metadata.get('input_ids')
239
+ return self.tensor_metadata.get("input_ids")
239
240
 
240
241
  def get_captured_attention_mask(self) -> torch.Tensor | None:
241
242
  """Get the captured attention_mask from the current batch (excludes padding and special tokens)."""
242
- return self.tensor_metadata.get('attention_mask')
243
+ return self.tensor_metadata.get("attention_mask")
243
244
 
244
245
  def clear_captured(self) -> None:
245
246
  """Clear all captured inputs for current batch."""
246
- keys_to_remove = ['input_ids', 'attention_mask']
247
+ keys_to_remove = ["input_ids", "attention_mask"]
247
248
  for key in keys_to_remove:
248
249
  self.tensor_metadata.pop(key, None)
249
- self.metadata.pop(f'{key}_shape', None)
250
-
250
+ self.metadata.pop(f"{key}_shape", None)
@@ -1,132 +1,133 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING
4
+
4
5
  import torch
5
6
 
6
7
  from mi_crow.hooks.detector import Detector
7
- from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
+ from mi_crow.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT, HookType
8
9
 
9
10
  if TYPE_CHECKING:
10
- from torch import nn
11
+ pass
11
12
 
12
13
 
13
14
  class ModelOutputDetector(Detector):
14
15
  """
15
16
  Detector hook that captures and saves model outputs.
16
-
17
+
17
18
  This detector is designed to be attached to the root model module and captures:
18
19
  - Model outputs (logits) from the model's forward pass
19
20
  - Hidden states (optional) from the model's forward pass
20
-
21
+
21
22
  Uses FORWARD hook to capture outputs after they are computed.
22
23
  Useful for saving model outputs for analysis or training.
23
24
  """
24
25
 
25
26
  def __init__(
26
- self,
27
- layer_signature: str | int | None = None,
28
- hook_id: str | None = None,
29
- save_output_logits: bool = True,
30
- save_output_hidden_state: bool = False
27
+ self,
28
+ layer_signature: str | int | None = None,
29
+ hook_id: str | None = None,
30
+ save_output_logits: bool = True,
31
+ save_output_hidden_state: bool = False,
31
32
  ):
32
33
  """
33
34
  Initialize the model output detector.
34
-
35
+
35
36
  Args:
36
37
  layer_signature: Layer to capture from (typically the root model, can be None)
37
38
  hook_id: Unique identifier for this hook
38
39
  save_output_logits: Whether to save output logits (if available)
39
40
  save_output_hidden_state: Whether to save last_hidden_state (if available)
40
41
  """
41
- super().__init__(
42
- hook_type=HookType.FORWARD,
43
- hook_id=hook_id,
44
- store=None,
45
- layer_signature=layer_signature
46
- )
42
+ super().__init__(hook_type=HookType.FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
47
43
  self.save_output_logits = save_output_logits
48
44
  self.save_output_hidden_state = save_output_hidden_state
49
45
 
50
46
  def _extract_output_tensor(self, output: HOOK_FUNCTION_OUTPUT) -> tuple[torch.Tensor | None, torch.Tensor | None]:
51
47
  """
52
48
  Extract logits and last_hidden_state from model output.
53
-
49
+
54
50
  Args:
55
51
  output: Output from the model forward pass
56
-
52
+
57
53
  Returns:
58
54
  Tuple of (logits, last_hidden_state), either can be None
59
55
  """
60
56
  logits = None
61
57
  hidden_state = None
62
-
58
+
63
59
  if output is None:
64
60
  return None, None
65
-
61
+
66
62
  # Handle HuggingFace output objects
67
63
  if hasattr(output, "logits"):
68
64
  logits = output.logits
69
65
  if hasattr(output, "last_hidden_state"):
70
66
  hidden_state = output.last_hidden_state
71
-
67
+
72
68
  # Handle tuple output (logits might be first element)
73
69
  if isinstance(output, (tuple, list)) and len(output) > 0:
74
70
  first_item = output[0]
75
71
  if isinstance(first_item, torch.Tensor) and logits is None:
76
72
  logits = first_item
77
-
73
+
78
74
  # Handle direct tensor output
79
75
  if isinstance(output, torch.Tensor) and logits is None:
80
76
  logits = output
81
-
77
+
82
78
  return logits, hidden_state
83
79
 
84
80
  def process_activations(
85
- self,
86
- module: torch.nn.Module,
87
- input: HOOK_FUNCTION_INPUT,
88
- output: HOOK_FUNCTION_OUTPUT
81
+ self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
89
82
  ) -> None:
90
83
  """
91
84
  Extract and store model outputs.
92
-
85
+
93
86
  Args:
94
87
  module: The PyTorch module being hooked (typically the root model)
95
88
  input: Tuple of input tensors/dicts to the module
96
89
  output: Output from the module
97
-
90
+
98
91
  Raises:
99
92
  RuntimeError: If tensor extraction or storage fails
100
93
  """
101
94
  try:
102
95
  # Extract and save outputs
103
96
  logits, hidden_state = self._extract_output_tensor(output)
104
-
97
+
105
98
  if self.save_output_logits and logits is not None:
106
- self.tensor_metadata['output_logits'] = logits.detach().to("cpu")
107
- self.metadata['output_logits_shape'] = tuple(logits.shape)
108
-
99
+ if logits.is_cuda:
100
+ logits_cpu = logits.detach().to("cpu", non_blocking=True)
101
+ else:
102
+ logits_cpu = logits.detach()
103
+ self.tensor_metadata["output_logits"] = logits_cpu
104
+ self.metadata["output_logits_shape"] = tuple(logits_cpu.shape)
105
+
109
106
  if self.save_output_hidden_state and hidden_state is not None:
110
- self.tensor_metadata['output_hidden_state'] = hidden_state.detach().to("cpu")
111
- self.metadata['output_hidden_state_shape'] = tuple(hidden_state.shape)
112
-
107
+ if hidden_state.is_cuda:
108
+ hidden_state_cpu = hidden_state.detach().to("cpu", non_blocking=True)
109
+ else:
110
+ hidden_state_cpu = hidden_state.detach()
111
+ self.tensor_metadata["output_hidden_state"] = hidden_state_cpu
112
+ self.metadata["output_hidden_state_shape"] = tuple(hidden_state_cpu.shape)
113
+
113
114
  except Exception as e:
115
+ layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
114
116
  raise RuntimeError(
115
- f"Error extracting outputs in ModelOutputDetector {self.id}: {e}"
117
+ f"Error extracting outputs in ModelOutputDetector {self.id} (layer={layer_sig}): {e}"
116
118
  ) from e
117
119
 
118
120
  def get_captured_output_logits(self) -> torch.Tensor | None:
119
121
  """Get the captured output logits from the current batch."""
120
- return self.tensor_metadata.get('output_logits')
122
+ return self.tensor_metadata.get("output_logits")
121
123
 
122
124
  def get_captured_output_hidden_state(self) -> torch.Tensor | None:
123
125
  """Get the captured output hidden state from the current batch."""
124
- return self.tensor_metadata.get('output_hidden_state')
126
+ return self.tensor_metadata.get("output_hidden_state")
125
127
 
126
128
  def clear_captured(self) -> None:
127
129
  """Clear all captured outputs for current batch."""
128
- keys_to_remove = ['output_logits', 'output_hidden_state']
130
+ keys_to_remove = ["output_logits", "output_hidden_state"]
129
131
  for key in keys_to_remove:
130
132
  self.tensor_metadata.pop(key, None)
131
- self.metadata.pop(f'{key}_shape', None)
132
-
133
+ self.metadata.pop(f"{key}_shape", None)