mi-crow 0.1.2__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mi_crow/datasets/base_dataset.py +71 -1
- mi_crow/datasets/classification_dataset.py +136 -30
- mi_crow/datasets/text_dataset.py +165 -24
- mi_crow/hooks/controller.py +12 -7
- mi_crow/hooks/implementations/layer_activation_detector.py +30 -34
- mi_crow/hooks/implementations/model_input_detector.py +87 -87
- mi_crow/hooks/implementations/model_output_detector.py +43 -42
- mi_crow/hooks/utils.py +74 -0
- mi_crow/language_model/activations.py +174 -77
- mi_crow/language_model/device_manager.py +119 -0
- mi_crow/language_model/inference.py +18 -5
- mi_crow/language_model/initialization.py +10 -6
- mi_crow/language_model/language_model.py +67 -97
- mi_crow/language_model/layers.py +16 -13
- mi_crow/language_model/persistence.py +4 -2
- mi_crow/language_model/utils.py +5 -5
- mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py +157 -95
- mi_crow/mechanistic/sae/concepts/concept_dictionary.py +12 -2
- mi_crow/mechanistic/sae/concepts/text_heap.py +161 -0
- mi_crow/mechanistic/sae/modules/topk_sae.py +29 -22
- mi_crow/mechanistic/sae/sae.py +3 -1
- mi_crow/mechanistic/sae/sae_trainer.py +362 -29
- mi_crow/store/local_store.py +11 -5
- mi_crow/store/store.py +34 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/METADATA +2 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/RECORD +28 -26
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/WHEEL +1 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,8 @@ from typing import Sequence, Any, Dict, List, TYPE_CHECKING
|
|
|
8
8
|
import torch
|
|
9
9
|
from torch import nn
|
|
10
10
|
|
|
11
|
-
from mi_crow.language_model.utils import
|
|
11
|
+
from mi_crow.language_model.utils import move_tensors_to_device, extract_logits_from_output
|
|
12
|
+
from mi_crow.language_model.device_manager import sync_model_to_context_device
|
|
12
13
|
from mi_crow.utils import get_logger
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
@@ -52,12 +53,19 @@ class InferenceEngine:
|
|
|
52
53
|
"""
|
|
53
54
|
if tok_kwargs is None:
|
|
54
55
|
tok_kwargs = {}
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
|
|
57
|
+
padding_strategy = tok_kwargs.pop("padding", True)
|
|
58
|
+
if padding_strategy is True and "max_length" in tok_kwargs:
|
|
59
|
+
padding_strategy = "longest"
|
|
60
|
+
|
|
61
|
+
result = {
|
|
62
|
+
"padding": padding_strategy,
|
|
57
63
|
"truncation": True,
|
|
58
64
|
"return_tensors": "pt",
|
|
59
65
|
**tok_kwargs,
|
|
60
66
|
}
|
|
67
|
+
|
|
68
|
+
return result
|
|
61
69
|
|
|
62
70
|
def _setup_trackers(self, texts: Sequence[str]) -> None:
|
|
63
71
|
"""
|
|
@@ -181,10 +189,15 @@ class InferenceEngine:
|
|
|
181
189
|
raise ValueError("Tokenizer must be initialized before running inference")
|
|
182
190
|
|
|
183
191
|
tok_kwargs = self._prepare_tokenizer_kwargs(tok_kwargs)
|
|
192
|
+
logger.debug(f"[DEBUG] About to tokenize {len(texts)} texts...")
|
|
184
193
|
enc = self.lm.tokenize(texts, **tok_kwargs)
|
|
194
|
+
logger.debug(f"[DEBUG] Tokenization completed, shape: {enc['input_ids'].shape if isinstance(enc, dict) else 'N/A'}")
|
|
185
195
|
|
|
186
|
-
device =
|
|
196
|
+
device = torch.device(self.lm.context.device)
|
|
187
197
|
device_type = str(device.type)
|
|
198
|
+
|
|
199
|
+
sync_model_to_context_device(self.lm)
|
|
200
|
+
|
|
188
201
|
enc = move_tensors_to_device(enc, device)
|
|
189
202
|
|
|
190
203
|
self.lm.model.eval()
|
|
@@ -469,7 +482,7 @@ class InferenceEngine:
|
|
|
469
482
|
if store is None:
|
|
470
483
|
raise ValueError("Store must be provided or set on the language model")
|
|
471
484
|
|
|
472
|
-
device =
|
|
485
|
+
device = torch.device(self.lm.context.device)
|
|
473
486
|
device_type = str(device.type)
|
|
474
487
|
|
|
475
488
|
options = {
|
|
@@ -39,6 +39,7 @@ def create_from_huggingface(
|
|
|
39
39
|
store: Store,
|
|
40
40
|
tokenizer_params: dict | None = None,
|
|
41
41
|
model_params: dict | None = None,
|
|
42
|
+
device: str | torch.device | None = None,
|
|
42
43
|
) -> "LanguageModel":
|
|
43
44
|
"""
|
|
44
45
|
Load a language model from HuggingFace Hub.
|
|
@@ -49,10 +50,11 @@ def create_from_huggingface(
|
|
|
49
50
|
store: Store instance for persistence
|
|
50
51
|
tokenizer_params: Optional tokenizer parameters
|
|
51
52
|
model_params: Optional model parameters
|
|
52
|
-
|
|
53
|
+
device: Target device ("cuda", "cpu", "mps"). Model will be moved to this device
|
|
54
|
+
after loading.
|
|
53
55
|
Returns:
|
|
54
56
|
LanguageModel instance
|
|
55
|
-
|
|
57
|
+
|
|
56
58
|
Raises:
|
|
57
59
|
ValueError: If model_name is invalid
|
|
58
60
|
RuntimeError: If model loading fails
|
|
@@ -67,7 +69,7 @@ def create_from_huggingface(
|
|
|
67
69
|
tokenizer_params = {}
|
|
68
70
|
if model_params is None:
|
|
69
71
|
model_params = {}
|
|
70
|
-
|
|
72
|
+
|
|
71
73
|
try:
|
|
72
74
|
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_params)
|
|
73
75
|
model = AutoModelForCausalLM.from_pretrained(model_name, **model_params)
|
|
@@ -76,14 +78,15 @@ def create_from_huggingface(
|
|
|
76
78
|
f"Failed to load model '{model_name}' from HuggingFace. Error: {e}"
|
|
77
79
|
) from e
|
|
78
80
|
|
|
79
|
-
return cls(model, tokenizer, store)
|
|
81
|
+
return cls(model, tokenizer, store, device=device)
|
|
80
82
|
|
|
81
83
|
|
|
82
84
|
def create_from_local_torch(
|
|
83
85
|
cls: type["LanguageModel"],
|
|
84
86
|
model_path: str,
|
|
85
87
|
tokenizer_path: str,
|
|
86
|
-
store: Store
|
|
88
|
+
store: Store,
|
|
89
|
+
device: str | torch.device | None = None,
|
|
87
90
|
) -> "LanguageModel":
|
|
88
91
|
"""
|
|
89
92
|
Load a language model from local HuggingFace paths.
|
|
@@ -93,6 +96,7 @@ def create_from_local_torch(
|
|
|
93
96
|
model_path: Path to the model directory or file
|
|
94
97
|
tokenizer_path: Path to the tokenizer directory or file
|
|
95
98
|
store: Store instance for persistence
|
|
99
|
+
device: Optional device string or torch.device (defaults to 'cpu' if None)
|
|
96
100
|
|
|
97
101
|
Returns:
|
|
98
102
|
LanguageModel instance
|
|
@@ -122,5 +126,5 @@ def create_from_local_torch(
|
|
|
122
126
|
f"model_path={model_path!r}, tokenizer_path={tokenizer_path!r}. Error: {e}"
|
|
123
127
|
) from e
|
|
124
128
|
|
|
125
|
-
return cls(model, tokenizer, store)
|
|
129
|
+
return cls(model, tokenizer, store, device=device)
|
|
126
130
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import gc
|
|
3
4
|
from collections import defaultdict
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import Sequence, Any, Dict, List, TYPE_CHECKING, Set, Tuple
|
|
@@ -15,6 +16,7 @@ from mi_crow.language_model.context import LanguageModelContext
|
|
|
15
16
|
from mi_crow.language_model.inference import InferenceEngine
|
|
16
17
|
from mi_crow.language_model.persistence import save_model, load_model_from_saved_file
|
|
17
18
|
from mi_crow.language_model.initialization import initialize_model_id, create_from_huggingface, create_from_local_torch
|
|
19
|
+
from mi_crow.language_model.device_manager import normalize_device, sync_model_to_context_device
|
|
18
20
|
from mi_crow.store.store import Store
|
|
19
21
|
from mi_crow.utils import get_logger
|
|
20
22
|
|
|
@@ -72,7 +74,7 @@ class LanguageModel:
|
|
|
72
74
|
|
|
73
75
|
Provides a unified interface for working with language models, including:
|
|
74
76
|
- Model initialization and configuration
|
|
75
|
-
- Inference operations
|
|
77
|
+
- Inference operations through the inference property
|
|
76
78
|
- Hook management (detectors and controllers)
|
|
77
79
|
- Model persistence
|
|
78
80
|
- Activation tracking
|
|
@@ -84,6 +86,7 @@ class LanguageModel:
|
|
|
84
86
|
tokenizer: PreTrainedTokenizerBase,
|
|
85
87
|
store: Store,
|
|
86
88
|
model_id: str | None = None,
|
|
89
|
+
device: str | torch.device | None = None,
|
|
87
90
|
):
|
|
88
91
|
"""
|
|
89
92
|
Initialize LanguageModel.
|
|
@@ -93,6 +96,7 @@ class LanguageModel:
|
|
|
93
96
|
tokenizer: HuggingFace tokenizer
|
|
94
97
|
store: Store instance for persistence
|
|
95
98
|
model_id: Optional model identifier (auto-extracted if not provided)
|
|
99
|
+
device: Optional device string or torch.device (defaults to 'cpu' if None)
|
|
96
100
|
"""
|
|
97
101
|
self.context = LanguageModelContext(self)
|
|
98
102
|
self.context.model = model
|
|
@@ -100,15 +104,17 @@ class LanguageModel:
|
|
|
100
104
|
self.context.model_id = initialize_model_id(model, model_id)
|
|
101
105
|
self.context.store = store
|
|
102
106
|
self.context.special_token_ids = _extract_special_token_ids(tokenizer)
|
|
107
|
+
self.context.device = normalize_device(device)
|
|
108
|
+
sync_model_to_context_device(self)
|
|
103
109
|
|
|
104
110
|
self.layers = LanguageModelLayers(self.context)
|
|
105
111
|
self.lm_tokenizer = LanguageModelTokenizer(self.context)
|
|
106
112
|
self.activations = LanguageModelActivations(self.context)
|
|
107
113
|
self.inference = InferenceEngine(self)
|
|
108
|
-
self._inference_engine = self.inference
|
|
109
114
|
|
|
110
115
|
self._input_tracker: "InputTracker | None" = None
|
|
111
116
|
|
|
117
|
+
|
|
112
118
|
@property
|
|
113
119
|
def model(self) -> nn.Module:
|
|
114
120
|
"""Get the underlying PyTorch model."""
|
|
@@ -147,86 +153,6 @@ class LanguageModel:
|
|
|
147
153
|
"""
|
|
148
154
|
return self.lm_tokenizer.tokenize(texts, **kwargs)
|
|
149
155
|
|
|
150
|
-
def forwards(
|
|
151
|
-
self,
|
|
152
|
-
texts: Sequence[str],
|
|
153
|
-
tok_kwargs: Dict | None = None,
|
|
154
|
-
autocast: bool = True,
|
|
155
|
-
autocast_dtype: torch.dtype | None = None,
|
|
156
|
-
with_controllers: bool = True,
|
|
157
|
-
) -> Tuple[Any, Any]:
|
|
158
|
-
"""
|
|
159
|
-
Run forward pass on texts.
|
|
160
|
-
|
|
161
|
-
Args:
|
|
162
|
-
texts: Input texts to process
|
|
163
|
-
tok_kwargs: Optional tokenizer keyword arguments
|
|
164
|
-
autocast: Whether to use automatic mixed precision
|
|
165
|
-
autocast_dtype: Optional dtype for autocast
|
|
166
|
-
with_controllers: Whether to use controllers during inference
|
|
167
|
-
|
|
168
|
-
Returns:
|
|
169
|
-
Tuple of (model_output, encodings)
|
|
170
|
-
"""
|
|
171
|
-
return self._inference_engine.execute_inference(
|
|
172
|
-
texts,
|
|
173
|
-
tok_kwargs=tok_kwargs,
|
|
174
|
-
autocast=autocast,
|
|
175
|
-
autocast_dtype=autocast_dtype,
|
|
176
|
-
with_controllers=with_controllers
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
def generate(
|
|
180
|
-
self,
|
|
181
|
-
texts: Sequence[str],
|
|
182
|
-
tok_kwargs: Dict | None = None,
|
|
183
|
-
autocast: bool = True,
|
|
184
|
-
autocast_dtype: torch.dtype | None = None,
|
|
185
|
-
with_controllers: bool = True,
|
|
186
|
-
skip_special_tokens: bool = True,
|
|
187
|
-
) -> Sequence[str]:
|
|
188
|
-
"""
|
|
189
|
-
Run inference and automatically decode the output with the tokenizer.
|
|
190
|
-
|
|
191
|
-
Args:
|
|
192
|
-
texts: Input texts to process
|
|
193
|
-
tok_kwargs: Optional tokenizer keyword arguments
|
|
194
|
-
autocast: Whether to use automatic mixed precision
|
|
195
|
-
autocast_dtype: Optional dtype for autocast
|
|
196
|
-
with_controllers: Whether to use controllers during inference
|
|
197
|
-
skip_special_tokens: Whether to skip special tokens when decoding
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
Sequence of decoded text strings
|
|
201
|
-
|
|
202
|
-
Raises:
|
|
203
|
-
ValueError: If texts is empty or tokenizer is None
|
|
204
|
-
"""
|
|
205
|
-
if not texts:
|
|
206
|
-
raise ValueError("Texts list cannot be empty")
|
|
207
|
-
|
|
208
|
-
if self.tokenizer is None:
|
|
209
|
-
raise ValueError("Tokenizer is required for decoding but is None")
|
|
210
|
-
|
|
211
|
-
output, enc = self._inference_engine.execute_inference(
|
|
212
|
-
texts,
|
|
213
|
-
tok_kwargs=tok_kwargs,
|
|
214
|
-
autocast=autocast,
|
|
215
|
-
autocast_dtype=autocast_dtype,
|
|
216
|
-
with_controllers=with_controllers
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
logits = self._inference_engine.extract_logits(output)
|
|
220
|
-
predicted_token_ids = logits.argmax(dim=-1)
|
|
221
|
-
|
|
222
|
-
decoded_texts = []
|
|
223
|
-
for i in range(predicted_token_ids.shape[0]):
|
|
224
|
-
token_ids = predicted_token_ids[i].cpu().tolist()
|
|
225
|
-
decoded_text = self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
226
|
-
decoded_texts.append(decoded_text)
|
|
227
|
-
|
|
228
|
-
return decoded_texts
|
|
229
|
-
|
|
230
156
|
def get_input_tracker(self) -> "InputTracker | None":
|
|
231
157
|
"""
|
|
232
158
|
Get the input tracker instance if it exists.
|
|
@@ -248,8 +174,8 @@ class LanguageModel:
|
|
|
248
174
|
detectors_tensor_metadata: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
|
|
249
175
|
|
|
250
176
|
for detector in detectors:
|
|
251
|
-
detectors_metadata[detector.layer_signature] = detector.metadata
|
|
252
|
-
detectors_tensor_metadata[detector.layer_signature] = detector.tensor_metadata
|
|
177
|
+
detectors_metadata[detector.layer_signature] = dict(detector.metadata)
|
|
178
|
+
detectors_tensor_metadata[detector.layer_signature] = dict(detector.tensor_metadata)
|
|
253
179
|
|
|
254
180
|
return detectors_metadata, detectors_tensor_metadata
|
|
255
181
|
|
|
@@ -263,17 +189,14 @@ class LanguageModel:
|
|
|
263
189
|
"""
|
|
264
190
|
detectors = self.layers.get_detectors()
|
|
265
191
|
for detector in detectors:
|
|
266
|
-
# Clear generic accumulated metadata
|
|
267
192
|
detector.metadata.clear()
|
|
268
193
|
detector.tensor_metadata.clear()
|
|
269
194
|
|
|
270
|
-
# Allow detector implementations to provide more specialized
|
|
271
|
-
# clearing logic (e.g. ModelInputDetector, ModelOutputDetector).
|
|
272
195
|
clear_captured = getattr(detector, "clear_captured", None)
|
|
273
196
|
if callable(clear_captured):
|
|
274
197
|
clear_captured()
|
|
275
198
|
|
|
276
|
-
def save_detector_metadata(self, run_name: str, batch_idx: int | None, unified: bool = False) -> str:
|
|
199
|
+
def save_detector_metadata(self, run_name: str, batch_idx: int | None, unified: bool = False, clear_after_save: bool = True) -> str:
|
|
277
200
|
"""
|
|
278
201
|
Save detector metadata to store.
|
|
279
202
|
|
|
@@ -282,6 +205,8 @@ class LanguageModel:
|
|
|
282
205
|
batch_idx: Batch index. Ignored when ``unified`` is True.
|
|
283
206
|
unified: If True, save metadata in a single detectors directory
|
|
284
207
|
for the whole run instead of per‑batch directories.
|
|
208
|
+
clear_after_save: If True, clear detector metadata after saving to free memory.
|
|
209
|
+
Defaults to True to prevent OOM errors when processing large batches.
|
|
285
210
|
|
|
286
211
|
Returns:
|
|
287
212
|
Path where metadata was saved
|
|
@@ -291,12 +216,36 @@ class LanguageModel:
|
|
|
291
216
|
"""
|
|
292
217
|
if self.store is None:
|
|
293
218
|
raise ValueError("Store must be provided or set on the language model")
|
|
219
|
+
|
|
294
220
|
detectors_metadata, detectors_tensor_metadata = self.get_all_detector_metadata()
|
|
221
|
+
|
|
295
222
|
if unified:
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
223
|
+
result = self.store.put_run_detector_metadata(run_name, detectors_metadata, detectors_tensor_metadata)
|
|
224
|
+
else:
|
|
225
|
+
if batch_idx is None:
|
|
226
|
+
raise ValueError("batch_idx must be provided when unified is False")
|
|
227
|
+
result = self.store.put_detector_metadata(run_name, batch_idx, detectors_metadata, detectors_tensor_metadata)
|
|
228
|
+
|
|
229
|
+
if clear_after_save:
|
|
230
|
+
for layer_signature in list(detectors_tensor_metadata.keys()):
|
|
231
|
+
detector_tensors = detectors_tensor_metadata[layer_signature]
|
|
232
|
+
for tensor_key in list(detector_tensors.keys()):
|
|
233
|
+
del detector_tensors[tensor_key]
|
|
234
|
+
del detectors_tensor_metadata[layer_signature]
|
|
235
|
+
detectors_metadata.clear()
|
|
236
|
+
|
|
237
|
+
detectors = self.layers.get_detectors()
|
|
238
|
+
for detector in detectors:
|
|
239
|
+
clear_captured = getattr(detector, "clear_captured", None)
|
|
240
|
+
if callable(clear_captured):
|
|
241
|
+
clear_captured()
|
|
242
|
+
for key in list(detector.tensor_metadata.keys()):
|
|
243
|
+
del detector.tensor_metadata[key]
|
|
244
|
+
detector.metadata.clear()
|
|
245
|
+
|
|
246
|
+
gc.collect()
|
|
247
|
+
|
|
248
|
+
return result
|
|
300
249
|
|
|
301
250
|
def _ensure_input_tracker(self) -> "InputTracker":
|
|
302
251
|
"""
|
|
@@ -339,23 +288,36 @@ class LanguageModel:
|
|
|
339
288
|
store: Store,
|
|
340
289
|
tokenizer_params: dict = None,
|
|
341
290
|
model_params: dict = None,
|
|
291
|
+
device: str | torch.device | None = None,
|
|
342
292
|
) -> "LanguageModel":
|
|
343
293
|
"""
|
|
344
294
|
Load a language model from HuggingFace Hub.
|
|
345
295
|
|
|
296
|
+
Automatically loads model to GPU if device is "cuda" and CUDA is available.
|
|
297
|
+
This prevents OOM errors by keeping the model on GPU instead of CPU RAM.
|
|
298
|
+
|
|
346
299
|
Args:
|
|
347
300
|
model_name: HuggingFace model identifier
|
|
348
301
|
store: Store instance for persistence
|
|
349
302
|
tokenizer_params: Optional tokenizer parameters
|
|
350
303
|
model_params: Optional model parameters
|
|
304
|
+
device: Target device ("cuda", "cpu", "mps"). If "cuda" and CUDA is available,
|
|
305
|
+
model will be loaded directly to GPU using device_map="auto"
|
|
306
|
+
(via the HuggingFace factory helpers).
|
|
351
307
|
|
|
352
308
|
Returns:
|
|
353
309
|
LanguageModel instance
|
|
354
310
|
"""
|
|
355
|
-
return create_from_huggingface(cls, model_name, store, tokenizer_params, model_params)
|
|
311
|
+
return create_from_huggingface(cls, model_name, store, tokenizer_params, model_params, device)
|
|
356
312
|
|
|
357
313
|
@classmethod
|
|
358
|
-
def from_local_torch(
|
|
314
|
+
def from_local_torch(
|
|
315
|
+
cls,
|
|
316
|
+
model_path: str,
|
|
317
|
+
tokenizer_path: str,
|
|
318
|
+
store: Store,
|
|
319
|
+
device: str | torch.device | None = None,
|
|
320
|
+
) -> "LanguageModel":
|
|
359
321
|
"""
|
|
360
322
|
Load a language model from local HuggingFace paths.
|
|
361
323
|
|
|
@@ -363,14 +325,21 @@ class LanguageModel:
|
|
|
363
325
|
model_path: Path to the model directory or file
|
|
364
326
|
tokenizer_path: Path to the tokenizer directory or file
|
|
365
327
|
store: Store instance for persistence
|
|
328
|
+
device: Optional device string or torch.device (defaults to 'cpu' if None)
|
|
366
329
|
|
|
367
330
|
Returns:
|
|
368
331
|
LanguageModel instance
|
|
369
332
|
"""
|
|
370
|
-
return create_from_local_torch(cls, model_path, tokenizer_path, store)
|
|
333
|
+
return create_from_local_torch(cls, model_path, tokenizer_path, store, device)
|
|
371
334
|
|
|
372
335
|
@classmethod
|
|
373
|
-
def from_local(
|
|
336
|
+
def from_local(
|
|
337
|
+
cls,
|
|
338
|
+
saved_path: Path | str,
|
|
339
|
+
store: Store,
|
|
340
|
+
model_id: str | None = None,
|
|
341
|
+
device: str | torch.device | None = None,
|
|
342
|
+
) -> "LanguageModel":
|
|
374
343
|
"""
|
|
375
344
|
Load a language model from a saved file (created by save_model).
|
|
376
345
|
|
|
@@ -379,6 +348,7 @@ class LanguageModel:
|
|
|
379
348
|
store: Store instance for persistence
|
|
380
349
|
model_id: Optional model identifier. If not provided, will use the model_id from saved metadata.
|
|
381
350
|
If provided, will be used to load the model architecture from HuggingFace.
|
|
351
|
+
device: Optional device string or torch.device (defaults to 'cpu' if None)
|
|
382
352
|
|
|
383
353
|
Returns:
|
|
384
354
|
LanguageModel instance
|
|
@@ -387,4 +357,4 @@ class LanguageModel:
|
|
|
387
357
|
FileNotFoundError: If the saved file doesn't exist
|
|
388
358
|
ValueError: If the saved file format is invalid or model_id is required but not provided
|
|
389
359
|
"""
|
|
390
|
-
return load_model_from_saved_file(cls, saved_path, store, model_id)
|
|
360
|
+
return load_model_from_saved_file(cls, saved_path, store, model_id, device)
|
mi_crow/language_model/layers.py
CHANGED
|
@@ -320,19 +320,22 @@ class LanguageModelLayers:
|
|
|
320
320
|
|
|
321
321
|
layer_signature, hook_type, hook = self.context._hook_id_map[hook_id]
|
|
322
322
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
323
|
+
if layer_signature not in self.context._hook_registry:
|
|
324
|
+
del self.context._hook_id_map[hook_id]
|
|
325
|
+
return True
|
|
326
|
+
|
|
327
|
+
hook_types = self.context._hook_registry[layer_signature]
|
|
328
|
+
if hook_type not in hook_types:
|
|
329
|
+
del self.context._hook_id_map[hook_id]
|
|
330
|
+
return True
|
|
331
|
+
|
|
332
|
+
hooks_list = hook_types[hook_type]
|
|
333
|
+
for i, (h, handle) in enumerate(hooks_list):
|
|
334
|
+
if h.id == hook_id:
|
|
335
|
+
handle.remove()
|
|
336
|
+
hooks_list.pop(i)
|
|
337
|
+
break
|
|
338
|
+
|
|
336
339
|
del self.context._hook_id_map[hook_id]
|
|
337
340
|
return True
|
|
338
341
|
|
|
@@ -89,7 +89,8 @@ def load_model_from_saved_file(
|
|
|
89
89
|
cls: type["LanguageModel"],
|
|
90
90
|
saved_path: Path | str,
|
|
91
91
|
store: "Store",
|
|
92
|
-
model_id: str | None = None
|
|
92
|
+
model_id: str | None = None,
|
|
93
|
+
device: str | torch.device | None = None,
|
|
93
94
|
) -> "LanguageModel":
|
|
94
95
|
"""
|
|
95
96
|
Load a language model from a saved file (created by save_model).
|
|
@@ -100,6 +101,7 @@ def load_model_from_saved_file(
|
|
|
100
101
|
store: Store instance for persistence
|
|
101
102
|
model_id: Optional model identifier. If not provided, will use the model_id from saved metadata.
|
|
102
103
|
If provided, will be used to load the model architecture from HuggingFace.
|
|
104
|
+
device: Optional device string or torch.device (defaults to 'cpu' if None)
|
|
103
105
|
|
|
104
106
|
Returns:
|
|
105
107
|
LanguageModel instance
|
|
@@ -164,7 +166,7 @@ def load_model_from_saved_file(
|
|
|
164
166
|
) from e
|
|
165
167
|
|
|
166
168
|
# Create LanguageModel instance
|
|
167
|
-
lm = cls(model, tokenizer, store, model_id=model_id)
|
|
169
|
+
lm = cls(model, tokenizer, store, model_id=model_id, device=device)
|
|
168
170
|
|
|
169
171
|
# Note: Hooks are not automatically restored as they require hook instances
|
|
170
172
|
# The hook metadata is available in metadata_dict["hooks"] if needed
|
mi_crow/language_model/utils.py
CHANGED
|
@@ -44,8 +44,11 @@ def get_device_from_model(model: nn.Module) -> torch.device:
|
|
|
44
44
|
Returns:
|
|
45
45
|
Device where model parameters are located, or CPU if no parameters
|
|
46
46
|
"""
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
try:
|
|
48
|
+
first_param = next(model.parameters(), None)
|
|
49
|
+
return first_param.device if first_param is not None else torch.device("cpu")
|
|
50
|
+
except (TypeError, AttributeError):
|
|
51
|
+
return torch.device("cpu")
|
|
49
52
|
|
|
50
53
|
|
|
51
54
|
def move_tensors_to_device(
|
|
@@ -62,9 +65,6 @@ def move_tensors_to_device(
|
|
|
62
65
|
Returns:
|
|
63
66
|
Dictionary with tensors moved to device
|
|
64
67
|
"""
|
|
65
|
-
device_type = str(device.type)
|
|
66
|
-
if device_type == "cuda":
|
|
67
|
-
return {k: v.to(device, non_blocking=True) for k, v in tensors.items()}
|
|
68
68
|
return {k: v.to(device) for k, v in tensors.items()}
|
|
69
69
|
|
|
70
70
|
|