mi-crow 0.1.2__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,8 @@ from typing import Sequence, Any, Dict, List, TYPE_CHECKING
8
8
  import torch
9
9
  from torch import nn
10
10
 
11
- from mi_crow.language_model.utils import get_device_from_model, move_tensors_to_device, extract_logits_from_output
11
+ from mi_crow.language_model.utils import move_tensors_to_device, extract_logits_from_output
12
+ from mi_crow.language_model.device_manager import sync_model_to_context_device
12
13
  from mi_crow.utils import get_logger
13
14
 
14
15
  if TYPE_CHECKING:
@@ -52,12 +53,19 @@ class InferenceEngine:
52
53
  """
53
54
  if tok_kwargs is None:
54
55
  tok_kwargs = {}
55
- return {
56
- "padding": True,
56
+
57
+ padding_strategy = tok_kwargs.pop("padding", True)
58
+ if padding_strategy is True and "max_length" in tok_kwargs:
59
+ padding_strategy = "longest"
60
+
61
+ result = {
62
+ "padding": padding_strategy,
57
63
  "truncation": True,
58
64
  "return_tensors": "pt",
59
65
  **tok_kwargs,
60
66
  }
67
+
68
+ return result
61
69
 
62
70
  def _setup_trackers(self, texts: Sequence[str]) -> None:
63
71
  """
@@ -181,10 +189,15 @@ class InferenceEngine:
181
189
  raise ValueError("Tokenizer must be initialized before running inference")
182
190
 
183
191
  tok_kwargs = self._prepare_tokenizer_kwargs(tok_kwargs)
192
+ logger.debug(f"[DEBUG] About to tokenize {len(texts)} texts...")
184
193
  enc = self.lm.tokenize(texts, **tok_kwargs)
194
+ logger.debug(f"[DEBUG] Tokenization completed, shape: {enc['input_ids'].shape if isinstance(enc, dict) else 'N/A'}")
185
195
 
186
- device = get_device_from_model(self.lm.model)
196
+ device = torch.device(self.lm.context.device)
187
197
  device_type = str(device.type)
198
+
199
+ sync_model_to_context_device(self.lm)
200
+
188
201
  enc = move_tensors_to_device(enc, device)
189
202
 
190
203
  self.lm.model.eval()
@@ -469,7 +482,7 @@ class InferenceEngine:
469
482
  if store is None:
470
483
  raise ValueError("Store must be provided or set on the language model")
471
484
 
472
- device = get_device_from_model(model)
485
+ device = torch.device(self.lm.context.device)
473
486
  device_type = str(device.type)
474
487
 
