mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Sequence, Any, Dict, List, TYPE_CHECKING, Set
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn, Tensor
|
|
9
|
+
from transformers import PreTrainedTokenizerBase
|
|
10
|
+
|
|
11
|
+
from amber.language_model.layers import LanguageModelLayers
|
|
12
|
+
from amber.language_model.tokenizer import LanguageModelTokenizer
|
|
13
|
+
from amber.language_model.activations import LanguageModelActivations
|
|
14
|
+
from amber.language_model.context import LanguageModelContext
|
|
15
|
+
from amber.language_model.inference import InferenceEngine
|
|
16
|
+
from amber.language_model.persistence import save_model, load_model_from_saved_file
|
|
17
|
+
from amber.language_model.initialization import initialize_model_id, create_from_huggingface, create_from_local_torch
|
|
18
|
+
from amber.store.store import Store
|
|
19
|
+
from amber.utils import get_logger
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from amber.mechanistic.sae.concepts.input_tracker import InputTracker
|
|
23
|
+
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _extract_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> Set[int]:
|
|
28
|
+
"""
|
|
29
|
+
Extract special token IDs from a tokenizer.
|
|
30
|
+
|
|
31
|
+
Prioritizes the common case (all_special_ids) and falls back to
|
|
32
|
+
individual token ID attributes for edge cases.
|
|
33
|
+
|
|
34
|
+
Handles cases where token_id attributes may be lists (e.g., eos_token_id: [4, 2]).
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
tokenizer: HuggingFace tokenizer
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Set of special token IDs
|
|
41
|
+
"""
|
|
42
|
+
special_ids = set()
|
|
43
|
+
|
|
44
|
+
# Common case: most tokenizers have all_special_ids
|
|
45
|
+
if hasattr(tokenizer, 'all_special_ids'):
|
|
46
|
+
all_special_ids = tokenizer.all_special_ids
|
|
47
|
+
if all_special_ids and isinstance(all_special_ids, (list, tuple, set)):
|
|
48
|
+
special_ids.update(all_special_ids)
|
|
49
|
+
return special_ids # Early return for common case
|
|
50
|
+
|
|
51
|
+
# Fallback: extract from individual token ID attributes
|
|
52
|
+
def add_token_id(token_id):
|
|
53
|
+
if token_id is None:
|
|
54
|
+
return
|
|
55
|
+
if isinstance(token_id, (list, tuple)):
|
|
56
|
+
special_ids.update(token_id)
|
|
57
|
+
else:
|
|
58
|
+
special_ids.add(token_id)
|
|
59
|
+
|
|
60
|
+
token_id_attrs = ['pad_token_id', 'eos_token_id', 'bos_token_id', 'unk_token_id',
|
|
61
|
+
'cls_token_id', 'sep_token_id', 'mask_token_id']
|
|
62
|
+
for attr in token_id_attrs:
|
|
63
|
+
token_id = getattr(tokenizer, attr, None)
|
|
64
|
+
add_token_id(token_id)
|
|
65
|
+
|
|
66
|
+
return special_ids
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class LanguageModel:
|
|
70
|
+
"""
|
|
71
|
+
Fence-style language model wrapper.
|
|
72
|
+
|
|
73
|
+
Provides a unified interface for working with language models, including:
|
|
74
|
+
- Model initialization and configuration
|
|
75
|
+
- Inference operations
|
|
76
|
+
- Hook management (detectors and controllers)
|
|
77
|
+
- Model persistence
|
|
78
|
+
- Activation tracking
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
model: nn.Module,
|
|
84
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
85
|
+
store: Store,
|
|
86
|
+
model_id: str | None = None,
|
|
87
|
+
):
|
|
88
|
+
"""
|
|
89
|
+
Initialize LanguageModel.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
model: PyTorch model module
|
|
93
|
+
tokenizer: HuggingFace tokenizer
|
|
94
|
+
store: Store instance for persistence
|
|
95
|
+
model_id: Optional model identifier (auto-extracted if not provided)
|
|
96
|
+
"""
|
|
97
|
+
self.context = LanguageModelContext(self)
|
|
98
|
+
self.context.model = model
|
|
99
|
+
self.context.tokenizer = tokenizer
|
|
100
|
+
self.context.model_id = initialize_model_id(model, model_id)
|
|
101
|
+
self.context.store = store
|
|
102
|
+
self.context.special_token_ids = _extract_special_token_ids(tokenizer)
|
|
103
|
+
|
|
104
|
+
self.layers = LanguageModelLayers(self.context)
|
|
105
|
+
self.lm_tokenizer = LanguageModelTokenizer(self.context)
|
|
106
|
+
self.activations = LanguageModelActivations(self.context)
|
|
107
|
+
self.inference = InferenceEngine(self)
|
|
108
|
+
self._inference_engine = self.inference
|
|
109
|
+
|
|
110
|
+
self._input_tracker: "InputTracker | None" = None
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def model(self) -> nn.Module:
|
|
114
|
+
"""Get the underlying PyTorch model."""
|
|
115
|
+
return self.context.model
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def tokenizer(self) -> PreTrainedTokenizerBase:
|
|
119
|
+
"""Get the tokenizer."""
|
|
120
|
+
return self.context.tokenizer
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def model_id(self) -> str:
|
|
124
|
+
"""Get the model identifier."""
|
|
125
|
+
return self.context.model_id
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def store(self) -> Store:
|
|
129
|
+
"""Get the store instance."""
|
|
130
|
+
return self.context.store
|
|
131
|
+
|
|
132
|
+
@store.setter
|
|
133
|
+
def store(self, value: Store) -> None:
|
|
134
|
+
"""Set the store instance."""
|
|
135
|
+
self.context.store = value
|
|
136
|
+
|
|
137
|
+
def tokenize(self, texts: Sequence[str], **kwargs: Any):
|
|
138
|
+
"""
|
|
139
|
+
Tokenize texts using the language model tokenizer.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
texts: Sequence of text strings to tokenize
|
|
143
|
+
**kwargs: Additional tokenizer arguments
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Tokenized encodings
|
|
147
|
+
"""
|
|
148
|
+
return self.lm_tokenizer.tokenize(texts, **kwargs)
|
|
149
|
+
|
|
150
|
+
def forwards(
|
|
151
|
+
self,
|
|
152
|
+
texts: Sequence[str],
|
|
153
|
+
tok_kwargs: Dict | None = None,
|
|
154
|
+
autocast: bool = True,
|
|
155
|
+
autocast_dtype: torch.dtype | None = None,
|
|
156
|
+
with_controllers: bool = True,
|
|
157
|
+
):
|
|
158
|
+
"""
|
|
159
|
+
Run forward pass on texts.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
texts: Input texts to process
|
|
163
|
+
tok_kwargs: Optional tokenizer keyword arguments
|
|
164
|
+
autocast: Whether to use automatic mixed precision
|
|
165
|
+
autocast_dtype: Optional dtype for autocast
|
|
166
|
+
with_controllers: Whether to use controllers during inference
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Tuple of (model_output, encodings)
|
|
170
|
+
"""
|
|
171
|
+
return self._inference_engine.execute_inference(
|
|
172
|
+
texts,
|
|
173
|
+
tok_kwargs=tok_kwargs,
|
|
174
|
+
autocast=autocast,
|
|
175
|
+
autocast_dtype=autocast_dtype,
|
|
176
|
+
with_controllers=with_controllers
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def generate(
|
|
180
|
+
self,
|
|
181
|
+
texts: Sequence[str],
|
|
182
|
+
tok_kwargs: Dict | None = None,
|
|
183
|
+
autocast: bool = True,
|
|
184
|
+
autocast_dtype: torch.dtype | None = None,
|
|
185
|
+
with_controllers: bool = True,
|
|
186
|
+
skip_special_tokens: bool = True,
|
|
187
|
+
) -> Sequence[str]:
|
|
188
|
+
"""
|
|
189
|
+
Run inference and automatically decode the output with the tokenizer.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
texts: Input texts to process
|
|
193
|
+
tok_kwargs: Optional tokenizer keyword arguments
|
|
194
|
+
autocast: Whether to use automatic mixed precision
|
|
195
|
+
autocast_dtype: Optional dtype for autocast
|
|
196
|
+
with_controllers: Whether to use controllers during inference
|
|
197
|
+
skip_special_tokens: Whether to skip special tokens when decoding
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Sequence of decoded text strings
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
ValueError: If texts is empty or tokenizer is None
|
|
204
|
+
"""
|
|
205
|
+
if not texts:
|
|
206
|
+
raise ValueError("Texts list cannot be empty")
|
|
207
|
+
|
|
208
|
+
if self.tokenizer is None:
|
|
209
|
+
raise ValueError("Tokenizer is required for decoding but is None")
|
|
210
|
+
|
|
211
|
+
output, enc = self._inference_engine.execute_inference(
|
|
212
|
+
texts,
|
|
213
|
+
tok_kwargs=tok_kwargs,
|
|
214
|
+
autocast=autocast,
|
|
215
|
+
autocast_dtype=autocast_dtype,
|
|
216
|
+
with_controllers=with_controllers
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
logits = self._inference_engine.extract_logits(output)
|
|
220
|
+
predicted_token_ids = logits.argmax(dim=-1)
|
|
221
|
+
|
|
222
|
+
decoded_texts = []
|
|
223
|
+
for i in range(predicted_token_ids.shape[0]):
|
|
224
|
+
token_ids = predicted_token_ids[i].cpu().tolist()
|
|
225
|
+
decoded_text = self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
226
|
+
decoded_texts.append(decoded_text)
|
|
227
|
+
|
|
228
|
+
return decoded_texts
|
|
229
|
+
|
|
230
|
+
def get_input_tracker(self) -> "InputTracker | None":
|
|
231
|
+
"""
|
|
232
|
+
Get the input tracker instance if it exists.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
InputTracker instance or None
|
|
236
|
+
"""
|
|
237
|
+
return self._input_tracker
|
|
238
|
+
|
|
239
|
+
def get_all_detector_metadata(self) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Tensor]]]:
|
|
240
|
+
"""
|
|
241
|
+
Get metadata from all registered detectors.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Tuple of (detectors_metadata, detectors_tensor_metadata)
|
|
245
|
+
"""
|
|
246
|
+
detectors = self.layers.get_detectors()
|
|
247
|
+
detectors_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict)
|
|
248
|
+
detectors_tensor_metadata: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
|
|
249
|
+
|
|
250
|
+
for detector in detectors:
|
|
251
|
+
detectors_metadata[detector.layer_signature] = detector.metadata
|
|
252
|
+
detectors_tensor_metadata[detector.layer_signature] = detector.tensor_metadata
|
|
253
|
+
|
|
254
|
+
return detectors_metadata, detectors_tensor_metadata
|
|
255
|
+
|
|
256
|
+
def clear_detectors(self) -> None:
|
|
257
|
+
"""
|
|
258
|
+
Clear all accumulated metadata for registered detectors.
|
|
259
|
+
|
|
260
|
+
This is useful when running multiple independent inference runs
|
|
261
|
+
(e.g. separate `infer_texts` / `infer_dataset` calls) and you want
|
|
262
|
+
to ensure that detector state does not leak between runs.
|
|
263
|
+
"""
|
|
264
|
+
detectors = self.layers.get_detectors()
|
|
265
|
+
for detector in detectors:
|
|
266
|
+
# Clear generic accumulated metadata
|
|
267
|
+
detector.metadata.clear()
|
|
268
|
+
detector.tensor_metadata.clear()
|
|
269
|
+
|
|
270
|
+
# Allow detector implementations to provide more specialized
|
|
271
|
+
# clearing logic (e.g. ModelInputDetector, ModelOutputDetector).
|
|
272
|
+
clear_captured = getattr(detector, "clear_captured", None)
|
|
273
|
+
if callable(clear_captured):
|
|
274
|
+
clear_captured()
|
|
275
|
+
|
|
276
|
+
def save_detector_metadata(self, run_name: str, batch_idx: int | None, unified: bool = False) -> str:
|
|
277
|
+
"""
|
|
278
|
+
Save detector metadata to store.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
run_name: Name of the run
|
|
282
|
+
batch_idx: Batch index. Ignored when ``unified`` is True.
|
|
283
|
+
unified: If True, save metadata in a single detectors directory
|
|
284
|
+
for the whole run instead of per‑batch directories.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
Path where metadata was saved
|
|
288
|
+
|
|
289
|
+
Raises:
|
|
290
|
+
ValueError: If store is not set
|
|
291
|
+
"""
|
|
292
|
+
if self.store is None:
|
|
293
|
+
raise ValueError("Store must be provided or set on the language model")
|
|
294
|
+
detectors_metadata, detectors_tensor_metadata = self.get_all_detector_metadata()
|
|
295
|
+
if unified:
|
|
296
|
+
return self.store.put_run_detector_metadata(run_name, detectors_metadata, detectors_tensor_metadata)
|
|
297
|
+
if batch_idx is None:
|
|
298
|
+
raise ValueError("batch_idx must be provided when unified is False")
|
|
299
|
+
return self.store.put_detector_metadata(run_name, batch_idx, detectors_metadata, detectors_tensor_metadata)
|
|
300
|
+
|
|
301
|
+
def _ensure_input_tracker(self) -> "InputTracker":
|
|
302
|
+
"""
|
|
303
|
+
Ensure InputTracker singleton exists.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
The InputTracker instance
|
|
307
|
+
"""
|
|
308
|
+
if self._input_tracker is not None:
|
|
309
|
+
return self._input_tracker
|
|
310
|
+
|
|
311
|
+
from amber.mechanistic.sae.concepts.input_tracker import InputTracker
|
|
312
|
+
|
|
313
|
+
self._input_tracker = InputTracker(language_model=self)
|
|
314
|
+
|
|
315
|
+
logger.debug(f"Created InputTracker singleton for {self.context.model_id}")
|
|
316
|
+
|
|
317
|
+
return self._input_tracker
|
|
318
|
+
|
|
319
|
+
def save_model(self, path: Path | str | None = None) -> Path:
|
|
320
|
+
"""
|
|
321
|
+
Save the model and its metadata to the store.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
path: Optional path to save the model. If None, defaults to {model_id}/model.pt
|
|
325
|
+
relative to the store base path.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
Path where the model was saved
|
|
329
|
+
|
|
330
|
+
Raises:
|
|
331
|
+
ValueError: If store is not set
|
|
332
|
+
"""
|
|
333
|
+
return save_model(self, path)
|
|
334
|
+
|
|
335
|
+
@classmethod
|
|
336
|
+
def from_huggingface(
|
|
337
|
+
cls,
|
|
338
|
+
model_name: str,
|
|
339
|
+
store: Store,
|
|
340
|
+
tokenizer_params: dict = None,
|
|
341
|
+
model_params: dict = None,
|
|
342
|
+
) -> "LanguageModel":
|
|
343
|
+
"""
|
|
344
|
+
Load a language model from HuggingFace Hub.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
model_name: HuggingFace model identifier
|
|
348
|
+
store: Store instance for persistence
|
|
349
|
+
tokenizer_params: Optional tokenizer parameters
|
|
350
|
+
model_params: Optional model parameters
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
LanguageModel instance
|
|
354
|
+
"""
|
|
355
|
+
return create_from_huggingface(cls, model_name, store, tokenizer_params, model_params)
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def from_local_torch(cls, model_path: str, tokenizer_path: str, store: Store) -> "LanguageModel":
|
|
359
|
+
"""
|
|
360
|
+
Load a language model from local HuggingFace paths.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
model_path: Path to the model directory or file
|
|
364
|
+
tokenizer_path: Path to the tokenizer directory or file
|
|
365
|
+
store: Store instance for persistence
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
LanguageModel instance
|
|
369
|
+
"""
|
|
370
|
+
return create_from_local_torch(cls, model_path, tokenizer_path, store)
|
|
371
|
+
|
|
372
|
+
@classmethod
|
|
373
|
+
def from_local(cls, saved_path: Path | str, store: Store, model_id: str | None = None) -> "LanguageModel":
|
|
374
|
+
"""
|
|
375
|
+
Load a language model from a saved file (created by save_model).
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
saved_path: Path to the saved model file (.pt file)
|
|
379
|
+
store: Store instance for persistence
|
|
380
|
+
model_id: Optional model identifier. If not provided, will use the model_id from saved metadata.
|
|
381
|
+
If provided, will be used to load the model architecture from HuggingFace.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
LanguageModel instance
|
|
385
|
+
|
|
386
|
+
Raises:
|
|
387
|
+
FileNotFoundError: If the saved file doesn't exist
|
|
388
|
+
ValueError: If the saved file format is invalid or model_id is required but not provided
|
|
389
|
+
"""
|
|
390
|
+
return load_model_from_saved_file(cls, saved_path, store, model_id)
|