mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,390 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+ from typing import Sequence, Any, Dict, List, TYPE_CHECKING, Set
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+ from transformers import PreTrainedTokenizerBase
10
+
11
+ from amber.language_model.layers import LanguageModelLayers
12
+ from amber.language_model.tokenizer import LanguageModelTokenizer
13
+ from amber.language_model.activations import LanguageModelActivations
14
+ from amber.language_model.context import LanguageModelContext
15
+ from amber.language_model.inference import InferenceEngine
16
+ from amber.language_model.persistence import save_model, load_model_from_saved_file
17
+ from amber.language_model.initialization import initialize_model_id, create_from_huggingface, create_from_local_torch
18
+ from amber.store.store import Store
19
+ from amber.utils import get_logger
20
+
21
+ if TYPE_CHECKING:
22
+ from amber.mechanistic.sae.concepts.input_tracker import InputTracker
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ def _extract_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> Set[int]:
28
+ """
29
+ Extract special token IDs from a tokenizer.
30
+
31
+ Prioritizes the common case (all_special_ids) and falls back to
32
+ individual token ID attributes for edge cases.
33
+
34
+ Handles cases where token_id attributes may be lists (e.g., eos_token_id: [4, 2]).
35
+
36
+ Args:
37
+ tokenizer: HuggingFace tokenizer
38
+
39
+ Returns:
40
+ Set of special token IDs
41
+ """
42
+ special_ids = set()
43
+
44
+ # Common case: most tokenizers have all_special_ids
45
+ if hasattr(tokenizer, 'all_special_ids'):
46
+ all_special_ids = tokenizer.all_special_ids
47
+ if all_special_ids and isinstance(all_special_ids, (list, tuple, set)):
48
+ special_ids.update(all_special_ids)
49
+ return special_ids # Early return for common case
50
+
51
+ # Fallback: extract from individual token ID attributes
52
+ def add_token_id(token_id):
53
+ if token_id is None:
54
+ return
55
+ if isinstance(token_id, (list, tuple)):
56
+ special_ids.update(token_id)
57
+ else:
58
+ special_ids.add(token_id)
59
+
60
+ token_id_attrs = ['pad_token_id', 'eos_token_id', 'bos_token_id', 'unk_token_id',
61
+ 'cls_token_id', 'sep_token_id', 'mask_token_id']
62
+ for attr in token_id_attrs:
63
+ token_id = getattr(tokenizer, attr, None)
64
+ add_token_id(token_id)
65
+
66
+ return special_ids
67
+
68
+
69
+ class LanguageModel:
70
+ """
71
+ Fence-style language model wrapper.
72
+
73
+ Provides a unified interface for working with language models, including:
74
+ - Model initialization and configuration
75
+ - Inference operations
76
+ - Hook management (detectors and controllers)
77
+ - Model persistence
78
+ - Activation tracking
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ model: nn.Module,
84
+ tokenizer: PreTrainedTokenizerBase,
85
+ store: Store,
86
+ model_id: str | None = None,
87
+ ):
88
+ """
89
+ Initialize LanguageModel.
90
+
91
+ Args:
92
+ model: PyTorch model module
93
+ tokenizer: HuggingFace tokenizer
94
+ store: Store instance for persistence
95
+ model_id: Optional model identifier (auto-extracted if not provided)
96
+ """
97
+ self.context = LanguageModelContext(self)
98
+ self.context.model = model
99
+ self.context.tokenizer = tokenizer
100
+ self.context.model_id = initialize_model_id(model, model_id)
101
+ self.context.store = store
102
+ self.context.special_token_ids = _extract_special_token_ids(tokenizer)
103
+
104
+ self.layers = LanguageModelLayers(self.context)
105
+ self.lm_tokenizer = LanguageModelTokenizer(self.context)
106
+ self.activations = LanguageModelActivations(self.context)
107
+ self.inference = InferenceEngine(self)
108
+ self._inference_engine = self.inference
109
+
110
+ self._input_tracker: "InputTracker | None" = None
111
+
112
+ @property
113
+ def model(self) -> nn.Module:
114
+ """Get the underlying PyTorch model."""
115
+ return self.context.model
116
+
117
+ @property
118
+ def tokenizer(self) -> PreTrainedTokenizerBase:
119
+ """Get the tokenizer."""
120
+ return self.context.tokenizer
121
+
122
+ @property
123
+ def model_id(self) -> str:
124
+ """Get the model identifier."""
125
+ return self.context.model_id
126
+
127
+ @property
128
+ def store(self) -> Store:
129
+ """Get the store instance."""
130
+ return self.context.store
131
+
132
+ @store.setter
133
+ def store(self, value: Store) -> None:
134
+ """Set the store instance."""
135
+ self.context.store = value
136
+
137
+ def tokenize(self, texts: Sequence[str], **kwargs: Any):
138
+ """
139
+ Tokenize texts using the language model tokenizer.
140
+
141
+ Args:
142
+ texts: Sequence of text strings to tokenize
143
+ **kwargs: Additional tokenizer arguments
144
+
145
+ Returns:
146
+ Tokenized encodings
147
+ """
148
+ return self.lm_tokenizer.tokenize(texts, **kwargs)
149
+
150
+ def forwards(
151
+ self,
152
+ texts: Sequence[str],
153
+ tok_kwargs: Dict | None = None,
154
+ autocast: bool = True,
155
+ autocast_dtype: torch.dtype | None = None,
156
+ with_controllers: bool = True,
157
+ ):
158
+ """
159
+ Run forward pass on texts.
160
+
161
+ Args:
162
+ texts: Input texts to process
163
+ tok_kwargs: Optional tokenizer keyword arguments
164
+ autocast: Whether to use automatic mixed precision
165
+ autocast_dtype: Optional dtype for autocast
166
+ with_controllers: Whether to use controllers during inference
167
+
168
+ Returns:
169
+ Tuple of (model_output, encodings)
170
+ """
171
+ return self._inference_engine.execute_inference(
172
+ texts,
173
+ tok_kwargs=tok_kwargs,
174
+ autocast=autocast,
175
+ autocast_dtype=autocast_dtype,
176
+ with_controllers=with_controllers
177
+ )
178
+
179
+ def generate(
180
+ self,
181
+ texts: Sequence[str],
182
+ tok_kwargs: Dict | None = None,
183
+ autocast: bool = True,
184
+ autocast_dtype: torch.dtype | None = None,
185
+ with_controllers: bool = True,
186
+ skip_special_tokens: bool = True,
187
+ ) -> Sequence[str]:
188
+ """
189
+ Run inference and automatically decode the output with the tokenizer.
190
+
191
+ Args:
192
+ texts: Input texts to process
193
+ tok_kwargs: Optional tokenizer keyword arguments
194
+ autocast: Whether to use automatic mixed precision
195
+ autocast_dtype: Optional dtype for autocast
196
+ with_controllers: Whether to use controllers during inference
197
+ skip_special_tokens: Whether to skip special tokens when decoding
198
+
199
+ Returns:
200
+ Sequence of decoded text strings
201
+
202
+ Raises:
203
+ ValueError: If texts is empty or tokenizer is None
204
+ """
205
+ if not texts:
206
+ raise ValueError("Texts list cannot be empty")
207
+
208
+ if self.tokenizer is None:
209
+ raise ValueError("Tokenizer is required for decoding but is None")
210
+
211
+ output, enc = self._inference_engine.execute_inference(
212
+ texts,
213
+ tok_kwargs=tok_kwargs,
214
+ autocast=autocast,
215
+ autocast_dtype=autocast_dtype,
216
+ with_controllers=with_controllers
217
+ )
218
+
219
+ logits = self._inference_engine.extract_logits(output)
220
+ predicted_token_ids = logits.argmax(dim=-1)
221
+
222
+ decoded_texts = []
223
+ for i in range(predicted_token_ids.shape[0]):
224
+ token_ids = predicted_token_ids[i].cpu().tolist()
225
+ decoded_text = self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
226
+ decoded_texts.append(decoded_text)
227
+
228
+ return decoded_texts
229
+
230
+ def get_input_tracker(self) -> "InputTracker | None":
231
+ """
232
+ Get the input tracker instance if it exists.
233
+
234
+ Returns:
235
+ InputTracker instance or None
236
+ """
237
+ return self._input_tracker
238
+
239
+ def get_all_detector_metadata(self) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Tensor]]]:
240
+ """
241
+ Get metadata from all registered detectors.
242
+
243
+ Returns:
244
+ Tuple of (detectors_metadata, detectors_tensor_metadata)
245
+ """
246
+ detectors = self.layers.get_detectors()
247
+ detectors_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict)
248
+ detectors_tensor_metadata: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
249
+
250
+ for detector in detectors:
251
+ detectors_metadata[detector.layer_signature] = detector.metadata
252
+ detectors_tensor_metadata[detector.layer_signature] = detector.tensor_metadata
253
+
254
+ return detectors_metadata, detectors_tensor_metadata
255
+
256
+ def clear_detectors(self) -> None:
257
+ """
258
+ Clear all accumulated metadata for registered detectors.
259
+
260
+ This is useful when running multiple independent inference runs
261
+ (e.g. separate `infer_texts` / `infer_dataset` calls) and you want
262
+ to ensure that detector state does not leak between runs.
263
+ """
264
+ detectors = self.layers.get_detectors()
265
+ for detector in detectors:
266
+ # Clear generic accumulated metadata
267
+ detector.metadata.clear()
268
+ detector.tensor_metadata.clear()
269
+
270
+ # Allow detector implementations to provide more specialized
271
+ # clearing logic (e.g. ModelInputDetector, ModelOutputDetector).
272
+ clear_captured = getattr(detector, "clear_captured", None)
273
+ if callable(clear_captured):
274
+ clear_captured()
275
+
276
+ def save_detector_metadata(self, run_name: str, batch_idx: int | None, unified: bool = False) -> str:
277
+ """
278
+ Save detector metadata to store.
279
+
280
+ Args:
281
+ run_name: Name of the run
282
+ batch_idx: Batch index. Ignored when ``unified`` is True.
283
+ unified: If True, save metadata in a single detectors directory
284
+ for the whole run instead of per‑batch directories.
285
+
286
+ Returns:
287
+ Path where metadata was saved
288
+
289
+ Raises:
290
+ ValueError: If store is not set
291
+ """
292
+ if self.store is None:
293
+ raise ValueError("Store must be provided or set on the language model")
294
+ detectors_metadata, detectors_tensor_metadata = self.get_all_detector_metadata()
295
+ if unified:
296
+ return self.store.put_run_detector_metadata(run_name, detectors_metadata, detectors_tensor_metadata)
297
+ if batch_idx is None:
298
+ raise ValueError("batch_idx must be provided when unified is False")
299
+ return self.store.put_detector_metadata(run_name, batch_idx, detectors_metadata, detectors_tensor_metadata)
300
+
301
+ def _ensure_input_tracker(self) -> "InputTracker":
302
+ """
303
+ Ensure InputTracker singleton exists.
304
+
305
+ Returns:
306
+ The InputTracker instance
307
+ """
308
+ if self._input_tracker is not None:
309
+ return self._input_tracker
310
+
311
+ from amber.mechanistic.sae.concepts.input_tracker import InputTracker
312
+
313
+ self._input_tracker = InputTracker(language_model=self)
314
+
315
+ logger.debug(f"Created InputTracker singleton for {self.context.model_id}")
316
+
317
+ return self._input_tracker
318
+
319
+ def save_model(self, path: Path | str | None = None) -> Path:
320
+ """
321
+ Save the model and its metadata to the store.
322
+
323
+ Args:
324
+ path: Optional path to save the model. If None, defaults to {model_id}/model.pt
325
+ relative to the store base path.
326
+
327
+ Returns:
328
+ Path where the model was saved
329
+
330
+ Raises:
331
+ ValueError: If store is not set
332
+ """
333
+ return save_model(self, path)
334
+
335
+ @classmethod
336
+ def from_huggingface(
337
+ cls,
338
+ model_name: str,
339
+ store: Store,
340
+ tokenizer_params: dict = None,
341
+ model_params: dict = None,
342
+ ) -> "LanguageModel":
343
+ """
344
+ Load a language model from HuggingFace Hub.
345
+
346
+ Args:
347
+ model_name: HuggingFace model identifier
348
+ store: Store instance for persistence
349
+ tokenizer_params: Optional tokenizer parameters
350
+ model_params: Optional model parameters
351
+
352
+ Returns:
353
+ LanguageModel instance
354
+ """
355
+ return create_from_huggingface(cls, model_name, store, tokenizer_params, model_params)
356
+
357
+ @classmethod
358
+ def from_local_torch(cls, model_path: str, tokenizer_path: str, store: Store) -> "LanguageModel":
359
+ """
360
+ Load a language model from local HuggingFace paths.
361
+
362
+ Args:
363
+ model_path: Path to the model directory or file
364
+ tokenizer_path: Path to the tokenizer directory or file
365
+ store: Store instance for persistence
366
+
367
+ Returns:
368
+ LanguageModel instance
369
+ """
370
+ return create_from_local_torch(cls, model_path, tokenizer_path, store)
371
+
372
+ @classmethod
373
+ def from_local(cls, saved_path: Path | str, store: Store, model_id: str | None = None) -> "LanguageModel":
374
+ """
375
+ Load a language model from a saved file (created by save_model).
376
+
377
+ Args:
378
+ saved_path: Path to the saved model file (.pt file)
379
+ store: Store instance for persistence
380
+ model_id: Optional model identifier. If not provided, will use the model_id from saved metadata.
381
+ If provided, will be used to load the model architecture from HuggingFace.
382
+
383
+ Returns:
384
+ LanguageModel instance
385
+
386
+ Raises:
387
+ FileNotFoundError: If the saved file doesn't exist
388
+ ValueError: If the saved file format is invalid or model_id is required but not provided
389
+ """
390
+ return load_model_from_saved_file(cls, saved_path, store, model_id)