sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +50 -16
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +2 -1
- sae_lens/config.py +59 -231
- sae_lens/constants.py +18 -0
- sae_lens/evals.py +16 -13
- sae_lens/loading/pretrained_sae_loaders.py +36 -3
- sae_lens/registry.py +49 -0
- sae_lens/sae_training_runner.py +22 -21
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +250 -272
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activations_store.py +31 -15
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +44 -69
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +28 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc2.dist-info/RECORD +35 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/LICENSE +0 -0
sae_lens/saes/sae.py
CHANGED
|
@@ -4,9 +4,18 @@ import json
|
|
|
4
4
|
import warnings
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from contextlib import contextmanager
|
|
7
|
-
from dataclasses import dataclass, field, fields
|
|
7
|
+
from dataclasses import asdict, dataclass, field, fields, replace
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Any,
|
|
12
|
+
Callable,
|
|
13
|
+
Generic,
|
|
14
|
+
Literal,
|
|
15
|
+
NamedTuple,
|
|
16
|
+
Type,
|
|
17
|
+
TypeVar,
|
|
18
|
+
)
|
|
10
19
|
|
|
11
20
|
import einops
|
|
12
21
|
import torch
|
|
@@ -15,16 +24,19 @@ from numpy.typing import NDArray
|
|
|
15
24
|
from safetensors.torch import save_file
|
|
16
25
|
from torch import nn
|
|
17
26
|
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
18
|
-
from typing_extensions import deprecated, overload
|
|
27
|
+
from typing_extensions import deprecated, overload, override
|
|
19
28
|
|
|
20
|
-
from sae_lens import logger
|
|
21
|
-
from sae_lens.
|
|
29
|
+
from sae_lens import __version__, logger
|
|
30
|
+
from sae_lens.constants import (
|
|
22
31
|
DTYPE_MAP,
|
|
23
32
|
SAE_CFG_FILENAME,
|
|
24
33
|
SAE_WEIGHTS_FILENAME,
|
|
25
|
-
SPARSITY_FILENAME,
|
|
26
|
-
LanguageModelSAERunnerConfig,
|
|
27
34
|
)
|
|
35
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
39
|
+
|
|
28
40
|
from sae_lens.loading.pretrained_sae_loaders import (
|
|
29
41
|
NAMED_PRETRAINED_SAE_LOADERS,
|
|
30
42
|
PretrainedSaeDiskLoader,
|
|
@@ -39,57 +51,82 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
39
51
|
get_pretrained_saes_directory,
|
|
40
52
|
get_repo_id_and_folder_name,
|
|
41
53
|
)
|
|
42
|
-
from sae_lens.
|
|
54
|
+
from sae_lens.registry import get_sae_class, get_sae_training_class
|
|
43
55
|
|
|
44
|
-
|
|
56
|
+
T_SAE_CONFIG = TypeVar("T_SAE_CONFIG", bound="SAEConfig")
|
|
57
|
+
T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound="TrainingSAEConfig")
|
|
58
|
+
T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
|
|
59
|
+
T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class SAEMetadata:
|
|
64
|
+
"""Core metadata about how this SAE should be used, if known."""
|
|
65
|
+
|
|
66
|
+
model_name: str | None = None
|
|
67
|
+
hook_name: str | None = None
|
|
68
|
+
model_class_name: str | None = None
|
|
69
|
+
hook_layer: int | None = None
|
|
70
|
+
hook_head_index: int | None = None
|
|
71
|
+
model_from_pretrained_kwargs: dict[str, Any] | None = None
|
|
72
|
+
prepend_bos: bool | None = None
|
|
73
|
+
exclude_special_tokens: bool | list[int] | None = None
|
|
74
|
+
neuronpedia_id: str | None = None
|
|
75
|
+
context_size: int | None = None
|
|
76
|
+
seqpos_slice: tuple[int | None, ...] | None = None
|
|
77
|
+
dataset_path: str | None = None
|
|
78
|
+
sae_lens_version: str = field(default_factory=lambda: __version__)
|
|
79
|
+
sae_lens_training_version: str = field(default_factory=lambda: __version__)
|
|
45
80
|
|
|
46
81
|
|
|
47
82
|
@dataclass
|
|
48
|
-
class SAEConfig:
|
|
83
|
+
class SAEConfig(ABC):
|
|
49
84
|
"""Base configuration for SAE models."""
|
|
50
85
|
|
|
51
|
-
architecture: str
|
|
52
86
|
d_in: int
|
|
53
87
|
d_sae: int
|
|
54
|
-
dtype: str
|
|
55
|
-
device: str
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
apply_b_dec_to_input: bool
|
|
63
|
-
finetuning_scaling_factor: bool
|
|
64
|
-
normalize_activations: str
|
|
65
|
-
context_size: int
|
|
66
|
-
dataset_path: str
|
|
67
|
-
dataset_trust_remote_code: bool
|
|
68
|
-
sae_lens_training_version: str
|
|
69
|
-
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
70
|
-
seqpos_slice: tuple[int, ...] | None = None
|
|
71
|
-
prepend_bos: bool = False
|
|
72
|
-
neuronpedia_id: str | None = None
|
|
73
|
-
|
|
74
|
-
def to_dict(self) -> dict[str, Any]:
|
|
75
|
-
return {field.name: getattr(self, field.name) for field in fields(self)}
|
|
88
|
+
dtype: str = "float32"
|
|
89
|
+
device: str = "cpu"
|
|
90
|
+
apply_b_dec_to_input: bool = True
|
|
91
|
+
normalize_activations: Literal[
|
|
92
|
+
"none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
|
|
93
|
+
] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
|
|
94
|
+
reshape_activations: Literal["none", "hook_z"] = "none"
|
|
95
|
+
metadata: SAEMetadata = field(default_factory=SAEMetadata)
|
|
76
96
|
|
|
77
97
|
@classmethod
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
valid_config_dict = {
|
|
81
|
-
key: val for key, val in config_dict.items() if key in valid_field_names
|
|
82
|
-
}
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def architecture(cls) -> str: ...
|
|
83
100
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
):
|
|
90
|
-
valid_config_dict["seqpos_slice"] = tuple(valid_config_dict["seqpos_slice"])
|
|
101
|
+
def to_dict(self) -> dict[str, Any]:
|
|
102
|
+
res = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
103
|
+
res["metadata"] = asdict(self.metadata)
|
|
104
|
+
res["architecture"] = self.architecture()
|
|
105
|
+
return res
|
|
91
106
|
|
|
92
|
-
|
|
107
|
+
@classmethod
|
|
108
|
+
def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
|
|
109
|
+
cfg_class = get_sae_class(config_dict["architecture"])[1]
|
|
110
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
|
|
111
|
+
res = cfg_class(**filtered_config_dict)
|
|
112
|
+
if "metadata" in config_dict:
|
|
113
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
114
|
+
if not isinstance(res, cls):
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"SAE config class {cls} does not match dict config class {type(res)}"
|
|
117
|
+
)
|
|
118
|
+
return res
|
|
119
|
+
|
|
120
|
+
def __post_init__(self):
|
|
121
|
+
if self.normalize_activations not in [
|
|
122
|
+
"none",
|
|
123
|
+
"expected_average_only_in",
|
|
124
|
+
"constant_norm_rescale",
|
|
125
|
+
"layer_norm",
|
|
126
|
+
]:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
|
|
129
|
+
)
|
|
93
130
|
|
|
94
131
|
|
|
95
132
|
@dataclass
|
|
@@ -109,14 +146,19 @@ class TrainStepInput:
|
|
|
109
146
|
"""Input to a training step."""
|
|
110
147
|
|
|
111
148
|
sae_in: torch.Tensor
|
|
112
|
-
|
|
149
|
+
coefficients: dict[str, float]
|
|
113
150
|
dead_neuron_mask: torch.Tensor | None
|
|
114
151
|
|
|
115
152
|
|
|
116
|
-
class
|
|
153
|
+
class TrainCoefficientConfig(NamedTuple):
|
|
154
|
+
value: float
|
|
155
|
+
warm_up_steps: int
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
117
159
|
"""Abstract base class for all SAE architectures."""
|
|
118
160
|
|
|
119
|
-
cfg:
|
|
161
|
+
cfg: T_SAE_CONFIG
|
|
120
162
|
dtype: torch.dtype
|
|
121
163
|
device: torch.device
|
|
122
164
|
use_error_term: bool
|
|
@@ -127,13 +169,13 @@ class SAE(HookedRootModule, ABC):
|
|
|
127
169
|
W_dec: nn.Parameter
|
|
128
170
|
b_dec: nn.Parameter
|
|
129
171
|
|
|
130
|
-
def __init__(self, cfg:
|
|
172
|
+
def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
|
|
131
173
|
"""Initialize the SAE."""
|
|
132
174
|
super().__init__()
|
|
133
175
|
|
|
134
176
|
self.cfg = cfg
|
|
135
177
|
|
|
136
|
-
if cfg.model_from_pretrained_kwargs:
|
|
178
|
+
if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
|
|
137
179
|
warnings.warn(
|
|
138
180
|
"\nThis SAE has non-empty model_from_pretrained_kwargs. "
|
|
139
181
|
"\nFor optimal performance, load the model like so:\n"
|
|
@@ -147,19 +189,11 @@ class SAE(HookedRootModule, ABC):
|
|
|
147
189
|
self.use_error_term = use_error_term
|
|
148
190
|
|
|
149
191
|
# Set up activation function
|
|
150
|
-
self.activation_fn = self.
|
|
192
|
+
self.activation_fn = self.get_activation_fn()
|
|
151
193
|
|
|
152
194
|
# Initialize weights
|
|
153
195
|
self.initialize_weights()
|
|
154
196
|
|
|
155
|
-
# Handle presence / absence of scaling factor
|
|
156
|
-
if self.cfg.finetuning_scaling_factor:
|
|
157
|
-
self.apply_finetuning_scaling_factor = (
|
|
158
|
-
lambda x: x * self.finetuning_scaling_factor
|
|
159
|
-
)
|
|
160
|
-
else:
|
|
161
|
-
self.apply_finetuning_scaling_factor = lambda x: x
|
|
162
|
-
|
|
163
197
|
# Set up hooks
|
|
164
198
|
self.hook_sae_input = HookPoint()
|
|
165
199
|
self.hook_sae_acts_pre = HookPoint()
|
|
@@ -169,11 +203,9 @@ class SAE(HookedRootModule, ABC):
|
|
|
169
203
|
self.hook_sae_error = HookPoint()
|
|
170
204
|
|
|
171
205
|
# handle hook_z reshaping if needed.
|
|
172
|
-
if self.cfg.
|
|
173
|
-
# print(f"Setting up hook_z reshaping for {self.cfg.hook_name}")
|
|
206
|
+
if self.cfg.reshape_activations == "hook_z":
|
|
174
207
|
self.turn_on_forward_pass_hook_z_reshaping()
|
|
175
208
|
else:
|
|
176
|
-
# print(f"No hook_z reshaping needed for {self.cfg.hook_name}")
|
|
177
209
|
self.turn_off_forward_pass_hook_z_reshaping()
|
|
178
210
|
|
|
179
211
|
# Set up activation normalization
|
|
@@ -188,40 +220,9 @@ class SAE(HookedRootModule, ABC):
|
|
|
188
220
|
self.b_dec.data /= scaling_factor # type: ignore
|
|
189
221
|
self.cfg.normalize_activations = "none"
|
|
190
222
|
|
|
191
|
-
def
|
|
223
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
192
224
|
"""Get the activation function specified in config."""
|
|
193
|
-
return
|
|
194
|
-
self.cfg.activation_fn, **(self.cfg.activation_fn_kwargs or {})
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
@staticmethod
|
|
198
|
-
def _get_activation_fn_static(
|
|
199
|
-
activation_fn: str, **kwargs: Any
|
|
200
|
-
) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
201
|
-
"""Get the activation function from a string specification."""
|
|
202
|
-
if activation_fn == "relu":
|
|
203
|
-
return torch.nn.ReLU()
|
|
204
|
-
if activation_fn == "tanh-relu":
|
|
205
|
-
|
|
206
|
-
def tanh_relu(input: torch.Tensor) -> torch.Tensor:
|
|
207
|
-
input = torch.relu(input)
|
|
208
|
-
return torch.tanh(input)
|
|
209
|
-
|
|
210
|
-
return tanh_relu
|
|
211
|
-
if activation_fn == "topk":
|
|
212
|
-
if "k" not in kwargs:
|
|
213
|
-
raise ValueError("TopK activation function requires a k value.")
|
|
214
|
-
k = kwargs.get("k", 1) # Default k to 1 if not provided
|
|
215
|
-
|
|
216
|
-
def topk_fn(x: torch.Tensor) -> torch.Tensor:
|
|
217
|
-
topk = torch.topk(x.flatten(start_dim=-1), k=k, dim=-1)
|
|
218
|
-
values = torch.relu(topk.values)
|
|
219
|
-
result = torch.zeros_like(x.flatten(start_dim=-1))
|
|
220
|
-
result.scatter_(-1, topk.indices, values)
|
|
221
|
-
return result.view_as(x)
|
|
222
|
-
|
|
223
|
-
return topk_fn
|
|
224
|
-
raise ValueError(f"Unknown activation function: {activation_fn}")
|
|
225
|
+
return nn.ReLU()
|
|
225
226
|
|
|
226
227
|
def _setup_activation_normalization(self):
|
|
227
228
|
"""Set up activation normalization functions based on config."""
|
|
@@ -264,10 +265,20 @@ class SAE(HookedRootModule, ABC):
|
|
|
264
265
|
self.run_time_activation_norm_fn_in = lambda x: x
|
|
265
266
|
self.run_time_activation_norm_fn_out = lambda x: x
|
|
266
267
|
|
|
267
|
-
@abstractmethod
|
|
268
268
|
def initialize_weights(self):
|
|
269
269
|
"""Initialize model weights."""
|
|
270
|
-
|
|
270
|
+
self.b_dec = nn.Parameter(
|
|
271
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
w_dec_data = torch.empty(
|
|
275
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
276
|
+
)
|
|
277
|
+
nn.init.kaiming_uniform_(w_dec_data)
|
|
278
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
279
|
+
|
|
280
|
+
w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
|
|
281
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
271
282
|
|
|
272
283
|
@abstractmethod
|
|
273
284
|
def encode(
|
|
@@ -284,7 +295,10 @@ class SAE(HookedRootModule, ABC):
|
|
|
284
295
|
pass
|
|
285
296
|
|
|
286
297
|
def turn_on_forward_pass_hook_z_reshaping(self):
|
|
287
|
-
if
|
|
298
|
+
if (
|
|
299
|
+
self.cfg.metadata.hook_name is not None
|
|
300
|
+
and not self.cfg.metadata.hook_name.endswith("_z")
|
|
301
|
+
):
|
|
288
302
|
raise ValueError("This method should only be called for hook_z SAEs.")
|
|
289
303
|
|
|
290
304
|
# print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
|
|
@@ -313,19 +327,19 @@ class SAE(HookedRootModule, ABC):
|
|
|
313
327
|
|
|
314
328
|
@overload
|
|
315
329
|
def to(
|
|
316
|
-
self:
|
|
330
|
+
self: T_SAE,
|
|
317
331
|
device: torch.device | str | None = ...,
|
|
318
332
|
dtype: torch.dtype | None = ...,
|
|
319
333
|
non_blocking: bool = ...,
|
|
320
|
-
) ->
|
|
334
|
+
) -> T_SAE: ...
|
|
321
335
|
|
|
322
336
|
@overload
|
|
323
|
-
def to(self:
|
|
337
|
+
def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...
|
|
324
338
|
|
|
325
339
|
@overload
|
|
326
|
-
def to(self:
|
|
340
|
+
def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...
|
|
327
341
|
|
|
328
|
-
def to(self:
|
|
342
|
+
def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE: # type: ignore
|
|
329
343
|
device_arg = None
|
|
330
344
|
dtype_arg = None
|
|
331
345
|
|
|
@@ -426,15 +440,12 @@ class SAE(HookedRootModule, ABC):
|
|
|
426
440
|
|
|
427
441
|
def get_name(self):
|
|
428
442
|
"""Generate a name for this SAE."""
|
|
429
|
-
return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
|
|
443
|
+
return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"
|
|
430
444
|
|
|
431
|
-
def save_model(
|
|
432
|
-
|
|
433
|
-
) -> tuple[Path, Path, Path | None]:
|
|
434
|
-
"""Save model weights, config, and optional sparsity tensor to disk."""
|
|
445
|
+
def save_model(self, path: str | Path) -> tuple[Path, Path]:
|
|
446
|
+
"""Save model weights and config to disk."""
|
|
435
447
|
path = Path(path)
|
|
436
|
-
|
|
437
|
-
path.mkdir(parents=True)
|
|
448
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
438
449
|
|
|
439
450
|
# Generate the weights
|
|
440
451
|
state_dict = self.state_dict() # Use internal SAE state dict
|
|
@@ -448,13 +459,7 @@ class SAE(HookedRootModule, ABC):
|
|
|
448
459
|
with open(cfg_path, "w") as f:
|
|
449
460
|
json.dump(config, f)
|
|
450
461
|
|
|
451
|
-
|
|
452
|
-
sparsity_in_dict = {"sparsity": sparsity}
|
|
453
|
-
sparsity_path = path / SPARSITY_FILENAME
|
|
454
|
-
save_file(sparsity_in_dict, sparsity_path)
|
|
455
|
-
return model_weights_path, cfg_path, sparsity_path
|
|
456
|
-
|
|
457
|
-
return model_weights_path, cfg_path, None
|
|
462
|
+
return model_weights_path, cfg_path
|
|
458
463
|
|
|
459
464
|
## Initialization Methods
|
|
460
465
|
@torch.no_grad()
|
|
@@ -482,18 +487,21 @@ class SAE(HookedRootModule, ABC):
|
|
|
482
487
|
@classmethod
|
|
483
488
|
@deprecated("Use load_from_disk instead")
|
|
484
489
|
def load_from_pretrained(
|
|
485
|
-
cls: Type[
|
|
486
|
-
|
|
490
|
+
cls: Type[T_SAE],
|
|
491
|
+
path: str | Path,
|
|
492
|
+
device: str = "cpu",
|
|
493
|
+
dtype: str | None = None,
|
|
494
|
+
) -> T_SAE:
|
|
487
495
|
return cls.load_from_disk(path, device=device, dtype=dtype)
|
|
488
496
|
|
|
489
497
|
@classmethod
|
|
490
498
|
def load_from_disk(
|
|
491
|
-
cls: Type[
|
|
499
|
+
cls: Type[T_SAE],
|
|
492
500
|
path: str | Path,
|
|
493
501
|
device: str = "cpu",
|
|
494
502
|
dtype: str | None = None,
|
|
495
503
|
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
496
|
-
) ->
|
|
504
|
+
) -> T_SAE:
|
|
497
505
|
overrides = {"dtype": dtype} if dtype is not None else None
|
|
498
506
|
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
499
507
|
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
@@ -501,7 +509,7 @@ class SAE(HookedRootModule, ABC):
|
|
|
501
509
|
cfg_dict["architecture"]
|
|
502
510
|
)
|
|
503
511
|
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
504
|
-
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
|
|
512
|
+
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
505
513
|
sae = sae_cls(sae_cfg)
|
|
506
514
|
sae.process_state_dict_for_loading(state_dict)
|
|
507
515
|
sae.load_state_dict(state_dict)
|
|
@@ -509,13 +517,13 @@ class SAE(HookedRootModule, ABC):
|
|
|
509
517
|
|
|
510
518
|
@classmethod
|
|
511
519
|
def from_pretrained(
|
|
512
|
-
cls,
|
|
520
|
+
cls: Type[T_SAE],
|
|
513
521
|
release: str,
|
|
514
522
|
sae_id: str,
|
|
515
523
|
device: str = "cpu",
|
|
516
524
|
force_download: bool = False,
|
|
517
525
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
518
|
-
) -> tuple[
|
|
526
|
+
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
519
527
|
"""
|
|
520
528
|
Load a pretrained SAE from the Hugging Face model hub.
|
|
521
529
|
|
|
@@ -584,47 +592,31 @@ class SAE(HookedRootModule, ABC):
|
|
|
584
592
|
)
|
|
585
593
|
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
586
594
|
|
|
587
|
-
# Rename keys to match SAEConfig field names
|
|
588
|
-
renamed_cfg_dict = {}
|
|
589
|
-
rename_map = {
|
|
590
|
-
"hook_point": "hook_name",
|
|
591
|
-
"hook_point_layer": "hook_layer",
|
|
592
|
-
"hook_point_head_index": "hook_head_index",
|
|
593
|
-
"activation_fn": "activation_fn",
|
|
594
|
-
}
|
|
595
|
-
|
|
596
|
-
for k, v in cfg_dict.items():
|
|
597
|
-
renamed_cfg_dict[rename_map.get(k, k)] = v
|
|
598
|
-
|
|
599
|
-
# Set default values for required fields
|
|
600
|
-
renamed_cfg_dict.setdefault("activation_fn_kwargs", {})
|
|
601
|
-
renamed_cfg_dict.setdefault("seqpos_slice", None)
|
|
602
|
-
|
|
603
595
|
# Create SAE with appropriate architecture
|
|
604
596
|
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
605
|
-
|
|
597
|
+
cfg_dict["architecture"]
|
|
606
598
|
)
|
|
607
|
-
sae_cfg = sae_config_cls.from_dict(
|
|
608
|
-
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
|
|
599
|
+
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
600
|
+
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
609
601
|
sae = sae_cls(sae_cfg)
|
|
610
602
|
sae.process_state_dict_for_loading(state_dict)
|
|
611
603
|
sae.load_state_dict(state_dict)
|
|
612
604
|
|
|
613
605
|
# Apply normalization if needed
|
|
614
|
-
if
|
|
606
|
+
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
615
607
|
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
|
|
616
608
|
if norm_scaling_factor is not None:
|
|
617
609
|
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
|
|
618
|
-
|
|
610
|
+
cfg_dict["normalize_activations"] = "none"
|
|
619
611
|
else:
|
|
620
612
|
warnings.warn(
|
|
621
613
|
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
622
614
|
)
|
|
623
615
|
|
|
624
|
-
return sae,
|
|
616
|
+
return sae, cfg_dict, log_sparsities
|
|
625
617
|
|
|
626
618
|
@classmethod
|
|
627
|
-
def from_dict(cls: Type[
|
|
619
|
+
def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
|
628
620
|
"""Create an SAE from a config dictionary."""
|
|
629
621
|
sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
|
|
630
622
|
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
@@ -633,9 +625,11 @@ class SAE(HookedRootModule, ABC):
|
|
|
633
625
|
return sae_cls(sae_config_cls.from_dict(config_dict))
|
|
634
626
|
|
|
635
627
|
@classmethod
|
|
636
|
-
def get_sae_class_for_architecture(
|
|
628
|
+
def get_sae_class_for_architecture(
|
|
629
|
+
cls: Type[T_SAE], architecture: str
|
|
630
|
+
) -> Type[T_SAE]:
|
|
637
631
|
"""Get the SAE class for a given architecture."""
|
|
638
|
-
sae_cls = get_sae_class(architecture)
|
|
632
|
+
sae_cls, _ = get_sae_class(architecture)
|
|
639
633
|
if not issubclass(sae_cls, cls):
|
|
640
634
|
raise ValueError(
|
|
641
635
|
f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
|
|
@@ -645,161 +639,107 @@ class SAE(HookedRootModule, ABC):
|
|
|
645
639
|
# in the future, this can be used to load different config classes for different architectures
|
|
646
640
|
@classmethod
|
|
647
641
|
def get_sae_config_class_for_architecture(
|
|
648
|
-
cls
|
|
642
|
+
cls,
|
|
649
643
|
architecture: str, # noqa: ARG003
|
|
650
644
|
) -> type[SAEConfig]:
|
|
651
645
|
return SAEConfig
|
|
652
646
|
|
|
653
647
|
|
|
654
648
|
@dataclass(kw_only=True)
|
|
655
|
-
class TrainingSAEConfig(SAEConfig):
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
decoder_heuristic_init: bool
|
|
667
|
-
init_encoder_as_decoder_transpose: bool
|
|
668
|
-
scale_sparsity_penalty_by_decoder_norm: bool
|
|
649
|
+
class TrainingSAEConfig(SAEConfig, ABC):
|
|
650
|
+
noise_scale: float = 0.0
|
|
651
|
+
mse_loss_normalization: str | None = None
|
|
652
|
+
b_dec_init_method: Literal["zeros", "geometric_median", "mean"] = "zeros"
|
|
653
|
+
# https://transformer-circuits.pub/2024/april-update/index.html#training-saes
|
|
654
|
+
# 0.1 corresponds to the "heuristic" initialization, use None to disable
|
|
655
|
+
decoder_init_norm: float | None = 0.1
|
|
656
|
+
|
|
657
|
+
@classmethod
|
|
658
|
+
@abstractmethod
|
|
659
|
+
def architecture(cls) -> str: ...
|
|
669
660
|
|
|
670
661
|
@classmethod
|
|
671
662
|
def from_sae_runner_config(
|
|
672
|
-
cls
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
architecture=cfg.architecture,
|
|
677
|
-
d_in=cfg.d_in,
|
|
678
|
-
d_sae=cfg.d_sae, # type: ignore
|
|
679
|
-
dtype=cfg.dtype,
|
|
680
|
-
device=cfg.device,
|
|
663
|
+
cls: type[T_TRAINING_SAE_CONFIG],
|
|
664
|
+
cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
|
|
665
|
+
) -> T_TRAINING_SAE_CONFIG:
|
|
666
|
+
metadata = SAEMetadata(
|
|
681
667
|
model_name=cfg.model_name,
|
|
682
668
|
hook_name=cfg.hook_name,
|
|
683
669
|
hook_layer=cfg.hook_layer,
|
|
684
670
|
hook_head_index=cfg.hook_head_index,
|
|
685
|
-
activation_fn=cfg.activation_fn,
|
|
686
|
-
activation_fn_kwargs=cfg.activation_fn_kwargs,
|
|
687
|
-
apply_b_dec_to_input=cfg.apply_b_dec_to_input,
|
|
688
|
-
finetuning_scaling_factor=cfg.finetuning_method is not None,
|
|
689
|
-
sae_lens_training_version=cfg.sae_lens_training_version,
|
|
690
671
|
context_size=cfg.context_size,
|
|
691
|
-
dataset_path=cfg.dataset_path,
|
|
692
672
|
prepend_bos=cfg.prepend_bos,
|
|
693
|
-
seqpos_slice=
|
|
694
|
-
if cfg.seqpos_slice is not None
|
|
695
|
-
else None,
|
|
696
|
-
# Training cfg
|
|
697
|
-
l1_coefficient=cfg.l1_coefficient,
|
|
698
|
-
lp_norm=cfg.lp_norm,
|
|
699
|
-
use_ghost_grads=cfg.use_ghost_grads,
|
|
700
|
-
normalize_sae_decoder=cfg.normalize_sae_decoder,
|
|
701
|
-
noise_scale=cfg.noise_scale,
|
|
702
|
-
decoder_orthogonal_init=cfg.decoder_orthogonal_init,
|
|
703
|
-
mse_loss_normalization=cfg.mse_loss_normalization,
|
|
704
|
-
decoder_heuristic_init=cfg.decoder_heuristic_init,
|
|
705
|
-
init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
|
|
706
|
-
scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
|
|
707
|
-
normalize_activations=cfg.normalize_activations,
|
|
708
|
-
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
|
|
673
|
+
seqpos_slice=cfg.seqpos_slice,
|
|
709
674
|
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
|
|
710
|
-
jumprelu_init_threshold=cfg.jumprelu_init_threshold,
|
|
711
|
-
jumprelu_bandwidth=cfg.jumprelu_bandwidth,
|
|
712
675
|
)
|
|
676
|
+
if not isinstance(cfg.sae, cls):
|
|
677
|
+
raise ValueError(
|
|
678
|
+
f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
|
|
679
|
+
)
|
|
680
|
+
return replace(cfg.sae, metadata=metadata)
|
|
713
681
|
|
|
714
682
|
@classmethod
|
|
715
|
-
def from_dict(
|
|
683
|
+
def from_dict(
|
|
684
|
+
cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
|
|
685
|
+
) -> T_TRAINING_SAE_CONFIG:
|
|
716
686
|
# remove any keys that are not in the dataclass
|
|
717
687
|
# since we sometimes enhance the config with the whole LM runner config
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
if "
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
valid_config_dict["seqpos_slice"]
|
|
730
|
-
)
|
|
731
|
-
elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
|
|
732
|
-
valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
|
|
733
|
-
|
|
734
|
-
return TrainingSAEConfig(**valid_config_dict)
|
|
688
|
+
valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
689
|
+
cfg_class = cls
|
|
690
|
+
if "architecture" in config_dict:
|
|
691
|
+
cfg_class = get_sae_training_class(config_dict["architecture"])[1]
|
|
692
|
+
if not issubclass(cfg_class, cls):
|
|
693
|
+
raise ValueError(
|
|
694
|
+
f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
|
|
695
|
+
)
|
|
696
|
+
if "metadata" in config_dict:
|
|
697
|
+
valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
|
|
698
|
+
return cfg_class(**valid_config_dict)
|
|
735
699
|
|
|
736
700
|
def to_dict(self) -> dict[str, Any]:
|
|
737
701
|
return {
|
|
738
702
|
**super().to_dict(),
|
|
739
|
-
|
|
740
|
-
"
|
|
741
|
-
"use_ghost_grads": self.use_ghost_grads,
|
|
742
|
-
"normalize_sae_decoder": self.normalize_sae_decoder,
|
|
743
|
-
"noise_scale": self.noise_scale,
|
|
744
|
-
"decoder_orthogonal_init": self.decoder_orthogonal_init,
|
|
745
|
-
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
|
|
746
|
-
"mse_loss_normalization": self.mse_loss_normalization,
|
|
747
|
-
"decoder_heuristic_init": self.decoder_heuristic_init,
|
|
748
|
-
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
|
|
749
|
-
"normalize_activations": self.normalize_activations,
|
|
750
|
-
"jumprelu_init_threshold": self.jumprelu_init_threshold,
|
|
751
|
-
"jumprelu_bandwidth": self.jumprelu_bandwidth,
|
|
703
|
+
**asdict(self),
|
|
704
|
+
"architecture": self.architecture(),
|
|
752
705
|
}
|
|
753
706
|
|
|
754
707
|
# this needs to exist so we can initialize the parent sae cfg without the training specific
|
|
755
708
|
# parameters. Maybe there's a cleaner way to do this
|
|
756
709
|
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
"model_name": self.model_name,
|
|
766
|
-
"hook_name": self.hook_name,
|
|
767
|
-
"hook_layer": self.hook_layer,
|
|
768
|
-
"hook_head_index": self.hook_head_index,
|
|
769
|
-
"device": self.device,
|
|
770
|
-
"context_size": self.context_size,
|
|
771
|
-
"prepend_bos": self.prepend_bos,
|
|
772
|
-
"finetuning_scaling_factor": self.finetuning_scaling_factor,
|
|
773
|
-
"normalize_activations": self.normalize_activations,
|
|
774
|
-
"dataset_path": self.dataset_path,
|
|
775
|
-
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
776
|
-
"sae_lens_training_version": self.sae_lens_training_version,
|
|
777
|
-
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
|
|
778
|
-
"seqpos_slice": self.seqpos_slice,
|
|
779
|
-
"neuronpedia_id": self.neuronpedia_id,
|
|
710
|
+
"""
|
|
711
|
+
Creates a dictionary containing attributes corresponding to the fields
|
|
712
|
+
defined in the base SAEConfig class.
|
|
713
|
+
"""
|
|
714
|
+
base_config_field_names = {f.name for f in fields(SAEConfig)}
|
|
715
|
+
result_dict = {
|
|
716
|
+
field_name: getattr(self, field_name)
|
|
717
|
+
for field_name in base_config_field_names
|
|
780
718
|
}
|
|
719
|
+
result_dict["architecture"] = self.architecture()
|
|
720
|
+
return result_dict
|
|
781
721
|
|
|
782
722
|
|
|
783
|
-
class TrainingSAE(SAE, ABC):
|
|
723
|
+
class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
784
724
|
"""Abstract base class for training versions of SAEs."""
|
|
785
725
|
|
|
786
|
-
cfg:
|
|
787
|
-
|
|
788
|
-
def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
|
|
726
|
+
def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
|
|
789
727
|
super().__init__(cfg, use_error_term)
|
|
790
728
|
|
|
791
729
|
# Turn off hook_z reshaping for training mode - the activation store
|
|
792
730
|
# is expected to handle reshaping before passing data to the SAE
|
|
793
731
|
self.turn_off_forward_pass_hook_z_reshaping()
|
|
794
|
-
|
|
795
732
|
self.mse_loss_fn = self._get_mse_loss_fn()
|
|
796
733
|
|
|
734
|
+
@abstractmethod
|
|
735
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
|
|
736
|
+
|
|
797
737
|
@abstractmethod
|
|
798
738
|
def encode_with_hidden_pre(
|
|
799
739
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
800
740
|
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
801
741
|
"""Encode with access to pre-activation values for training."""
|
|
802
|
-
|
|
742
|
+
...
|
|
803
743
|
|
|
804
744
|
def encode(
|
|
805
745
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -818,12 +758,20 @@ class TrainingSAE(SAE, ABC):
|
|
|
818
758
|
Decodes feature activations back into input space,
|
|
819
759
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
820
760
|
"""
|
|
821
|
-
|
|
822
|
-
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
761
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
823
762
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
824
763
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
825
764
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
826
765
|
|
|
766
|
+
@override
|
|
767
|
+
def initialize_weights(self):
|
|
768
|
+
super().initialize_weights()
|
|
769
|
+
if self.cfg.decoder_init_norm is not None:
|
|
770
|
+
with torch.no_grad():
|
|
771
|
+
self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
|
|
772
|
+
self.W_dec.data *= self.cfg.decoder_init_norm
|
|
773
|
+
self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()
|
|
774
|
+
|
|
827
775
|
@abstractmethod
|
|
828
776
|
def calculate_aux_loss(
|
|
829
777
|
self,
|
|
@@ -833,7 +781,7 @@ class TrainingSAE(SAE, ABC):
|
|
|
833
781
|
sae_out: torch.Tensor,
|
|
834
782
|
) -> torch.Tensor | dict[str, torch.Tensor]:
|
|
835
783
|
"""Calculate architecture-specific auxiliary loss terms."""
|
|
836
|
-
|
|
784
|
+
...
|
|
837
785
|
|
|
838
786
|
def training_forward_pass(
|
|
839
787
|
self,
|
|
@@ -883,6 +831,39 @@ class TrainingSAE(SAE, ABC):
|
|
|
883
831
|
losses=losses,
|
|
884
832
|
)
|
|
885
833
|
|
|
834
|
+
def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
|
|
835
|
+
"""Save inference version of model weights and config to disk."""
|
|
836
|
+
path = Path(path)
|
|
837
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
838
|
+
|
|
839
|
+
# Generate the weights
|
|
840
|
+
state_dict = self.state_dict() # Use internal SAE state dict
|
|
841
|
+
self.process_state_dict_for_saving_inference(state_dict)
|
|
842
|
+
model_weights_path = path / SAE_WEIGHTS_FILENAME
|
|
843
|
+
save_file(state_dict, model_weights_path)
|
|
844
|
+
|
|
845
|
+
# Save the config
|
|
846
|
+
config = self.to_inference_config_dict()
|
|
847
|
+
cfg_path = path / SAE_CFG_FILENAME
|
|
848
|
+
with open(cfg_path, "w") as f:
|
|
849
|
+
json.dump(config, f)
|
|
850
|
+
|
|
851
|
+
return model_weights_path, cfg_path
|
|
852
|
+
|
|
853
|
+
@abstractmethod
|
|
854
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
855
|
+
"""Convert the config into an inference SAE config dict."""
|
|
856
|
+
...
|
|
857
|
+
|
|
858
|
+
def process_state_dict_for_saving_inference(
|
|
859
|
+
self, state_dict: dict[str, Any]
|
|
860
|
+
) -> None:
|
|
861
|
+
"""
|
|
862
|
+
Process the state dict for saving the inference model.
|
|
863
|
+
This is a hook that can be overridden to change how the state dict is processed for the inference model.
|
|
864
|
+
"""
|
|
865
|
+
return self.process_state_dict_for_saving(state_dict)
|
|
866
|
+
|
|
886
867
|
def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
887
868
|
"""Get the MSE loss function based on config."""
|
|
888
869
|
|
|
@@ -921,11 +902,6 @@ class TrainingSAE(SAE, ABC):
|
|
|
921
902
|
"d_sae, d_sae d_in -> d_sae d_in",
|
|
922
903
|
)
|
|
923
904
|
|
|
924
|
-
@torch.no_grad()
|
|
925
|
-
def set_decoder_norm_to_unit_norm(self):
|
|
926
|
-
"""Normalize decoder columns to unit norm."""
|
|
927
|
-
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
928
|
-
|
|
929
905
|
@torch.no_grad()
|
|
930
906
|
def log_histograms(self) -> dict[str, NDArray[Any]]:
|
|
931
907
|
"""Log histograms of the weights and biases."""
|
|
@@ -935,9 +911,11 @@ class TrainingSAE(SAE, ABC):
|
|
|
935
911
|
}
|
|
936
912
|
|
|
937
913
|
@classmethod
|
|
938
|
-
def get_sae_class_for_architecture(
|
|
914
|
+
def get_sae_class_for_architecture(
|
|
915
|
+
cls: Type[T_TRAINING_SAE], architecture: str
|
|
916
|
+
) -> Type[T_TRAINING_SAE]:
|
|
939
917
|
"""Get the SAE class for a given architecture."""
|
|
940
|
-
sae_cls = get_sae_training_class(architecture)
|
|
918
|
+
sae_cls, _ = get_sae_training_class(architecture)
|
|
941
919
|
if not issubclass(sae_cls, cls):
|
|
942
920
|
raise ValueError(
|
|
943
921
|
f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
|
|
@@ -947,17 +925,17 @@ class TrainingSAE(SAE, ABC):
|
|
|
947
925
|
# in the future, this can be used to load different config classes for different architectures
|
|
948
926
|
@classmethod
|
|
949
927
|
def get_sae_config_class_for_architecture(
|
|
950
|
-
cls
|
|
928
|
+
cls,
|
|
951
929
|
architecture: str, # noqa: ARG003
|
|
952
|
-
) -> type[
|
|
953
|
-
return
|
|
930
|
+
) -> type[TrainingSAEConfig]:
|
|
931
|
+
return get_sae_training_class(architecture)[1]
|
|
954
932
|
|
|
955
933
|
|
|
956
934
|
_blank_hook = nn.Identity()
|
|
957
935
|
|
|
958
936
|
|
|
959
937
|
@contextmanager
|
|
960
|
-
def _disable_hooks(sae: SAE):
|
|
938
|
+
def _disable_hooks(sae: SAE[Any]):
|
|
961
939
|
"""
|
|
962
940
|
Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
|
|
963
941
|
"""
|