mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, asdict
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, Sequence, TYPE_CHECKING, Optional
|
|
6
|
+
import json
|
|
7
|
+
import csv
|
|
8
|
+
|
|
9
|
+
from amber.store.store import Store
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Concept:
|
|
17
|
+
name: str
|
|
18
|
+
score: float
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ConceptDictionary:
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
n_size: int,
|
|
25
|
+
store: Store | None = None
|
|
26
|
+
) -> None:
|
|
27
|
+
self.n_size = n_size
|
|
28
|
+
self.concepts_map: Dict[int, Concept] = {}
|
|
29
|
+
self.store = store
|
|
30
|
+
self._directory: Path | None = None
|
|
31
|
+
|
|
32
|
+
def set_directory(self, directory: Path | str) -> None:
|
|
33
|
+
p = Path(directory)
|
|
34
|
+
p.mkdir(parents=True, exist_ok=True)
|
|
35
|
+
self._directory = p
|
|
36
|
+
|
|
37
|
+
def add(self, index: int, name: str, score: float) -> None:
|
|
38
|
+
if not (0 <= index < self.n_size):
|
|
39
|
+
raise IndexError(f"index {index} out of bounds for n_size={self.n_size}")
|
|
40
|
+
# Only allow 1 concept per neuron - replace if exists
|
|
41
|
+
self.concepts_map[index] = Concept(name=name, score=score)
|
|
42
|
+
|
|
43
|
+
def get(self, index: int) -> Optional[Concept]:
|
|
44
|
+
if not (0 <= index < self.n_size):
|
|
45
|
+
raise IndexError(f"index {index} out of bounds for n_size={self.n_size}")
|
|
46
|
+
return self.concepts_map.get(index)
|
|
47
|
+
|
|
48
|
+
def get_many(self, indices: Sequence[int]) -> Dict[int, Optional[Concept]]:
|
|
49
|
+
return {i: self.get(i) for i in indices}
|
|
50
|
+
|
|
51
|
+
def save(self, directory: Path | str | None = None) -> Path:
|
|
52
|
+
if directory is not None:
|
|
53
|
+
self.set_directory(directory)
|
|
54
|
+
if self._directory is None:
|
|
55
|
+
raise ValueError("No directory set. Call save(directory=...) or set_directory() first.")
|
|
56
|
+
path = self._directory / "concepts.json"
|
|
57
|
+
serializable = {str(k): asdict(v) for k, v in self.concepts_map.items()}
|
|
58
|
+
meta = {
|
|
59
|
+
"n_size": self.n_size,
|
|
60
|
+
"concepts": serializable,
|
|
61
|
+
}
|
|
62
|
+
with path.open("w", encoding="utf-8") as f:
|
|
63
|
+
json.dump(meta, f, ensure_ascii=False, indent=2)
|
|
64
|
+
return path
|
|
65
|
+
|
|
66
|
+
def load(self, directory: Path | str | None = None) -> None:
|
|
67
|
+
if directory is not None:
|
|
68
|
+
self.set_directory(directory)
|
|
69
|
+
if self._directory is None:
|
|
70
|
+
raise ValueError("No directory set. Call load(directory=...) or set_directory() first.")
|
|
71
|
+
path = self._directory / "concepts.json"
|
|
72
|
+
if not path.exists():
|
|
73
|
+
raise FileNotFoundError(path)
|
|
74
|
+
with path.open("r", encoding="utf-8") as f:
|
|
75
|
+
meta = json.load(f)
|
|
76
|
+
self.n_size = int(meta.get("n_size", self.n_size))
|
|
77
|
+
concepts = meta.get("concepts", {})
|
|
78
|
+
# Handle both old format (list) and new format (single dict)
|
|
79
|
+
self.concepts_map = {}
|
|
80
|
+
for k, v in concepts.items():
|
|
81
|
+
if isinstance(v, list):
|
|
82
|
+
# Old format: take first concept if list
|
|
83
|
+
if v:
|
|
84
|
+
self.concepts_map[int(k)] = Concept(**v[0])
|
|
85
|
+
else:
|
|
86
|
+
# New format: single concept dict
|
|
87
|
+
self.concepts_map[int(k)] = Concept(**v)
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def from_csv(
|
|
91
|
+
cls,
|
|
92
|
+
csv_filepath: Path | str,
|
|
93
|
+
n_size: int,
|
|
94
|
+
store: Store | None = None
|
|
95
|
+
) -> "ConceptDictionary":
|
|
96
|
+
csv_path = Path(csv_filepath)
|
|
97
|
+
if not csv_path.exists():
|
|
98
|
+
raise FileNotFoundError(f"CSV file not found: {csv_path}")
|
|
99
|
+
|
|
100
|
+
concept_dict = cls(n_size=n_size, store=store)
|
|
101
|
+
|
|
102
|
+
# Track best concept per neuron (highest score)
|
|
103
|
+
neuron_concepts: Dict[int, tuple[str, float]] = {}
|
|
104
|
+
|
|
105
|
+
with csv_path.open("r", encoding="utf-8") as f:
|
|
106
|
+
reader = csv.DictReader(f)
|
|
107
|
+
for row in reader:
|
|
108
|
+
neuron_idx = int(row["neuron_idx"])
|
|
109
|
+
concept_name = row["concept_name"]
|
|
110
|
+
score = float(row["score"])
|
|
111
|
+
|
|
112
|
+
# Keep only the concept with highest score per neuron
|
|
113
|
+
if neuron_idx not in neuron_concepts or score > neuron_concepts[neuron_idx][1]:
|
|
114
|
+
neuron_concepts[neuron_idx] = (concept_name, score)
|
|
115
|
+
|
|
116
|
+
# Add the best concept for each neuron
|
|
117
|
+
for neuron_idx, (concept_name, score) in neuron_concepts.items():
|
|
118
|
+
concept_dict.add(neuron_idx, concept_name, score)
|
|
119
|
+
|
|
120
|
+
return concept_dict
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_json(
|
|
124
|
+
cls,
|
|
125
|
+
json_filepath: Path | str,
|
|
126
|
+
n_size: int,
|
|
127
|
+
store: Store | None = None
|
|
128
|
+
) -> "ConceptDictionary":
|
|
129
|
+
json_path = Path(json_filepath)
|
|
130
|
+
if not json_path.exists():
|
|
131
|
+
raise FileNotFoundError(f"JSON file not found: {json_path}")
|
|
132
|
+
|
|
133
|
+
concept_dict = cls(n_size=n_size, store=store)
|
|
134
|
+
|
|
135
|
+
with json_path.open("r", encoding="utf-8") as f:
|
|
136
|
+
data = json.load(f)
|
|
137
|
+
|
|
138
|
+
for neuron_idx_str, concepts in data.items():
|
|
139
|
+
neuron_idx = int(neuron_idx_str)
|
|
140
|
+
|
|
141
|
+
# Handle both old format (list) and new format (single dict)
|
|
142
|
+
if isinstance(concepts, list):
|
|
143
|
+
# Old format: take the concept with highest score
|
|
144
|
+
best_concept = None
|
|
145
|
+
best_score = float('-inf')
|
|
146
|
+
for concept in concepts:
|
|
147
|
+
if not isinstance(concept, dict):
|
|
148
|
+
continue
|
|
149
|
+
score = float(concept["score"])
|
|
150
|
+
if score > best_score:
|
|
151
|
+
best_score = score
|
|
152
|
+
best_concept = concept
|
|
153
|
+
|
|
154
|
+
if best_concept is not None:
|
|
155
|
+
concept_dict.add(neuron_idx, best_concept["name"], best_score)
|
|
156
|
+
elif isinstance(concepts, dict):
|
|
157
|
+
# New format: single concept dict
|
|
158
|
+
concept_name = concepts["name"]
|
|
159
|
+
score = float(concepts["score"])
|
|
160
|
+
concept_dict.add(neuron_idx, concept_name, score)
|
|
161
|
+
|
|
162
|
+
return concept_dict
|
|
163
|
+
|
|
164
|
+
@classmethod
|
|
165
|
+
def from_llm(
|
|
166
|
+
cls,
|
|
167
|
+
neuron_texts: list[list["NeuronText"]],
|
|
168
|
+
n_size: int,
|
|
169
|
+
store: Store | None = None,
|
|
170
|
+
llm_provider: str | None = None
|
|
171
|
+
) -> "ConceptDictionary":
|
|
172
|
+
concept_dict = cls(n_size=n_size, store=store)
|
|
173
|
+
|
|
174
|
+
for neuron_idx, texts in enumerate(neuron_texts):
|
|
175
|
+
if not texts:
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
# Extract texts and their specific activated tokens
|
|
179
|
+
texts_with_tokens = []
|
|
180
|
+
for nt in texts:
|
|
181
|
+
texts_with_tokens.append({
|
|
182
|
+
"text": nt.text,
|
|
183
|
+
"score": nt.score,
|
|
184
|
+
"token_str": nt.token_str,
|
|
185
|
+
"token_idx": nt.token_idx
|
|
186
|
+
})
|
|
187
|
+
|
|
188
|
+
# Generate concept names using LLM
|
|
189
|
+
concept_names = cls._generate_concept_names_llm(texts_with_tokens, llm_provider)
|
|
190
|
+
|
|
191
|
+
# Add only the best concept (highest score) to dictionary
|
|
192
|
+
if concept_names:
|
|
193
|
+
# Sort by score descending and take the first one
|
|
194
|
+
concept_names_sorted = sorted(concept_names, key=lambda x: x[1], reverse=True)
|
|
195
|
+
concept_name, score = concept_names_sorted[0]
|
|
196
|
+
concept_dict.add(neuron_idx, concept_name, score)
|
|
197
|
+
|
|
198
|
+
return concept_dict
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def _generate_concept_names_llm(texts_with_tokens: list[dict], llm_provider: str | None = None) -> list[
|
|
202
|
+
tuple[str, float]]:
|
|
203
|
+
raise NotImplementedError(
|
|
204
|
+
"LLM provider not configured. Please implement _generate_concept_names_llm "
|
|
205
|
+
"method with your preferred LLM provider (OpenAI, Anthropic, etc.)"
|
|
206
|
+
)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Sequence
|
|
2
|
+
|
|
3
|
+
from amber.utils import get_logger
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from amber.language_model.language_model import LanguageModel
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InputTracker:
|
|
12
|
+
"""
|
|
13
|
+
Simple listener that saves input texts before tokenization.
|
|
14
|
+
|
|
15
|
+
This is a singleton per LanguageModel instance. It's used as a listener
|
|
16
|
+
during inference to capture texts before they are tokenized. SAE hooks
|
|
17
|
+
can then access these texts to track top activating texts for their neurons.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
language_model: "LanguageModel",
|
|
23
|
+
) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Initialize InputTracker.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
language_model: Language model instance
|
|
29
|
+
"""
|
|
30
|
+
self.language_model = language_model
|
|
31
|
+
|
|
32
|
+
# Flag to control whether to save inputs
|
|
33
|
+
self._enabled: bool = False
|
|
34
|
+
|
|
35
|
+
# Runtime state - only stores texts
|
|
36
|
+
self._current_texts: list[str] = []
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def enabled(self) -> bool:
|
|
40
|
+
"""Whether input tracking is enabled."""
|
|
41
|
+
return self._enabled
|
|
42
|
+
|
|
43
|
+
def enable(self) -> None:
|
|
44
|
+
"""Enable input tracking."""
|
|
45
|
+
self._enabled = True
|
|
46
|
+
|
|
47
|
+
def disable(self) -> None:
|
|
48
|
+
"""Disable input tracking."""
|
|
49
|
+
self._enabled = False
|
|
50
|
+
|
|
51
|
+
def reset(self) -> None:
|
|
52
|
+
"""Reset stored texts."""
|
|
53
|
+
self._current_texts.clear()
|
|
54
|
+
|
|
55
|
+
def set_current_texts(self, texts: Sequence[str]) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Set the current batch of texts being processed.
|
|
58
|
+
|
|
59
|
+
This is called by LanguageModel._inference() before tokenization
|
|
60
|
+
if tracking is enabled.
|
|
61
|
+
"""
|
|
62
|
+
if self._enabled:
|
|
63
|
+
self._current_texts = list(texts)
|
|
64
|
+
|
|
65
|
+
def get_current_texts(self) -> list[str]:
|
|
66
|
+
"""Get the current batch of texts."""
|
|
67
|
+
return self._current_texts.copy()
|
|
68
|
+
|
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from overcomplete import SAE as OvercompleteSAE
|
|
6
|
+
from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
7
|
+
from amber.mechanistic.sae.sae import Sae
|
|
8
|
+
from amber.mechanistic.sae.sae_trainer import SaeTrainingConfig
|
|
9
|
+
from amber.store.store import Store
|
|
10
|
+
from amber.utils import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class L1SaeTrainingConfig(SaeTrainingConfig):
|
|
16
|
+
"""Training configuration for L1 SAE models.
|
|
17
|
+
|
|
18
|
+
This class extends SaeTrainingConfig to provide a type-safe configuration
|
|
19
|
+
interface specifically for L1 SAE models. While it currently uses the same
|
|
20
|
+
training parameters as the base SaeTrainingConfig, this design allows for:
|
|
21
|
+
|
|
22
|
+
1. **Type Safety**: Ensures that L1-specific training methods receive the
|
|
23
|
+
correct configuration type, preventing accidental use of incompatible configs.
|
|
24
|
+
|
|
25
|
+
2. **Future Extensibility**: Provides a clear extension point for L1-specific
|
|
26
|
+
training parameters that may be needed in the future (e.g., L1 regularization
|
|
27
|
+
scheduling, sparsity target parameters, etc.).
|
|
28
|
+
|
|
29
|
+
3. **API Clarity**: Makes the intent explicit in the codebase - when you see
|
|
30
|
+
L1SaeTrainingConfig, you know it's specifically for L1 SAE training.
|
|
31
|
+
|
|
32
|
+
For now, you can use this class exactly like SaeTrainingConfig. All parameters
|
|
33
|
+
from SaeTrainingConfig are available and work identically.
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class L1Sae(Sae):
|
|
39
|
+
"""L1 Sparse Autoencoder implementation.
|
|
40
|
+
|
|
41
|
+
Uses L1 regularization to enforce sparsity in the latent activations.
|
|
42
|
+
This implementation uses the base SAE class from the overcomplete library.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
n_latents: int,
|
|
48
|
+
n_inputs: int,
|
|
49
|
+
hook_id: str | None = None,
|
|
50
|
+
device: str = 'cpu',
|
|
51
|
+
store: Store | None = None,
|
|
52
|
+
*args: Any,
|
|
53
|
+
**kwargs: Any
|
|
54
|
+
) -> None:
|
|
55
|
+
super().__init__(n_latents, n_inputs, hook_id, device, store, *args, **kwargs)
|
|
56
|
+
|
|
57
|
+
def _initialize_sae_engine(self) -> OvercompleteSAE:
|
|
58
|
+
"""Initialize the SAE engine.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
OvercompleteSAE instance configured for L1 regularization
|
|
62
|
+
"""
|
|
63
|
+
return OvercompleteSAE(
|
|
64
|
+
input_shape=self.context.n_inputs,
|
|
65
|
+
nb_concepts=self.context.n_latents,
|
|
66
|
+
device=self.context.device
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
70
|
+
"""
|
|
71
|
+
Encode input using sae_engine.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
x: Input tensor of shape [batch_size, n_inputs]
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Encoded latents (L1 sparse activations)
|
|
78
|
+
"""
|
|
79
|
+
# Overcomplete SAE encode returns (pre_codes, codes)
|
|
80
|
+
_, codes = self.sae_engine.encode(x)
|
|
81
|
+
return codes
|
|
82
|
+
|
|
83
|
+
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
|
84
|
+
"""
|
|
85
|
+
Decode latents using sae_engine.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
x: Encoded tensor of shape [batch_size, n_latents]
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Reconstructed tensor of shape [batch_size, n_inputs]
|
|
92
|
+
"""
|
|
93
|
+
return self.sae_engine.decode(x)
|
|
94
|
+
|
|
95
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
96
|
+
"""
|
|
97
|
+
Forward pass using sae_engine.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
x: Input tensor of shape [batch_size, n_inputs]
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Reconstructed tensor of shape [batch_size, n_inputs]
|
|
104
|
+
"""
|
|
105
|
+
# Overcomplete SAE forward returns (pre_codes, codes, x_reconstructed)
|
|
106
|
+
_, _, x_reconstructed = self.sae_engine.forward(x)
|
|
107
|
+
return x_reconstructed
|
|
108
|
+
|
|
109
|
+
def train(
|
|
110
|
+
self,
|
|
111
|
+
store: Store,
|
|
112
|
+
run_id: str,
|
|
113
|
+
layer_signature: str | int,
|
|
114
|
+
config: L1SaeTrainingConfig | None = None,
|
|
115
|
+
training_run_id: str | None = None
|
|
116
|
+
) -> dict[str, Any]:
|
|
117
|
+
"""
|
|
118
|
+
Train L1SAE using activations from a Store.
|
|
119
|
+
|
|
120
|
+
This method delegates to the SaeTrainer composite class.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
store: Store instance containing activations
|
|
124
|
+
run_id: Run ID to train on
|
|
125
|
+
layer_signature: Layer signature to train on
|
|
126
|
+
config: Training configuration
|
|
127
|
+
training_run_id: Optional training run ID
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Dictionary with keys:
|
|
131
|
+
- "history": Training history dictionary
|
|
132
|
+
- "training_run_id": Training run ID where outputs were saved
|
|
133
|
+
"""
|
|
134
|
+
if config is None:
|
|
135
|
+
config = L1SaeTrainingConfig()
|
|
136
|
+
return self.trainer.train(store, run_id, layer_signature, config, training_run_id)
|
|
137
|
+
|
|
138
|
+
def modify_activations(
|
|
139
|
+
self,
|
|
140
|
+
module: "torch.nn.Module",
|
|
141
|
+
inputs: torch.Tensor | None,
|
|
142
|
+
output: torch.Tensor | None
|
|
143
|
+
) -> torch.Tensor | None:
|
|
144
|
+
"""
|
|
145
|
+
Modify activations using L1SAE (Controller hook interface).
|
|
146
|
+
|
|
147
|
+
Extracts tensor from inputs/output, applies SAE forward pass,
|
|
148
|
+
and optionally applies concept manipulation.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
module: The PyTorch module being hooked
|
|
152
|
+
inputs: Tuple of inputs to the module
|
|
153
|
+
output: Output from the module (None for pre_forward hooks)
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Modified activations with same shape as input
|
|
157
|
+
"""
|
|
158
|
+
# Extract tensor from output/inputs, handling objects with last_hidden_state
|
|
159
|
+
if self.hook_type == HookType.FORWARD:
|
|
160
|
+
if isinstance(output, torch.Tensor):
|
|
161
|
+
tensor = output
|
|
162
|
+
elif hasattr(output, "last_hidden_state") and isinstance(output.last_hidden_state, torch.Tensor):
|
|
163
|
+
tensor = output.last_hidden_state
|
|
164
|
+
elif isinstance(output, (tuple, list)):
|
|
165
|
+
# Try to find first tensor in tuple/list
|
|
166
|
+
tensor = next((item for item in output if isinstance(item, torch.Tensor)), None)
|
|
167
|
+
else:
|
|
168
|
+
tensor = None
|
|
169
|
+
else:
|
|
170
|
+
tensor = inputs[0] if len(inputs) > 0 and isinstance(inputs[0], torch.Tensor) else None
|
|
171
|
+
|
|
172
|
+
if tensor is None or not isinstance(tensor, torch.Tensor):
|
|
173
|
+
return output if self.hook_type == HookType.FORWARD else inputs
|
|
174
|
+
|
|
175
|
+
original_shape = tensor.shape
|
|
176
|
+
|
|
177
|
+
# Flatten to 2D for SAE processing: (batch, seq_len, hidden) -> (batch * seq_len, hidden)
|
|
178
|
+
# or keep as 2D if already 2D: (batch, hidden)
|
|
179
|
+
if len(original_shape) > 2:
|
|
180
|
+
batch_size, seq_len = original_shape[:2]
|
|
181
|
+
tensor_flat = tensor.reshape(-1, original_shape[-1])
|
|
182
|
+
else:
|
|
183
|
+
batch_size = original_shape[0]
|
|
184
|
+
seq_len = 1
|
|
185
|
+
tensor_flat = tensor
|
|
186
|
+
|
|
187
|
+
# Get full activations (pre_codes) and sparse codes
|
|
188
|
+
# Overcomplete SAE encode returns (pre_codes, codes)
|
|
189
|
+
pre_codes, codes = self.sae_engine.encode(tensor_flat)
|
|
190
|
+
|
|
191
|
+
# Save SAE activations (pre_codes) as 3D tensor: (batch, seq, n_latents)
|
|
192
|
+
latents_cpu = pre_codes.detach().cpu()
|
|
193
|
+
latents_3d = latents_cpu.reshape(batch_size, seq_len, -1)
|
|
194
|
+
|
|
195
|
+
# Save to tensor_metadata
|
|
196
|
+
self.tensor_metadata['neurons'] = latents_3d
|
|
197
|
+
self.tensor_metadata['activations'] = latents_3d
|
|
198
|
+
|
|
199
|
+
# Process each item in the batch individually for metadata
|
|
200
|
+
batch_items = []
|
|
201
|
+
n_items = latents_cpu.shape[0]
|
|
202
|
+
for item_idx in range(n_items):
|
|
203
|
+
item_latents = latents_cpu[item_idx] # [n_latents]
|
|
204
|
+
|
|
205
|
+
# Find nonzero indices for this item
|
|
206
|
+
nonzero_mask = item_latents != 0
|
|
207
|
+
nonzero_indices = torch.nonzero(nonzero_mask, as_tuple=False).flatten().tolist()
|
|
208
|
+
|
|
209
|
+
# Create map of nonzero indices to activations
|
|
210
|
+
activations_map = {
|
|
211
|
+
int(idx): float(item_latents[idx].item())
|
|
212
|
+
for idx in nonzero_indices
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
# Create item metadata
|
|
216
|
+
item_metadata = {
|
|
217
|
+
"nonzero_indices": nonzero_indices,
|
|
218
|
+
"activations": activations_map
|
|
219
|
+
}
|
|
220
|
+
batch_items.append(item_metadata)
|
|
221
|
+
|
|
222
|
+
# Save batch items metadata
|
|
223
|
+
self.metadata['batch_items'] = batch_items
|
|
224
|
+
|
|
225
|
+
# Use sparse codes for reconstruction
|
|
226
|
+
latents = codes
|
|
227
|
+
|
|
228
|
+
# Update top texts if text tracking is enabled
|
|
229
|
+
if self._text_tracking_enabled and self.context.lm is not None:
|
|
230
|
+
input_tracker = self.context.lm.get_input_tracker()
|
|
231
|
+
if input_tracker is not None:
|
|
232
|
+
texts = input_tracker.get_current_texts()
|
|
233
|
+
if texts:
|
|
234
|
+
# Use pre_codes (full activations) for text tracking
|
|
235
|
+
self.concepts.update_top_texts_from_latents(
|
|
236
|
+
latents_cpu,
|
|
237
|
+
texts,
|
|
238
|
+
original_shape
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Apply concept manipulation if parameters are set
|
|
242
|
+
# Check if multiplication or bias differ from defaults (ones)
|
|
243
|
+
if not torch.allclose(self.concepts.multiplication, torch.ones_like(self.concepts.multiplication)) or \
|
|
244
|
+
not torch.allclose(self.concepts.bias, torch.ones_like(self.concepts.bias)):
|
|
245
|
+
# Apply manipulation: latents = latents * multiplication + bias
|
|
246
|
+
latents = latents * self.concepts.multiplication + self.concepts.bias
|
|
247
|
+
|
|
248
|
+
# Decode to get reconstruction
|
|
249
|
+
reconstructed = self.decode(latents)
|
|
250
|
+
|
|
251
|
+
# Reshape back to original shape
|
|
252
|
+
if len(original_shape) > 2:
|
|
253
|
+
reconstructed = reconstructed.reshape(original_shape)
|
|
254
|
+
|
|
255
|
+
# Return in appropriate format
|
|
256
|
+
if self.hook_type == HookType.FORWARD:
|
|
257
|
+
if isinstance(output, torch.Tensor):
|
|
258
|
+
return reconstructed
|
|
259
|
+
elif isinstance(output, (tuple, list)):
|
|
260
|
+
# Replace first tensor in tuple/list
|
|
261
|
+
result = list(output)
|
|
262
|
+
for i, item in enumerate(result):
|
|
263
|
+
if isinstance(item, torch.Tensor):
|
|
264
|
+
result[i] = reconstructed
|
|
265
|
+
break
|
|
266
|
+
return tuple(result) if isinstance(output, tuple) else result
|
|
267
|
+
else:
|
|
268
|
+
# For objects with attributes, try to set last_hidden_state
|
|
269
|
+
if hasattr(output, "last_hidden_state"):
|
|
270
|
+
output.last_hidden_state = reconstructed
|
|
271
|
+
return output
|
|
272
|
+
else: # PRE_FORWARD
|
|
273
|
+
# Return modified inputs tuple
|
|
274
|
+
result = list(inputs)
|
|
275
|
+
if len(result) > 0:
|
|
276
|
+
result[0] = reconstructed
|
|
277
|
+
return tuple(result)
|
|
278
|
+
|
|
279
|
+
def save(self, name: str, path: str | Path | None = None) -> None:
|
|
280
|
+
"""
|
|
281
|
+
Save model using overcomplete's state dict + our metadata.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
name: Model name
|
|
285
|
+
path: Directory path to save to (defaults to current directory)
|
|
286
|
+
"""
|
|
287
|
+
if path is None:
|
|
288
|
+
path = Path.cwd()
|
|
289
|
+
save_dir = Path(path)
|
|
290
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
291
|
+
save_path = save_dir / f"{name}.pt"
|
|
292
|
+
|
|
293
|
+
# Save overcomplete model state dict
|
|
294
|
+
sae_state_dict = self.sae_engine.state_dict()
|
|
295
|
+
|
|
296
|
+
amber_metadata = {
|
|
297
|
+
"concepts_state": {
|
|
298
|
+
'multiplication': self.concepts.multiplication.data,
|
|
299
|
+
'bias': self.concepts.bias.data,
|
|
300
|
+
},
|
|
301
|
+
"n_latents": self.context.n_latents,
|
|
302
|
+
"n_inputs": self.context.n_inputs,
|
|
303
|
+
"device": self.context.device,
|
|
304
|
+
"layer_signature": self.context.lm_layer_signature,
|
|
305
|
+
"model_id": self.context.model_id,
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
payload = {
|
|
309
|
+
"sae_state_dict": sae_state_dict,
|
|
310
|
+
"amber_metadata": amber_metadata,
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
torch.save(payload, save_path)
|
|
314
|
+
logger.info(f"Saved L1SAE to {save_path}")
|
|
315
|
+
|
|
316
|
+
@classmethod
|
|
317
|
+
def load(cls, path: Path) -> "L1Sae":
|
|
318
|
+
"""
|
|
319
|
+
Load L1SAE from saved file using overcomplete's load method + our metadata.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
path: Path to saved model file
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
Loaded L1Sae instance
|
|
326
|
+
"""
|
|
327
|
+
p = Path(path)
|
|
328
|
+
|
|
329
|
+
# Load payload
|
|
330
|
+
if torch.cuda.is_available():
|
|
331
|
+
map_location = 'cuda'
|
|
332
|
+
elif torch.backends.mps.is_available():
|
|
333
|
+
map_location = 'mps'
|
|
334
|
+
else:
|
|
335
|
+
map_location = 'cpu'
|
|
336
|
+
payload = torch.load(p, map_location=map_location)
|
|
337
|
+
|
|
338
|
+
# Extract our metadata
|
|
339
|
+
if "amber_metadata" not in payload:
|
|
340
|
+
raise ValueError(f"Invalid L1SAE save format: missing 'amber_metadata' key in {p}")
|
|
341
|
+
|
|
342
|
+
amber_meta = payload["amber_metadata"]
|
|
343
|
+
n_latents = int(amber_meta["n_latents"])
|
|
344
|
+
n_inputs = int(amber_meta["n_inputs"])
|
|
345
|
+
device = amber_meta.get("device", "cpu")
|
|
346
|
+
layer_signature = amber_meta.get("layer_signature")
|
|
347
|
+
model_id = amber_meta.get("model_id")
|
|
348
|
+
concepts_state = amber_meta.get("concepts_state", {})
|
|
349
|
+
|
|
350
|
+
# Create L1Sae instance
|
|
351
|
+
l1_sae = L1Sae(
|
|
352
|
+
n_latents=n_latents,
|
|
353
|
+
n_inputs=n_inputs,
|
|
354
|
+
device=device
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Load overcomplete model state dict
|
|
358
|
+
if "sae_state_dict" in payload:
|
|
359
|
+
l1_sae.sae_engine.load_state_dict(payload["sae_state_dict"])
|
|
360
|
+
elif "model" in payload:
|
|
361
|
+
# Backward compatibility with old format
|
|
362
|
+
l1_sae.sae_engine.load_state_dict(payload["model"])
|
|
363
|
+
else:
|
|
364
|
+
# Assume payload is the state dict itself (backward compatibility)
|
|
365
|
+
l1_sae.sae_engine.load_state_dict(payload)
|
|
366
|
+
|
|
367
|
+
# Load concepts state
|
|
368
|
+
if concepts_state:
|
|
369
|
+
device = l1_sae.context.device
|
|
370
|
+
if isinstance(device, str):
|
|
371
|
+
device = torch.device(device)
|
|
372
|
+
if "multiplication" in concepts_state:
|
|
373
|
+
l1_sae.concepts.multiplication.data = concepts_state["multiplication"].to(device)
|
|
374
|
+
if "bias" in concepts_state:
|
|
375
|
+
l1_sae.concepts.bias.data = concepts_state["bias"].to(device)
|
|
376
|
+
|
|
377
|
+
# Note: Top texts loading was removed as serialization methods were removed
|
|
378
|
+
# Top texts should be exported/imported separately if needed
|
|
379
|
+
|
|
380
|
+
# Set context metadata
|
|
381
|
+
l1_sae.context.lm_layer_signature = layer_signature
|
|
382
|
+
l1_sae.context.model_id = model_id
|
|
383
|
+
|
|
384
|
+
params_str = f"n_latents={n_latents}, n_inputs={n_inputs}"
|
|
385
|
+
logger.info(f"\nLoaded L1SAE from {p}\n{params_str}")
|
|
386
|
+
|
|
387
|
+
return l1_sae
|
|
388
|
+
|
|
389
|
+
def process_activations(
|
|
390
|
+
self,
|
|
391
|
+
module: torch.nn.Module,
|
|
392
|
+
input: HOOK_FUNCTION_INPUT,
|
|
393
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
394
|
+
) -> None:
|
|
395
|
+
"""
|
|
396
|
+
Process activations (Detector interface).
|
|
397
|
+
|
|
398
|
+
Metadata saving is handled in modify_activations to avoid duplicate work.
|
|
399
|
+
This method is kept for interface compatibility but does nothing since
|
|
400
|
+
modify_activations already saves the metadata when called.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
module: The PyTorch module being hooked
|
|
404
|
+
input: Tuple of input tensors to the module
|
|
405
|
+
output: Output tensor(s) from the module
|
|
406
|
+
"""
|
|
407
|
+
# Metadata saving is done in modify_activations to avoid duplicate encoding
|
|
408
|
+
pass
|
|
409
|
+
|