475
488
  options = {
@@ -39,6 +39,7 @@ def create_from_huggingface(
39
39
  store: Store,
40
40
  tokenizer_params: dict | None = None,
41
41
  model_params: dict | None = None,
42
+ device: str | torch.device | None = None,
42
43
  ) -> "LanguageModel":
43
44
  """
44
45
  Load a language model from HuggingFace Hub.
@@ -49,10 +50,11 @@ def create_from_huggingface(
49
50
  store: Store instance for persistence
50
51
  tokenizer_params: Optional tokenizer parameters
51
52
  model_params: Optional model parameters
52
-
53
+ device: Target device ("cuda", "cpu", "mps"). Model will be moved to this device
54
+ after loading.
53
55
  Returns:
54
56
  LanguageModel instance
55
-
57
+
56
58
  Raises:
57
59
  ValueError: If model_name is invalid
58
60
  RuntimeError: If model loading fails
@@ -67,7 +69,7 @@ def create_from_huggingface(
67
69
  tokenizer_params = {}
68
70
  if model_params is None:
69
71
  model_params = {}
70
-
72
+
71
73
  try:
72
74
  tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_params)
73
75
  model = AutoModelForCausalLM.from_pretrained(model_name, **model_params)
@@ -76,14 +78,15 @@ def create_from_huggingface(
76
78
  f"Failed to load model '{model_name}' from HuggingFace. Error: {e}"
77
79
  ) from e
78
80
 
79
- return cls(model, tokenizer, store)
81
+ return cls(model, tokenizer, store, device=device)
80
82
 
81
83
 
82
84
  def create_from_local_torch(
83
85
  cls: type["LanguageModel"],
84
86
  model_path: str,
85
87
  tokenizer_path: str,
86
- store: Store
88
+ store: Store,
89
+ device: str | torch.device | None = None,
87
90
  ) -> "LanguageModel":
88
91
  """
89
92
  Load a language model from local HuggingFace paths.
@@ -93,6 +96,7 @@ def create_from_local_torch(
93
96
  model_path: Path to the model directory or file
94
97
  tokenizer_path: Path to the tokenizer directory or file
95
98
  store: Store instance for persistence
99
+ device: Optional device string or torch.device (defaults to 'cpu' if None)
96
100
 
97
101
  Returns:
98
102
  LanguageModel instance
@@ -122,5 +126,5 @@ def create_from_local_torch(
122
126
  f"model_path={model_path!r}, tokenizer_path={tokenizer_path!r}. Error: {e}"
123
127
  ) from e
124
128
 
125
- return cls(model, tokenizer, store)
129
+ return cls(model, tokenizer, store, device=device)
126
130
 
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import gc
3
4
  from collections import defaultdict
4
5
  from pathlib import Path
5
6
  from typing import Sequence, Any, Dict, List, TYPE_CHECKING, Set, Tuple
@@ -15,6 +16,7 @@ from mi_crow.language_model.context import LanguageModelContext
15
16
  from mi_crow.language_model.inference import InferenceEngine
16
17
  from mi_crow.language_model.persistence import save_model, load_model_from_saved_file
17
18
  from mi_crow.language_model.initialization import initialize_model_id, create_from_huggingface, create_from_local_torch
19
+ from mi_crow.language_model.device_manager import normalize_device, sync_model_to_context_device
18
20
  from mi_crow.store.store import Store
19
21
  from mi_crow.utils import get_logger
20
22
 
@@ -72,7 +74,7 @@ class LanguageModel:
72
74
 
73
75
  Provides a unified interface for working with language models, including:
74
76
  - Model initialization and configuration
75
- - Inference operations
77
+ - Inference operations through the inference property
76
78
  - Hook management (detectors and controllers)
77
79
  - Model persistence
78
80
  - Activation tracking
@@ -84,6 +86,7 @@ class LanguageModel:
84
86
  tokenizer: PreTrainedTokenizerBase,
85
87
  store: Store,
86
88
  model_id: str | None = None,
89
+ device: str | torch.device | None = None,
87
90
  ):
88
91
  """
89
92
  Initialize LanguageModel.
@@ -93,6 +96,7 @@ class LanguageModel:
93
96
  tokenizer: HuggingFace tokenizer
94
97
  store: Store instance for persistence
95
98
  model_id: Optional model identifier (auto-extracted if not provided)
99
+ device: Optional device string or torch.device (defaults to 'cpu' if None)
96
100
  """
97
101
  self.context = LanguageModelContext(self)
98
102
  self.context.model = model
@@ -100,15 +104,17 @@ class LanguageModel:
100
104
  self.context.model_id = initialize_model_id(model, model_id)
101
105
  self.context.store = store
102
106
  self.context.special_token_ids = _extract_special_token_ids(tokenizer)
107
+ self.context.device = normalize_device(device)
108
+ sync_model_to_context_device(self)
103
109
 
104
110
  self.layers = LanguageModelLayers(self.context)
105
111
  self.lm_tokenizer = LanguageModelTokenizer(self.context)
106
112
  self.activations = LanguageModelActivations(self.context)
107
113
  self.inference = InferenceEngine(self)
108
- self._inference_engine = self.inference
109
114
 
110
115
  self._input_tracker: "InputTracker | None" = None
111
116
 
117
+
112
118
  @property
113
119
  def model(self) -> nn.Module:
114
120
  """Get the underlying PyTorch model."""
@@ -147,86 +153,6 @@ class LanguageModel:
147
153
  """
148
154
  return self.lm_tokenizer.tokenize(texts, **kwargs)
149
155
 
150
- def forwards(
151
- self,
152
- texts: Sequence[str],
153
- tok_kwargs: Dict | None = None,
154
- autocast: bool = True,
155
- autocast_dtype: torch.dtype | None = None,
156
- with_controllers: bool = True,
157
- ) -> Tuple[Any, Any]:
158
- """
159
- Run forward pass on texts.
160
-
161
- Args:
162
- texts: Input texts to process
163
- tok_kwargs: Optional tokenizer keyword arguments
164
- autocast: Whether to use automatic mixed precision
165
- autocast_dtype: Optional dtype for autocast
166
- with_controllers: Whether to use controllers during inference
167
-
168
- Returns:
169
- Tuple of (model_output, encodings)
170
- """
171
- return self._inference_engine.execute_inference(
172
- texts,
173
- tok_kwargs=tok_kwargs,
174
- autocast=autocast,
175
- autocast_dtype=autocast_dtype,
176
- with_controllers=with_controllers
177
- )
178
-
179
- def generate(
180
- self,
181
- texts: Sequence[str],
182
- tok_kwargs: Dict | None = None,
183
- autocast: bool = True,
184
- autocast_dtype: torch.dtype | None = None,
185
- with_controllers: bool = True,
186
- skip_special_tokens: bool = True,
187
- ) -> Sequence[str]:
188
- """
189
- Run inference and automatically decode the output with the tokenizer.
190
-
191
- Args:
192
- texts: Input texts to process
193
- tok_kwargs: Optional tokenizer keyword arguments
194
- autocast: Whether to use automatic mixed precision
195
- autocast_dtype: Optional dtype for autocast
196
- with_controllers: Whether to use controllers during inference
197
- skip_special_tokens: Whether to skip special tokens when decoding
198
-
199
- Returns:
200
- Sequence of decoded text strings
201
-
202
- Raises:
203
- ValueError: If texts is empty or tokenizer is None
204
- """
205
- if not texts:
206
- raise ValueError("Texts list cannot be empty")
207
-
208
- if self.tokenizer is None:
209
- raise ValueError("Tokenizer is required for decoding but is None")
210
-
211
- output, enc = self._inference_engine.execute_inference(
212
- texts,
213
- tok_kwargs=tok_kwargs,
214
- autocast=autocast,
215
- autocast_dtype=autocast_dtype,
216
- with_controllers=with_controllers
217
- )
218
-
219
- logits = self._inference_engine.extract_logits(output)
220
- predicted_token_ids = logits.argmax(dim=-1)
221
-
222
- decoded_texts = []
223
- for i in range(predicted_token_ids.shape[0]):
224
- token_ids = predicted_token_ids[i].cpu().tolist()
225
- decoded_text = self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
226
- decoded_texts.append(decoded_text)
227
-
228
- return decoded_texts
229
-
230
156
  def get_input_tracker(self) -> "InputTracker | None":
231
157
  """
232
158
  Get the input tracker instance if it exists.
@@ -248,8 +174,8 @@ class LanguageModel:
248
174
  detectors_tensor_metadata: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
249
175
 
250
176
  for detector in detectors:
251
- detectors_metadata[detector.layer_signature] = detector.metadata
252
- detectors_tensor_metadata[detector.layer_signature] = detector.tensor_metadata
177
+ detectors_metadata[detector.layer_signature] = dict(detector.metadata)
178
+ detectors_tensor_metadata[detector.layer_signature] = dict(detector.tensor_metadata)
253
179
 
254
180
  return detectors_metadata, detectors_tensor_metadata
255
181
 
@@ -263,17 +189,14 @@ class LanguageModel:
263
189
  """
264
190
  detectors = self.layers.get_detectors()
265
191
  for detector in detectors:
266
- # Clear generic accumulated metadata
267
192
  detector.metadata.clear()
268
193
  detector.tensor_metadata.clear()
269
194
 
270
- # Allow detector implementations to provide more specialized
271
- # clearing logic (e.g. ModelInputDetector, ModelOutputDetector).
272
195
  clear_captured = getattr(detector, "clear_captured", None)
273
196
  if callable(clear_captured):
274
197
  clear_captured()
275
198
 
276
- def save_detector_metadata(self, run_name: str, batch_idx: int | None, unified: bool = False) -> str:
199
+ def save_detector_metadata(self, run_name: str, batch_idx: int | None, unified: bool = False, clear_after_save: bool = True) -> str:
277
200
  """
278
201
  Save detector metadata to store.
279
202
 
@@ -282,6 +205,8 @@ class LanguageModel:
282
205
  batch_idx: Batch index. Ignored when ``unified`` is True.
283
206
  unified: If True, save metadata in a single detectors directory
284
207
  for the whole run instead of per‑batch directories.
208
+ clear_after_save: If True, clear detector metadata after saving to free memory.
209
+ Defaults to True to prevent OOM errors when processing large batches.
285
210
 
286
211
  Returns:
287
212
  Path where metadata was saved
@@ -291,12 +216,36 @@ class LanguageModel:
291
216
  """
292
217
  if self.store is None:
293
218
  raise ValueError("Store must be provided or set on the language model")
219
+
294
220
  detectors_metadata, detectors_tensor_metadata = self.get_all_detector_metadata()
221
+
295
222
  if unified:
296
- return self.store.put_run_detector_metadata(run_name, detectors_metadata, detectors_tensor_metadata)
297
- if batch_idx is None:
298
- raise ValueError("batch_idx must be provided when unified is False")
299
- return self.store.put_detector_metadata(run_name, batch_idx, detectors_metadata, detectors_tensor_metadata)
223
+ result = self.store.put_run_detector_metadata(run_name, detectors_metadata, detectors_tensor_metadata)
224
+ else:
225
+ if batch_idx is None:
226
+ raise ValueError("batch_idx must be provided when unified is False")
227
+ result = self.store.put_detector_metadata(run_name, batch_idx, detectors_metadata, detectors_tensor_metadata)
228
+
229
+ if clear_after_save:
230
+ for layer_signature in list(detectors_tensor_metadata.keys()):
231
+ detector_tensors = detectors_tensor_metadata[layer_signature]
232
+ for tensor_key in list(detector_tensors.keys()):
233
+ del detector_tensors[tensor_key]
234
+ del detectors_tensor_metadata[layer_signature]
235
+ detectors_metadata.clear()
236
+
237
+ detectors = self.layers.get_detectors()
238
+ for detector in detectors:
239
+ clear_captured = getattr(detector, "clear_captured", None)
240
+ if callable(clear_captured):
241
+ clear_captured()
242
+ for key in list(detector.tensor_metadata.keys()):
243
+ del detector.tensor_metadata[key]
244
+ detector.metadata.clear()
245
+
246
+ gc.collect()
247
+
248
+ return result
300
249
 
301
250
  def _ensure_input_tracker(self) -> "InputTracker":
302
251
  """
@@ -339,23 +288,36 @@ class LanguageModel:
339
288
  store: Store,
340
289
  tokenizer_params: dict = None,
341
290
  model_params: dict = None,
291
+ device: str | torch.device | None = None,
342
292
  ) -> "LanguageModel":
