mi-crow 0.1.2__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mi_crow/hooks/utils.py CHANGED
@@ -74,3 +74,77 @@ def extract_tensor_from_output(output: HOOK_FUNCTION_OUTPUT) -> torch.Tensor | N
74
74
 
75
75
  return None
76
76
 
77
+
78
+ def apply_modification_to_output(
79
+ output: HOOK_FUNCTION_OUTPUT,
80
+ modified_tensor: torch.Tensor,
81
+ target_device: torch.device | None = None
82
+ ) -> None:
83
+ """
84
+ Apply a modified tensor to an output object in-place.
85
+
86
+ Handles various output formats:
87
+ - Plain tensors: modifies the tensor directly (in-place)
88
+ - Tuples/lists of tensors: replaces first tensor
89
+ - Objects with last_hidden_state attribute: sets last_hidden_state
90
+
91
+ If target_device is provided, output tensors are moved to target_device first,
92
+ ensuring consistency with the desired device (e.g., context.device).
93
+ Otherwise, modified_tensor is moved to match output's current device.
94
+
95
+ Args:
96
+ output: Output object to modify
97
+ modified_tensor: Modified tensor to apply
98
+ target_device: Optional target device. If provided, output tensors are moved
99
+ to this device before applying modification. If None, uses output's current device.
100
+ """
101
+ if output is None:
102
+ return
103
+
104
+ if isinstance(output, torch.Tensor):
105
+ if target_device is not None:
106
+ if output.device != target_device:
107
+ output = output.to(target_device)
108
+ if modified_tensor.device != target_device:
109
+ modified_tensor = modified_tensor.to(target_device)
110
+ else:
111
+ if modified_tensor.device != output.device:
112
+ modified_tensor = modified_tensor.to(output.device)
113
+ output.data.copy_(modified_tensor.data)
114
+ return
115
+
116
+ if isinstance(output, (tuple, list)):
117
+ for i, item in enumerate(output):
118
+ if isinstance(item, torch.Tensor):
119
+ if target_device is not None:
120
+ if item.device != target_device:
121
+ item = item.to(target_device)
122
+ if isinstance(output, list):
123
+ output[i] = item
124
+ if modified_tensor.device != target_device or modified_tensor.dtype != item.dtype:
125
+ modified_tensor = modified_tensor.to(device=target_device, dtype=item.dtype)
126
+ else:
127
+ if modified_tensor.device != item.device or modified_tensor.dtype != item.dtype:
128
+ modified_tensor = modified_tensor.to(device=item.device, dtype=item.dtype)
129
+ if isinstance(output, tuple):
130
+ item.data.copy_(modified_tensor.data)
131
+ else:
132
+ output[i] = modified_tensor
133
+ break
134
+ return
135
+
136
+ if hasattr(output, "last_hidden_state"):
137
+ original_tensor = output.last_hidden_state
138
+ if isinstance(original_tensor, torch.Tensor):
139
+ if target_device is not None:
140
+ if original_tensor.device != target_device:
141
+ output.last_hidden_state = original_tensor.to(target_device)
142
+ original_tensor = output.last_hidden_state
143
+ if modified_tensor.device != target_device:
144
+ modified_tensor = modified_tensor.to(target_device)
145
+ else:
146
+ if modified_tensor.device != original_tensor.device:
147
+ modified_tensor = modified_tensor.to(original_tensor.device)
148
+ output.last_hidden_state = modified_tensor
149
+ return
150
+
@@ -1,4 +1,5 @@
1
1
  import datetime
2
+ import gc
2
3
  from typing import TYPE_CHECKING, Any, Dict, Sequence
3
4
 
4
5
  import torch
@@ -10,7 +11,6 @@ from mi_crow.hooks.implementations.layer_activation_detector import LayerActivat
10
11
  from mi_crow.hooks.implementations.model_input_detector import ModelInputDetector
11
12
  from mi_crow.store.store import Store
12
13
  from mi_crow.utils import get_logger
13
- from mi_crow.language_model.utils import get_device_from_model
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  from mi_crow.language_model.context import LanguageModelContext
@@ -30,18 +30,25 @@ class LanguageModelActivations:
30
30
  """
31
31
  self.context = context
32
32
 
33
- def _setup_detector(self, layer_signature: str | int, hook_id_suffix: str) -> tuple[LayerActivationDetector, str]:
33
+ def _setup_detector(
34
+ self, layer_signature: str | int, hook_id_suffix: str, dtype: torch.dtype | None = None
35
+ ) -> tuple[LayerActivationDetector, str]:
34
36
  """
