mi-crow 0.1.2__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mi_crow/datasets/base_dataset.py +71 -1
- mi_crow/datasets/classification_dataset.py +136 -30
- mi_crow/datasets/text_dataset.py +165 -24
- mi_crow/hooks/controller.py +12 -7
- mi_crow/hooks/implementations/layer_activation_detector.py +30 -34
- mi_crow/hooks/implementations/model_input_detector.py +87 -87
- mi_crow/hooks/implementations/model_output_detector.py +43 -42
- mi_crow/hooks/utils.py +74 -0
- mi_crow/language_model/activations.py +174 -77
- mi_crow/language_model/device_manager.py +119 -0
- mi_crow/language_model/inference.py +18 -5
- mi_crow/language_model/initialization.py +10 -6
- mi_crow/language_model/language_model.py +67 -97
- mi_crow/language_model/layers.py +16 -13
- mi_crow/language_model/persistence.py +4 -2
- mi_crow/language_model/utils.py +5 -5
- mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py +157 -95
- mi_crow/mechanistic/sae/concepts/concept_dictionary.py +12 -2
- mi_crow/mechanistic/sae/concepts/text_heap.py +161 -0
- mi_crow/mechanistic/sae/modules/topk_sae.py +29 -22
- mi_crow/mechanistic/sae/sae.py +3 -1
- mi_crow/mechanistic/sae/sae_trainer.py +362 -29
- mi_crow/store/local_store.py +11 -5
- mi_crow/store/store.py +34 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/METADATA +2 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/RECORD +28 -26
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/WHEEL +1 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/top_level.txt +0 -0
mi_crow/hooks/utils.py
CHANGED
|
@@ -74,3 +74,77 @@ def extract_tensor_from_output(output: HOOK_FUNCTION_OUTPUT) -> torch.Tensor | N
|
|
|
74
74
|
|
|
75
75
|
return None
|
|
76
76
|
|
|
77
|
+
|
|
78
|
+
def apply_modification_to_output(
|
|
79
|
+
output: HOOK_FUNCTION_OUTPUT,
|
|
80
|
+
modified_tensor: torch.Tensor,
|
|
81
|
+
target_device: torch.device | None = None
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Apply a modified tensor to an output object in-place.
|
|
85
|
+
|
|
86
|
+
Handles various output formats:
|
|
87
|
+
- Plain tensors: modifies the tensor directly (in-place)
|
|
88
|
+
- Tuples/lists of tensors: replaces first tensor
|
|
89
|
+
- Objects with last_hidden_state attribute: sets last_hidden_state
|
|
90
|
+
|
|
91
|
+
If target_device is provided, output tensors are moved to target_device first,
|
|
92
|
+
ensuring consistency with the desired device (e.g., context.device).
|
|
93
|
+
Otherwise, modified_tensor is moved to match output's current device.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
output: Output object to modify
|
|
97
|
+
modified_tensor: Modified tensor to apply
|
|
98
|
+
target_device: Optional target device. If provided, output tensors are moved
|
|
99
|
+
to this device before applying modification. If None, uses output's current device.
|
|
100
|
+
"""
|
|
101
|
+
if output is None:
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
if isinstance(output, torch.Tensor):
|
|
105
|
+
if target_device is not None:
|
|
106
|
+
if output.device != target_device:
|
|
107
|
+
output = output.to(target_device)
|
|
108
|
+
if modified_tensor.device != target_device:
|
|
109
|
+
modified_tensor = modified_tensor.to(target_device)
|
|
110
|
+
else:
|
|
111
|
+
if modified_tensor.device != output.device:
|
|
112
|
+
modified_tensor = modified_tensor.to(output.device)
|
|
113
|
+
output.data.copy_(modified_tensor.data)
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
if isinstance(output, (tuple, list)):
|
|
117
|
+
for i, item in enumerate(output):
|
|
118
|
+
if isinstance(item, torch.Tensor):
|
|
119
|
+
if target_device is not None:
|
|
120
|
+
if item.device != target_device:
|
|
121
|
+
item = item.to(target_device)
|
|
122
|
+
if isinstance(output, list):
|
|
123
|
+
output[i] = item
|
|
124
|
+
if modified_tensor.device != target_device or modified_tensor.dtype != item.dtype:
|
|
125
|
+
modified_tensor = modified_tensor.to(device=target_device, dtype=item.dtype)
|
|
126
|
+
else:
|
|
127
|
+
if modified_tensor.device != item.device or modified_tensor.dtype != item.dtype:
|
|
128
|
+
modified_tensor = modified_tensor.to(device=item.device, dtype=item.dtype)
|
|
129
|
+
if isinstance(output, tuple):
|
|
130
|
+
item.data.copy_(modified_tensor.data)
|
|
131
|
+
else:
|
|
132
|
+
output[i] = modified_tensor
|
|
133
|
+
break
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
if hasattr(output, "last_hidden_state"):
|
|
137
|
+
original_tensor = output.last_hidden_state
|
|
138
|
+
if isinstance(original_tensor, torch.Tensor):
|
|
139
|
+
if target_device is not None:
|
|
140
|
+
if original_tensor.device != target_device:
|
|
141
|
+
output.last_hidden_state = original_tensor.to(target_device)
|
|
142
|
+
original_tensor = output.last_hidden_state
|
|
143
|
+
if modified_tensor.device != target_device:
|
|
144
|
+
modified_tensor = modified_tensor.to(target_device)
|
|
145
|
+
else:
|
|
146
|
+
if modified_tensor.device != original_tensor.device:
|
|
147
|
+
modified_tensor = modified_tensor.to(original_tensor.device)
|
|
148
|
+
output.last_hidden_state = modified_tensor
|
|
149
|
+
return
|
|
150
|
+
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import datetime
|
|
2
|
+
import gc
|
|
2
3
|
from typing import TYPE_CHECKING, Any, Dict, Sequence
|
|
3
4
|
|
|
4
5
|
import torch
|
|
@@ -10,7 +11,6 @@ from mi_crow.hooks.implementations.layer_activation_detector import LayerActivat
|
|
|
10
11
|
from mi_crow.hooks.implementations.model_input_detector import ModelInputDetector
|
|
11
12
|
from mi_crow.store.store import Store
|
|
12
13
|
from mi_crow.utils import get_logger
|
|
13
|
-
from mi_crow.language_model.utils import get_device_from_model
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
from mi_crow.language_model.context import LanguageModelContext
|
|
@@ -30,18 +30,25 @@ class LanguageModelActivations:
|
|
|
30
30
|
"""
|
|
31
31
|
self.context = context
|
|
32
32
|
|
|
33
|
-
def _setup_detector(
|
|
33
|
+
def _setup_detector(
|
|
34
|
+
self, layer_signature: str | int, hook_id_suffix: str, dtype: torch.dtype | None = None
|
|
35
|
+
) -> tuple[LayerActivationDetector, str]:
|
|
34
36
|
"""
|
|
35
37
|
Create and register an activation detector.
|
|
36
38
|
|
|
37
39
|
Args:
|
|
38
40
|
layer_signature: Layer to attach detector to
|
|
39
41
|
hook_id_suffix: Suffix for hook ID
|
|
42
|
+
dtype: Optional dtype for activations
|
|
40
43
|
|
|
41
44
|
Returns:
|
|
42
45
|
Tuple of (detector, hook_id)
|
|
43
46
|
"""
|
|
44
|
-
detector = LayerActivationDetector(
|
|
47
|
+
detector = LayerActivationDetector(
|
|
48
|
+
layer_signature=layer_signature,
|
|
49
|
+
hook_id=f"detector_{hook_id_suffix}",
|
|
50
|
+
target_dtype=dtype,
|
|
51
|
+
)
|
|
45
52
|
|
|
46
53
|
hook_id = self.context.language_model.layers.register_hook(layer_signature, detector, HookType.FORWARD)
|
|
47
54
|
|
|
@@ -71,24 +78,115 @@ class LanguageModelActivations:
|
|
|
71
78
|
"""
|
|
72
79
|
attention_mask_layer_sig = "attention_masks"
|
|
73
80
|
root_model = self.context.model
|
|
74
|
-
|
|
75
|
-
# Add layer signature to registry for root model
|
|
81
|
+
|
|
76
82
|
if attention_mask_layer_sig not in self.context.language_model.layers.name_to_layer:
|
|
77
83
|
self.context.language_model.layers.name_to_layer[attention_mask_layer_sig] = root_model
|
|
78
|
-
|
|
84
|
+
|
|
79
85
|
detector = ModelInputDetector(
|
|
80
86
|
layer_signature=attention_mask_layer_sig,
|
|
81
87
|
hook_id=f"attention_mask_detector_{run_name}",
|
|
82
88
|
save_input_ids=False,
|
|
83
89
|
save_attention_mask=True,
|
|
84
90
|
)
|
|
85
|
-
|
|
91
|
+
|
|
86
92
|
hook_id = self.context.language_model.layers.register_hook(
|
|
87
93
|
attention_mask_layer_sig, detector, HookType.PRE_FORWARD
|
|
88
94
|
)
|
|
89
|
-
|
|
95
|
+
|
|
90
96
|
return detector, hook_id
|
|
91
97
|
|
|
98
|
+
def _setup_activation_hooks(
|
|
99
|
+
self,
|
|
100
|
+
layer_sig_list: list[str],
|
|
101
|
+
run_name: str,
|
|
102
|
+
save_attention_mask: bool,
|
|
103
|
+
dtype: torch.dtype | None = None,
|
|
104
|
+
) -> tuple[list[str], str | None]:
|
|
105
|
+
"""
|
|
106
|
+
Setup activation hooks for saving.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
layer_sig_list: List of layer signatures to hook
|
|
110
|
+
run_name: Run name for hook IDs
|
|
111
|
+
save_attention_mask: Whether to setup attention mask detector
|
|
112
|
+
dtype: Optional dtype for activations
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Tuple of (hook_ids list, attention_mask_hook_id or None)
|
|
116
|
+
"""
|
|
117
|
+
hook_ids: list[str] = []
|
|
118
|
+
for sig in layer_sig_list:
|
|
119
|
+
_, hook_id = self._setup_detector(sig, f"save_{run_name}_{sig}", dtype=dtype)
|
|
120
|
+
hook_ids.append(hook_id)
|
|
121
|
+
|
|
122
|
+
attention_mask_hook_id: str | None = None
|
|
123
|
+
if save_attention_mask:
|
|
124
|
+
_, attention_mask_hook_id = self._setup_attention_mask_detector(run_name)
|
|
125
|
+
|
|
126
|
+
return hook_ids, attention_mask_hook_id
|
|
127
|
+
|
|
128
|
+
def _teardown_activation_hooks(
|
|
129
|
+
self,
|
|
130
|
+
hook_ids: list[str],
|
|
131
|
+
attention_mask_hook_id: str | None,
|
|
132
|
+
) -> None:
|
|
133
|
+
"""
|
|
134
|
+
Teardown activation hooks.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
hook_ids: List of hook IDs to cleanup
|
|
138
|
+
attention_mask_hook_id: Optional attention mask hook ID to cleanup
|
|
139
|
+
"""
|
|
140
|
+
for hook_id in hook_ids:
|
|
141
|
+
self._cleanup_detector(hook_id)
|
|
142
|
+
if attention_mask_hook_id is not None:
|
|
143
|
+
self._cleanup_detector(attention_mask_hook_id)
|
|
144
|
+
|
|
145
|
+
def _validate_save_prerequisites(self) -> tuple[nn.Module, Store]:
|
|
146
|
+
"""
|
|
147
|
+
Validate prerequisites for saving activations.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Tuple of (model, store)
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError: If model or store is not initialized
|
|
154
|
+
"""
|
|
155
|
+
model: nn.Module | None = self.context.model
|
|
156
|
+
if model is None:
|
|
157
|
+
raise ValueError("Model must be initialized before running")
|
|
158
|
+
|
|
159
|
+
store = self.context.store
|
|
160
|
+
if store is None:
|
|
161
|
+
raise ValueError("Store must be provided or set on the language model")
|
|
162
|
+
|
|
163
|
+
return model, store
|
|
164
|
+
|
|
165
|
+
def _prepare_save_metadata(
|
|
166
|
+
self,
|
|
167
|
+
layer_signature: str | int | list[str | int],
|
|
168
|
+
dataset: BaseDataset | None,
|
|
169
|
+
run_name: str | None,
|
|
170
|
+
options: Dict[str, Any],
|
|
171
|
+
) -> tuple[str, Dict[str, Any], list[str]]:
|
|
172
|
+
"""
|
|
173
|
+
Prepare metadata for activation saving.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
layer_signature: Layer signature(s) to save
|
|
177
|
+
dataset: Optional dataset
|
|
178
|
+
run_name: Optional run name
|
|
179
|
+
options: Options dictionary
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Tuple of (run_name, metadata, layer_sig_list)
|
|
183
|
+
"""
|
|
184
|
+
_, layer_sig_list = self._normalize_layer_signatures(layer_signature)
|
|
185
|
+
run_name, meta = self._prepare_run_metadata(
|
|
186
|
+
layer_signature, dataset=dataset, run_name=run_name, options=options
|
|
187
|
+
)
|
|
188
|
+
return run_name, meta, layer_sig_list
|
|
189
|
+
|
|
92
190
|
def _normalize_layer_signatures(
|
|
93
191
|
self, layer_signatures: str | int | list[str | int] | None
|
|
94
192
|
) -> tuple[str | None, list[str]]:
|
|
@@ -171,6 +269,7 @@ class LanguageModelActivations:
|
|
|
171
269
|
verbose: Whether to log
|
|
172
270
|
"""
|
|
173
271
|
from mi_crow.language_model.inference import InferenceEngine
|
|
272
|
+
|
|
174
273
|
InferenceEngine._save_run_metadata(store, run_name, meta, verbose)
|
|
175
274
|
|
|
176
275
|
def _process_batch(
|
|
@@ -184,6 +283,7 @@ class LanguageModelActivations:
|
|
|
184
283
|
dtype: torch.dtype | None,
|
|
185
284
|
verbose: bool,
|
|
186
285
|
save_in_batches: bool = True,
|
|
286
|
+
stop_after_layer: str | int | None = None,
|
|
187
287
|
) -> None:
|
|
188
288
|
"""Process a single batch of texts.
|
|
189
289
|
|
|
@@ -196,6 +296,7 @@ class LanguageModelActivations:
|
|
|
196
296
|
autocast_dtype: Optional dtype for autocast
|
|
197
297
|
dtype: Optional dtype to convert activations to
|
|
198
298
|
verbose: Whether to log progress
|
|
299
|
+
stop_after_layer: Optional layer signature to stop after (name or index)
|
|
199
300
|
"""
|
|
200
301
|
if not texts:
|
|
201
302
|
return
|
|
@@ -209,31 +310,48 @@ class LanguageModelActivations:
|
|
|
209
310
|
tok_kwargs=tok_kwargs,
|
|
210
311
|
autocast=autocast,
|
|
211
312
|
autocast_dtype=autocast_dtype,
|
|
313
|
+
stop_after_layer=stop_after_layer,
|
|
212
314
|
)
|
|
213
315
|
|
|
214
|
-
if dtype is not None:
|
|
215
|
-
self._convert_activations_to_dtype(dtype)
|
|
216
|
-
|
|
217
316
|
self.context.language_model.save_detector_metadata(
|
|
218
317
|
run_name,
|
|
219
318
|
batch_index,
|
|
220
319
|
unified=not save_in_batches,
|
|
221
320
|
)
|
|
222
321
|
|
|
322
|
+
# Synchronize CUDA to ensure async CPU transfers from detector hooks complete
|
|
323
|
+
# Only synchronize if CUDA is actually available and initialized
|
|
324
|
+
try:
|
|
325
|
+
if torch.cuda.is_available():
|
|
326
|
+
torch.cuda.synchronize()
|
|
327
|
+
except (AssertionError, RuntimeError):
|
|
328
|
+
# CUDA not available or not initialized (e.g., in test environment)
|
|
329
|
+
pass
|
|
330
|
+
|
|
331
|
+
gc.collect()
|
|
332
|
+
if torch.cuda.is_available():
|
|
333
|
+
try:
|
|
334
|
+
torch.cuda.empty_cache()
|
|
335
|
+
except (AssertionError, RuntimeError):
|
|
336
|
+
# CUDA not available or not initialized
|
|
337
|
+
pass
|
|
338
|
+
|
|
223
339
|
if verbose:
|
|
224
340
|
logger.info(f"Saved batch {batch_index} for run={run_name}")
|
|
225
341
|
|
|
226
342
|
def _convert_activations_to_dtype(self, dtype: torch.dtype) -> None:
|
|
227
343
|
"""
|
|
228
|
-
Convert captured activations to specified dtype.
|
|
344
|
+
Convert all captured activations in detectors to the specified dtype.
|
|
229
345
|
|
|
230
346
|
Args:
|
|
231
|
-
dtype: Target dtype
|
|
347
|
+
dtype: Target dtype to convert activations to
|
|
232
348
|
"""
|
|
233
349
|
detectors = self.context.language_model.layers.get_detectors()
|
|
234
350
|
for detector in detectors:
|
|
235
|
-
if "activations" in detector.tensor_metadata:
|
|
236
|
-
|
|
351
|
+
if hasattr(detector, "tensor_metadata") and "activations" in detector.tensor_metadata:
|
|
352
|
+
tensor = detector.tensor_metadata["activations"]
|
|
353
|
+
if tensor.dtype != dtype:
|
|
354
|
+
detector.tensor_metadata["activations"] = tensor.to(dtype)
|
|
237
355
|
|
|
238
356
|
def _manage_cuda_cache(
|
|
239
357
|
self, batch_counter: int, free_cuda_cache_every: int | None, device_type: str, verbose: bool
|
|
@@ -264,10 +382,11 @@ class LanguageModelActivations:
|
|
|
264
382
|
max_length: int | None = None,
|
|
265
383
|
autocast: bool = True,
|
|
266
384
|
autocast_dtype: torch.dtype | None = None,
|
|
267
|
-
free_cuda_cache_every: int | None =
|
|
385
|
+
free_cuda_cache_every: int | None = None,
|
|
268
386
|
verbose: bool = False,
|
|
269
387
|
save_in_batches: bool = True,
|
|
270
388
|
save_attention_mask: bool = False,
|
|
389
|
+
stop_after_last_layer: bool = True,
|
|
271
390
|
) -> str:
|
|
272
391
|
"""
|
|
273
392
|
Save activations from a dataset.
|
|
@@ -281,9 +400,11 @@ class LanguageModelActivations:
|
|
|
281
400
|
max_length: Optional max length for tokenization
|
|
282
401
|
autocast: Whether to use autocast
|
|
283
402
|
autocast_dtype: Optional dtype for autocast
|
|
284
|
-
free_cuda_cache_every: Clear CUDA cache every N batches (
|
|
403
|
+
free_cuda_cache_every: Clear CUDA cache every N batches (None to auto-detect, 0 to disable)
|
|
285
404
|
verbose: Whether to log progress
|
|
286
405
|
save_attention_mask: Whether to also save attention masks (automatically attaches ModelInputDetector)
|
|
406
|
+
stop_after_last_layer: Whether to stop model forward pass after the last requested layer
|
|
407
|
+
to save memory and time. Defaults to True.
|
|
287
408
|
|
|
288
409
|
Returns:
|
|
289
410
|
Run name used for saving
|
|
@@ -291,29 +412,22 @@ class LanguageModelActivations:
|
|
|
291
412
|
Raises:
|
|
292
413
|
ValueError: If model or store is not initialized
|
|
293
414
|
"""
|
|
294
|
-
model
|
|
295
|
-
if model is None:
|
|
296
|
-
raise ValueError("Model must be initialized before running")
|
|
415
|
+
model, store = self._validate_save_prerequisites()
|
|
297
416
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
store = self.context.store
|
|
301
|
-
if store is None:
|
|
302
|
-
raise ValueError("Store must be provided or set on the language model")
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
device = get_device_from_model(model)
|
|
417
|
+
device = torch.device(self.context.device)
|
|
306
418
|
device_type = str(device.type)
|
|
307
419
|
|
|
420
|
+
if free_cuda_cache_every is None:
|
|
421
|
+
free_cuda_cache_every = 5 if device_type == "cuda" else 0
|
|
422
|
+
|
|
308
423
|
options = {
|
|
309
424
|
"dtype": str(dtype) if dtype is not None else None,
|
|
310
425
|
"max_length": max_length,
|
|
311
426
|
"batch_size": int(batch_size),
|
|
427
|
+
"stop_after_last_layer": stop_after_last_layer,
|
|
312
428
|
}
|
|
313
429
|
|
|
314
|
-
run_name, meta = self.
|
|
315
|
-
layer_signature, dataset=dataset, run_name=run_name, options=options
|
|
316
|
-
)
|
|
430
|
+
run_name, meta, layer_sig_list = self._prepare_save_metadata(layer_signature, dataset, run_name, options)
|
|
317
431
|
|
|
318
432
|
if verbose:
|
|
319
433
|
logger.info(
|
|
@@ -323,17 +437,13 @@ class LanguageModelActivations:
|
|
|
323
437
|
|
|
324
438
|
self._save_run_metadata(store, run_name, meta, verbose)
|
|
325
439
|
|
|
326
|
-
hook_ids
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
hook_ids.append(hook_id)
|
|
330
|
-
|
|
331
|
-
# Setup attention mask detector if requested
|
|
332
|
-
attention_mask_hook_id: str | None = None
|
|
333
|
-
if save_attention_mask:
|
|
334
|
-
_, attention_mask_hook_id = self._setup_attention_mask_detector(run_name)
|
|
440
|
+
hook_ids, attention_mask_hook_id = self._setup_activation_hooks(
|
|
441
|
+
layer_sig_list, run_name, save_attention_mask, dtype=dtype
|
|
442
|
+
)
|
|
335
443
|
|
|
336
444
|
batch_counter = 0
|
|
445
|
+
# Stop after last hooked layer if requested
|
|
446
|
+
stop_after = layer_sig_list[-1] if (layer_sig_list and stop_after_last_layer) else None
|
|
337
447
|
|
|
338
448
|
try:
|
|
339
449
|
with torch.inference_mode():
|
|
@@ -349,19 +459,18 @@ class LanguageModelActivations:
|
|
|
349
459
|
dtype,
|
|
350
460
|
verbose,
|
|
351
461
|
save_in_batches=save_in_batches,
|
|
462
|
+
stop_after_layer=stop_after,
|
|
352
463
|
)
|
|
353
464
|
batch_counter += 1
|
|
465
|
+
|
|
354
466
|
self._manage_cuda_cache(batch_counter, free_cuda_cache_every, device_type, verbose)
|
|
355
467
|
finally:
|
|
356
|
-
|
|
357
|
-
self._cleanup_detector(hook_id)
|
|
358
|
-
if attention_mask_hook_id is not None:
|
|
359
|
-
self._cleanup_detector(attention_mask_hook_id)
|
|
468
|
+
self._teardown_activation_hooks(hook_ids, attention_mask_hook_id)
|
|
360
469
|
if verbose:
|
|
361
470
|
logger.info(f"Completed save_activations_dataset: run={run_name}, batches_saved={batch_counter}")
|
|
362
|
-
|
|
471
|
+
|
|
363
472
|
return run_name
|
|
364
|
-
|
|
473
|
+
|
|
365
474
|
def save_activations(
|
|
366
475
|
self,
|
|
367
476
|
texts: Sequence[str],
|
|
@@ -377,6 +486,7 @@ class LanguageModelActivations:
|
|
|
377
486
|
verbose: bool = False,
|
|
378
487
|
save_in_batches: bool = True,
|
|
379
488
|
save_attention_mask: bool = False,
|
|
489
|
+
stop_after_last_layer: bool = True,
|
|
380
490
|
) -> str:
|
|
381
491
|
"""
|
|
382
492
|
Save activations from a list of texts.
|
|
@@ -393,6 +503,8 @@ class LanguageModelActivations:
|
|
|
393
503
|
free_cuda_cache_every: Clear CUDA cache every N batches (0 or None to disable)
|
|
394
504
|
verbose: Whether to log progress
|
|
395
505
|
save_attention_mask: Whether to also save attention masks (automatically attaches ModelInputDetector)
|
|
506
|
+
stop_after_last_layer: Whether to stop model forward pass after the last requested layer
|
|
507
|
+
to save memory and time. Defaults to True.
|
|
396
508
|
|
|
397
509
|
Returns:
|
|
398
510
|
Run name used for saving
|
|
@@ -400,20 +512,12 @@ class LanguageModelActivations:
|
|
|
400
512
|
Raises:
|
|
401
513
|
ValueError: If model or store is not initialized
|
|
402
514
|
"""
|
|
403
|
-
model: nn.Module | None = self.context.model
|
|
404
|
-
if model is None:
|
|
405
|
-
raise ValueError("Model must be initialized before running")
|
|
406
|
-
|
|
407
|
-
_, layer_sig_list = self._normalize_layer_signatures(layer_signature)
|
|
408
|
-
|
|
409
|
-
store = self.context.store
|
|
410
|
-
if store is None:
|
|
411
|
-
raise ValueError("Store must be provided or set on the language model")
|
|
412
|
-
|
|
413
515
|
if not texts:
|
|
414
516
|
raise ValueError("Texts list cannot be empty")
|
|
415
517
|
|
|
416
|
-
|
|
518
|
+
model, store = self._validate_save_prerequisites()
|
|
519
|
+
|
|
520
|
+
device = torch.device(self.context.device)
|
|
417
521
|
device_type = str(device.type)
|
|
418
522
|
|
|
419
523
|
if batch_size is None:
|
|
@@ -423,11 +527,10 @@ class LanguageModelActivations:
|
|
|
423
527
|
"dtype": str(dtype) if dtype is not None else None,
|
|
424
528
|
"max_length": max_length,
|
|
425
529
|
"batch_size": int(batch_size),
|
|
530
|
+
"stop_after_last_layer": stop_after_last_layer,
|
|
426
531
|
}
|
|
427
532
|
|
|
428
|
-
run_name, meta = self.
|
|
429
|
-
layer_signature, dataset=None, run_name=run_name, options=options
|
|
430
|
-
)
|
|
533
|
+
run_name, meta, layer_sig_list = self._prepare_save_metadata(layer_signature, None, run_name, options)
|
|
431
534
|
|
|
432
535
|
if verbose:
|
|
433
536
|
logger.info(
|
|
@@ -437,24 +540,20 @@ class LanguageModelActivations:
|
|
|
437
540
|
|
|
438
541
|
self._save_run_metadata(store, run_name, meta, verbose)
|
|
439
542
|
|
|
440
|
-
hook_ids
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
hook_ids.append(hook_id)
|
|
444
|
-
|
|
445
|
-
# Setup attention mask detector if requested
|
|
446
|
-
attention_mask_hook_id: str | None = None
|
|
447
|
-
if save_attention_mask:
|
|
448
|
-
_, attention_mask_hook_id = self._setup_attention_mask_detector(run_name)
|
|
543
|
+
hook_ids, attention_mask_hook_id = self._setup_activation_hooks(
|
|
544
|
+
layer_sig_list, run_name, save_attention_mask, dtype=dtype
|
|
545
|
+
)
|
|
449
546
|
|
|
450
547
|
batch_counter = 0
|
|
548
|
+
# Stop after last hooked layer if requested
|
|
549
|
+
stop_after = layer_sig_list[-1] if (layer_sig_list and stop_after_last_layer) else None
|
|
451
550
|
|
|
452
551
|
try:
|
|
453
552
|
with torch.inference_mode():
|
|
454
553
|
for i in range(0, len(texts), batch_size):
|
|
455
|
-
batch_texts = texts[i:i + batch_size]
|
|
554
|
+
batch_texts = texts[i : i + batch_size]
|
|
456
555
|
batch_index = i // batch_size
|
|
457
|
-
|
|
556
|
+
|
|
458
557
|
self._process_batch(
|
|
459
558
|
batch_texts,
|
|
460
559
|
run_name,
|
|
@@ -465,15 +564,13 @@ class LanguageModelActivations:
|
|
|
465
564
|
dtype,
|
|
466
565
|
verbose,
|
|
467
566
|
save_in_batches=save_in_batches,
|
|
567
|
+
stop_after_layer=stop_after,
|
|
468
568
|
)
|
|
469
569
|
batch_counter += 1
|
|
470
570
|
self._manage_cuda_cache(batch_counter, free_cuda_cache_every, device_type, verbose)
|
|
471
571
|
finally:
|
|
472
|
-
|
|
473
|
-
self._cleanup_detector(hook_id)
|
|
474
|
-
if attention_mask_hook_id is not None:
|
|
475
|
-
self._cleanup_detector(attention_mask_hook_id)
|
|
572
|
+
self._teardown_activation_hooks(hook_ids, attention_mask_hook_id)
|
|
476
573
|
if verbose:
|
|
477
574
|
logger.info(f"Completed save_activations: run={run_name}, batches_saved={batch_counter}")
|
|
478
|
-
|
|
479
|
-
return run_name
|
|
575
|
+
|
|
576
|
+
return run_name
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Centralized device management utilities for LanguageModel operations.
|
|
2
|
+
|
|
3
|
+
This module provides shared device handling logic to ensure consistent
|
|
4
|
+
device management across the codebase.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from mi_crow.language_model.language_model import LanguageModel
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def normalize_device(device: str | torch.device | None) -> str:
|
|
21
|
+
"""
|
|
22
|
+
Normalize and validate device specification.
|
|
23
|
+
|
|
24
|
+
Ensures the device is available and normalizes generic device strings.
|
|
25
|
+
- None → "cpu"
|
|
26
|
+
- "cuda" → "cuda:0" (if available)
|
|
27
|
+
- Validates CUDA/MPS availability
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
device: Device specification as string, torch.device, or None
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Normalized device string such as "cpu", "cuda:0", or "mps"
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If requested device is not available
|
|
37
|
+
"""
|
|
38
|
+
if device is None:
|
|
39
|
+
return "cpu"
|
|
40
|
+
|
|
41
|
+
if isinstance(device, torch.device):
|
|
42
|
+
device_str = str(device)
|
|
43
|
+
else:
|
|
44
|
+
device_str = str(device)
|
|
45
|
+
|
|
46
|
+
if device_str.startswith("cuda"):
|
|
47
|
+
if not torch.cuda.is_available():
|
|
48
|
+
raise ValueError(
|
|
49
|
+
"Requested device 'cuda' but CUDA is not available. "
|
|
50
|
+
"Install a CUDA-enabled PyTorch build or use device='cpu'."
|
|
51
|
+
)
|
|
52
|
+
if device_str == "cuda":
|
|
53
|
+
device_str = "cuda:0"
|
|
54
|
+
|
|
55
|
+
if device_str == "mps":
|
|
56
|
+
mps_backend = getattr(torch.backends, "mps", None)
|
|
57
|
+
mps_available = bool(mps_backend and mps_backend.is_available())
|
|
58
|
+
if not mps_available:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Requested device 'mps' but MPS is not available. "
|
|
61
|
+
"Ensure PyTorch is built with MPS support or use device='cpu'."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return device_str
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def ensure_context_device(lm: LanguageModel) -> torch.device:
|
|
68
|
+
"""
|
|
69
|
+
Ensure LanguageModel has valid context.device and return it.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
lm: LanguageModel instance
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
torch.device from context
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If context.device is not properly set
|
|
79
|
+
"""
|
|
80
|
+
if not hasattr(lm, "context") or not hasattr(lm.context, "device") or lm.context.device is None:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"LanguageModel must have context.device set. "
|
|
83
|
+
"Ensure LanguageModel is properly initialized with a device."
|
|
84
|
+
)
|
|
85
|
+
return torch.device(lm.context.device)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def sync_model_to_context_device(lm: LanguageModel) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Ensure model is on the device specified by context.device.
|
|
91
|
+
|
|
92
|
+
Moves the model if there's a mismatch between current location
|
|
93
|
+
and context.device. This is the primary device synchronization
|
|
94
|
+
function that should be called before any model operations.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
lm: LanguageModel instance with context.device set
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
ValueError: If context.device is not set
|
|
101
|
+
RuntimeError: If model cannot be moved to target device
|
|
102
|
+
"""
|
|
103
|
+
from mi_crow.language_model.utils import get_device_from_model
|
|
104
|
+
|
|
105
|
+
target_device = ensure_context_device(lm)
|
|
106
|
+
model_device = get_device_from_model(lm.context.model)
|
|
107
|
+
|
|
108
|
+
if model_device != target_device:
|
|
109
|
+
try:
|
|
110
|
+
lm.context.model = lm.context.model.to(target_device)
|
|
111
|
+
logger.debug(
|
|
112
|
+
"Moved model from %s to %s to match context.device",
|
|
113
|
+
model_device,
|
|
114
|
+
target_device,
|
|
115
|
+
)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise RuntimeError(
|
|
118
|
+
f"Failed to move model from {model_device} to {target_device}: {e}"
|
|
119
|
+
) from e
|