343
293
  """
344
294
  Load a language model from HuggingFace Hub.
345
295
 
296
+ Automatically loads model to GPU if device is "cuda" and CUDA is available.
297
+ This prevents OOM errors by keeping the model on GPU instead of CPU RAM.
298
+
346
299
  Args:
347
300
  model_name: HuggingFace model identifier
348
301
  store: Store instance for persistence
349
302
  tokenizer_params: Optional tokenizer parameters
350
303
  model_params: Optional model parameters
304
+ device: Target device ("cuda", "cpu", "mps"). If "cuda" and CUDA is available,
305
+ model will be loaded directly to GPU using device_map="auto"
306
+ (via the HuggingFace factory helpers).
351
307
 
352
308
  Returns:
353
309
  LanguageModel instance
354
310
  """
355
- return create_from_huggingface(cls, model_name, store, tokenizer_params, model_params)
311
+ return create_from_huggingface(cls, model_name, store, tokenizer_params, model_params, device)
356
312
 
357
313
  @classmethod
358
- def from_local_torch(cls, model_path: str, tokenizer_path: str, store: Store) -> "LanguageModel":
314
+ def from_local_torch(
315
+ cls,
316
+ model_path: str,
317
+ tokenizer_path: str,
318
+ store: Store,
319
+ device: str | torch.device | None = None,
320
+ ) -> "LanguageModel":
359
321
  """
360
322
  Load a language model from local HuggingFace paths.
361
323
 
@@ -363,14 +325,21 @@ class LanguageModel:
363
325
  model_path: Path to the model directory or file
364
326
  tokenizer_path: Path to the tokenizer directory or file
365
327
  store: Store instance for persistence
328
+ device: Optional device string or torch.device (defaults to 'cpu' if None)
366
329
 
367
330
  Returns:
368
331
  LanguageModel instance
369
332
  """