35
37
  Create and register an activation detector.
36
38
 
37
39
  Args:
38
40
  layer_signature: Layer to attach detector to
39
41
  hook_id_suffix: Suffix for hook ID
42
+ dtype: Optional dtype for activations
40
43
 
41
44
  Returns:
42
45
  Tuple of (detector, hook_id)
43
46
  """
44
- detector = LayerActivationDetector(layer_signature=layer_signature, hook_id=f"detector_{hook_id_suffix}")
47
+ detector = LayerActivationDetector(
48
+ layer_signature=layer_signature,
49
+ hook_id=f"detector_{hook_id_suffix}",
50
+ target_dtype=dtype,
51
+ )
45
52
 
46
53
  hook_id = self.context.language_model.layers.register_hook(layer_signature, detector, HookType.FORWARD)
47
54
 
@@ -71,24 +78,115 @@ class LanguageModelActivations:
71
78
  """
72
79
  attention_mask_layer_sig = "attention_masks"
73
80
  root_model = self.context.model
74
-
75
- # Add layer signature to registry for root model
81
+
76
82
  if attention_mask_layer_sig not in self.context.language_model.layers.name_to_layer:
77
83
  self.context.language_model.layers.name_to_layer[attention_mask_layer_sig] = root_model
78
-
84
+
79
85
  detector = ModelInputDetector(
80
86
  layer_signature=attention_mask_layer_sig,
81
87
  hook_id=f"attention_mask_detector_{run_name}",
82
88
  save_input_ids=False,
83
89
  save_attention_mask=True,
84
90
  )
85
-
91
+
86
92
  hook_id = self.context.language_model.layers.register_hook(
87
93
  attention_mask_layer_sig, detector, HookType.PRE_FORWARD
88
94
  )
89
-
95
+
90
96
  return detector, hook_id
91
97
 
98
+ def _setup_activation_hooks(
99
+ self,
100
+ layer_sig_list: list[str],
101
+ run_name: str,
102
+ save_attention_mask: bool,
103
+ dtype: torch.dtype | None = None,
104
+ ) -> tuple[list[str], str | None]:
105
+ """
106
+ Setup activation hooks for saving.
107
+
108
+ Args:
109
+ layer_sig_list: List of layer signatures to hook
110
+ run_name: Run name for hook IDs
111
+ save_attention_mask: Whether to setup attention mask detector
112
+ dtype: Optional dtype for activations
113
+
114
+ Returns:
115
+ Tuple of (hook_ids list, attention_mask_hook_id or None)
116
+ """
117
+ hook_ids: list[str] = []
118
+ for sig in layer_sig_list:
119
+ _, hook_id = self._setup_detector(sig, f"save_{run_name}_{sig}", dtype=dtype)
120
+ hook_ids.append(hook_id)
121
+
122
+ attention_mask_hook_id: str | None = None
123
+ if save_attention_mask:
124
+ _, attention_mask_hook_id = self._setup_attention_mask_detector(run_name)
125
+
126
+ return hook_ids, attention_mask_hook_id
127
+
128
+ def _teardown_activation_hooks(
129
+ self,
130
+ hook_ids: list[str],
131
+ attention_mask_hook_id: str | None,
132
+ ) -> None:
133
+ """
134
+ Teardown activation hooks.
135
+
136
+ Args:
137
+ hook_ids: List of hook IDs to cleanup
138
+ attention_mask_hook_id: Optional attention mask hook ID to cleanup
139
+ """
140
+ for hook_id in hook_ids:
141
+ self._cleanup_detector(hook_id)
142
+ if attention_mask_hook_id is not None:
143
+ self._cleanup_detector(attention_mask_hook_id)
144
+
145
+ def _validate_save_prerequisites(self) -> tuple[nn.Module, Store]:
146
+ """
147
+ Validate prerequisites for saving activations.
148
+
149
+ Returns:
150
+ Tuple of (model, store)
151
+
152
+ Raises:
153
+ ValueError: If model or store is not initialized
154
+ """
155
+ model: nn.Module | None = self.context.model
156
+ if model is None:
157
+ raise ValueError("Model must be initialized before running")
158
+
159
+ store = self.context.store
160
+ if store is None:
161
+ raise ValueError("Store must be provided or set on the language model")
162
+
163
+ return model, store
164
+
165
+ def _prepare_save_metadata(
166
+ self,
167
+ layer_signature: str | int | list[str | int],
168
+ dataset: BaseDataset | None,
169
+ run_name: str | None,
170
+ options: Dict[str, Any],
171
+ ) -> tuple[str, Dict[str, Any], list[str]]:
172
+ """
173
+ Prepare metadata for activation saving.
174
+
175
+ Args:
176
+ layer_signature: Layer signature(s) to save
177
+ dataset: Optional dataset
178
+ run_name: Optional run name
179
+ options: Options dictionary
180
+
181
+ Returns:
182
+ Tuple of (run_name, metadata, layer_sig_list)
183
+ """
184
+ _, layer_sig_list = self._normalize_layer_signatures(layer_signature)
185
+ run_name, meta = self._prepare_run_metadata(
186
+ layer_signature, dataset=dataset, run_name=run_name, options=options
187
+ )
188
+ return run_name, meta, layer_sig_list
189
+
92
190
  def _normalize_layer_signatures(
93
191
  self, layer_signatures: str | int | list[str | int] | None
94
192
  ) -> tuple[str | None, list[str]]:
