mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,206 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, asdict
4
+ from pathlib import Path
5
+ from typing import Dict, Sequence, TYPE_CHECKING, Optional
6
+ import json
7
+ import csv
8
+
9
+ from amber.store.store import Store
10
+
11
+ if TYPE_CHECKING:
12
+ pass
13
+
14
+
15
+ @dataclass
16
+ class Concept:
17
+ name: str
18
+ score: float
19
+
20
+
21
+ class ConceptDictionary:
22
+ def __init__(
23
+ self,
24
+ n_size: int,
25
+ store: Store | None = None
26
+ ) -> None:
27
+ self.n_size = n_size
28
+ self.concepts_map: Dict[int, Concept] = {}
29
+ self.store = store
30
+ self._directory: Path | None = None
31
+
32
+ def set_directory(self, directory: Path | str) -> None:
33
+ p = Path(directory)
34
+ p.mkdir(parents=True, exist_ok=True)
35
+ self._directory = p
36
+
37
+ def add(self, index: int, name: str, score: float) -> None:
38
+ if not (0 <= index < self.n_size):
39
+ raise IndexError(f"index {index} out of bounds for n_size={self.n_size}")
40
+ # Only allow 1 concept per neuron - replace if exists
41
+ self.concepts_map[index] = Concept(name=name, score=score)
42
+
43
+ def get(self, index: int) -> Optional[Concept]:
44
+ if not (0 <= index < self.n_size):
45
+ raise IndexError(f"index {index} out of bounds for n_size={self.n_size}")
46
+ return self.concepts_map.get(index)
47
+
48
+ def get_many(self, indices: Sequence[int]) -> Dict[int, Optional[Concept]]:
49
+ return {i: self.get(i) for i in indices}
50
+
51
+ def save(self, directory: Path | str | None = None) -> Path:
52
+ if directory is not None:
53
+ self.set_directory(directory)
54
+ if self._directory is None:
55
+ raise ValueError("No directory set. Call save(directory=...) or set_directory() first.")
56
+ path = self._directory / "concepts.json"
57
+ serializable = {str(k): asdict(v) for k, v in self.concepts_map.items()}
58
+ meta = {
59
+ "n_size": self.n_size,
60
+ "concepts": serializable,
61
+ }
62
+ with path.open("w", encoding="utf-8") as f:
63
+ json.dump(meta, f, ensure_ascii=False, indent=2)
64
+ return path
65
+
66
+ def load(self, directory: Path | str | None = None) -> None:
67
+ if directory is not None:
68
+ self.set_directory(directory)
69
+ if self._directory is None:
70
+ raise ValueError("No directory set. Call load(directory=...) or set_directory() first.")
71
+ path = self._directory / "concepts.json"
72
+ if not path.exists():
73
+ raise FileNotFoundError(path)
74
+ with path.open("r", encoding="utf-8") as f:
75
+ meta = json.load(f)
76
+ self.n_size = int(meta.get("n_size", self.n_size))
77
+ concepts = meta.get("concepts", {})
78
+ # Handle both old format (list) and new format (single dict)
79
+ self.concepts_map = {}
80
+ for k, v in concepts.items():
81
+ if isinstance(v, list):
82
+ # Old format: take first concept if list
83
+ if v:
84
+ self.concepts_map[int(k)] = Concept(**v[0])
85
+ else:
86
+ # New format: single concept dict
87
+ self.concepts_map[int(k)] = Concept(**v)
88
+
89
+ @classmethod
90
+ def from_csv(
91
+ cls,
92
+ csv_filepath: Path | str,
93
+ n_size: int,
94
+ store: Store | None = None
95
+ ) -> "ConceptDictionary":
96
+ csv_path = Path(csv_filepath)
97
+ if not csv_path.exists():
98
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
99
+
100
+ concept_dict = cls(n_size=n_size, store=store)
101
+
102
+ # Track best concept per neuron (highest score)
103
+ neuron_concepts: Dict[int, tuple[str, float]] = {}
104
+
105
+ with csv_path.open("r", encoding="utf-8") as f:
106
+ reader = csv.DictReader(f)
107
+ for row in reader:
108
+ neuron_idx = int(row["neuron_idx"])
109
+ concept_name = row["concept_name"]
110
+ score = float(row["score"])
111
+
112
+ # Keep only the concept with highest score per neuron
113
+ if neuron_idx not in neuron_concepts or score > neuron_concepts[neuron_idx][1]:
114
+ neuron_concepts[neuron_idx] = (concept_name, score)
115
+
116
+ # Add the best concept for each neuron
117
+ for neuron_idx, (concept_name, score) in neuron_concepts.items():
118
+ concept_dict.add(neuron_idx, concept_name, score)
119
+
120
+ return concept_dict
121
+
122
+ @classmethod
123
+ def from_json(
124
+ cls,
125
+ json_filepath: Path | str,
126
+ n_size: int,
127
+ store: Store | None = None
128
+ ) -> "ConceptDictionary":
129
+ json_path = Path(json_filepath)
130
+ if not json_path.exists():
131
+ raise FileNotFoundError(f"JSON file not found: {json_path}")
132
+
133
+ concept_dict = cls(n_size=n_size, store=store)
134
+
135
+ with json_path.open("r", encoding="utf-8") as f:
136
+ data = json.load(f)
137
+
138
+ for neuron_idx_str, concepts in data.items():
139
+ neuron_idx = int(neuron_idx_str)
140
+
141
+ # Handle both old format (list) and new format (single dict)
142
+ if isinstance(concepts, list):
143
+ # Old format: take the concept with highest score
144
+ best_concept = None
145
+ best_score = float('-inf')
146
+ for concept in concepts:
147
+ if not isinstance(concept, dict):
148
+ continue
149
+ score = float(concept["score"])
150
+ if score > best_score:
151
+ best_score = score
152
+ best_concept = concept
153
+
154
+ if best_concept is not None:
155
+ concept_dict.add(neuron_idx, best_concept["name"], best_score)
156
+ elif isinstance(concepts, dict):
157
+ # New format: single concept dict
158
+ concept_name = concepts["name"]
159
+ score = float(concepts["score"])
160
+ concept_dict.add(neuron_idx, concept_name, score)
161
+
162
+ return concept_dict
163
+
164
+ @classmethod
165
+ def from_llm(
166
+ cls,
167
+ neuron_texts: list[list["NeuronText"]],
168
+ n_size: int,
169
+ store: Store | None = None,
170
+ llm_provider: str | None = None
171
+ ) -> "ConceptDictionary":
172
+ concept_dict = cls(n_size=n_size, store=store)
173
+
174
+ for neuron_idx, texts in enumerate(neuron_texts):
175
+ if not texts:
176
+ continue
177
+
178
+ # Extract texts and their specific activated tokens
179
+ texts_with_tokens = []
180
+ for nt in texts:
181
+ texts_with_tokens.append({
182
+ "text": nt.text,
183
+ "score": nt.score,
184
+ "token_str": nt.token_str,
185
+ "token_idx": nt.token_idx
186
+ })
187
+
188
+ # Generate concept names using LLM
189
+ concept_names = cls._generate_concept_names_llm(texts_with_tokens, llm_provider)
190
+
191
+ # Add only the best concept (highest score) to dictionary
192
+ if concept_names:
193
+ # Sort by score descending and take the first one
194
+ concept_names_sorted = sorted(concept_names, key=lambda x: x[1], reverse=True)
195
+ concept_name, score = concept_names_sorted[0]
196
+ concept_dict.add(neuron_idx, concept_name, score)
197
+
198
+ return concept_dict
199
+
200
+ @staticmethod
201
+ def _generate_concept_names_llm(texts_with_tokens: list[dict], llm_provider: str | None = None) -> list[
202
+ tuple[str, float]]:
203
+ raise NotImplementedError(
204
+ "LLM provider not configured. Please implement _generate_concept_names_llm "
205
+ "method with your preferred LLM provider (OpenAI, Anthropic, etc.)"
206
+ )
@@ -0,0 +1,9 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class NeuronText:
6
+ score: float
7
+ text: str
8
+ token_idx: int
9
+ token_str: str
@@ -0,0 +1,68 @@
1
+ from typing import TYPE_CHECKING, Sequence
2
+
3
+ from amber.utils import get_logger
4
+
5
+ if TYPE_CHECKING:
6
+ from amber.language_model.language_model import LanguageModel
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ class InputTracker:
12
+ """
13
+ Simple listener that saves input texts before tokenization.
14
+
15
+ This is a singleton per LanguageModel instance. It's used as a listener
16
+ during inference to capture texts before they are tokenized. SAE hooks
17
+ can then access these texts to track top activating texts for their neurons.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ language_model: "LanguageModel",
23
+ ) -> None:
24
+ """
25
+ Initialize InputTracker.
26
+
27
+ Args:
28
+ language_model: Language model instance
29
+ """
30
+ self.language_model = language_model
31
+
32
+ # Flag to control whether to save inputs
33
+ self._enabled: bool = False
34
+
35
+ # Runtime state - only stores texts
36
+ self._current_texts: list[str] = []
37
+
38
+ @property
39
+ def enabled(self) -> bool:
40
+ """Whether input tracking is enabled."""
41
+ return self._enabled
42
+
43
+ def enable(self) -> None:
44
+ """Enable input tracking."""
45
+ self._enabled = True
46
+
47
+ def disable(self) -> None:
48
+ """Disable input tracking."""
49
+ self._enabled = False
50
+
51
+ def reset(self) -> None:
52
+ """Reset stored texts."""
53
+ self._current_texts.clear()
54
+
55
+ def set_current_texts(self, texts: Sequence[str]) -> None:
56
+ """
57
+ Set the current batch of texts being processed.
58
+
59
+ This is called by LanguageModel._inference() before tokenization
60
+ if tracking is enabled.
61
+ """
62
+ if self._enabled:
63
+ self._current_texts = list(texts)
64
+
65
+ def get_current_texts(self) -> list[str]:
66
+ """Get the current batch of texts."""
67
+ return self._current_texts.copy()
68
+
@@ -0,0 +1,5 @@
1
+ from amber.mechanistic.sae.modules.topk_sae import TopKSae, TopKSaeTrainingConfig
2
+ from amber.mechanistic.sae.modules.l1_sae import L1Sae, L1SaeTrainingConfig
3
+
4
+ __all__ = ["TopKSae", "TopKSaeTrainingConfig", "L1Sae", "L1SaeTrainingConfig"]
5
+
@@ -0,0 +1,409 @@
1
+ from pathlib import Path
2
+ from typing import Any
3
+
4
+ import torch
5
+ from overcomplete import SAE as OvercompleteSAE
6
+ from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
7
+ from amber.mechanistic.sae.sae import Sae
8
+ from amber.mechanistic.sae.sae_trainer import SaeTrainingConfig
9
+ from amber.store.store import Store
10
+ from amber.utils import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class L1SaeTrainingConfig(SaeTrainingConfig):
16
+ """Training configuration for L1 SAE models.
17
+
18
+ This class extends SaeTrainingConfig to provide a type-safe configuration
19
+ interface specifically for L1 SAE models. While it currently uses the same
20
+ training parameters as the base SaeTrainingConfig, this design allows for:
21
+
22
+ 1. **Type Safety**: Ensures that L1-specific training methods receive the
23
+ correct configuration type, preventing accidental use of incompatible configs.
24
+
25
+ 2. **Future Extensibility**: Provides a clear extension point for L1-specific
26
+ training parameters that may be needed in the future (e.g., L1 regularization
27
+ scheduling, sparsity target parameters, etc.).
28
+
29
+ 3. **API Clarity**: Makes the intent explicit in the codebase - when you see
30
+ L1SaeTrainingConfig, you know it's specifically for L1 SAE training.
31
+
32
+ For now, you can use this class exactly like SaeTrainingConfig. All parameters
33
+ from SaeTrainingConfig are available and work identically.
34
+ """
35
+ pass
36
+
37
+
38
+ class L1Sae(Sae):
39
+ """L1 Sparse Autoencoder implementation.
40
+
41
+ Uses L1 regularization to enforce sparsity in the latent activations.
42
+ This implementation uses the base SAE class from the overcomplete library.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ n_latents: int,
48
+ n_inputs: int,
49
+ hook_id: str | None = None,
50
+ device: str = 'cpu',
51
+ store: Store | None = None,
52
+ *args: Any,
53
+ **kwargs: Any
54
+ ) -> None:
55
+ super().__init__(n_latents, n_inputs, hook_id, device, store, *args, **kwargs)
56
+
57
+ def _initialize_sae_engine(self) -> OvercompleteSAE:
58
+ """Initialize the SAE engine.
59
+
60
+ Returns:
61
+ OvercompleteSAE instance configured for L1 regularization
62
+ """
63
+ return OvercompleteSAE(
64
+ input_shape=self.context.n_inputs,
65
+ nb_concepts=self.context.n_latents,
66
+ device=self.context.device
67
+ )
68
+
69
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
70
+ """
71
+ Encode input using sae_engine.
72
+
73
+ Args:
74
+ x: Input tensor of shape [batch_size, n_inputs]
75
+
76
+ Returns:
77
+ Encoded latents (L1 sparse activations)
78
+ """
79
+ # Overcomplete SAE encode returns (pre_codes, codes)
80
+ _, codes = self.sae_engine.encode(x)
81
+ return codes
82
+
83
+ def decode(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Decode latents using sae_engine.
86
+
87
+ Args:
88
+ x: Encoded tensor of shape [batch_size, n_latents]
89
+
90
+ Returns:
91
+ Reconstructed tensor of shape [batch_size, n_inputs]
92
+ """
93
+ return self.sae_engine.decode(x)
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ """
97
+ Forward pass using sae_engine.
98
+
99
+ Args:
100
+ x: Input tensor of shape [batch_size, n_inputs]
101
+
102
+ Returns:
103
+ Reconstructed tensor of shape [batch_size, n_inputs]
104
+ """
105
+ # Overcomplete SAE forward returns (pre_codes, codes, x_reconstructed)
106
+ _, _, x_reconstructed = self.sae_engine.forward(x)
107
+ return x_reconstructed
108
+
109
+ def train(
110
+ self,
111
+ store: Store,
112
+ run_id: str,
113
+ layer_signature: str | int,
114
+ config: L1SaeTrainingConfig | None = None,
115
+ training_run_id: str | None = None
116
+ ) -> dict[str, Any]:
117
+ """
118
+ Train L1SAE using activations from a Store.
119
+
120
+ This method delegates to the SaeTrainer composite class.
121
+
122
+ Args:
123
+ store: Store instance containing activations
124
+ run_id: Run ID to train on
125
+ layer_signature: Layer signature to train on
126
+ config: Training configuration
127
+ training_run_id: Optional training run ID
128
+
129
+ Returns:
130
+ Dictionary with keys:
131
+ - "history": Training history dictionary
132
+ - "training_run_id": Training run ID where outputs were saved
133
+ """
134
+ if config is None:
135
+ config = L1SaeTrainingConfig()
136
+ return self.trainer.train(store, run_id, layer_signature, config, training_run_id)
137
+
138
+ def modify_activations(
139
+ self,
140
+ module: "torch.nn.Module",
141
+ inputs: torch.Tensor | None,
142
+ output: torch.Tensor | None
143
+ ) -> torch.Tensor | None:
144
+ """
145
+ Modify activations using L1SAE (Controller hook interface).
146
+
147
+ Extracts tensor from inputs/output, applies SAE forward pass,
148
+ and optionally applies concept manipulation.
149
+
150
+ Args:
151
+ module: The PyTorch module being hooked
152
+ inputs: Tuple of inputs to the module
153
+ output: Output from the module (None for pre_forward hooks)
154
+
155
+ Returns:
156
+ Modified activations with same shape as input
157
+ """
158
+ # Extract tensor from output/inputs, handling objects with last_hidden_state
159
+ if self.hook_type == HookType.FORWARD:
160
+ if isinstance(output, torch.Tensor):
161
+ tensor = output
162
+ elif hasattr(output, "last_hidden_state") and isinstance(output.last_hidden_state, torch.Tensor):
163
+ tensor = output.last_hidden_state
164
+ elif isinstance(output, (tuple, list)):
165
+ # Try to find first tensor in tuple/list
166
+ tensor = next((item for item in output if isinstance(item, torch.Tensor)), None)
167
+ else:
168
+ tensor = None
169
+ else:
170
+ tensor = inputs[0] if len(inputs) > 0 and isinstance(inputs[0], torch.Tensor) else None
171
+
172
+ if tensor is None or not isinstance(tensor, torch.Tensor):
173
+ return output if self.hook_type == HookType.FORWARD else inputs
174
+
175
+ original_shape = tensor.shape
176
+
177
+ # Flatten to 2D for SAE processing: (batch, seq_len, hidden) -> (batch * seq_len, hidden)
178
+ # or keep as 2D if already 2D: (batch, hidden)
179
+ if len(original_shape) > 2:
180
+ batch_size, seq_len = original_shape[:2]
181
+ tensor_flat = tensor.reshape(-1, original_shape[-1])
182
+ else:
183
+ batch_size = original_shape[0]
184
+ seq_len = 1
185
+ tensor_flat = tensor
186
+
187
+ # Get full activations (pre_codes) and sparse codes
188
+ # Overcomplete SAE encode returns (pre_codes, codes)
189
+ pre_codes, codes = self.sae_engine.encode(tensor_flat)
190
+
191
+ # Save SAE activations (pre_codes) as 3D tensor: (batch, seq, n_latents)
192
+ latents_cpu = pre_codes.detach().cpu()
193
+ latents_3d = latents_cpu.reshape(batch_size, seq_len, -1)
194
+
195
+ # Save to tensor_metadata
196
+ self.tensor_metadata['neurons'] = latents_3d
197
+ self.tensor_metadata['activations'] = latents_3d
198
+
199
+ # Process each item in the batch individually for metadata
200
+ batch_items = []
201
+ n_items = latents_cpu.shape[0]
202
+ for item_idx in range(n_items):
203
+ item_latents = latents_cpu[item_idx] # [n_latents]
204
+
205
+ # Find nonzero indices for this item
206
+ nonzero_mask = item_latents != 0
207
+ nonzero_indices = torch.nonzero(nonzero_mask, as_tuple=False).flatten().tolist()
208
+
209
+ # Create map of nonzero indices to activations
210
+ activations_map = {
211
+ int(idx): float(item_latents[idx].item())
212
+ for idx in nonzero_indices
213
+ }
214
+
215
+ # Create item metadata
216
+ item_metadata = {
217
+ "nonzero_indices": nonzero_indices,
218
+ "activations": activations_map
219
+ }
220
+ batch_items.append(item_metadata)
221
+
222
+ # Save batch items metadata
223
+ self.metadata['batch_items'] = batch_items
224
+
225
+ # Use sparse codes for reconstruction
226
+ latents = codes
227
+
228
+ # Update top texts if text tracking is enabled
229
+ if self._text_tracking_enabled and self.context.lm is not None:
230
+ input_tracker = self.context.lm.get_input_tracker()
231
+ if input_tracker is not None:
232
+ texts = input_tracker.get_current_texts()
233
+ if texts:
234
+ # Use pre_codes (full activations) for text tracking
235
+ self.concepts.update_top_texts_from_latents(
236
+ latents_cpu,
237
+ texts,
238
+ original_shape
239
+ )
240
+
241
+ # Apply concept manipulation if parameters are set
242
+ # Check if multiplication or bias differ from defaults (ones)
243
+ if not torch.allclose(self.concepts.multiplication, torch.ones_like(self.concepts.multiplication)) or \
244
+ not torch.allclose(self.concepts.bias, torch.ones_like(self.concepts.bias)):
245
+ # Apply manipulation: latents = latents * multiplication + bias
246
+ latents = latents * self.concepts.multiplication + self.concepts.bias
247
+
248
+ # Decode to get reconstruction
249
+ reconstructed = self.decode(latents)
250
+
251
+ # Reshape back to original shape
252
+ if len(original_shape) > 2:
253
+ reconstructed = reconstructed.reshape(original_shape)
254
+
255
+ # Return in appropriate format
256
+ if self.hook_type == HookType.FORWARD:
257
+ if isinstance(output, torch.Tensor):
258
+ return reconstructed
259
+ elif isinstance(output, (tuple, list)):
260
+ # Replace first tensor in tuple/list
261
+ result = list(output)
262
+ for i, item in enumerate(result):
263
+ if isinstance(item, torch.Tensor):
264
+ result[i] = reconstructed
265
+ break
266
+ return tuple(result) if isinstance(output, tuple) else result
267
+ else:
268
+ # For objects with attributes, try to set last_hidden_state
269
+ if hasattr(output, "last_hidden_state"):
270
+ output.last_hidden_state = reconstructed
271
+ return output
272
+ else: # PRE_FORWARD
273
+ # Return modified inputs tuple
274
+ result = list(inputs)
275
+ if len(result) > 0:
276
+ result[0] = reconstructed
277
+ return tuple(result)
278
+
279
+ def save(self, name: str, path: str | Path | None = None) -> None:
280
+ """
281
+ Save model using overcomplete's state dict + our metadata.
282
+
283
+ Args:
284
+ name: Model name
285
+ path: Directory path to save to (defaults to current directory)
286
+ """
287
+ if path is None:
288
+ path = Path.cwd()
289
+ save_dir = Path(path)
290
+ save_dir.mkdir(parents=True, exist_ok=True)
291
+ save_path = save_dir / f"{name}.pt"
292
+
293
+ # Save overcomplete model state dict
294
+ sae_state_dict = self.sae_engine.state_dict()
295
+
296
+ amber_metadata = {
297
+ "concepts_state": {
298
+ 'multiplication': self.concepts.multiplication.data,
299
+ 'bias': self.concepts.bias.data,
300
+ },
301
+ "n_latents": self.context.n_latents,
302
+ "n_inputs": self.context.n_inputs,
303
+ "device": self.context.device,
304
+ "layer_signature": self.context.lm_layer_signature,
305
+ "model_id": self.context.model_id,
306
+ }
307
+
308
+ payload = {
309
+ "sae_state_dict": sae_state_dict,
310
+ "amber_metadata": amber_metadata,
311
+ }
312
+
313
+ torch.save(payload, save_path)
314
+ logger.info(f"Saved L1SAE to {save_path}")
315
+
316
+ @classmethod
317
+ def load(cls, path: Path) -> "L1Sae":
318
+ """
319
+ Load L1SAE from saved file using overcomplete's load method + our metadata.
320
+
321
+ Args:
322
+ path: Path to saved model file
323
+
324
+ Returns:
325
+ Loaded L1Sae instance
326
+ """
327
+ p = Path(path)
328
+
329
+ # Load payload
330
+ if torch.cuda.is_available():
331
+ map_location = 'cuda'
332
+ elif torch.backends.mps.is_available():
333
+ map_location = 'mps'
334
+ else:
335
+ map_location = 'cpu'
336
+ payload = torch.load(p, map_location=map_location)
337
+
338
+ # Extract our metadata
339
+ if "amber_metadata" not in payload:
340
+ raise ValueError(f"Invalid L1SAE save format: missing 'amber_metadata' key in {p}")
341
+
342
+ amber_meta = payload["amber_metadata"]
343
+ n_latents = int(amber_meta["n_latents"])
344
+ n_inputs = int(amber_meta["n_inputs"])
345
+ device = amber_meta.get("device", "cpu")
346
+ layer_signature = amber_meta.get("layer_signature")
347
+ model_id = amber_meta.get("model_id")
348
+ concepts_state = amber_meta.get("concepts_state", {})
349
+
350
+ # Create L1Sae instance
351
+ l1_sae = L1Sae(
352
+ n_latents=n_latents,
353
+ n_inputs=n_inputs,
354
+ device=device
355
+ )
356
+
357
+ # Load overcomplete model state dict
358
+ if "sae_state_dict" in payload:
359
+ l1_sae.sae_engine.load_state_dict(payload["sae_state_dict"])
360
+ elif "model" in payload:
361
+ # Backward compatibility with old format
362
+ l1_sae.sae_engine.load_state_dict(payload["model"])
363
+ else:
364
+ # Assume payload is the state dict itself (backward compatibility)
365
+ l1_sae.sae_engine.load_state_dict(payload)
366
+
367
+ # Load concepts state
368
+ if concepts_state:
369
+ device = l1_sae.context.device
370
+ if isinstance(device, str):
371
+ device = torch.device(device)
372
+ if "multiplication" in concepts_state:
373
+ l1_sae.concepts.multiplication.data = concepts_state["multiplication"].to(device)
374
+ if "bias" in concepts_state:
375
+ l1_sae.concepts.bias.data = concepts_state["bias"].to(device)
376
+
377
+ # Note: Top texts loading was removed as serialization methods were removed
378
+ # Top texts should be exported/imported separately if needed
379
+
380
+ # Set context metadata
381
+ l1_sae.context.lm_layer_signature = layer_signature
382
+ l1_sae.context.model_id = model_id
383
+
384
+ params_str = f"n_latents={n_latents}, n_inputs={n_inputs}"
385
+ logger.info(f"\nLoaded L1SAE from {p}\n{params_str}")
386
+
387
+ return l1_sae
388
+
389
+ def process_activations(
390
+ self,
391
+ module: torch.nn.Module,
392
+ input: HOOK_FUNCTION_INPUT,
393
+ output: HOOK_FUNCTION_OUTPUT
394
+ ) -> None:
395
+ """
396
+ Process activations (Detector interface).
397
+
398
+ Metadata saving is handled in modify_activations to avoid duplicate work.
399
+ This method is kept for interface compatibility but does nothing since
400
+ modify_activations already saves the metadata when called.
401
+
402
+ Args:
403
+ module: The PyTorch module being hooked
404
+ input: Tuple of input tensors to the module
405
+ output: Output tensor(s) from the module
406
+ """
407
+ # Metadata saving is done in modify_activations to avoid duplicate encoding
408
+ pass
409
+