mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,525 @@
1
+ """Inference engine for language models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import datetime
6
+ from typing import Sequence, Any, Dict, List, TYPE_CHECKING
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from amber.language_model.utils import get_device_from_model, move_tensors_to_device, extract_logits_from_output
12
+ from amber.utils import get_logger
13
+
14
+ if TYPE_CHECKING:
15
+ from amber.language_model.language_model import LanguageModel
16
+ from amber.hooks.controller import Controller
17
+ from amber.datasets import BaseDataset
18
+ from amber.store.store import Store
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ class _EarlyStopInference(Exception):
24
+ """Internal exception used to stop model forward pass after a specific layer."""
25
+
26
+ def __init__(self, output: Any):
27
+ super().__init__("Early stop after requested layer")
28
+ self.output = output
29
+
30
+
31
+ class InferenceEngine:
32
+ """Handles inference operations for LanguageModel."""
33
+
34
+ def __init__(self, language_model: "LanguageModel"):
35
+ """
36
+ Initialize inference engine.
37
+
38
+ Args:
39
+ language_model: LanguageModel instance
40
+ """
41
+ self.lm = language_model
42
+
43
+ def _prepare_tokenizer_kwargs(self, tok_kwargs: Dict | None) -> Dict[str, Any]:
44
+ """
45
+ Prepare tokenizer keyword arguments with defaults.
46
+
47
+ Args:
48
+ tok_kwargs: Optional tokenizer keyword arguments
49
+
50
+ Returns:
51
+ Dictionary of tokenizer kwargs with defaults applied
52
+ """
53
+ if tok_kwargs is None:
54
+ tok_kwargs = {}
55
+ return {
56
+ "padding": True,
57
+ "truncation": True,
58
+ "return_tensors": "pt",
59
+ **tok_kwargs,
60
+ }
61
+
62
+ def _setup_trackers(self, texts: Sequence[str]) -> None:
63
+ """
64
+ Setup input trackers for current texts.
65
+
66
+ Args:
67
+ texts: Sequence of input texts
68
+ """
69
+ if self.lm._input_tracker is not None and self.lm._input_tracker.enabled:
70
+ self.lm._input_tracker.set_current_texts(texts)
71
+
72
+ def _setup_model_input_detectors(self, enc: Dict[str, torch.Tensor]) -> None:
73
+ """
74
+ Automatically set inputs from encodings for all registered ModelInputDetector hooks.
75
+
76
+ This is necessary because PyTorch's pre_forward hook doesn't receive kwargs,
77
+ so ModelInputDetector hooks can't automatically capture attention masks when
78
+ models are called with **kwargs (e.g., model(**encodings)).
79
+
80
+ Args:
81
+ enc: Encoded inputs dictionary
82
+ """
83
+ from amber.hooks.implementations.model_input_detector import ModelInputDetector
84
+
85
+ detectors = self.lm.layers.get_detectors()
86
+ for detector in detectors:
87
+ if isinstance(detector, ModelInputDetector):
88
+ detector.set_inputs_from_encodings(enc, module=self.lm.model)
89
+
90
+ def _prepare_controllers(self, with_controllers: bool) -> List["Controller"]:
91
+ """
92
+ Prepare controllers for inference, disabling if needed.
93
+
94
+ Args:
95
+ with_controllers: Whether to keep controllers enabled
96
+
97
+ Returns:
98
+ List of controllers that were disabled (to restore later)
99
+ """
100
+ controllers_to_restore = []
101
+ if not with_controllers:
102
+ controllers = self.lm.layers.get_controllers()
103
+ for controller in controllers:
104
+ if controller.enabled:
105
+ controller.disable()
106
+ controllers_to_restore.append(controller)
107
+ return controllers_to_restore
108
+
109
+ def _restore_controllers(self, controllers_to_restore: List["Controller"]) -> None:
110
+ """
111
+ Restore controllers that were disabled.
112
+
113
+ Args:
114
+ controllers_to_restore: List of controllers to restore
115
+ """
116
+ for controller in controllers_to_restore:
117
+ controller.enable()
118
+
119
+ def _run_model_forward(
120
+ self,
121
+ enc: Dict[str, torch.Tensor],
122
+ autocast: bool,
123
+ device_type: str,
124
+ autocast_dtype: torch.dtype | None,
125
+ ) -> Any:
126
+ """
127
+ Run model forward pass with optional autocast.
128
+
129
+ Args:
130
+ enc: Encoded inputs dictionary
131
+ autocast: Whether to use automatic mixed precision
132
+ device_type: Device type string ("cuda", "cpu", etc.)
133
+ autocast_dtype: Optional dtype for autocast
134
+
135
+ Returns:
136
+ Model output
137
+ """
138
+ try:
139
+ with torch.inference_mode():
140
+ if autocast and device_type == "cuda":
141
+ amp_dtype = autocast_dtype or torch.float16
142
+ with torch.autocast(device_type, dtype=amp_dtype):
143
+ return self.lm.model(**enc)
144
+ return self.lm.model(**enc)
145
+ except _EarlyStopInference as e:
146
+ # Early stopping hook raised this to short‑circuit the remaining forward pass.
147
+ # We return the output captured at the requested layer.
148
+ return e.output
149
+
150
+ def execute_inference(
151
+ self,
152
+ texts: Sequence[str],
153
+ tok_kwargs: Dict | None = None,
154
+ autocast: bool = True,
155
+ autocast_dtype: torch.dtype | None = None,
156
+ with_controllers: bool = True,
157
+ stop_after_layer: str | int | None = None,
158
+ ) -> tuple[Any, Dict[str, torch.Tensor]]:
159
+ """
160
+ Execute inference on texts.
161
+
162
+ Args:
163
+ texts: Sequence of input texts
164
+ tok_kwargs: Optional tokenizer keyword arguments
165
+ autocast: Whether to use automatic mixed precision
166
+ autocast_dtype: Optional dtype for autocast
167
+ with_controllers: Whether to use controllers during inference
168
+ stop_after_layer: Optional layer signature (name or index) after which
169
+ the forward pass should be stopped early
170
+
171
+ Returns:
172
+ Tuple of (model_output, encodings)
173
+
174
+ Raises:
175
+ ValueError: If texts is empty or tokenizer is not initialized
176
+ """
177
+ if not texts:
178
+ raise ValueError("Texts list cannot be empty")
179
+
180
+ if self.lm.tokenizer is None:
181
+ raise ValueError("Tokenizer must be initialized before running inference")
182
+
183
+ tok_kwargs = self._prepare_tokenizer_kwargs(tok_kwargs)
184
+ enc = self.lm.tokenize(texts, **tok_kwargs)
185
+
186
+ device = get_device_from_model(self.lm.model)
187
+ device_type = str(device.type)
188
+ enc = move_tensors_to_device(enc, device)
189
+
190
+ self.lm.model.eval()
191
+
192
+ self._setup_trackers(texts)
193
+ self._setup_model_input_detectors(enc)
194
+
195
+ controllers_to_restore = self._prepare_controllers(with_controllers)
196
+
197
+ hook_handle = None
198
+ try:
199
+ if stop_after_layer is not None:
200
+ # Register a temporary forward hook that stops the forward pass
201
+ def _early_stop_hook(module: nn.Module, inputs: tuple, output: Any):
202
+ raise _EarlyStopInference(output)
203
+
204
+ hook_handle = self.lm.layers.register_forward_hook_for_layer(
205
+ stop_after_layer, _early_stop_hook
206
+ )
207
+
208
+ output = self._run_model_forward(enc, autocast, device_type, autocast_dtype)
209
+ return output, enc
210
+ finally:
211
+ if hook_handle is not None:
212
+ try:
213
+ hook_handle.remove()
214
+ except Exception:
215
+ pass
216
+ self._restore_controllers(controllers_to_restore)
217
+
218
+ def extract_logits(self, output: Any) -> torch.Tensor:
219
+ """
220
+ Extract logits tensor from model output.
221
+
222
+ Args:
223
+ output: Model output
224
+
225
+ Returns:
226
+ Logits tensor
227
+ """
228
+ return extract_logits_from_output(output)
229
+
230
+ def _extract_dataset_info(self, dataset: "BaseDataset | None") -> Dict[str, Any]:
231
+ """
232
+ Extract dataset information for metadata.
233
+
234
+ Args:
235
+ dataset: Optional dataset instance
236
+
237
+ Returns:
238
+ Dictionary with dataset information
239
+ """
240
+ if dataset is None:
241
+ return {}
242
+
243
+ try:
244
+ ds_id = str(getattr(dataset, "dataset_dir", ""))
245
+ ds_len = int(len(dataset))
246
+ return {
247
+ "dataset_dir": ds_id,
248
+ "length": ds_len,
249
+ }
250
+ except (AttributeError, TypeError, ValueError, RuntimeError):
251
+ return {
252
+ "dataset_dir": "",
253
+ "length": -1,
254
+ }
255
+
256
+ def _prepare_run_metadata(
257
+ self,
258
+ dataset: "BaseDataset | None" = None,
259
+ run_name: str | None = None,
260
+ options: Dict[str, Any] | None = None,
261
+ ) -> tuple[str, Dict[str, Any]]:
262
+ """
263
+ Prepare run metadata dictionary.
264
+
265
+ Args:
266
+ dataset: Optional dataset (for dataset info)
267
+ run_name: Optional run name (generates if None)
268
+ options: Optional dict of options to include
269
+
270
+ Returns:
271
+ Tuple of (run_name, metadata_dict)
272
+ """
273
+ if run_name is None:
274
+ run_name = f"run_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
275
+
276
+ if options is None:
277
+ options = {}
278
+
279
+ dataset_info = self._extract_dataset_info(dataset)
280
+
281
+ meta: Dict[str, Any] = {
282
+ "run_name": run_name,
283
+ "model": getattr(self.lm.model, "model_name", self.lm.model.__class__.__name__),
284
+ "options": options.copy(),
285
+ }
286
+
287
+ if dataset_info:
288
+ meta["dataset"] = dataset_info
289
+
290
+ return run_name, meta
291
+
292
+ @staticmethod
293
+ def _save_run_metadata(
294
+ store: "Store",
295
+ run_name: str,
296
+ meta: Dict[str, Any],
297
+ verbose: bool = False,
298
+ ) -> None:
299
+ """
300
+ Save run metadata to store.
301
+
302
+ Args:
303
+ store: Store to save to
304
+ run_name: Run name
305
+ meta: Metadata dictionary
306
+ verbose: Whether to log
307
+ """
308
+ try:
309
+ store.put_run_metadata(run_name, meta)
310
+ except (OSError, IOError, ValueError, RuntimeError) as e:
311
+ if verbose:
312
+ logger.warning(f"Failed to save run metadata for {run_name}: {e}")
313
+
314
+ def infer_texts(
315
+ self,
316
+ texts: Sequence[str],
317
+ run_name: str | None = None,
318
+ batch_size: int | None = None,
319
+ tok_kwargs: Dict | None = None,
320
+ autocast: bool = True,
321
+ autocast_dtype: torch.dtype | None = None,
322
+ with_controllers: bool = True,
323
+ clear_detectors_before: bool = False,
324
+ verbose: bool = False,
325
+ stop_after_layer: str | int | None = None,
326
+ save_in_batches: bool = True,
327
+ ) -> tuple[Any, Dict[str, torch.Tensor]] | tuple[List[Any], List[Dict[str, torch.Tensor]]]:
328
+ """
329
+ Run inference on list of strings with optional metadata saving.
330
+
331
+ Args:
332
+ texts: Sequence of input texts
333
+ run_name: Optional run name for saving metadata (if None, no metadata saved)
334
+ batch_size: Optional batch size for processing (if None, processes all at once)
335
+ tok_kwargs: Optional tokenizer keyword arguments
336
+ autocast: Whether to use automatic mixed precision
337
+ autocast_dtype: Optional dtype for autocast
338
+ with_controllers: Whether to use controllers during inference
339
+ clear_detectors_before: If True, clears all detector state before running
340
+ verbose: Whether to log progress
341
+ stop_after_layer: Optional layer signature (name or index) after which
342
+ the forward pass should be stopped early
343
+ save_in_batches: If True, save detector metadata in per‑batch
344
+ directories. If False, aggregate all detector metadata for
345
+ the run under a single detectors directory.
346
+
347
+ Returns:
348
+ If batch_size is None or >= len(texts): Tuple of (model_output, encodings)
349
+ If batch_size < len(texts): Tuple of (list of outputs, list of encodings)
350
+
351
+ Raises:
352
+ ValueError: If texts is empty or tokenizer is not initialized
353
+ """
354
+ if not texts:
355
+ raise ValueError("Texts list cannot be empty")
356
+
357
+ if self.lm.tokenizer is None:
358
+ raise ValueError("Tokenizer must be initialized before running inference")
359
+
360
+ if clear_detectors_before:
361
+ self.lm.clear_detectors()
362
+
363
+ store = self.lm.store
364
+ if run_name is not None and store is None:
365
+ raise ValueError("Store must be provided to save metadata")
366
+
367
+ if batch_size is None or batch_size >= len(texts):
368
+ output, enc = self.execute_inference(
369
+ texts,
370
+ tok_kwargs=tok_kwargs,
371
+ autocast=autocast,
372
+ autocast_dtype=autocast_dtype,
373
+ with_controllers=with_controllers,
374
+ stop_after_layer=stop_after_layer,
375
+ )
376
+
377
+ if run_name is not None:
378
+ options = {
379
+ "batch_size": len(texts),
380
+ "max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
381
+ }
382
+ _, meta = self._prepare_run_metadata(dataset=None, run_name=run_name, options=options)
383
+ self._save_run_metadata(store, run_name, meta, verbose)
384
+ self.lm.save_detector_metadata(run_name, 0, unified=not save_in_batches)
385
+
386
+ return output, enc
387
+
388
+ all_outputs = []
389
+ all_encodings = []
390
+ batch_counter = 0
391
+
392
+ if run_name is not None:
393
+ options = {
394
+ "batch_size": batch_size,
395
+ "max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
396
+ }
397
+ _, meta = self._prepare_run_metadata(dataset=None, run_name=run_name, options=options)
398
+ self._save_run_metadata(store, run_name, meta, verbose)
399
+
400
+ for i in range(0, len(texts), batch_size):
401
+ batch_texts = texts[i:i + batch_size]
402
+ output, enc = self.execute_inference(
403
+ batch_texts,
404
+ tok_kwargs=tok_kwargs,
405
+ autocast=autocast,
406
+ autocast_dtype=autocast_dtype,
407
+ with_controllers=with_controllers,
408
+ stop_after_layer=stop_after_layer,
409
+ )
410
+
411
+ all_outputs.append(output)
412
+ all_encodings.append(enc)
413
+
414
+ if run_name is not None:
415
+ self.lm.save_detector_metadata(run_name, batch_counter, unified=not save_in_batches)
416
+ if verbose:
417
+ logger.info(f"Saved batch {batch_counter} for run={run_name}")
418
+
419
+ batch_counter += 1
420
+
421
+ return all_outputs, all_encodings
422
+
423
+ def infer_dataset(
424
+ self,
425
+ dataset: "BaseDataset",
426
+ run_name: str | None = None,
427
+ batch_size: int = 32,
428
+ tok_kwargs: Dict | None = None,
429
+ autocast: bool = True,
430
+ autocast_dtype: torch.dtype | None = None,
431
+ with_controllers: bool = True,
432
+ free_cuda_cache_every: int | None = 0,
433
+ clear_detectors_before: bool = False,
434
+ verbose: bool = False,
435
+ stop_after_layer: str | int | None = None,
436
+ save_in_batches: bool = True,
437
+ ) -> str:
438
+ """
439
+ Run inference on whole dataset with metadata saving.
440
+
441
+ Args:
442
+ dataset: Dataset to process
443
+ run_name: Optional run name (generated if None)
444
+ batch_size: Batch size for processing
445
+ tok_kwargs: Optional tokenizer keyword arguments
446
+ autocast: Whether to use automatic mixed precision
447
+ autocast_dtype: Optional dtype for autocast
448
+ with_controllers: Whether to use controllers during inference
449
+ free_cuda_cache_every: Clear CUDA cache every N batches (0 or None to disable)
450
+ clear_detectors_before: If True, clears all detector state before running
451
+ verbose: Whether to log progress
452
+ stop_after_layer: Optional layer signature (name or index) after which
453
+ the forward pass should be stopped early
454
+
455
+ Returns:
456
+ Run name used for saving
457
+
458
+ Raises:
459
+ ValueError: If model or store is not initialized
460
+ """
461
+ if clear_detectors_before:
462
+ self.lm.clear_detectors()
463
+
464
+ model: nn.Module | None = self.lm.model
465
+ if model is None:
466
+ raise ValueError("Model must be initialized before running")
467
+
468
+ store = self.lm.store
469
+ if store is None:
470
+ raise ValueError("Store must be provided or set on the language model")
471
+
472
+ device = get_device_from_model(model)
473
+ device_type = str(device.type)
474
+
475
+ options = {
476
+ "max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
477
+ "batch_size": int(batch_size),
478
+ }
479
+
480
+ run_name, meta = self._prepare_run_metadata(dataset=dataset, run_name=run_name, options=options)
481
+
482
+ if verbose:
483
+ logger.info(
484
+ f"Starting infer_dataset: run={run_name}, "
485
+ f"batch_size={batch_size}, device={device_type}"
486
+ )
487
+
488
+ self._save_run_metadata(store, run_name, meta, verbose)
489
+
490
+ batch_counter = 0
491
+
492
+ with torch.inference_mode():
493
+ for batch_index, batch in enumerate(dataset.iter_batches(batch_size)):
494
+ if not batch:
495
+ continue
496
+
497
+ texts = dataset.extract_texts_from_batch(batch)
498
+
499
+ self.execute_inference(
500
+ texts,
501
+ tok_kwargs=tok_kwargs,
502
+ autocast=autocast,
503
+ autocast_dtype=autocast_dtype,
504
+ with_controllers=with_controllers,
505
+ stop_after_layer=stop_after_layer,
506
+ )
507
+
508
+ self.lm.save_detector_metadata(run_name, batch_index, unified=not save_in_batches)
509
+
510
+ batch_counter += 1
511
+
512
+ if device_type == "cuda" and free_cuda_cache_every and free_cuda_cache_every > 0:
513
+ if (batch_counter % free_cuda_cache_every) == 0:
514
+ torch.cuda.empty_cache()
515
+ if verbose:
516
+ logger.info("Emptied CUDA cache")
517
+
518
+ if verbose:
519
+ logger.info(f"Saved batch {batch_index} for run={run_name}")
520
+
521
+ if verbose:
522
+ logger.info(f"Completed infer_dataset: run={run_name}, batches_saved={batch_counter}")
523
+
524
+ return run_name
525
+
@@ -0,0 +1,126 @@
1
+ """Model initialization and factory methods."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ import torch
9
+ from torch import nn
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
11
+
12
+ from amber.store.store import Store
13
+ from amber.language_model.utils import extract_model_id
14
+
15
+ if TYPE_CHECKING:
16
+ from amber.language_model.language_model import LanguageModel
17
+
18
+
19
+ def initialize_model_id(
20
+ model: nn.Module,
21
+ provided_model_id: str | None = None
22
+ ) -> str:
23
+ """
24
+ Initialize model ID for LanguageModel.
25
+
26
+ Args:
27
+ model: PyTorch model module
28
+ provided_model_id: Optional model ID provided by user
29
+
30
+ Returns:
31
+ Model ID string
32
+ """
33
+ return extract_model_id(model, provided_model_id)
34
+
35
+
36
+ def create_from_huggingface(
37
+ cls: type["LanguageModel"],
38
+ model_name: str,
39
+ store: Store,
40
+ tokenizer_params: dict | None = None,
41
+ model_params: dict | None = None,
42
+ ) -> "LanguageModel":
43
+ """
44
+ Load a language model from HuggingFace Hub.
45
+
46
+ Args:
47
+ cls: LanguageModel class
48
+ model_name: HuggingFace model identifier
49
+ store: Store instance for persistence
50
+ tokenizer_params: Optional tokenizer parameters
51
+ model_params: Optional model parameters
52
+
53
+ Returns:
54
+ LanguageModel instance
55
+
56
+ Raises:
57
+ ValueError: If model_name is invalid
58
+ RuntimeError: If model loading fails
59
+ """
60
+ if not model_name or not isinstance(model_name, str) or not model_name.strip():
61
+ raise ValueError(f"model_name must be a non-empty string, got: {model_name!r}")
62
+
63
+ if store is None:
64
+ raise ValueError("store cannot be None")
65
+
66
+ if tokenizer_params is None:
67
+ tokenizer_params = {}
68
+ if model_params is None:
69
+ model_params = {}
70
+
71
+ try:
72
+ tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_params)
73
+ model = AutoModelForCausalLM.from_pretrained(model_name, **model_params)
74
+ except Exception as e:
75
+ raise RuntimeError(
76
+ f"Failed to load model '{model_name}' from HuggingFace. Error: {e}"
77
+ ) from e
78
+
79
+ return cls(model, tokenizer, store)
80
+
81
+
82
+ def create_from_local_torch(
83
+ cls: type["LanguageModel"],
84
+ model_path: str,
85
+ tokenizer_path: str,
86
+ store: Store
87
+ ) -> "LanguageModel":
88
+ """
89
+ Load a language model from local HuggingFace paths.
90
+
91
+ Args:
92
+ cls: LanguageModel class
93
+ model_path: Path to the model directory or file
94
+ tokenizer_path: Path to the tokenizer directory or file
95
+ store: Store instance for persistence
96
+
97
+ Returns:
98
+ LanguageModel instance
99
+
100
+ Raises:
101
+ FileNotFoundError: If model or tokenizer paths don't exist
102
+ RuntimeError: If model loading fails
103
+ """
104
+ if store is None:
105
+ raise ValueError("store cannot be None")
106
+
107
+ model_path_obj = Path(model_path)
108
+ tokenizer_path_obj = Path(tokenizer_path)
109
+
110
+ if not model_path_obj.exists():
111
+ raise FileNotFoundError(f"Model path does not exist: {model_path}")
112
+
113
+ if not tokenizer_path_obj.exists():
114
+ raise FileNotFoundError(f"Tokenizer path does not exist: {tokenizer_path}")
115
+
116
+ try:
117
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
118
+ model = AutoModelForCausalLM.from_pretrained(model_path)
119
+ except Exception as e:
120
+ raise RuntimeError(
121
+ f"Failed to load model from local paths. "
122
+ f"model_path={model_path!r}, tokenizer_path={tokenizer_path!r}. Error: {e}"
123
+ ) from e
124
+
125
+ return cls(model, tokenizer, store)
126
+