@@ -171,6 +269,7 @@ class LanguageModelActivations:
171
269
  verbose: Whether to log
172
270
  """
173
271
  from mi_crow.language_model.inference import InferenceEngine
272
+
174
273
  InferenceEngine._save_run_metadata(store, run_name, meta, verbose)
175
274
 
176
275
  def _process_batch(
@@ -184,6 +283,7 @@ class LanguageModelActivations:
184
283
  dtype: torch.dtype | None,
185
284
  verbose: bool,
186
285
  save_in_batches: bool = True,
286
+ stop_after_layer: str | int | None = None,
187
287
  ) -> None:
188
288
  """Process a single batch of texts.
189
289
 
@@ -196,6 +296,7 @@ class LanguageModelActivations:
196
296
  autocast_dtype: Optional dtype for autocast
197
297
  dtype: Optional dtype to convert activations to
198
298
  verbose: Whether to log progress
299
+ stop_after_layer: Optional layer signature to stop after (name or index)
199
300
  """
200
301
  if not texts:
201
302
  return
@@ -209,31 +310,48 @@ class LanguageModelActivations:
209
310
  tok_kwargs=tok_kwargs,
210
311
  autocast=autocast,
211
312
  autocast_dtype=autocast_dtype,
313
+ stop_after_layer=stop_after_layer,
212
314
  )
213
315
 
214
- if dtype is not None:
215
- self._convert_activations_to_dtype(dtype)
216
-
217
316
  self.context.language_model.save_detector_metadata(
218
317
  run_name,
219
318
  batch_index,
220
319
  unified=not save_in_batches,
221
320
  )
222
321
 
322
+ # Synchronize CUDA to ensure async CPU transfers from detector hooks complete
323
+ # Only synchronize if CUDA is actually available and initialized
324
+ try:
325
+ if torch.cuda.is_available():
326
+ torch.cuda.synchronize()
327
+ except (AssertionError, RuntimeError):
328
+ # CUDA not available or not initialized (e.g., in test environment)
329
+ pass
330
+
331
+ gc.collect()
332
+ if torch.cuda.is_available():
333
+ try:
334
+ torch.cuda.empty_cache()
335
+ except (AssertionError, RuntimeError):
336
+ # CUDA not available or not initialized
337
+ pass
338
+
223
339
  if verbose:
224
340
  logger.info(f"Saved batch {batch_index} for run={run_name}")
225
341
 
226
342
  def _convert_activations_to_dtype(self, dtype: torch.dtype) -> None:
227
343
  """
228
- Convert captured activations to specified dtype.
344
+ Convert all captured activations in detectors to the specified dtype.
229
345
 
230
346
  Args:
231
- dtype: Target dtype
347
+ dtype: Target dtype to convert activations to
232
348
  """
233
349
  detectors = self.context.language_model.layers.get_detectors()
234
350
  for detector in detectors:
235
- if "activations" in detector.tensor_metadata:
236
- detector.tensor_metadata["activations"] = detector.tensor_metadata["activations"].to(dtype)
351
+ if hasattr(detector, "tensor_metadata") and "activations" in detector.tensor_metadata:
352
+ tensor = detector.tensor_metadata["activations"]
353
+ if tensor.dtype != dtype:
354
+ detector.tensor_metadata["activations"] = tensor.to(dtype)
237
355
 
238
356
  def _manage_cuda_cache(
239
357
  self, batch_counter: int, free_cuda_cache_every: int | None, device_type: str, verbose: bool
@@ -264,10 +382,11 @@ class LanguageModelActivations:
264
382
  max_length: int | None = None,
265
383
  autocast: bool = True,
266
384
  autocast_dtype: torch.dtype | None = None,
267
- free_cuda_cache_every: int | None = 0,
385
+ free_cuda_cache_every: int | None = None,
268
386
  verbose: bool = False,
269
387
  save_in_batches: bool = True,
270
388
  save_attention_mask: bool = False,
389
+ stop_after_last_layer: bool = True,
271
390
  ) -> str:
272
391
  """
