mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,250 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Dict, Set, List, Optional
4
+ import torch
5
+
6
+ from amber.hooks.detector import Detector
7
+ from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
+
9
+ if TYPE_CHECKING:
10
+ from torch import nn
11
+
12
+
13
+ class ModelInputDetector(Detector):
14
+ """
15
+ Detector hook that captures and saves tokenized inputs from model forward pass.
16
+
17
+ This detector is designed to be attached to the root model module and captures:
18
+ - Tokenized inputs (input_ids) from the model's forward pass
19
+ - Attention masks (optional) that exclude both padding and special tokens
20
+
21
+ Uses PRE_FORWARD hook to capture inputs before they are processed.
22
+ Useful for saving tokenized inputs for analysis or training.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ layer_signature: str | int | None = None,
28
+ hook_id: str | None = None,
29
+ save_input_ids: bool = True,
30
+ save_attention_mask: bool = False,
31
+ special_token_ids: Optional[List[int] | Set[int]] = None
32
+ ):
33
+ """
34
+ Initialize the model input detector.
35
+
36
+ Args:
37
+ layer_signature: Layer to capture from (typically the root model, can be None)
38
+ hook_id: Unique identifier for this hook
39
+ save_input_ids: Whether to save input_ids tensor
40
+ save_attention_mask: Whether to save attention_mask tensor (excludes padding and special tokens)
41
+ special_token_ids: Optional list/set of special token IDs. If None, will extract from LanguageModel context.
42
+ """
43
+ super().__init__(
44
+ hook_type=HookType.PRE_FORWARD,
45
+ hook_id=hook_id,
46
+ store=None,
47
+ layer_signature=layer_signature
48
+ )
49
+ self.save_input_ids = save_input_ids
50
+ self.save_attention_mask = save_attention_mask
51
+ self.special_token_ids = set(special_token_ids) if special_token_ids is not None else None
52
+
53
+ def _extract_input_ids(self, input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
54
+ """
55
+ Extract input_ids from model input.
56
+
57
+ Handles various input formats:
58
+ - Dict with 'input_ids' key (most common for HuggingFace models)
59
+ - Tuple with dict as first element
60
+ - Tuple with tensor as first element
61
+
62
+ Args:
63
+ input: Input to the model forward pass
64
+
65
+ Returns:
66
+ input_ids tensor or None if not found
67
+ """
68
+ if not input or len(input) == 0:
69
+ return None
70
+
71
+ first_item = input[0]
72
+
73
+ if isinstance(first_item, dict):
74
+ if 'input_ids' in first_item:
75
+ return first_item['input_ids']
76
+ return None
77
+
78
+ if isinstance(first_item, torch.Tensor):
79
+ return first_item
80
+
81
+ return None
82
+
83
+ def _extract_attention_mask(self, input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
84
+ """
85
+ Extract attention_mask from model input.
86
+
87
+ Args:
88
+ input: Input to the model forward pass
89
+
90
+ Returns:
91
+ attention_mask tensor or None if not found
92
+ """
93
+ if not input or len(input) == 0:
94
+ return None
95
+
96
+ first_item = input[0]
97
+
98
+ if isinstance(first_item, dict):
99
+ if 'attention_mask' in first_item:
100
+ return first_item['attention_mask']
101
+
102
+ return None
103
+
104
+ def _get_special_token_ids(self, module: torch.nn.Module) -> Set[int]:
105
+ """
106
+ Get special token IDs from user-provided list or from LanguageModel context.
107
+
108
+ Priority order:
109
+ 1. self.special_token_ids (user-provided during initialization)
110
+ 2. self.context.special_token_ids (from LanguageModel initialization)
111
+
112
+ Args:
113
+ module: The PyTorch module being hooked (unused, kept for API compatibility)
114
+
115
+ Returns:
116
+ Set of special token IDs, or empty set if none available
117
+ """
118
+ if self.special_token_ids is not None:
119
+ return self.special_token_ids
120
+
121
+ if self.context is not None and self.context.special_token_ids is not None:
122
+ return self.context.special_token_ids
123
+
124
+ return set()
125
+
126
+ def _create_combined_attention_mask(
127
+ self,
128
+ input_ids: torch.Tensor,
129
+ attention_mask: torch.Tensor | None,
130
+ module: torch.nn.Module
131
+ ) -> torch.Tensor:
132
+ """
133
+ Create a combined attention mask that excludes both padding and special tokens.
134
+
135
+ Args:
136
+ input_ids: Input token IDs tensor (batch_size × sequence_length)
137
+ attention_mask: Original attention mask from tokenizer (None if not provided)
138
+ module: The PyTorch module being hooked
139
+
140
+ Returns:
141
+ Binary mask tensor with same shape as input_ids (1 for regular tokens, 0 for padding/special tokens)
142
+ """
143
+ if attention_mask is None:
144
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
145
+ else:
146
+ attention_mask = attention_mask.bool()
147
+
148
+ special_token_ids = self._get_special_token_ids(module)
149
+
150
+ if special_token_ids:
151
+ special_ids_tensor = torch.tensor(list(special_token_ids), device=input_ids.device, dtype=input_ids.dtype)
152
+ expanded_input = input_ids.unsqueeze(-1)
153
+ expanded_special = special_ids_tensor.unsqueeze(0).unsqueeze(0)
154
+ is_special = (expanded_input == expanded_special).any(dim=-1)
155
+ attention_mask = attention_mask & ~is_special
156
+
157
+ return attention_mask.to(torch.bool)
158
+
159
+ def set_inputs_from_encodings(self, encodings: Dict[str, torch.Tensor], module: Optional[torch.nn.Module] = None) -> None:
160
+ """
161
+ Manually set inputs from encodings dictionary.
162
+
163
+ This is useful when the model is called with keyword arguments,
164
+ as PyTorch's pre_forward hook doesn't receive kwargs.
165
+
166
+ Args:
167
+ encodings: Dictionary of encoded inputs (e.g., from lm.forwards() or lm.tokenize())
168
+ module: Optional module for extracting special token IDs. If None, will use DummyModule.
169
+
170
+ Raises:
171
+ RuntimeError: If tensor extraction or storage fails
172
+ """
173
+ try:
174
+ if self.save_input_ids and 'input_ids' in encodings:
175
+ input_ids = encodings['input_ids']
176
+ self.tensor_metadata['input_ids'] = input_ids.detach().to("cpu")
177
+ self.metadata['input_ids_shape'] = tuple(input_ids.shape)
178
+
179
+ if self.save_attention_mask and 'input_ids' in encodings:
180
+ input_ids = encodings['input_ids']
181
+ if module is None:
182
+ class DummyModule:
183
+ pass
184
+ module = DummyModule()
185
+
186
+ original_attention_mask = encodings.get('attention_mask')
187
+ combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
188
+ self.tensor_metadata['attention_mask'] = combined_mask.detach().to("cpu")
189
+ self.metadata['attention_mask_shape'] = tuple(combined_mask.shape)
190
+ except Exception as e:
191
+ raise RuntimeError(
192
+ f"Error setting inputs from encodings in ModelInputDetector {self.id}: {e}"
193
+ ) from e
194
+
195
+ def process_activations(
196
+ self,
197
+ module: torch.nn.Module,
198
+ input: HOOK_FUNCTION_INPUT,
199
+ output: HOOK_FUNCTION_OUTPUT
200
+ ) -> None:
201
+ """
202
+ Extract and store tokenized inputs.
203
+
204
+ Note: For HuggingFace models called with **kwargs, the input tuple may be empty.
205
+ In such cases, use set_inputs_from_encodings() to manually set inputs from
206
+ the encodings dictionary returned by lm.forwards().
207
+
208
+ Args:
209
+ module: The PyTorch module being hooked (typically the root model)
210
+ input: Tuple of input tensors/dicts to the module
211
+ output: Output from the module (None for PRE_FORWARD hooks)
212
+
213
+ Raises:
214
+ RuntimeError: If tensor extraction or storage fails
215
+ """
216
+ try:
217
+ if self.save_input_ids:
218
+ input_ids = self._extract_input_ids(input)
219
+ if input_ids is not None:
220
+ self.tensor_metadata['input_ids'] = input_ids.detach().to("cpu")
221
+ self.metadata['input_ids_shape'] = tuple(input_ids.shape)
222
+
223
+ if self.save_attention_mask:
224
+ input_ids = self._extract_input_ids(input)
225
+ if input_ids is not None:
226
+ original_attention_mask = self._extract_attention_mask(input)
227
+ combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
228
+ self.tensor_metadata['attention_mask'] = combined_mask.detach().to("cpu")
229
+ self.metadata['attention_mask_shape'] = tuple(combined_mask.shape)
230
+
231
+ except Exception as e:
232
+ raise RuntimeError(
233
+ f"Error extracting inputs in ModelInputDetector {self.id}: {e}"
234
+ ) from e
235
+
236
+ def get_captured_input_ids(self) -> torch.Tensor | None:
237
+ """Get the captured input_ids from the current batch."""
238
+ return self.tensor_metadata.get('input_ids')
239
+
240
+ def get_captured_attention_mask(self) -> torch.Tensor | None:
241
+ """Get the captured attention_mask from the current batch (excludes padding and special tokens)."""
242
+ return self.tensor_metadata.get('attention_mask')
243
+
244
+ def clear_captured(self) -> None:
245
+ """Clear all captured inputs for current batch."""
246
+ keys_to_remove = ['input_ids', 'attention_mask']
247
+ for key in keys_to_remove:
248
+ self.tensor_metadata.pop(key, None)
249
+ self.metadata.pop(f'{key}_shape', None)
250
+
@@ -0,0 +1,132 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+ import torch
5
+
6
+ from amber.hooks.detector import Detector
7
+ from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
+
9
+ if TYPE_CHECKING:
10
+ from torch import nn
11
+
12
+
13
+ class ModelOutputDetector(Detector):
14
+ """
15
+ Detector hook that captures and saves model outputs.
16
+
17
+ This detector is designed to be attached to the root model module and captures:
18
+ - Model outputs (logits) from the model's forward pass
19
+ - Hidden states (optional) from the model's forward pass
20
+
21
+ Uses FORWARD hook to capture outputs after they are computed.
22
+ Useful for saving model outputs for analysis or training.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ layer_signature: str | int | None = None,
28
+ hook_id: str | None = None,
29
+ save_output_logits: bool = True,
30
+ save_output_hidden_state: bool = False
31
+ ):
32
+ """
33
+ Initialize the model output detector.
34
+
35
+ Args:
36
+ layer_signature: Layer to capture from (typically the root model, can be None)
37
+ hook_id: Unique identifier for this hook
38
+ save_output_logits: Whether to save output logits (if available)
39
+ save_output_hidden_state: Whether to save last_hidden_state (if available)
40
+ """
41
+ super().__init__(
42
+ hook_type=HookType.FORWARD,
43
+ hook_id=hook_id,
44
+ store=None,
45
+ layer_signature=layer_signature
46
+ )
47
+ self.save_output_logits = save_output_logits
48
+ self.save_output_hidden_state = save_output_hidden_state
49
+
50
+ def _extract_output_tensor(self, output: HOOK_FUNCTION_OUTPUT) -> tuple[torch.Tensor | None, torch.Tensor | None]:
51
+ """
52
+ Extract logits and last_hidden_state from model output.
53
+
54
+ Args:
55
+ output: Output from the model forward pass
56
+
57
+ Returns:
58
+ Tuple of (logits, last_hidden_state), either can be None
59
+ """
60
+ logits = None
61
+ hidden_state = None
62
+
63
+ if output is None:
64
+ return None, None
65
+
66
+ # Handle HuggingFace output objects
67
+ if hasattr(output, "logits"):
68
+ logits = output.logits
69
+ if hasattr(output, "last_hidden_state"):
70
+ hidden_state = output.last_hidden_state
71
+
72
+ # Handle tuple output (logits might be first element)
73
+ if isinstance(output, (tuple, list)) and len(output) > 0:
74
+ first_item = output[0]
75
+ if isinstance(first_item, torch.Tensor) and logits is None:
76
+ logits = first_item
77
+
78
+ # Handle direct tensor output
79
+ if isinstance(output, torch.Tensor) and logits is None:
80
+ logits = output
81
+
82
+ return logits, hidden_state
83
+
84
+ def process_activations(
85
+ self,
86
+ module: torch.nn.Module,
87
+ input: HOOK_FUNCTION_INPUT,
88
+ output: HOOK_FUNCTION_OUTPUT
89
+ ) -> None:
90
+ """
91
+ Extract and store model outputs.
92
+
93
+ Args:
94
+ module: The PyTorch module being hooked (typically the root model)
95
+ input: Tuple of input tensors/dicts to the module
96
+ output: Output from the module
97
+
98
+ Raises:
99
+ RuntimeError: If tensor extraction or storage fails
100
+ """
101
+ try:
102
+ # Extract and save outputs
103
+ logits, hidden_state = self._extract_output_tensor(output)
104
+
105
+ if self.save_output_logits and logits is not None:
106
+ self.tensor_metadata['output_logits'] = logits.detach().to("cpu")
107
+ self.metadata['output_logits_shape'] = tuple(logits.shape)
108
+
109
+ if self.save_output_hidden_state and hidden_state is not None:
110
+ self.tensor_metadata['output_hidden_state'] = hidden_state.detach().to("cpu")
111
+ self.metadata['output_hidden_state_shape'] = tuple(hidden_state.shape)
112
+
113
+ except Exception as e:
114
+ raise RuntimeError(
115
+ f"Error extracting outputs in ModelOutputDetector {self.id}: {e}"
116
+ ) from e
117
+
118
+ def get_captured_output_logits(self) -> torch.Tensor | None:
119
+ """Get the captured output logits from the current batch."""
120
+ return self.tensor_metadata.get('output_logits')
121
+
122
+ def get_captured_output_hidden_state(self) -> torch.Tensor | None:
123
+ """Get the captured output hidden state from the current batch."""
124
+ return self.tensor_metadata.get('output_hidden_state')
125
+
126
+ def clear_captured(self) -> None:
127
+ """Clear all captured outputs for current batch."""
128
+ keys_to_remove = ['output_logits', 'output_hidden_state']
129
+ for key in keys_to_remove:
130
+ self.tensor_metadata.pop(key, None)
131
+ self.metadata.pop(f'{key}_shape', None)
132
+
amber/hooks/utils.py ADDED
@@ -0,0 +1,76 @@
1
+ """Utility functions for hook implementations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from amber.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
10
+
11
+
12
+ def extract_tensor_from_input(input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
13
+ """
14
+ Extract the first tensor from input sequence.
15
+
16
+ Handles various input formats:
17
+ - Direct tensor in first position
18
+ - Tuple/list of tensors in first position
19
+ - Empty or None inputs
20
+
21
+ Args:
22
+ input: Input sequence (tuple/list of tensors)
23
+
24
+ Returns:
25
+ First tensor found, or None if no tensor found
26
+ """
27
+ if not input or len(input) == 0:
28
+ return None
29
+
30
+ first_item = input[0]
31
+ if isinstance(first_item, torch.Tensor):
32
+ return first_item
33
+
34
+ if isinstance(first_item, (tuple, list)):
35
+ for item in first_item:
36
+ if isinstance(item, torch.Tensor):
37
+ return item
38
+
39
+ return None
40
+
41
+
42
+ def extract_tensor_from_output(output: HOOK_FUNCTION_OUTPUT) -> torch.Tensor | None:
43
+ """
44
+ Extract tensor from output (handles various output types).
45
+
46
+ Handles various output formats:
47
+ - Plain tensors
48
+ - Tuples/lists of tensors (takes first tensor)
49
+ - Objects with last_hidden_state attribute (e.g., HuggingFace outputs)
50
+ - None outputs
51
+
52
+ Args:
53
+ output: Output from module (tensor, tuple, or object with attributes)
54
+
55
+ Returns:
56
+ First tensor found, or None if no tensor found
57
+ """
58
+ if output is None:
59
+ return None
60
+
61
+ if isinstance(output, torch.Tensor):
62
+ return output
63
+
64
+ if isinstance(output, (tuple, list)):
65
+ for item in output:
66
+ if isinstance(item, torch.Tensor):
67
+ return item
68
+
69
+ # Try common HuggingFace output objects
70
+ if hasattr(output, "last_hidden_state"):
71
+ maybe = getattr(output, "last_hidden_state")
72
+ if isinstance(maybe, torch.Tensor):
73
+ return maybe
74
+
75
+ return None
76
+
File without changes