370
- return create_from_local_torch(cls, model_path, tokenizer_path, store)
333
+ return create_from_local_torch(cls, model_path, tokenizer_path, store, device)
371
334
 
372
335
  @classmethod
373
- def from_local(cls, saved_path: Path | str, store: Store, model_id: str | None = None) -> "LanguageModel":
336
+ def from_local(
337
+ cls,
338
+ saved_path: Path | str,
339
+ store: Store,
340
+ model_id: str | None = None,
341
+ device: str | torch.device | None = None,
342
+ ) -> "LanguageModel":
374
343
  """
375
344
  Load a language model from a saved file (created by save_model).
376
345
 
@@ -379,6 +348,7 @@ class LanguageModel:
379
348
  store: Store instance for persistence
380
349
  model_id: Optional model identifier. If not provided, will use the model_id from saved metadata.
381
350
  If provided, will be used to load the model architecture from HuggingFace.
351
+ device: Optional device string or torch.device (defaults to 'cpu' if None)
382
352
 
383
353
  Returns:
384
354
  LanguageModel instance
@@ -387,4 +357,4 @@ class LanguageModel:
387
357
  FileNotFoundError: If the saved file doesn't exist
388
358
  ValueError: If the saved file format is invalid or model_id is required but not provided
389
359
  """