273
392
  Save activations from a dataset.
@@ -281,9 +400,11 @@ class LanguageModelActivations:
281
400
  max_length: Optional max length for tokenization
282
401
  autocast: Whether to use autocast
283
402
  autocast_dtype: Optional dtype for autocast
284
- free_cuda_cache_every: Clear CUDA cache every N batches (0 or None to disable)
403
+ free_cuda_cache_every: Clear CUDA cache every N batches (None to auto-detect, 0 to disable)
285
404
  verbose: Whether to log progress
286
405
  save_attention_mask: Whether to also save attention masks (automatically attaches ModelInputDetector)
406
+ stop_after_last_layer: Whether to stop model forward pass after the last requested layer
407
+ to save memory and time. Defaults to True.
287
408
 
288
409
  Returns:
289
410
  Run name used for saving
@@ -291,29 +412,22 @@ class LanguageModelActivations:
291
412
  Raises:
292
413
  ValueError: If model or store is not initialized
293
414
  """
294
- model: nn.Module | None = self.context.model
295
- if model is None:
296
- raise ValueError("Model must be initialized before running")
415
+ model, store = self._validate_save_prerequisites()
297
416
 
298
- _, layer_sig_list = self._normalize_layer_signatures(layer_signature)
299
-
300
- store = self.context.store
301
- if store is None:
302
- raise ValueError("Store must be provided or set on the language model")
303
-
304
-
305
- device = get_device_from_model(model)
417
+ device = torch.device(self.context.device)
306
418
  device_type = str(device.type)
307
419
 
420
+ if free_cuda_cache_every is None:
421
+ free_cuda_cache_every = 5 if device_type == "cuda" else 0
422
+
308
423
  options = {
309
424
  "dtype": str(dtype) if dtype is not None else None,
310
425
  "max_length": max_length,
311
426
  "batch_size": int(batch_size),
427
+ "stop_after_last_layer": stop_after_last_layer,
312
428
  }
313
429
 
314
- run_name, meta = self._prepare_run_metadata(
315
- layer_signature, dataset=dataset, run_name=run_name, options=options
316
- )
430
+ run_name, meta, layer_sig_list = self._prepare_save_metadata(layer_signature, dataset, run_name, options)
317
431
 
318
432
  if verbose:
319
433
  logger.info(
@@ -323,17 +437,13 @@ class LanguageModelActivations:
323
437
 
324
438
  self._save_run_metadata(store, run_name, meta, verbose)
325
439
 
326
- hook_ids: list[str] = []
327
- for sig in layer_sig_list:
328
- _, hook_id = self._setup_detector(sig, f"save_{run_name}_{sig}")
329
- hook_ids.append(hook_id)
330
-
331
- # Setup attention mask detector if requested
332
- attention_mask_hook_id: str | None = None
333
- if save_attention_mask:
334
- _, attention_mask_hook_id = self._setup_attention_mask_detector(run_name)
440
+ hook_ids, attention_mask_hook_id = self._setup_activation_hooks(
441
+ layer_sig_list, run_name, save_attention_mask, dtype=dtype
442
+ )
335
443
 
336
444
  batch_counter = 0
445
+ # Stop after last hooked layer if requested
446
+ stop_after = layer_sig_list[-1] if (layer_sig_list and stop_after_last_layer) else None
337
447
 
338
448
  try:
339
449
  with torch.inference_mode():
@@ -349,19 +459,18 @@ class LanguageModelActivations:
349
459
  dtype,
350
460
  verbose,
351
461
  save_in_batches=save_in_batches,
462
+ stop_after_layer=stop_after,
352
463
  )
353
464
  batch_counter += 1
465
+
354
466
  self._manage_cuda_cache(batch_counter, free_cuda_cache_every, device_type, verbose)
355
467
  finally:
356
- for hook_id in hook_ids:
357
- self._cleanup_detector(hook_id)
358
- if attention_mask_hook_id is not None:
359
- self._cleanup_detector(attention_mask_hook_id)
468
+ self._teardown_activation_hooks(hook_ids, attention_mask_hook_id)
360
469
  if verbose:
361
470
  logger.info(f"Completed save_activations_dataset: run={run_name}, batches_saved={batch_counter}")
362
-
471
+
363
472
  return run_name
364
-
473
+
365
474
  def save_activations(
366
475
  self,
367
476
  texts: Sequence[str],
@@ -377,6 +486,7 @@ class LanguageModelActivations:
377
486
  verbose: bool = False,
378
487
  save_in_batches: bool = True,
379
488
  save_attention_mask: bool = False,
489
+ stop_after_last_layer: bool = True,
380
490
  ) -> str:
381
491
  """
