mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,525 @@
|
|
|
1
|
+
"""Inference engine for language models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import datetime
|
|
6
|
+
from typing import Sequence, Any, Dict, List, TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from amber.language_model.utils import get_device_from_model, move_tensors_to_device, extract_logits_from_output
|
|
12
|
+
from amber.utils import get_logger
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from amber.language_model.language_model import LanguageModel
|
|
16
|
+
from amber.hooks.controller import Controller
|
|
17
|
+
from amber.datasets import BaseDataset
|
|
18
|
+
from amber.store.store import Store
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class _EarlyStopInference(Exception):
|
|
24
|
+
"""Internal exception used to stop model forward pass after a specific layer."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, output: Any):
|
|
27
|
+
super().__init__("Early stop after requested layer")
|
|
28
|
+
self.output = output
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InferenceEngine:
|
|
32
|
+
"""Handles inference operations for LanguageModel."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, language_model: "LanguageModel"):
|
|
35
|
+
"""
|
|
36
|
+
Initialize inference engine.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
language_model: LanguageModel instance
|
|
40
|
+
"""
|
|
41
|
+
self.lm = language_model
|
|
42
|
+
|
|
43
|
+
def _prepare_tokenizer_kwargs(self, tok_kwargs: Dict | None) -> Dict[str, Any]:
|
|
44
|
+
"""
|
|
45
|
+
Prepare tokenizer keyword arguments with defaults.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
tok_kwargs: Optional tokenizer keyword arguments
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Dictionary of tokenizer kwargs with defaults applied
|
|
52
|
+
"""
|
|
53
|
+
if tok_kwargs is None:
|
|
54
|
+
tok_kwargs = {}
|
|
55
|
+
return {
|
|
56
|
+
"padding": True,
|
|
57
|
+
"truncation": True,
|
|
58
|
+
"return_tensors": "pt",
|
|
59
|
+
**tok_kwargs,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
def _setup_trackers(self, texts: Sequence[str]) -> None:
|
|
63
|
+
"""
|
|
64
|
+
Setup input trackers for current texts.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
texts: Sequence of input texts
|
|
68
|
+
"""
|
|
69
|
+
if self.lm._input_tracker is not None and self.lm._input_tracker.enabled:
|
|
70
|
+
self.lm._input_tracker.set_current_texts(texts)
|
|
71
|
+
|
|
72
|
+
def _setup_model_input_detectors(self, enc: Dict[str, torch.Tensor]) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Automatically set inputs from encodings for all registered ModelInputDetector hooks.
|
|
75
|
+
|
|
76
|
+
This is necessary because PyTorch's pre_forward hook doesn't receive kwargs,
|
|
77
|
+
so ModelInputDetector hooks can't automatically capture attention masks when
|
|
78
|
+
models are called with **kwargs (e.g., model(**encodings)).
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
enc: Encoded inputs dictionary
|
|
82
|
+
"""
|
|
83
|
+
from amber.hooks.implementations.model_input_detector import ModelInputDetector
|
|
84
|
+
|
|
85
|
+
detectors = self.lm.layers.get_detectors()
|
|
86
|
+
for detector in detectors:
|
|
87
|
+
if isinstance(detector, ModelInputDetector):
|
|
88
|
+
detector.set_inputs_from_encodings(enc, module=self.lm.model)
|
|
89
|
+
|
|
90
|
+
def _prepare_controllers(self, with_controllers: bool) -> List["Controller"]:
|
|
91
|
+
"""
|
|
92
|
+
Prepare controllers for inference, disabling if needed.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
with_controllers: Whether to keep controllers enabled
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
List of controllers that were disabled (to restore later)
|
|
99
|
+
"""
|
|
100
|
+
controllers_to_restore = []
|
|
101
|
+
if not with_controllers:
|
|
102
|
+
controllers = self.lm.layers.get_controllers()
|
|
103
|
+
for controller in controllers:
|
|
104
|
+
if controller.enabled:
|
|
105
|
+
controller.disable()
|
|
106
|
+
controllers_to_restore.append(controller)
|
|
107
|
+
return controllers_to_restore
|
|
108
|
+
|
|
109
|
+
def _restore_controllers(self, controllers_to_restore: List["Controller"]) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Restore controllers that were disabled.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
controllers_to_restore: List of controllers to restore
|
|
115
|
+
"""
|
|
116
|
+
for controller in controllers_to_restore:
|
|
117
|
+
controller.enable()
|
|
118
|
+
|
|
119
|
+
def _run_model_forward(
|
|
120
|
+
self,
|
|
121
|
+
enc: Dict[str, torch.Tensor],
|
|
122
|
+
autocast: bool,
|
|
123
|
+
device_type: str,
|
|
124
|
+
autocast_dtype: torch.dtype | None,
|
|
125
|
+
) -> Any:
|
|
126
|
+
"""
|
|
127
|
+
Run model forward pass with optional autocast.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
enc: Encoded inputs dictionary
|
|
131
|
+
autocast: Whether to use automatic mixed precision
|
|
132
|
+
device_type: Device type string ("cuda", "cpu", etc.)
|
|
133
|
+
autocast_dtype: Optional dtype for autocast
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Model output
|
|
137
|
+
"""
|
|
138
|
+
try:
|
|
139
|
+
with torch.inference_mode():
|
|
140
|
+
if autocast and device_type == "cuda":
|
|
141
|
+
amp_dtype = autocast_dtype or torch.float16
|
|
142
|
+
with torch.autocast(device_type, dtype=amp_dtype):
|
|
143
|
+
return self.lm.model(**enc)
|
|
144
|
+
return self.lm.model(**enc)
|
|
145
|
+
except _EarlyStopInference as e:
|
|
146
|
+
# Early stopping hook raised this to short‑circuit the remaining forward pass.
|
|
147
|
+
# We return the output captured at the requested layer.
|
|
148
|
+
return e.output
|
|
149
|
+
|
|
150
|
+
def execute_inference(
|
|
151
|
+
self,
|
|
152
|
+
texts: Sequence[str],
|
|
153
|
+
tok_kwargs: Dict | None = None,
|
|
154
|
+
autocast: bool = True,
|
|
155
|
+
autocast_dtype: torch.dtype | None = None,
|
|
156
|
+
with_controllers: bool = True,
|
|
157
|
+
stop_after_layer: str | int | None = None,
|
|
158
|
+
) -> tuple[Any, Dict[str, torch.Tensor]]:
|
|
159
|
+
"""
|
|
160
|
+
Execute inference on texts.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
texts: Sequence of input texts
|
|
164
|
+
tok_kwargs: Optional tokenizer keyword arguments
|
|
165
|
+
autocast: Whether to use automatic mixed precision
|
|
166
|
+
autocast_dtype: Optional dtype for autocast
|
|
167
|
+
with_controllers: Whether to use controllers during inference
|
|
168
|
+
stop_after_layer: Optional layer signature (name or index) after which
|
|
169
|
+
the forward pass should be stopped early
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Tuple of (model_output, encodings)
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
ValueError: If texts is empty or tokenizer is not initialized
|
|
176
|
+
"""
|
|
177
|
+
if not texts:
|
|
178
|
+
raise ValueError("Texts list cannot be empty")
|
|
179
|
+
|
|
180
|
+
if self.lm.tokenizer is None:
|
|
181
|
+
raise ValueError("Tokenizer must be initialized before running inference")
|
|
182
|
+
|
|
183
|
+
tok_kwargs = self._prepare_tokenizer_kwargs(tok_kwargs)
|
|
184
|
+
enc = self.lm.tokenize(texts, **tok_kwargs)
|
|
185
|
+
|
|
186
|
+
device = get_device_from_model(self.lm.model)
|
|
187
|
+
device_type = str(device.type)
|
|
188
|
+
enc = move_tensors_to_device(enc, device)
|
|
189
|
+
|
|
190
|
+
self.lm.model.eval()
|
|
191
|
+
|
|
192
|
+
self._setup_trackers(texts)
|
|
193
|
+
self._setup_model_input_detectors(enc)
|
|
194
|
+
|
|
195
|
+
controllers_to_restore = self._prepare_controllers(with_controllers)
|
|
196
|
+
|
|
197
|
+
hook_handle = None
|
|
198
|
+
try:
|
|
199
|
+
if stop_after_layer is not None:
|
|
200
|
+
# Register a temporary forward hook that stops the forward pass
|
|
201
|
+
def _early_stop_hook(module: nn.Module, inputs: tuple, output: Any):
|
|
202
|
+
raise _EarlyStopInference(output)
|
|
203
|
+
|
|
204
|
+
hook_handle = self.lm.layers.register_forward_hook_for_layer(
|
|
205
|
+
stop_after_layer, _early_stop_hook
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
output = self._run_model_forward(enc, autocast, device_type, autocast_dtype)
|
|
209
|
+
return output, enc
|
|
210
|
+
finally:
|
|
211
|
+
if hook_handle is not None:
|
|
212
|
+
try:
|
|
213
|
+
hook_handle.remove()
|
|
214
|
+
except Exception:
|
|
215
|
+
pass
|
|
216
|
+
self._restore_controllers(controllers_to_restore)
|
|
217
|
+
|
|
218
|
+
def extract_logits(self, output: Any) -> torch.Tensor:
|
|
219
|
+
"""
|
|
220
|
+
Extract logits tensor from model output.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
output: Model output
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Logits tensor
|
|
227
|
+
"""
|
|
228
|
+
return extract_logits_from_output(output)
|
|
229
|
+
|
|
230
|
+
def _extract_dataset_info(self, dataset: "BaseDataset | None") -> Dict[str, Any]:
|
|
231
|
+
"""
|
|
232
|
+
Extract dataset information for metadata.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
dataset: Optional dataset instance
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Dictionary with dataset information
|
|
239
|
+
"""
|
|
240
|
+
if dataset is None:
|
|
241
|
+
return {}
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
ds_id = str(getattr(dataset, "dataset_dir", ""))
|
|
245
|
+
ds_len = int(len(dataset))
|
|
246
|
+
return {
|
|
247
|
+
"dataset_dir": ds_id,
|
|
248
|
+
"length": ds_len,
|
|
249
|
+
}
|
|
250
|
+
except (AttributeError, TypeError, ValueError, RuntimeError):
|
|
251
|
+
return {
|
|
252
|
+
"dataset_dir": "",
|
|
253
|
+
"length": -1,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
def _prepare_run_metadata(
|
|
257
|
+
self,
|
|
258
|
+
dataset: "BaseDataset | None" = None,
|
|
259
|
+
run_name: str | None = None,
|
|
260
|
+
options: Dict[str, Any] | None = None,
|
|
261
|
+
) -> tuple[str, Dict[str, Any]]:
|
|
262
|
+
"""
|
|
263
|
+
Prepare run metadata dictionary.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
dataset: Optional dataset (for dataset info)
|
|
267
|
+
run_name: Optional run name (generates if None)
|
|
268
|
+
options: Optional dict of options to include
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Tuple of (run_name, metadata_dict)
|
|
272
|
+
"""
|
|
273
|
+
if run_name is None:
|
|
274
|
+
run_name = f"run_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
275
|
+
|
|
276
|
+
if options is None:
|
|
277
|
+
options = {}
|
|
278
|
+
|
|
279
|
+
dataset_info = self._extract_dataset_info(dataset)
|
|
280
|
+
|
|
281
|
+
meta: Dict[str, Any] = {
|
|
282
|
+
"run_name": run_name,
|
|
283
|
+
"model": getattr(self.lm.model, "model_name", self.lm.model.__class__.__name__),
|
|
284
|
+
"options": options.copy(),
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
if dataset_info:
|
|
288
|
+
meta["dataset"] = dataset_info
|
|
289
|
+
|
|
290
|
+
return run_name, meta
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def _save_run_metadata(
|
|
294
|
+
store: "Store",
|
|
295
|
+
run_name: str,
|
|
296
|
+
meta: Dict[str, Any],
|
|
297
|
+
verbose: bool = False,
|
|
298
|
+
) -> None:
|
|
299
|
+
"""
|
|
300
|
+
Save run metadata to store.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
store: Store to save to
|
|
304
|
+
run_name: Run name
|
|
305
|
+
meta: Metadata dictionary
|
|
306
|
+
verbose: Whether to log
|
|
307
|
+
"""
|
|
308
|
+
try:
|
|
309
|
+
store.put_run_metadata(run_name, meta)
|
|
310
|
+
except (OSError, IOError, ValueError, RuntimeError) as e:
|
|
311
|
+
if verbose:
|
|
312
|
+
logger.warning(f"Failed to save run metadata for {run_name}: {e}")
|
|
313
|
+
|
|
314
|
+
def infer_texts(
|
|
315
|
+
self,
|
|
316
|
+
texts: Sequence[str],
|
|
317
|
+
run_name: str | None = None,
|
|
318
|
+
batch_size: int | None = None,
|
|
319
|
+
tok_kwargs: Dict | None = None,
|
|
320
|
+
autocast: bool = True,
|
|
321
|
+
autocast_dtype: torch.dtype | None = None,
|
|
322
|
+
with_controllers: bool = True,
|
|
323
|
+
clear_detectors_before: bool = False,
|
|
324
|
+
verbose: bool = False,
|
|
325
|
+
stop_after_layer: str | int | None = None,
|
|
326
|
+
save_in_batches: bool = True,
|
|
327
|
+
) -> tuple[Any, Dict[str, torch.Tensor]] | tuple[List[Any], List[Dict[str, torch.Tensor]]]:
|
|
328
|
+
"""
|
|
329
|
+
Run inference on list of strings with optional metadata saving.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
texts: Sequence of input texts
|
|
333
|
+
run_name: Optional run name for saving metadata (if None, no metadata saved)
|
|
334
|
+
batch_size: Optional batch size for processing (if None, processes all at once)
|
|
335
|
+
tok_kwargs: Optional tokenizer keyword arguments
|
|
336
|
+
autocast: Whether to use automatic mixed precision
|
|
337
|
+
autocast_dtype: Optional dtype for autocast
|
|
338
|
+
with_controllers: Whether to use controllers during inference
|
|
339
|
+
clear_detectors_before: If True, clears all detector state before running
|
|
340
|
+
verbose: Whether to log progress
|
|
341
|
+
stop_after_layer: Optional layer signature (name or index) after which
|
|
342
|
+
the forward pass should be stopped early
|
|
343
|
+
save_in_batches: If True, save detector metadata in per‑batch
|
|
344
|
+
directories. If False, aggregate all detector metadata for
|
|
345
|
+
the run under a single detectors directory.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
If batch_size is None or >= len(texts): Tuple of (model_output, encodings)
|
|
349
|
+
If batch_size < len(texts): Tuple of (list of outputs, list of encodings)
|
|
350
|
+
|
|
351
|
+
Raises:
|
|
352
|
+
ValueError: If texts is empty or tokenizer is not initialized
|
|
353
|
+
"""
|
|
354
|
+
if not texts:
|
|
355
|
+
raise ValueError("Texts list cannot be empty")
|
|
356
|
+
|
|
357
|
+
if self.lm.tokenizer is None:
|
|
358
|
+
raise ValueError("Tokenizer must be initialized before running inference")
|
|
359
|
+
|
|
360
|
+
if clear_detectors_before:
|
|
361
|
+
self.lm.clear_detectors()
|
|
362
|
+
|
|
363
|
+
store = self.lm.store
|
|
364
|
+
if run_name is not None and store is None:
|
|
365
|
+
raise ValueError("Store must be provided to save metadata")
|
|
366
|
+
|
|
367
|
+
if batch_size is None or batch_size >= len(texts):
|
|
368
|
+
output, enc = self.execute_inference(
|
|
369
|
+
texts,
|
|
370
|
+
tok_kwargs=tok_kwargs,
|
|
371
|
+
autocast=autocast,
|
|
372
|
+
autocast_dtype=autocast_dtype,
|
|
373
|
+
with_controllers=with_controllers,
|
|
374
|
+
stop_after_layer=stop_after_layer,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if run_name is not None:
|
|
378
|
+
options = {
|
|
379
|
+
"batch_size": len(texts),
|
|
380
|
+
"max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
|
|
381
|
+
}
|
|
382
|
+
_, meta = self._prepare_run_metadata(dataset=None, run_name=run_name, options=options)
|
|
383
|
+
self._save_run_metadata(store, run_name, meta, verbose)
|
|
384
|
+
self.lm.save_detector_metadata(run_name, 0, unified=not save_in_batches)
|
|
385
|
+
|
|
386
|
+
return output, enc
|
|
387
|
+
|
|
388
|
+
all_outputs = []
|
|
389
|
+
all_encodings = []
|
|
390
|
+
batch_counter = 0
|
|
391
|
+
|
|
392
|
+
if run_name is not None:
|
|
393
|
+
options = {
|
|
394
|
+
"batch_size": batch_size,
|
|
395
|
+
"max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
|
|
396
|
+
}
|
|
397
|
+
_, meta = self._prepare_run_metadata(dataset=None, run_name=run_name, options=options)
|
|
398
|
+
self._save_run_metadata(store, run_name, meta, verbose)
|
|
399
|
+
|
|
400
|
+
for i in range(0, len(texts), batch_size):
|
|
401
|
+
batch_texts = texts[i:i + batch_size]
|
|
402
|
+
output, enc = self.execute_inference(
|
|
403
|
+
batch_texts,
|
|
404
|
+
tok_kwargs=tok_kwargs,
|
|
405
|
+
autocast=autocast,
|
|
406
|
+
autocast_dtype=autocast_dtype,
|
|
407
|
+
with_controllers=with_controllers,
|
|
408
|
+
stop_after_layer=stop_after_layer,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
all_outputs.append(output)
|
|
412
|
+
all_encodings.append(enc)
|
|
413
|
+
|
|
414
|
+
if run_name is not None:
|
|
415
|
+
self.lm.save_detector_metadata(run_name, batch_counter, unified=not save_in_batches)
|
|
416
|
+
if verbose:
|
|
417
|
+
logger.info(f"Saved batch {batch_counter} for run={run_name}")
|
|
418
|
+
|
|
419
|
+
batch_counter += 1
|
|
420
|
+
|
|
421
|
+
return all_outputs, all_encodings
|
|
422
|
+
|
|
423
|
+
def infer_dataset(
|
|
424
|
+
self,
|
|
425
|
+
dataset: "BaseDataset",
|
|
426
|
+
run_name: str | None = None,
|
|
427
|
+
batch_size: int = 32,
|
|
428
|
+
tok_kwargs: Dict | None = None,
|
|
429
|
+
autocast: bool = True,
|
|
430
|
+
autocast_dtype: torch.dtype | None = None,
|
|
431
|
+
with_controllers: bool = True,
|
|
432
|
+
free_cuda_cache_every: int | None = 0,
|
|
433
|
+
clear_detectors_before: bool = False,
|
|
434
|
+
verbose: bool = False,
|
|
435
|
+
stop_after_layer: str | int | None = None,
|
|
436
|
+
save_in_batches: bool = True,
|
|
437
|
+
) -> str:
|
|
438
|
+
"""
|
|
439
|
+
Run inference on whole dataset with metadata saving.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
dataset: Dataset to process
|
|
443
|
+
run_name: Optional run name (generated if None)
|
|
444
|
+
batch_size: Batch size for processing
|
|
445
|
+
tok_kwargs: Optional tokenizer keyword arguments
|
|
446
|
+
autocast: Whether to use automatic mixed precision
|
|
447
|
+
autocast_dtype: Optional dtype for autocast
|
|
448
|
+
with_controllers: Whether to use controllers during inference
|
|
449
|
+
free_cuda_cache_every: Clear CUDA cache every N batches (0 or None to disable)
|
|
450
|
+
clear_detectors_before: If True, clears all detector state before running
|
|
451
|
+
verbose: Whether to log progress
|
|
452
|
+
stop_after_layer: Optional layer signature (name or index) after which
|
|
453
|
+
the forward pass should be stopped early
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Run name used for saving
|
|
457
|
+
|
|
458
|
+
Raises:
|
|
459
|
+
ValueError: If model or store is not initialized
|
|
460
|
+
"""
|
|
461
|
+
if clear_detectors_before:
|
|
462
|
+
self.lm.clear_detectors()
|
|
463
|
+
|
|
464
|
+
model: nn.Module | None = self.lm.model
|
|
465
|
+
if model is None:
|
|
466
|
+
raise ValueError("Model must be initialized before running")
|
|
467
|
+
|
|
468
|
+
store = self.lm.store
|
|
469
|
+
if store is None:
|
|
470
|
+
raise ValueError("Store must be provided or set on the language model")
|
|
471
|
+
|
|
472
|
+
device = get_device_from_model(model)
|
|
473
|
+
device_type = str(device.type)
|
|
474
|
+
|
|
475
|
+
options = {
|
|
476
|
+
"max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
|
|
477
|
+
"batch_size": int(batch_size),
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
run_name, meta = self._prepare_run_metadata(dataset=dataset, run_name=run_name, options=options)
|
|
481
|
+
|
|
482
|
+
if verbose:
|
|
483
|
+
logger.info(
|
|
484
|
+
f"Starting infer_dataset: run={run_name}, "
|
|
485
|
+
f"batch_size={batch_size}, device={device_type}"
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
self._save_run_metadata(store, run_name, meta, verbose)
|
|
489
|
+
|
|
490
|
+
batch_counter = 0
|
|
491
|
+
|
|
492
|
+
with torch.inference_mode():
|
|
493
|
+
for batch_index, batch in enumerate(dataset.iter_batches(batch_size)):
|
|
494
|
+
if not batch:
|
|
495
|
+
continue
|
|
496
|
+
|
|
497
|
+
texts = dataset.extract_texts_from_batch(batch)
|
|
498
|
+
|
|
499
|
+
self.execute_inference(
|
|
500
|
+
texts,
|
|
501
|
+
tok_kwargs=tok_kwargs,
|
|
502
|
+
autocast=autocast,
|
|
503
|
+
autocast_dtype=autocast_dtype,
|
|
504
|
+
with_controllers=with_controllers,
|
|
505
|
+
stop_after_layer=stop_after_layer,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
self.lm.save_detector_metadata(run_name, batch_index, unified=not save_in_batches)
|
|
509
|
+
|
|
510
|
+
batch_counter += 1
|
|
511
|
+
|
|
512
|
+
if device_type == "cuda" and free_cuda_cache_every and free_cuda_cache_every > 0:
|
|
513
|
+
if (batch_counter % free_cuda_cache_every) == 0:
|
|
514
|
+
torch.cuda.empty_cache()
|
|
515
|
+
if verbose:
|
|
516
|
+
logger.info("Emptied CUDA cache")
|
|
517
|
+
|
|
518
|
+
if verbose:
|
|
519
|
+
logger.info(f"Saved batch {batch_index} for run={run_name}")
|
|
520
|
+
|
|
521
|
+
if verbose:
|
|
522
|
+
logger.info(f"Completed infer_dataset: run={run_name}, batches_saved={batch_counter}")
|
|
523
|
+
|
|
524
|
+
return run_name
|
|
525
|
+
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Model initialization and factory methods."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
|
11
|
+
|
|
12
|
+
from amber.store.store import Store
|
|
13
|
+
from amber.language_model.utils import extract_model_id
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from amber.language_model.language_model import LanguageModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def initialize_model_id(
|
|
20
|
+
model: nn.Module,
|
|
21
|
+
provided_model_id: str | None = None
|
|
22
|
+
) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Initialize model ID for LanguageModel.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: PyTorch model module
|
|
28
|
+
provided_model_id: Optional model ID provided by user
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Model ID string
|
|
32
|
+
"""
|
|
33
|
+
return extract_model_id(model, provided_model_id)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def create_from_huggingface(
|
|
37
|
+
cls: type["LanguageModel"],
|
|
38
|
+
model_name: str,
|
|
39
|
+
store: Store,
|
|
40
|
+
tokenizer_params: dict | None = None,
|
|
41
|
+
model_params: dict | None = None,
|
|
42
|
+
) -> "LanguageModel":
|
|
43
|
+
"""
|
|
44
|
+
Load a language model from HuggingFace Hub.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
cls: LanguageModel class
|
|
48
|
+
model_name: HuggingFace model identifier
|
|
49
|
+
store: Store instance for persistence
|
|
50
|
+
tokenizer_params: Optional tokenizer parameters
|
|
51
|
+
model_params: Optional model parameters
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
LanguageModel instance
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: If model_name is invalid
|
|
58
|
+
RuntimeError: If model loading fails
|
|
59
|
+
"""
|
|
60
|
+
if not model_name or not isinstance(model_name, str) or not model_name.strip():
|
|
61
|
+
raise ValueError(f"model_name must be a non-empty string, got: {model_name!r}")
|
|
62
|
+
|
|
63
|
+
if store is None:
|
|
64
|
+
raise ValueError("store cannot be None")
|
|
65
|
+
|
|
66
|
+
if tokenizer_params is None:
|
|
67
|
+
tokenizer_params = {}
|
|
68
|
+
if model_params is None:
|
|
69
|
+
model_params = {}
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_params)
|
|
73
|
+
model = AutoModelForCausalLM.from_pretrained(model_name, **model_params)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
raise RuntimeError(
|
|
76
|
+
f"Failed to load model '{model_name}' from HuggingFace. Error: {e}"
|
|
77
|
+
) from e
|
|
78
|
+
|
|
79
|
+
return cls(model, tokenizer, store)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def create_from_local_torch(
|
|
83
|
+
cls: type["LanguageModel"],
|
|
84
|
+
model_path: str,
|
|
85
|
+
tokenizer_path: str,
|
|
86
|
+
store: Store
|
|
87
|
+
) -> "LanguageModel":
|
|
88
|
+
"""
|
|
89
|
+
Load a language model from local HuggingFace paths.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
cls: LanguageModel class
|
|
93
|
+
model_path: Path to the model directory or file
|
|
94
|
+
tokenizer_path: Path to the tokenizer directory or file
|
|
95
|
+
store: Store instance for persistence
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
LanguageModel instance
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
FileNotFoundError: If model or tokenizer paths don't exist
|
|
102
|
+
RuntimeError: If model loading fails
|
|
103
|
+
"""
|
|
104
|
+
if store is None:
|
|
105
|
+
raise ValueError("store cannot be None")
|
|
106
|
+
|
|
107
|
+
model_path_obj = Path(model_path)
|
|
108
|
+
tokenizer_path_obj = Path(tokenizer_path)
|
|
109
|
+
|
|
110
|
+
if not model_path_obj.exists():
|
|
111
|
+
raise FileNotFoundError(f"Model path does not exist: {model_path}")
|
|
112
|
+
|
|
113
|
+
if not tokenizer_path_obj.exists():
|
|
114
|
+
raise FileNotFoundError(f"Tokenizer path does not exist: {tokenizer_path}")
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
118
|
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
|
119
|
+
except Exception as e:
|
|
120
|
+
raise RuntimeError(
|
|
121
|
+
f"Failed to load model from local paths. "
|
|
122
|
+
f"model_path={model_path!r}, tokenizer_path={tokenizer_path!r}. Error: {e}"
|
|
123
|
+
) from e
|
|
124
|
+
|
|
125
|
+
return cls(model, tokenizer, store)
|
|
126
|
+
|