mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import Any, TYPE_CHECKING, Dict
5
+
6
+ import torch
7
+
8
+ from amber.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
9
+ from amber.store.store import Store
10
+
11
+ if TYPE_CHECKING:
12
+ pass
13
+
14
+
15
+ class Detector(Hook):
16
+ """
17
+ Abstract base class for detector hooks that collect metadata during inference.
18
+
19
+ Detectors can accumulate data across batches and optionally save it to a Store.
20
+ They are designed to observe and record information without modifying activations.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ hook_type: HookType | str = HookType.FORWARD,
26
+ hook_id: str | None = None,
27
+ store: Store | None = None,
28
+ layer_signature: str | int | None = None
29
+ ):
30
+ """
31
+ Initialize a detector hook.
32
+
33
+ Args:
34
+ hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
35
+ hook_id: Unique identifier
36
+ store: Optional Store for saving metadata
37
+ layer_signature: Layer to attach to (optional, for compatibility)
38
+ """
39
+ super().__init__(layer_signature=layer_signature, hook_type=hook_type, hook_id=hook_id)
40
+ self.store = store
41
+ self.metadata: Dict[str, Any] = {}
42
+ self.tensor_metadata: Dict[str, torch.Tensor] = {}
43
+
44
+ def _hook_fn(
45
+ self,
46
+ module: torch.nn.Module,
47
+ input: HOOK_FUNCTION_INPUT,
48
+ output: HOOK_FUNCTION_OUTPUT
49
+ ) -> None:
50
+ """
51
+ Internal hook function that collects metadata.
52
+
53
+ This calls process_activations to allow subclasses to implement
54
+ their specific detection logic.
55
+
56
+ Args:
57
+ module: The PyTorch module being hooked
58
+ input: Tuple of input tensors to the module
59
+ output: Output tensor(s) from the module
60
+
61
+ Raises:
62
+ Exception: If process_activations raises an exception
63
+ """
64
+ if not self._enabled:
65
+ return None
66
+ try:
67
+ self.process_activations(module, input, output)
68
+ except Exception as e:
69
+ raise RuntimeError(
70
+ f"Error in detector {self.id} process_activations: {e}"
71
+ ) from e
72
+ return None
73
+
74
+ @abc.abstractmethod
75
+ def process_activations(
76
+ self,
77
+ module: torch.nn.Module,
78
+ input: HOOK_FUNCTION_INPUT,
79
+ output: HOOK_FUNCTION_OUTPUT
80
+ ) -> None:
81
+ """
82
+ Process activations from the hooked layer.
83
+
84
+ This is where detector-specific logic goes (e.g., tracking top activations,
85
+ computing statistics, etc.).
86
+
87
+ Args:
88
+ module: The PyTorch module being hooked
89
+ input: Tuple of input tensors to the module
90
+ output: Output tensor(s) from the module
91
+
92
+ Raises:
93
+ Exception: Subclasses may raise exceptions for invalid inputs or processing errors
94
+ """
95
+ raise NotImplementedError("process_activations must be implemented by subclasses")
amber/hooks/hook.py ADDED
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import uuid
5
+ from enum import Enum
6
+ from typing import Callable, TypeAlias, Sequence, Optional, TYPE_CHECKING
7
+
8
+ import torch
9
+ from torch import nn, Tensor
10
+ from torch.types import _TensorOrTensors
11
+
12
+ if TYPE_CHECKING:
13
+ from amber.language_model.context import LanguageModelContext
14
+
15
+
16
+ class HookType(str, Enum):
17
+ """Type of hook to register on a layer."""
18
+ FORWARD = "forward"
19
+ PRE_FORWARD = "pre_forward"
20
+
21
+
22
+ HOOK_FUNCTION_INPUT: TypeAlias = Sequence[Tensor]
23
+ HOOK_FUNCTION_OUTPUT: TypeAlias = _TensorOrTensors | None
24
+
25
+
26
+ class HookError(Exception):
27
+ """Exception raised when a hook encounters an error during execution."""
28
+
29
+ def __init__(self, hook_id: str, hook_type: str, original_error: Exception):
30
+ """
31
+ Initialize HookError.
32
+
33
+ Args:
34
+ hook_id: Unique identifier of the hook that raised the error
35
+ hook_type: Type of hook (e.g., "forward", "pre_forward")
36
+ original_error: The original exception that was raised
37
+ """
38
+ self.hook_id = hook_id
39
+ self.hook_type = hook_type
40
+ self.original_error = original_error
41
+ message = f"Hook {hook_id} (type={hook_type}) raised exception: {original_error}"
42
+ super().__init__(message)
43
+
44
+
45
+ class Hook(abc.ABC):
46
+ """
47
+ Abstract base class for hooks that can be registered on language model layers.
48
+
49
+ Hooks provide a way to intercept and process activations during model inference.
50
+ They expose PyTorch-compatible callables via get_torch_hook() while providing
51
+ additional functionality like enable/disable and unique identification.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ layer_signature: str | int | None = None,
57
+ hook_type: HookType | str = HookType.FORWARD,
58
+ hook_id: str | None = None
59
+ ):
60
+ """
61
+ Initialize a hook.
62
+
63
+ Args:
64
+ layer_signature: Layer name or index to attach hook to
65
+ hook_type: Type of hook - HookType.FORWARD or HookType.PRE_FORWARD
66
+ hook_id: Unique identifier (auto-generated if not provided)
67
+
68
+ Raises:
69
+ ValueError: If hook_type string is invalid
70
+ """
71
+ self.layer_signature = layer_signature
72
+ self.hook_type = self._normalize_hook_type(hook_type)
73
+ self.id = hook_id if hook_id is not None else str(uuid.uuid4())
74
+ self._enabled = True
75
+ self._torch_hook_handle = None
76
+ self._context: Optional["LanguageModelContext"] = None
77
+
78
+ def _normalize_hook_type(self, hook_type: HookType | str) -> HookType:
79
+ """Normalize hook_type to HookType enum.
80
+
81
+ Args:
82
+ hook_type: HookType enum or string value
83
+
84
+ Returns:
85
+ HookType enum value
86
+
87
+ Raises:
88
+ ValueError: If hook_type string is not a valid HookType value
89
+ """
90
+ if isinstance(hook_type, HookType):
91
+ return hook_type
92
+
93
+ if isinstance(hook_type, str):
94
+ try:
95
+ return HookType(hook_type)
96
+ except ValueError:
97
+ valid_values = [ht.value for ht in HookType]
98
+ raise ValueError(
99
+ f"Invalid hook_type string '{hook_type}'. "
100
+ f"Must be one of: {valid_values}"
101
+ ) from None
102
+
103
+ raise ValueError(
104
+ f"hook_type must be HookType enum or string, got: {type(hook_type)}"
105
+ )
106
+
107
+ def _create_pre_forward_wrapper(self) -> Callable:
108
+ """Create a pre-forward hook wrapper function.
109
+
110
+ Returns:
111
+ Wrapper function for pre-forward hooks
112
+ """
113
+ def pre_forward_wrapper(module: nn.Module, input: HOOK_FUNCTION_INPUT) -> None | HOOK_FUNCTION_INPUT:
114
+ if not self._enabled:
115
+ return None
116
+ try:
117
+ result = self._hook_fn(module, input, None)
118
+ return result if result is not None else None
119
+ except Exception as e:
120
+ raise HookError(self.id, self.hook_type.value, e) from e
121
+
122
+ return pre_forward_wrapper
123
+
124
+ def _create_forward_wrapper(self) -> Callable:
125
+ """Create a forward hook wrapper function.
126
+
127
+ Returns:
128
+ Wrapper function for forward hooks
129
+ """
130
+ def forward_wrapper(module: nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT) -> None:
131
+ if not self._enabled:
132
+ return None
133
+ try:
134
+ self._hook_fn(module, input, output)
135
+ return None
136
+ except Exception as e:
137
+ raise HookError(self.id, self.hook_type.value, e) from e
138
+
139
+ return forward_wrapper
140
+
141
+ @property
142
+ def enabled(self) -> bool:
143
+ """Whether this hook is currently enabled."""
144
+ return self._enabled
145
+
146
+ def enable(self) -> None:
147
+ """Enable this hook."""
148
+ self._enabled = True
149
+
150
+ def disable(self) -> None:
151
+ """Disable this hook."""
152
+ self._enabled = False
153
+
154
+ @property
155
+ def context(self) -> Optional["LanguageModelContext"]:
156
+ """Get the LanguageModelContext associated with this hook."""
157
+ return self._context
158
+
159
+ def set_context(self, context: "LanguageModelContext") -> None:
160
+ """Set the LanguageModelContext for this hook.
161
+
162
+ Args:
163
+ context: The LanguageModelContext instance
164
+ """
165
+ self._context = context
166
+
167
+ def _is_both_controller_and_detector(self) -> bool:
168
+ """
169
+ Check if this hook instance inherits from both Controller and Detector.
170
+
171
+ Uses MRO (Method Resolution Order) to check for both class names
172
+ without requiring imports, avoiding circular dependencies.
173
+
174
+ Returns:
175
+ True if the instance inherits from both Controller and Detector, False otherwise
176
+ """
177
+ mro_class_names = [cls.__name__ for cls in type(self).__mro__]
178
+ return 'Controller' in mro_class_names and 'Detector' in mro_class_names
179
+
180
+ def get_torch_hook(self) -> Callable:
181
+ """
182
+ Return a PyTorch-compatible hook function.
183
+
184
+ The returned callable will check the enabled flag before executing
185
+ and call the abstract _hook_fn method.
186
+
187
+ Returns:
188
+ A callable compatible with PyTorch's register_forward_hook or
189
+ register_forward_pre_hook APIs.
190
+ """
191
+ if self.hook_type == HookType.PRE_FORWARD:
192
+ return self._create_pre_forward_wrapper()
193
+ else:
194
+ return self._create_forward_wrapper()
195
+
196
+ @abc.abstractmethod
197
+ def _hook_fn(
198
+ self,
199
+ module: torch.nn.Module,
200
+ input: HOOK_FUNCTION_INPUT,
201
+ output: HOOK_FUNCTION_OUTPUT
202
+ ) -> None | HOOK_FUNCTION_INPUT:
203
+ """
204
+ Internal hook function to be implemented by subclasses.
205
+
206
+ Args:
207
+ module: The PyTorch module being hooked
208
+ input: Tuple of input tensors to the module
209
+ output: Output tensor(s) from the module (None for pre_forward hooks)
210
+
211
+ Returns:
212
+ For pre_forward hooks: modified inputs (tuple) or None to keep original
213
+ For forward hooks: None (forward hooks cannot modify output in PyTorch)
214
+
215
+ Raises:
216
+ Exception: Subclasses may raise exceptions which will be caught by the wrapper
217
+ """
218
+ raise NotImplementedError("_hook_fn must be implemented by subclasses")
File without changes
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable, TYPE_CHECKING
4
+ import torch
5
+
6
+ from amber.hooks.controller import Controller
7
+ from amber.hooks.hook import HookType
8
+
9
+ if TYPE_CHECKING:
10
+ from torch import nn
11
+
12
+
13
+ class FunctionController(Controller):
14
+ """
15
+ A controller that applies a user-provided function to tensors during inference.
16
+
17
+ This controller allows users to pass any function and apply it to activations.
18
+ The function will be applied to:
19
+ - Single tensors directly
20
+ - All tensors in tuples/lists (default behavior)
21
+
22
+ Example:
23
+ >>> # Scale activations by 2
24
+ >>> controller = FunctionController(
25
+ ... layer_signature="layer_0",
26
+ ... function=lambda x: x * 2.0
27
+ ... )
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ layer_signature: str | int,
33
+ function: Callable[[torch.Tensor], torch.Tensor],
34
+ hook_type: HookType | str = HookType.FORWARD,
35
+ hook_id: str | None = None,
36
+ ):
37
+ """
38
+ Initialize a function controller.
39
+
40
+ Args:
41
+ layer_signature: Layer to attach to
42
+ function: Function to apply to tensors. Must take a torch.Tensor and return a torch.Tensor
43
+ hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
44
+ hook_id: Unique identifier
45
+
46
+ Raises:
47
+ ValueError: If function is None or not callable
48
+ """
49
+ if function is None:
50
+ raise ValueError("function cannot be None")
51
+
52
+ if not callable(function):
53
+ raise ValueError(f"function must be callable, got: {type(function)}")
54
+
55
+ super().__init__(hook_type=hook_type, hook_id=hook_id, layer_signature=layer_signature)
56
+ self.function = function
57
+
58
+ def modify_activations(
59
+ self,
60
+ module: "nn.Module",
61
+ inputs: torch.Tensor | None,
62
+ output: torch.Tensor | None
63
+ ) -> torch.Tensor | None:
64
+ """
65
+ Apply the user-provided function to activations.
66
+
67
+ Args:
68
+ module: The PyTorch module being hooked
69
+ inputs: Input tensor (None for forward hooks)
70
+ output: Output tensor (None for pre_forward hooks)
71
+
72
+ Returns:
73
+ Modified tensor with function applied, or None if target tensor is None
74
+
75
+ Raises:
76
+ RuntimeError: If function raises an exception when applied to tensor
77
+ """
78
+ target = output if self.hook_type == HookType.FORWARD else inputs
79
+
80
+ if target is None or not isinstance(target, torch.Tensor):
81
+ return target
82
+
83
+ try:
84
+ result = self.function(target)
85
+ if not isinstance(result, torch.Tensor):
86
+ raise TypeError(
87
+ f"Function must return a torch.Tensor, got: {type(result)}"
88
+ )
89
+ return result
90
+ except Exception as e:
91
+ raise RuntimeError(
92
+ f"Error applying function in FunctionController {self.id}: {e}"
93
+ ) from e
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+ import torch
5
+
6
+ from amber.hooks.detector import Detector
7
+ from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
+ from amber.hooks.utils import extract_tensor_from_output
9
+
10
+ if TYPE_CHECKING:
11
+ from torch import nn
12
+
13
+
14
+ class LayerActivationDetector(Detector):
15
+ """
16
+ Detector hook that captures and saves activations during inference.
17
+
18
+ This detector extracts activations from layer outputs and stores them
19
+ for later use (e.g., saving to disk, further analysis).
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ layer_signature: str | int,
25
+ hook_id: str | None = None
26
+ ):
27
+ """
28
+ Initialize the activation saver detector.
29
+
30
+ Args:
31
+ layer_signature: Layer to capture activations from
32
+ hook_id: Unique identifier for this hook
33
+
34
+ Raises:
35
+ ValueError: If layer_signature is None
36
+ """
37
+ if layer_signature is None:
38
+ raise ValueError("layer_signature cannot be None for LayerActivationDetector")
39
+
40
+ super().__init__(
41
+ hook_type=HookType.FORWARD,
42
+ hook_id=hook_id,
43
+ store=None,
44
+ layer_signature=layer_signature
45
+ )
46
+
47
+ def process_activations(
48
+ self,
49
+ module: torch.nn.Module,
50
+ input: HOOK_FUNCTION_INPUT,
51
+ output: HOOK_FUNCTION_OUTPUT
52
+ ) -> None:
53
+ """
54
+ Extract and store activations from output.
55
+
56
+ Handles various output types:
57
+ - Plain tensors
58
+ - Tuples/lists of tensors (takes first tensor)
59
+ - Objects with last_hidden_state attribute (e.g., HuggingFace outputs)
60
+
61
+ Args:
62
+ module: The PyTorch module being hooked
63
+ input: Tuple of input tensors to the module
64
+ output: Output tensor(s) from the module
65
+
66
+ Raises:
67
+ RuntimeError: If tensor extraction or storage fails
68
+ """
69
+ try:
70
+ tensor = extract_tensor_from_output(output)
71
+
72
+ if tensor is not None:
73
+ tensor_cpu = tensor.detach().to("cpu")
74
+ # Store current batch's tensor (overwrites previous)
75
+ self.tensor_metadata['activations'] = tensor_cpu
76
+ # Store activations shape to metadata
77
+ self.metadata['activations_shape'] = tuple(tensor_cpu.shape)
78
+ except Exception as e:
79
+ raise RuntimeError(
80
+ f"Error extracting activations in LayerActivationDetector {self.id}: {e}"
81
+ ) from e
82
+
83
+ def get_captured(self) -> torch.Tensor | None:
84
+ """
85
+ Get the captured activations from the current batch.
86
+
87
+ Returns:
88
+ The captured activation tensor from the current batch or None if no activations captured yet
89
+ """
90
+ return self.tensor_metadata.get('activations')
91
+
92
+ def clear_captured(self) -> None:
93
+ """Clear captured activations for current batch."""
94
+ self.tensor_metadata.pop('activations', None)
95
+ self.metadata.pop('activations_shape', None)
96
+