sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +55 -18
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +105 -235
- sae_lens/constants.py +20 -0
- sae_lens/evals.py +34 -31
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +103 -70
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +36 -10
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +248 -273
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +105 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +134 -158
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +47 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc3.dist-info/RECORD +38 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
sae_lens/saes/sae.py
CHANGED
|
@@ -4,9 +4,18 @@ import json
|
|
|
4
4
|
import warnings
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from contextlib import contextmanager
|
|
7
|
-
from dataclasses import dataclass, field, fields
|
|
7
|
+
from dataclasses import asdict, dataclass, field, fields, replace
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Any,
|
|
12
|
+
Callable,
|
|
13
|
+
Generic,
|
|
14
|
+
Literal,
|
|
15
|
+
NamedTuple,
|
|
16
|
+
Type,
|
|
17
|
+
TypeVar,
|
|
18
|
+
)
|
|
10
19
|
|
|
11
20
|
import einops
|
|
12
21
|
import torch
|
|
@@ -15,16 +24,19 @@ from numpy.typing import NDArray
|
|
|
15
24
|
from safetensors.torch import save_file
|
|
16
25
|
from torch import nn
|
|
17
26
|
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
18
|
-
from typing_extensions import deprecated, overload
|
|
27
|
+
from typing_extensions import deprecated, overload, override
|
|
19
28
|
|
|
20
|
-
from sae_lens import logger
|
|
21
|
-
from sae_lens.
|
|
29
|
+
from sae_lens import __version__, logger
|
|
30
|
+
from sae_lens.constants import (
|
|
22
31
|
DTYPE_MAP,
|
|
23
32
|
SAE_CFG_FILENAME,
|
|
24
33
|
SAE_WEIGHTS_FILENAME,
|
|
25
|
-
SPARSITY_FILENAME,
|
|
26
|
-
LanguageModelSAERunnerConfig,
|
|
27
34
|
)
|
|
35
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
39
|
+
|
|
28
40
|
from sae_lens.loading.pretrained_sae_loaders import (
|
|
29
41
|
NAMED_PRETRAINED_SAE_LOADERS,
|
|
30
42
|
PretrainedSaeDiskLoader,
|
|
@@ -39,57 +51,81 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
39
51
|
get_pretrained_saes_directory,
|
|
40
52
|
get_repo_id_and_folder_name,
|
|
41
53
|
)
|
|
42
|
-
from sae_lens.
|
|
54
|
+
from sae_lens.registry import get_sae_class, get_sae_training_class
|
|
43
55
|
|
|
44
|
-
|
|
56
|
+
T_SAE_CONFIG = TypeVar("T_SAE_CONFIG", bound="SAEConfig")
|
|
57
|
+
T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound="TrainingSAEConfig")
|
|
58
|
+
T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
|
|
59
|
+
T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class SAEMetadata:
|
|
64
|
+
"""Core metadata about how this SAE should be used, if known."""
|
|
65
|
+
|
|
66
|
+
model_name: str | None = None
|
|
67
|
+
hook_name: str | None = None
|
|
68
|
+
model_class_name: str | None = None
|
|
69
|
+
hook_head_index: int | None = None
|
|
70
|
+
model_from_pretrained_kwargs: dict[str, Any] | None = None
|
|
71
|
+
prepend_bos: bool | None = None
|
|
72
|
+
exclude_special_tokens: bool | list[int] | None = None
|
|
73
|
+
neuronpedia_id: str | None = None
|
|
74
|
+
context_size: int | None = None
|
|
75
|
+
seqpos_slice: tuple[int | None, ...] | None = None
|
|
76
|
+
dataset_path: str | None = None
|
|
77
|
+
sae_lens_version: str = field(default_factory=lambda: __version__)
|
|
78
|
+
sae_lens_training_version: str = field(default_factory=lambda: __version__)
|
|
45
79
|
|
|
46
80
|
|
|
47
81
|
@dataclass
|
|
48
|
-
class SAEConfig:
|
|
82
|
+
class SAEConfig(ABC):
|
|
49
83
|
"""Base configuration for SAE models."""
|
|
50
84
|
|
|
51
|
-
architecture: str
|
|
52
85
|
d_in: int
|
|
53
86
|
d_sae: int
|
|
54
|
-
dtype: str
|
|
55
|
-
device: str
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
apply_b_dec_to_input: bool
|
|
63
|
-
finetuning_scaling_factor: bool
|
|
64
|
-
normalize_activations: str
|
|
65
|
-
context_size: int
|
|
66
|
-
dataset_path: str
|
|
67
|
-
dataset_trust_remote_code: bool
|
|
68
|
-
sae_lens_training_version: str
|
|
69
|
-
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
70
|
-
seqpos_slice: tuple[int, ...] | None = None
|
|
71
|
-
prepend_bos: bool = False
|
|
72
|
-
neuronpedia_id: str | None = None
|
|
73
|
-
|
|
74
|
-
def to_dict(self) -> dict[str, Any]:
|
|
75
|
-
return {field.name: getattr(self, field.name) for field in fields(self)}
|
|
87
|
+
dtype: str = "float32"
|
|
88
|
+
device: str = "cpu"
|
|
89
|
+
apply_b_dec_to_input: bool = True
|
|
90
|
+
normalize_activations: Literal[
|
|
91
|
+
"none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
|
|
92
|
+
] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
|
|
93
|
+
reshape_activations: Literal["none", "hook_z"] = "none"
|
|
94
|
+
metadata: SAEMetadata = field(default_factory=SAEMetadata)
|
|
76
95
|
|
|
77
96
|
@classmethod
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
valid_config_dict = {
|
|
81
|
-
key: val for key, val in config_dict.items() if key in valid_field_names
|
|
82
|
-
}
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def architecture(cls) -> str: ...
|
|
83
99
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
):
|
|
90
|
-
valid_config_dict["seqpos_slice"] = tuple(valid_config_dict["seqpos_slice"])
|
|
100
|
+
def to_dict(self) -> dict[str, Any]:
|
|
101
|
+
res = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
102
|
+
res["metadata"] = asdict(self.metadata)
|
|
103
|
+
res["architecture"] = self.architecture()
|
|
104
|
+
return res
|
|
91
105
|
|
|
92
|
-
|
|
106
|
+
@classmethod
|
|
107
|
+
def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
|
|
108
|
+
cfg_class = get_sae_class(config_dict["architecture"])[1]
|
|
109
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
|
|
110
|
+
res = cfg_class(**filtered_config_dict)
|
|
111
|
+
if "metadata" in config_dict:
|
|
112
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
113
|
+
if not isinstance(res, cls):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"SAE config class {cls} does not match dict config class {type(res)}"
|
|
116
|
+
)
|
|
117
|
+
return res
|
|
118
|
+
|
|
119
|
+
def __post_init__(self):
|
|
120
|
+
if self.normalize_activations not in [
|
|
121
|
+
"none",
|
|
122
|
+
"expected_average_only_in",
|
|
123
|
+
"constant_norm_rescale",
|
|
124
|
+
"layer_norm",
|
|
125
|
+
]:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
|
|
128
|
+
)
|
|
93
129
|
|
|
94
130
|
|
|
95
131
|
@dataclass
|
|
@@ -109,14 +145,19 @@ class TrainStepInput:
|
|
|
109
145
|
"""Input to a training step."""
|
|
110
146
|
|
|
111
147
|
sae_in: torch.Tensor
|
|
112
|
-
|
|
148
|
+
coefficients: dict[str, float]
|
|
113
149
|
dead_neuron_mask: torch.Tensor | None
|
|
114
150
|
|
|
115
151
|
|
|
116
|
-
class
|
|
152
|
+
class TrainCoefficientConfig(NamedTuple):
|
|
153
|
+
value: float
|
|
154
|
+
warm_up_steps: int
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
117
158
|
"""Abstract base class for all SAE architectures."""
|
|
118
159
|
|
|
119
|
-
cfg:
|
|
160
|
+
cfg: T_SAE_CONFIG
|
|
120
161
|
dtype: torch.dtype
|
|
121
162
|
device: torch.device
|
|
122
163
|
use_error_term: bool
|
|
@@ -127,13 +168,13 @@ class SAE(HookedRootModule, ABC):
|
|
|
127
168
|
W_dec: nn.Parameter
|
|
128
169
|
b_dec: nn.Parameter
|
|
129
170
|
|
|
130
|
-
def __init__(self, cfg:
|
|
171
|
+
def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
|
|
131
172
|
"""Initialize the SAE."""
|
|
132
173
|
super().__init__()
|
|
133
174
|
|
|
134
175
|
self.cfg = cfg
|
|
135
176
|
|
|
136
|
-
if cfg.model_from_pretrained_kwargs:
|
|
177
|
+
if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
|
|
137
178
|
warnings.warn(
|
|
138
179
|
"\nThis SAE has non-empty model_from_pretrained_kwargs. "
|
|
139
180
|
"\nFor optimal performance, load the model like so:\n"
|
|
@@ -147,19 +188,11 @@ class SAE(HookedRootModule, ABC):
|
|
|
147
188
|
self.use_error_term = use_error_term
|
|
148
189
|
|
|
149
190
|
# Set up activation function
|
|
150
|
-
self.activation_fn = self.
|
|
191
|
+
self.activation_fn = self.get_activation_fn()
|
|
151
192
|
|
|
152
193
|
# Initialize weights
|
|
153
194
|
self.initialize_weights()
|
|
154
195
|
|
|
155
|
-
# Handle presence / absence of scaling factor
|
|
156
|
-
if self.cfg.finetuning_scaling_factor:
|
|
157
|
-
self.apply_finetuning_scaling_factor = (
|
|
158
|
-
lambda x: x * self.finetuning_scaling_factor
|
|
159
|
-
)
|
|
160
|
-
else:
|
|
161
|
-
self.apply_finetuning_scaling_factor = lambda x: x
|
|
162
|
-
|
|
163
196
|
# Set up hooks
|
|
164
197
|
self.hook_sae_input = HookPoint()
|
|
165
198
|
self.hook_sae_acts_pre = HookPoint()
|
|
@@ -169,11 +202,9 @@ class SAE(HookedRootModule, ABC):
|
|
|
169
202
|
self.hook_sae_error = HookPoint()
|
|
170
203
|
|
|
171
204
|
# handle hook_z reshaping if needed.
|
|
172
|
-
if self.cfg.
|
|
173
|
-
# print(f"Setting up hook_z reshaping for {self.cfg.hook_name}")
|
|
205
|
+
if self.cfg.reshape_activations == "hook_z":
|
|
174
206
|
self.turn_on_forward_pass_hook_z_reshaping()
|
|
175
207
|
else:
|
|
176
|
-
# print(f"No hook_z reshaping needed for {self.cfg.hook_name}")
|
|
177
208
|
self.turn_off_forward_pass_hook_z_reshaping()
|
|
178
209
|
|
|
179
210
|
# Set up activation normalization
|
|
@@ -188,40 +219,9 @@ class SAE(HookedRootModule, ABC):
|
|
|
188
219
|
self.b_dec.data /= scaling_factor # type: ignore
|
|
189
220
|
self.cfg.normalize_activations = "none"
|
|
190
221
|
|
|
191
|
-
def
|
|
222
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
192
223
|
"""Get the activation function specified in config."""
|
|
193
|
-
return
|
|
194
|
-
self.cfg.activation_fn, **(self.cfg.activation_fn_kwargs or {})
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
@staticmethod
|
|
198
|
-
def _get_activation_fn_static(
|
|
199
|
-
activation_fn: str, **kwargs: Any
|
|
200
|
-
) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
201
|
-
"""Get the activation function from a string specification."""
|
|
202
|
-
if activation_fn == "relu":
|
|
203
|
-
return torch.nn.ReLU()
|
|
204
|
-
if activation_fn == "tanh-relu":
|
|
205
|
-
|
|
206
|
-
def tanh_relu(input: torch.Tensor) -> torch.Tensor:
|
|
207
|
-
input = torch.relu(input)
|
|
208
|
-
return torch.tanh(input)
|
|
209
|
-
|
|
210
|
-
return tanh_relu
|
|
211
|
-
if activation_fn == "topk":
|
|
212
|
-
if "k" not in kwargs:
|
|
213
|
-
raise ValueError("TopK activation function requires a k value.")
|
|
214
|
-
k = kwargs.get("k", 1) # Default k to 1 if not provided
|
|
215
|
-
|
|
216
|
-
def topk_fn(x: torch.Tensor) -> torch.Tensor:
|
|
217
|
-
topk = torch.topk(x.flatten(start_dim=-1), k=k, dim=-1)
|
|
218
|
-
values = torch.relu(topk.values)
|
|
219
|
-
result = torch.zeros_like(x.flatten(start_dim=-1))
|
|
220
|
-
result.scatter_(-1, topk.indices, values)
|
|
221
|
-
return result.view_as(x)
|
|
222
|
-
|
|
223
|
-
return topk_fn
|
|
224
|
-
raise ValueError(f"Unknown activation function: {activation_fn}")
|
|
224
|
+
return nn.ReLU()
|
|
225
225
|
|
|
226
226
|
def _setup_activation_normalization(self):
|
|
227
227
|
"""Set up activation normalization functions based on config."""
|
|
@@ -264,10 +264,20 @@ class SAE(HookedRootModule, ABC):
|
|
|
264
264
|
self.run_time_activation_norm_fn_in = lambda x: x
|
|
265
265
|
self.run_time_activation_norm_fn_out = lambda x: x
|
|
266
266
|
|
|
267
|
-
@abstractmethod
|
|
268
267
|
def initialize_weights(self):
|
|
269
268
|
"""Initialize model weights."""
|
|
270
|
-
|
|
269
|
+
self.b_dec = nn.Parameter(
|
|
270
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
w_dec_data = torch.empty(
|
|
274
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
275
|
+
)
|
|
276
|
+
nn.init.kaiming_uniform_(w_dec_data)
|
|
277
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
278
|
+
|
|
279
|
+
w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
|
|
280
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
271
281
|
|
|
272
282
|
@abstractmethod
|
|
273
283
|
def encode(
|
|
@@ -284,7 +294,10 @@ class SAE(HookedRootModule, ABC):
|
|
|
284
294
|
pass
|
|
285
295
|
|
|
286
296
|
def turn_on_forward_pass_hook_z_reshaping(self):
|
|
287
|
-
if
|
|
297
|
+
if (
|
|
298
|
+
self.cfg.metadata.hook_name is not None
|
|
299
|
+
and not self.cfg.metadata.hook_name.endswith("_z")
|
|
300
|
+
):
|
|
288
301
|
raise ValueError("This method should only be called for hook_z SAEs.")
|
|
289
302
|
|
|
290
303
|
# print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
|
|
@@ -313,19 +326,19 @@ class SAE(HookedRootModule, ABC):
|
|
|
313
326
|
|
|
314
327
|
@overload
|
|
315
328
|
def to(
|
|
316
|
-
self:
|
|
329
|
+
self: T_SAE,
|
|
317
330
|
device: torch.device | str | None = ...,
|
|
318
331
|
dtype: torch.dtype | None = ...,
|
|
319
332
|
non_blocking: bool = ...,
|
|
320
|
-
) ->
|
|
333
|
+
) -> T_SAE: ...
|
|
321
334
|
|
|
322
335
|
@overload
|
|
323
|
-
def to(self:
|
|
336
|
+
def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...
|
|
324
337
|
|
|
325
338
|
@overload
|
|
326
|
-
def to(self:
|
|
339
|
+
def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...
|
|
327
340
|
|
|
328
|
-
def to(self:
|
|
341
|
+
def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE: # type: ignore
|
|
329
342
|
device_arg = None
|
|
330
343
|
dtype_arg = None
|
|
331
344
|
|
|
@@ -426,15 +439,12 @@ class SAE(HookedRootModule, ABC):
|
|
|
426
439
|
|
|
427
440
|
def get_name(self):
|
|
428
441
|
"""Generate a name for this SAE."""
|
|
429
|
-
return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
|
|
442
|
+
return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"
|
|
430
443
|
|
|
431
|
-
def save_model(
|
|
432
|
-
|
|
433
|
-
) -> tuple[Path, Path, Path | None]:
|
|
434
|
-
"""Save model weights, config, and optional sparsity tensor to disk."""
|
|
444
|
+
def save_model(self, path: str | Path) -> tuple[Path, Path]:
|
|
445
|
+
"""Save model weights and config to disk."""
|
|
435
446
|
path = Path(path)
|
|
436
|
-
|
|
437
|
-
path.mkdir(parents=True)
|
|
447
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
438
448
|
|
|
439
449
|
# Generate the weights
|
|
440
450
|
state_dict = self.state_dict() # Use internal SAE state dict
|
|
@@ -448,13 +458,7 @@ class SAE(HookedRootModule, ABC):
|
|
|
448
458
|
with open(cfg_path, "w") as f:
|
|
449
459
|
json.dump(config, f)
|
|
450
460
|
|
|
451
|
-
|
|
452
|
-
sparsity_in_dict = {"sparsity": sparsity}
|
|
453
|
-
sparsity_path = path / SPARSITY_FILENAME
|
|
454
|
-
save_file(sparsity_in_dict, sparsity_path)
|
|
455
|
-
return model_weights_path, cfg_path, sparsity_path
|
|
456
|
-
|
|
457
|
-
return model_weights_path, cfg_path, None
|
|
461
|
+
return model_weights_path, cfg_path
|
|
458
462
|
|
|
459
463
|
## Initialization Methods
|
|
460
464
|
@torch.no_grad()
|
|
@@ -482,18 +486,21 @@ class SAE(HookedRootModule, ABC):
|
|
|
482
486
|
@classmethod
|
|
483
487
|
@deprecated("Use load_from_disk instead")
|
|
484
488
|
def load_from_pretrained(
|
|
485
|
-
cls: Type[
|
|
486
|
-
|
|
489
|
+
cls: Type[T_SAE],
|
|
490
|
+
path: str | Path,
|
|
491
|
+
device: str = "cpu",
|
|
492
|
+
dtype: str | None = None,
|
|
493
|
+
) -> T_SAE:
|
|
487
494
|
return cls.load_from_disk(path, device=device, dtype=dtype)
|
|
488
495
|
|
|
489
496
|
@classmethod
|
|
490
497
|
def load_from_disk(
|
|
491
|
-
cls: Type[
|
|
498
|
+
cls: Type[T_SAE],
|
|
492
499
|
path: str | Path,
|
|
493
500
|
device: str = "cpu",
|
|
494
501
|
dtype: str | None = None,
|
|
495
502
|
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
496
|
-
) ->
|
|
503
|
+
) -> T_SAE:
|
|
497
504
|
overrides = {"dtype": dtype} if dtype is not None else None
|
|
498
505
|
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
499
506
|
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
@@ -501,7 +508,7 @@ class SAE(HookedRootModule, ABC):
|
|
|
501
508
|
cfg_dict["architecture"]
|
|
502
509
|
)
|
|
503
510
|
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
504
|
-
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
|
|
511
|
+
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
505
512
|
sae = sae_cls(sae_cfg)
|
|
506
513
|
sae.process_state_dict_for_loading(state_dict)
|
|
507
514
|
sae.load_state_dict(state_dict)
|
|
@@ -509,13 +516,13 @@ class SAE(HookedRootModule, ABC):
|
|
|
509
516
|
|
|
510
517
|
@classmethod
|
|
511
518
|
def from_pretrained(
|
|
512
|
-
cls,
|
|
519
|
+
cls: Type[T_SAE],
|
|
513
520
|
release: str,
|
|
514
521
|
sae_id: str,
|
|
515
522
|
device: str = "cpu",
|
|
516
523
|
force_download: bool = False,
|
|
517
524
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
518
|
-
) -> tuple[
|
|
525
|
+
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
519
526
|
"""
|
|
520
527
|
Load a pretrained SAE from the Hugging Face model hub.
|
|
521
528
|
|
|
@@ -584,47 +591,31 @@ class SAE(HookedRootModule, ABC):
|
|
|
584
591
|
)
|
|
585
592
|
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
586
593
|
|
|
587
|
-
# Rename keys to match SAEConfig field names
|
|
588
|
-
renamed_cfg_dict = {}
|
|
589
|
-
rename_map = {
|
|
590
|
-
"hook_point": "hook_name",
|
|
591
|
-
"hook_point_layer": "hook_layer",
|
|
592
|
-
"hook_point_head_index": "hook_head_index",
|
|
593
|
-
"activation_fn": "activation_fn",
|
|
594
|
-
}
|
|
595
|
-
|
|
596
|
-
for k, v in cfg_dict.items():
|
|
597
|
-
renamed_cfg_dict[rename_map.get(k, k)] = v
|
|
598
|
-
|
|
599
|
-
# Set default values for required fields
|
|
600
|
-
renamed_cfg_dict.setdefault("activation_fn_kwargs", {})
|
|
601
|
-
renamed_cfg_dict.setdefault("seqpos_slice", None)
|
|
602
|
-
|
|
603
594
|
# Create SAE with appropriate architecture
|
|
604
595
|
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
605
|
-
|
|
596
|
+
cfg_dict["architecture"]
|
|
606
597
|
)
|
|
607
|
-
sae_cfg = sae_config_cls.from_dict(
|
|
608
|
-
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
|
|
598
|
+
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
599
|
+
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
609
600
|
sae = sae_cls(sae_cfg)
|
|
610
601
|
sae.process_state_dict_for_loading(state_dict)
|
|
611
602
|
sae.load_state_dict(state_dict)
|
|
612
603
|
|
|
613
604
|
# Apply normalization if needed
|
|
614
|
-
if
|
|
605
|
+
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
615
606
|
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
|
|
616
607
|
if norm_scaling_factor is not None:
|
|
617
608
|
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
|
|
618
|
-
|
|
609
|
+
cfg_dict["normalize_activations"] = "none"
|
|
619
610
|
else:
|
|
620
611
|
warnings.warn(
|
|
621
612
|
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
622
613
|
)
|
|
623
614
|
|
|
624
|
-
return sae,
|
|
615
|
+
return sae, cfg_dict, log_sparsities
|
|
625
616
|
|
|
626
617
|
@classmethod
|
|
627
|
-
def from_dict(cls: Type[
|
|
618
|
+
def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
|
628
619
|
"""Create an SAE from a config dictionary."""
|
|
629
620
|
sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
|
|
630
621
|
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
@@ -633,9 +624,11 @@ class SAE(HookedRootModule, ABC):
|
|
|
633
624
|
return sae_cls(sae_config_cls.from_dict(config_dict))
|
|
634
625
|
|
|
635
626
|
@classmethod
|
|
636
|
-
def get_sae_class_for_architecture(
|
|
627
|
+
def get_sae_class_for_architecture(
|
|
628
|
+
cls: Type[T_SAE], architecture: str
|
|
629
|
+
) -> Type[T_SAE]:
|
|
637
630
|
"""Get the SAE class for a given architecture."""
|
|
638
|
-
sae_cls = get_sae_class(architecture)
|
|
631
|
+
sae_cls, _ = get_sae_class(architecture)
|
|
639
632
|
if not issubclass(sae_cls, cls):
|
|
640
633
|
raise ValueError(
|
|
641
634
|
f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
|
|
@@ -645,161 +638,105 @@ class SAE(HookedRootModule, ABC):
|
|
|
645
638
|
# in the future, this can be used to load different config classes for different architectures
|
|
646
639
|
@classmethod
|
|
647
640
|
def get_sae_config_class_for_architecture(
|
|
648
|
-
cls
|
|
641
|
+
cls,
|
|
649
642
|
architecture: str, # noqa: ARG003
|
|
650
643
|
) -> type[SAEConfig]:
|
|
651
644
|
return SAEConfig
|
|
652
645
|
|
|
653
646
|
|
|
654
647
|
@dataclass(kw_only=True)
|
|
655
|
-
class TrainingSAEConfig(SAEConfig):
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
jumprelu_bandwidth: float
|
|
666
|
-
decoder_heuristic_init: bool
|
|
667
|
-
init_encoder_as_decoder_transpose: bool
|
|
668
|
-
scale_sparsity_penalty_by_decoder_norm: bool
|
|
648
|
+
class TrainingSAEConfig(SAEConfig, ABC):
|
|
649
|
+
noise_scale: float = 0.0
|
|
650
|
+
mse_loss_normalization: str | None = None
|
|
651
|
+
# https://transformer-circuits.pub/2024/april-update/index.html#training-saes
|
|
652
|
+
# 0.1 corresponds to the "heuristic" initialization, use None to disable
|
|
653
|
+
decoder_init_norm: float | None = 0.1
|
|
654
|
+
|
|
655
|
+
@classmethod
|
|
656
|
+
@abstractmethod
|
|
657
|
+
def architecture(cls) -> str: ...
|
|
669
658
|
|
|
670
659
|
@classmethod
|
|
671
660
|
def from_sae_runner_config(
|
|
672
|
-
cls
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
architecture=cfg.architecture,
|
|
677
|
-
d_in=cfg.d_in,
|
|
678
|
-
d_sae=cfg.d_sae, # type: ignore
|
|
679
|
-
dtype=cfg.dtype,
|
|
680
|
-
device=cfg.device,
|
|
661
|
+
cls: type[T_TRAINING_SAE_CONFIG],
|
|
662
|
+
cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
|
|
663
|
+
) -> T_TRAINING_SAE_CONFIG:
|
|
664
|
+
metadata = SAEMetadata(
|
|
681
665
|
model_name=cfg.model_name,
|
|
682
666
|
hook_name=cfg.hook_name,
|
|
683
|
-
hook_layer=cfg.hook_layer,
|
|
684
667
|
hook_head_index=cfg.hook_head_index,
|
|
685
|
-
activation_fn=cfg.activation_fn,
|
|
686
|
-
activation_fn_kwargs=cfg.activation_fn_kwargs,
|
|
687
|
-
apply_b_dec_to_input=cfg.apply_b_dec_to_input,
|
|
688
|
-
finetuning_scaling_factor=cfg.finetuning_method is not None,
|
|
689
|
-
sae_lens_training_version=cfg.sae_lens_training_version,
|
|
690
668
|
context_size=cfg.context_size,
|
|
691
|
-
dataset_path=cfg.dataset_path,
|
|
692
669
|
prepend_bos=cfg.prepend_bos,
|
|
693
|
-
seqpos_slice=
|
|
694
|
-
if cfg.seqpos_slice is not None
|
|
695
|
-
else None,
|
|
696
|
-
# Training cfg
|
|
697
|
-
l1_coefficient=cfg.l1_coefficient,
|
|
698
|
-
lp_norm=cfg.lp_norm,
|
|
699
|
-
use_ghost_grads=cfg.use_ghost_grads,
|
|
700
|
-
normalize_sae_decoder=cfg.normalize_sae_decoder,
|
|
701
|
-
noise_scale=cfg.noise_scale,
|
|
702
|
-
decoder_orthogonal_init=cfg.decoder_orthogonal_init,
|
|
703
|
-
mse_loss_normalization=cfg.mse_loss_normalization,
|
|
704
|
-
decoder_heuristic_init=cfg.decoder_heuristic_init,
|
|
705
|
-
init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
|
|
706
|
-
scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
|
|
707
|
-
normalize_activations=cfg.normalize_activations,
|
|
708
|
-
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
|
|
670
|
+
seqpos_slice=cfg.seqpos_slice,
|
|
709
671
|
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
|
|
710
|
-
jumprelu_init_threshold=cfg.jumprelu_init_threshold,
|
|
711
|
-
jumprelu_bandwidth=cfg.jumprelu_bandwidth,
|
|
712
672
|
)
|
|
673
|
+
if not isinstance(cfg.sae, cls):
|
|
674
|
+
raise ValueError(
|
|
675
|
+
f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
|
|
676
|
+
)
|
|
677
|
+
return replace(cfg.sae, metadata=metadata)
|
|
713
678
|
|
|
714
679
|
@classmethod
|
|
715
|
-
def from_dict(
|
|
680
|
+
def from_dict(
|
|
681
|
+
cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
|
|
682
|
+
) -> T_TRAINING_SAE_CONFIG:
|
|
716
683
|
# remove any keys that are not in the dataclass
|
|
717
684
|
# since we sometimes enhance the config with the whole LM runner config
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
if "
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
valid_config_dict["seqpos_slice"]
|
|
730
|
-
)
|
|
731
|
-
elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
|
|
732
|
-
valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
|
|
733
|
-
|
|
734
|
-
return TrainingSAEConfig(**valid_config_dict)
|
|
685
|
+
valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
686
|
+
cfg_class = cls
|
|
687
|
+
if "architecture" in config_dict:
|
|
688
|
+
cfg_class = get_sae_training_class(config_dict["architecture"])[1]
|
|
689
|
+
if not issubclass(cfg_class, cls):
|
|
690
|
+
raise ValueError(
|
|
691
|
+
f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
|
|
692
|
+
)
|
|
693
|
+
if "metadata" in config_dict:
|
|
694
|
+
valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
|
|
695
|
+
return cfg_class(**valid_config_dict)
|
|
735
696
|
|
|
736
697
|
def to_dict(self) -> dict[str, Any]:
|
|
737
698
|
return {
|
|
738
699
|
**super().to_dict(),
|
|
739
|
-
|
|
740
|
-
"
|
|
741
|
-
"use_ghost_grads": self.use_ghost_grads,
|
|
742
|
-
"normalize_sae_decoder": self.normalize_sae_decoder,
|
|
743
|
-
"noise_scale": self.noise_scale,
|
|
744
|
-
"decoder_orthogonal_init": self.decoder_orthogonal_init,
|
|
745
|
-
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
|
|
746
|
-
"mse_loss_normalization": self.mse_loss_normalization,
|
|
747
|
-
"decoder_heuristic_init": self.decoder_heuristic_init,
|
|
748
|
-
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
|
|
749
|
-
"normalize_activations": self.normalize_activations,
|
|
750
|
-
"jumprelu_init_threshold": self.jumprelu_init_threshold,
|
|
751
|
-
"jumprelu_bandwidth": self.jumprelu_bandwidth,
|
|
700
|
+
**asdict(self),
|
|
701
|
+
"architecture": self.architecture(),
|
|
752
702
|
}
|
|
753
703
|
|
|
754
704
|
# this needs to exist so we can initialize the parent sae cfg without the training specific
|
|
755
705
|
# parameters. Maybe there's a cleaner way to do this
|
|
756
706
|
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
"model_name": self.model_name,
|
|
766
|
-
"hook_name": self.hook_name,
|
|
767
|
-
"hook_layer": self.hook_layer,
|
|
768
|
-
"hook_head_index": self.hook_head_index,
|
|
769
|
-
"device": self.device,
|
|
770
|
-
"context_size": self.context_size,
|
|
771
|
-
"prepend_bos": self.prepend_bos,
|
|
772
|
-
"finetuning_scaling_factor": self.finetuning_scaling_factor,
|
|
773
|
-
"normalize_activations": self.normalize_activations,
|
|
774
|
-
"dataset_path": self.dataset_path,
|
|
775
|
-
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
776
|
-
"sae_lens_training_version": self.sae_lens_training_version,
|
|
777
|
-
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
|
|
778
|
-
"seqpos_slice": self.seqpos_slice,
|
|
779
|
-
"neuronpedia_id": self.neuronpedia_id,
|
|
707
|
+
"""
|
|
708
|
+
Creates a dictionary containing attributes corresponding to the fields
|
|
709
|
+
defined in the base SAEConfig class.
|
|
710
|
+
"""
|
|
711
|
+
base_config_field_names = {f.name for f in fields(SAEConfig)}
|
|
712
|
+
result_dict = {
|
|
713
|
+
field_name: getattr(self, field_name)
|
|
714
|
+
for field_name in base_config_field_names
|
|
780
715
|
}
|
|
716
|
+
result_dict["architecture"] = self.architecture()
|
|
717
|
+
return result_dict
|
|
781
718
|
|
|
782
719
|
|
|
783
|
-
class TrainingSAE(SAE, ABC):
|
|
720
|
+
class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
784
721
|
"""Abstract base class for training versions of SAEs."""
|
|
785
722
|
|
|
786
|
-
cfg:
|
|
787
|
-
|
|
788
|
-
def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
|
|
723
|
+
def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
|
|
789
724
|
super().__init__(cfg, use_error_term)
|
|
790
725
|
|
|
791
726
|
# Turn off hook_z reshaping for training mode - the activation store
|
|
792
727
|
# is expected to handle reshaping before passing data to the SAE
|
|
793
728
|
self.turn_off_forward_pass_hook_z_reshaping()
|
|
794
|
-
|
|
795
729
|
self.mse_loss_fn = self._get_mse_loss_fn()
|
|
796
730
|
|
|
731
|
+
@abstractmethod
|
|
732
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
|
|
733
|
+
|
|
797
734
|
@abstractmethod
|
|
798
735
|
def encode_with_hidden_pre(
|
|
799
736
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
800
737
|
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
801
738
|
"""Encode with access to pre-activation values for training."""
|
|
802
|
-
|
|
739
|
+
...
|
|
803
740
|
|
|
804
741
|
def encode(
|
|
805
742
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -818,12 +755,20 @@ class TrainingSAE(SAE, ABC):
|
|
|
818
755
|
Decodes feature activations back into input space,
|
|
819
756
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
820
757
|
"""
|
|
821
|
-
|
|
822
|
-
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
758
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
823
759
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
824
760
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
825
761
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
826
762
|
|
|
763
|
+
@override
|
|
764
|
+
def initialize_weights(self):
|
|
765
|
+
super().initialize_weights()
|
|
766
|
+
if self.cfg.decoder_init_norm is not None:
|
|
767
|
+
with torch.no_grad():
|
|
768
|
+
self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
|
|
769
|
+
self.W_dec.data *= self.cfg.decoder_init_norm
|
|
770
|
+
self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()
|
|
771
|
+
|
|
827
772
|
@abstractmethod
|
|
828
773
|
def calculate_aux_loss(
|
|
829
774
|
self,
|
|
@@ -833,7 +778,7 @@ class TrainingSAE(SAE, ABC):
|
|
|
833
778
|
sae_out: torch.Tensor,
|
|
834
779
|
) -> torch.Tensor | dict[str, torch.Tensor]:
|
|
835
780
|
"""Calculate architecture-specific auxiliary loss terms."""
|
|
836
|
-
|
|
781
|
+
...
|
|
837
782
|
|
|
838
783
|
def training_forward_pass(
|
|
839
784
|
self,
|
|
@@ -883,6 +828,39 @@ class TrainingSAE(SAE, ABC):
|
|
|
883
828
|
losses=losses,
|
|
884
829
|
)
|
|
885
830
|
|
|
831
|
+
def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
|
|
832
|
+
"""Save inference version of model weights and config to disk."""
|
|
833
|
+
path = Path(path)
|
|
834
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
835
|
+
|
|
836
|
+
# Generate the weights
|
|
837
|
+
state_dict = self.state_dict() # Use internal SAE state dict
|
|
838
|
+
self.process_state_dict_for_saving_inference(state_dict)
|
|
839
|
+
model_weights_path = path / SAE_WEIGHTS_FILENAME
|
|
840
|
+
save_file(state_dict, model_weights_path)
|
|
841
|
+
|
|
842
|
+
# Save the config
|
|
843
|
+
config = self.to_inference_config_dict()
|
|
844
|
+
cfg_path = path / SAE_CFG_FILENAME
|
|
845
|
+
with open(cfg_path, "w") as f:
|
|
846
|
+
json.dump(config, f)
|
|
847
|
+
|
|
848
|
+
return model_weights_path, cfg_path
|
|
849
|
+
|
|
850
|
+
@abstractmethod
|
|
851
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
852
|
+
"""Convert the config into an inference SAE config dict."""
|
|
853
|
+
...
|
|
854
|
+
|
|
855
|
+
def process_state_dict_for_saving_inference(
|
|
856
|
+
self, state_dict: dict[str, Any]
|
|
857
|
+
) -> None:
|
|
858
|
+
"""
|
|
859
|
+
Process the state dict for saving the inference model.
|
|
860
|
+
This is a hook that can be overridden to change how the state dict is processed for the inference model.
|
|
861
|
+
"""
|
|
862
|
+
return self.process_state_dict_for_saving(state_dict)
|
|
863
|
+
|
|
886
864
|
def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
887
865
|
"""Get the MSE loss function based on config."""
|
|
888
866
|
|
|
@@ -921,11 +899,6 @@ class TrainingSAE(SAE, ABC):
|
|
|
921
899
|
"d_sae, d_sae d_in -> d_sae d_in",
|
|
922
900
|
)
|
|
923
901
|
|
|
924
|
-
@torch.no_grad()
|
|
925
|
-
def set_decoder_norm_to_unit_norm(self):
|
|
926
|
-
"""Normalize decoder columns to unit norm."""
|
|
927
|
-
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
928
|
-
|
|
929
902
|
@torch.no_grad()
|
|
930
903
|
def log_histograms(self) -> dict[str, NDArray[Any]]:
|
|
931
904
|
"""Log histograms of the weights and biases."""
|
|
@@ -935,9 +908,11 @@ class TrainingSAE(SAE, ABC):
|
|
|
935
908
|
}
|
|
936
909
|
|
|
937
910
|
@classmethod
|
|
938
|
-
def get_sae_class_for_architecture(
|
|
911
|
+
def get_sae_class_for_architecture(
|
|
912
|
+
cls: Type[T_TRAINING_SAE], architecture: str
|
|
913
|
+
) -> Type[T_TRAINING_SAE]:
|
|
939
914
|
"""Get the SAE class for a given architecture."""
|
|
940
|
-
sae_cls = get_sae_training_class(architecture)
|
|
915
|
+
sae_cls, _ = get_sae_training_class(architecture)
|
|
941
916
|
if not issubclass(sae_cls, cls):
|
|
942
917
|
raise ValueError(
|
|
943
918
|
f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
|
|
@@ -947,17 +922,17 @@ class TrainingSAE(SAE, ABC):
|
|
|
947
922
|
# in the future, this can be used to load different config classes for different architectures
|
|
948
923
|
@classmethod
|
|
949
924
|
def get_sae_config_class_for_architecture(
|
|
950
|
-
cls
|
|
925
|
+
cls,
|
|
951
926
|
architecture: str, # noqa: ARG003
|
|
952
|
-
) -> type[
|
|
953
|
-
return
|
|
927
|
+
) -> type[TrainingSAEConfig]:
|
|
928
|
+
return get_sae_training_class(architecture)[1]
|
|
954
929
|
|
|
955
930
|
|
|
956
931
|
_blank_hook = nn.Identity()
|
|
957
932
|
|
|
958
933
|
|
|
959
934
|
@contextmanager
|
|
960
|
-
def _disable_hooks(sae: SAE):
|
|
935
|
+
def _disable_hooks(sae: SAE[Any]):
|
|
961
936
|
"""
|
|
962
937
|
Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
|
|
963
938
|
"""
|