382
492
  Save activations from a list of texts.
@@ -393,6 +503,8 @@ class LanguageModelActivations:
393
503
  free_cuda_cache_every: Clear CUDA cache every N batches (0 or None to disable)
394
504
  verbose: Whether to log progress
395
505
  save_attention_mask: Whether to also save attention masks (automatically attaches ModelInputDetector)
506
+ stop_after_last_layer: Whether to stop model forward pass after the last requested layer
507
+ to save memory and time. Defaults to True.
396
508
 
397
509
  Returns:
398
510
  Run name used for saving
@@ -400,20 +512,12 @@ class LanguageModelActivations:
400
512
  Raises:
401
513
  ValueError: If model or store is not initialized
402
514
  """
403
- model: nn.Module | None = self.context.model
404
- if model is None:
405
- raise ValueError("Model must be initialized before running")
406
-
407
- _, layer_sig_list = self._normalize_layer_signatures(layer_signature)
408
-
409
- store = self.context.store
410
- if store is None:
411
- raise ValueError("Store must be provided or set on the language model")
412
-
413
515
  if not texts:
414
516
  raise ValueError("Texts list cannot be empty")
415
517
 
416
- device = get_device_from_model(model)
518
+ model, store = self._validate_save_prerequisites()
519
+
520
+ device = torch.device(self.context.device)
417
521
  device_type = str(device.type)
418
522
 
419
523
  if batch_size is None:
@@ -423,11 +527,10 @@ class LanguageModelActivations:
423
527
  "dtype": str(dtype) if dtype is not None else None,
424
528
  "max_length": max_length,
425
529
  "batch_size": int(batch_size),
530
+ "stop_after_last_layer": stop_after_last_layer,
426
531
  }
427
532
 
428
- run_name, meta = self._prepare_run_metadata(
429
- layer_signature, dataset=None, run_name=run_name, options=options
430
- )
533
+ run_name, meta, layer_sig_list = self._prepare_save_metadata(layer_signature, None, run_name, options)
431
534
 
432
535
  if verbose:
433
536
  logger.info(
@@ -437,24 +540,20 @@ class LanguageModelActivations:
437
540
 
438
541
  self._save_run_metadata(store, run_name, meta, verbose)
439
542
 
440
- hook_ids: list[str] = []
441
- for sig in layer_sig_list:
442
- _, hook_id = self._setup_detector(sig, f"save_{run_name}_{sig}")
443
- hook_ids.append(hook_id)
444
-
445
- # Setup attention mask detector if requested
446
- attention_mask_hook_id: str | None = None
447
- if save_attention_mask:
448
- _, attention_mask_hook_id = self._setup_attention_mask_detector(run_name)
543
+ hook_ids, attention_mask_hook_id = self._setup_activation_hooks(
544
+ layer_sig_list, run_name, save_attention_mask, dtype=dtype
545
+ )
449
546
 
450
547
  batch_counter = 0
548
+ # Stop after last hooked layer if requested
549
+ stop_after = layer_sig_list[-1] if (layer_sig_list and stop_after_last_layer) else None
451
550
 
452
551
  try:
453
552
  with torch.inference_mode():
454
553
  for i in range(0, len(texts), batch_size):
455
- batch_texts = texts[i:i + batch_size]
554
+ batch_texts = texts[i : i + batch_size]
456
555
  batch_index = i // batch_size
457
-
556
+
458
557
  self._process_batch(
459
558
  batch_texts,
460
559
  run_name,
@@ -465,15 +564,13 @@ class LanguageModelActivations:
465
564
  dtype,
466
565
  verbose,
467
566
  save_in_batches=save_in_batches,
567
+ stop_after_layer=stop_after,
468
568
  )
469
569
  batch_counter += 1
470
570
  self._manage_cuda_cache(batch_counter, free_cuda_cache_every, device_type, verbose)
471
571
  finally:
472
- for hook_id in hook_ids:
473
- self._cleanup_detector(hook_id)
474
- if attention_mask_hook_id is not None:
475
- self._cleanup_detector(attention_mask_hook_id)
572
+ self._teardown_activation_hooks(hook_ids, attention_mask_hook_id)
476
573
  if verbose:
477
574
  logger.info(f"Completed save_activations: run={run_name}, batches_saved={batch_counter}")
478
-
479
- return run_name
575
+
576
+ return run_name
@@ -0,0 +1,119 @@
1
+ """Centralized device management utilities for LanguageModel operations.
2
+
3
+ This module provides shared device handling logic to ensure consistent
4
+ device management across the codebase.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from typing import TYPE_CHECKING
11
+
12
+ import torch
13
+
14
+ if TYPE_CHECKING:
15
+ from mi_crow.language_model.language_model import LanguageModel
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def normalize_device(device: str | torch.device | None) -> str:
21
+ """
22
+ Normalize and validate device specification.
23
+
24
+ Ensures the device is available and normalizes generic device strings.
25
+ - None → "cpu"
26
+ - "cuda" → "cuda:0" (if available)
27
+ - Validates CUDA/MPS availability
28
+
29
+ Args:
30
+ device: Device specification as string, torch.device, or None
31
+
32
+ Returns:
33
+ Normalized device string such as "cpu", "cuda:0", or "mps"
34
+
35
+ Raises:
36
+ ValueError: If requested device is not available
37
+ """
38
+ if device is None:
39
+ return "cpu"
40
+
41
+ if isinstance(device, torch.device):
42
+ device_str = str(device)
43
+ else:
44
+ device_str = str(device)
45
+
46
+ if device_str.startswith("cuda"):
47
+ if not torch.cuda.is_available():
48
+ raise ValueError(
49
+ "Requested device 'cuda' but CUDA is not available. "
50
+ "Install a CUDA-enabled PyTorch build or use device='cpu'."
51
+ )
52
+ if device_str == "cuda":
53
+ device_str = "cuda:0"
54
+
55
+ if device_str == "mps":
56
+ mps_backend = getattr(torch.backends, "mps", None)
57
+ mps_available = bool(mps_backend and mps_backend.is_available())
58
+ if not mps_available:
59
+ raise ValueError(
60
+ "Requested device 'mps' but MPS is not available. "
61
+ "Ensure PyTorch is built with MPS support or use device='cpu'."
62
+ )
63
+
64
+ return device_str
65
+
66
+
67
+ def ensure_context_device(lm: LanguageModel) -> torch.device:
68
+ """
69
+ Ensure LanguageModel has valid context.device and return it.
70
+
71
+ Args:
72
+ lm: LanguageModel instance
73
+
74
+ Returns:
75
+ torch.device from context
76
+
77
+ Raises:
78
+ ValueError: If context.device is not properly set
79
+ """
80
+ if not hasattr(lm, "context") or not hasattr(lm.context, "device") or lm.context.device is None:
81
+ raise ValueError(
82
+ "LanguageModel must have context.device set. "
83
+ "Ensure LanguageModel is properly initialized with a device."
84
+ )
85
+ return torch.device(lm.context.device)
86
+
87
+
88
+ def sync_model_to_context_device(lm: LanguageModel) -> None:
89
+ """
90
+ Ensure model is on the device specified by context.device.
91
+
92
+ Moves the model if there's a mismatch between current location
93
+ and context.device. This is the primary device synchronization
94
+ function that should be called before any model operations.
95
+
96
+ Args:
97
+ lm: LanguageModel instance with context.device set
98
+
99
+ Raises:
100
+ ValueError: If context.device is not set
101
+ RuntimeError: If model cannot be moved to target device
102
+ """
103
+ from mi_crow.language_model.utils import get_device_from_model
104
+
105
+ target_device = ensure_context_device(lm)
106
+ model_device = get_device_from_model(lm.context.model)
107
+
108
+ if model_device != target_device:
109
+ try:
110
+ lm.context.model = lm.context.model.to(target_device)
111
+ logger.debug(
112
+ "Moved model from %s to %s to match context.device",
113
+ model_device,
114
+ target_device,
115
+ )
116
+ except Exception as e:
117
+ raise RuntimeError(
118
+ f"Failed to move model from {model_device} to {target_device}: {e}"
119
+ ) from e