390
- return load_model_from_saved_file(cls, saved_path, store, model_id)
360
+ return load_model_from_saved_file(cls, saved_path, store, model_id, device)
@@ -320,19 +320,22 @@ class LanguageModelLayers:
320
320
 
321
321
  layer_signature, hook_type, hook = self.context._hook_id_map[hook_id]
322
322
 
323
- # Find and remove from registry
324
- if layer_signature in self.context._hook_registry:
325
- if hook_type in self.context._hook_registry[layer_signature]:
326
- hooks_list = self.context._hook_registry[layer_signature][hook_type]
327
- for i, (h, handle) in enumerate(hooks_list):
328
- if h.id == hook_id:
329
- # Remove PyTorch hook
330
- handle.remove()
331
- # Remove from our list
332
- hooks_list.pop(i)
333
- break
334
-
335
- # Remove from ID map
323
+ if layer_signature not in self.context._hook_registry:
324
+ del self.context._hook_id_map[hook_id]
325
+ return True
326
+
327
+ hook_types = self.context._hook_registry[layer_signature]
328
+ if hook_type not in hook_types:
329
+ del self.context._hook_id_map[hook_id]
330
+ return True
331
+
332
+ hooks_list = hook_types[hook_type]
333
+ for i, (h, handle) in enumerate(hooks_list):
334
+ if h.id == hook_id:
335
+ handle.remove()
336
+ hooks_list.pop(i)
337
+ break
338
+
336
339
  del self.context._hook_id_map[hook_id]
337
340
  return True
338
341
 
@@ -89,7 +89,8 @@ def load_model_from_saved_file(
89
89
  cls: type["LanguageModel"],
90
90
  saved_path: Path | str,
91
91
  store: "Store",
92
- model_id: str | None = None
92
+ model_id: str | None = None,
93
+ device: str | torch.device | None = None,
93
94
  ) -> "LanguageModel":
94
95
  """
95
96
  Load a language model from a saved file (created by save_model).
@@ -100,6 +101,7 @@ def load_model_from_saved_file(
100
101
  store: Store instance for persistence
101
102
  model_id: Optional model identifier. If not provided, will use the model_id from saved metadata.
102
103
  If provided, will be used to load the model architecture from HuggingFace.
104
+ device: Optional device string or torch.device (defaults to 'cpu' if None)
103
105
 
104
106
  Returns:
105
107
  LanguageModel instance
@@ -164,7 +166,7 @@ def load_model_from_saved_file(
164
166
  ) from e
165
167
 
166
168
  # Create LanguageModel instance
167
- lm = cls(model, tokenizer, store, model_id=model_id)
169
+ lm = cls(model, tokenizer, store, model_id=model_id, device=device)
168
170
 
169
171
  # Note: Hooks are not automatically restored as they require hook instances
170
172
  # The hook metadata is available in metadata_dict["hooks"] if needed
@@ -44,8 +44,11 @@ def get_device_from_model(model: nn.Module) -> torch.device:
44
44
  Returns:
45
45
  Device where model parameters are located, or CPU if no parameters
46
46
  """
47
- first_param = next(model.parameters(), None)
48
- return first_param.device if first_param is not None else torch.device("cpu")
47
+ try:
48
+ first_param = next(model.parameters(), None)
49
+ return first_param.device if first_param is not None else torch.device("cpu")
50
+ except (TypeError, AttributeError):
51
+ return torch.device("cpu")
49
52
 
50
53
 
51
54
  def move_tensors_to_device(
@@ -62,9 +65,6 @@ def move_tensors_to_device(
62
65
  Returns:
63
66
  Dictionary with tensors moved to device
64
67
  """
65
- device_type = str(device.type)
66
- if device_type == "cuda":
67
- return {k: v.to(device, non_blocking=True) for k, v in tensors.items()}
68
68
  return {k: v.to(device) for k, v in tensors.items()}
69
69
 
70
70