sae-lens 6.12.2__py3-none-any.whl → 6.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/evals.py +2 -0
- sae_lens/loading/pretrained_sae_loaders.py +6 -0
- sae_lens/pretokenize_runner.py +2 -1
- sae_lens/saes/sae.py +9 -10
- sae_lens/saes/topk_sae.py +211 -9
- sae_lens/tokenization_and_batching.py +1 -1
- sae_lens/training/sae_trainer.py +7 -5
- sae_lens/training/types.py +1 -1
- {sae_lens-6.12.2.dist-info → sae_lens-6.13.0.dist-info}/METADATA +1 -1
- {sae_lens-6.12.2.dist-info → sae_lens-6.13.0.dist-info}/RECORD +13 -13
- {sae_lens-6.12.2.dist-info → sae_lens-6.13.0.dist-info}/WHEEL +0 -0
- {sae_lens-6.12.2.dist-info → sae_lens-6.13.0.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/evals.py
CHANGED
|
@@ -466,6 +466,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
466
466
|
sae_out_scaled = sae.decode(sae_feature_activations).to(
|
|
467
467
|
original_act_scaled.device
|
|
468
468
|
)
|
|
469
|
+
if sae_feature_activations.is_sparse:
|
|
470
|
+
sae_feature_activations = sae_feature_activations.to_dense()
|
|
469
471
|
del cache
|
|
470
472
|
|
|
471
473
|
sae_out = activation_scaler.unscale(sae_out_scaled)
|
|
@@ -233,6 +233,12 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
233
233
|
"reshape_activations",
|
|
234
234
|
"hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
|
|
235
235
|
)
|
|
236
|
+
if (
|
|
237
|
+
new_cfg.get("activation_fn") == "topk"
|
|
238
|
+
and new_cfg.get("activation_fn_kwargs", {}).get("k") is not None
|
|
239
|
+
):
|
|
240
|
+
new_cfg["architecture"] = "topk"
|
|
241
|
+
new_cfg["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
236
242
|
|
|
237
243
|
if "normalize_activations" in new_cfg and isinstance(
|
|
238
244
|
new_cfg["normalize_activations"], bool
|
sae_lens/pretokenize_runner.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
|
+
from collections.abc import Iterator
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import Literal, cast
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
from datasets import Dataset, DatasetDict, load_dataset
|
sae_lens/saes/sae.py
CHANGED
|
@@ -14,7 +14,6 @@ from typing import (
|
|
|
14
14
|
Generic,
|
|
15
15
|
Literal,
|
|
16
16
|
NamedTuple,
|
|
17
|
-
Type,
|
|
18
17
|
TypeVar,
|
|
19
18
|
)
|
|
20
19
|
|
|
@@ -534,7 +533,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
534
533
|
@classmethod
|
|
535
534
|
@deprecated("Use load_from_disk instead")
|
|
536
535
|
def load_from_pretrained(
|
|
537
|
-
cls:
|
|
536
|
+
cls: type[T_SAE],
|
|
538
537
|
path: str | Path,
|
|
539
538
|
device: str = "cpu",
|
|
540
539
|
dtype: str | None = None,
|
|
@@ -543,7 +542,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
543
542
|
|
|
544
543
|
@classmethod
|
|
545
544
|
def load_from_disk(
|
|
546
|
-
cls:
|
|
545
|
+
cls: type[T_SAE],
|
|
547
546
|
path: str | Path,
|
|
548
547
|
device: str = "cpu",
|
|
549
548
|
dtype: str | None = None,
|
|
@@ -564,7 +563,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
564
563
|
|
|
565
564
|
@classmethod
|
|
566
565
|
def from_pretrained(
|
|
567
|
-
cls:
|
|
566
|
+
cls: type[T_SAE],
|
|
568
567
|
release: str,
|
|
569
568
|
sae_id: str,
|
|
570
569
|
device: str = "cpu",
|
|
@@ -585,7 +584,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
585
584
|
|
|
586
585
|
@classmethod
|
|
587
586
|
def from_pretrained_with_cfg_and_sparsity(
|
|
588
|
-
cls:
|
|
587
|
+
cls: type[T_SAE],
|
|
589
588
|
release: str,
|
|
590
589
|
sae_id: str,
|
|
591
590
|
device: str = "cpu",
|
|
@@ -684,7 +683,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
684
683
|
return sae, cfg_dict, log_sparsities
|
|
685
684
|
|
|
686
685
|
@classmethod
|
|
687
|
-
def from_dict(cls:
|
|
686
|
+
def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
|
688
687
|
"""Create an SAE from a config dictionary."""
|
|
689
688
|
sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
|
|
690
689
|
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
@@ -694,8 +693,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
694
693
|
|
|
695
694
|
@classmethod
|
|
696
695
|
def get_sae_class_for_architecture(
|
|
697
|
-
cls:
|
|
698
|
-
) ->
|
|
696
|
+
cls: type[T_SAE], architecture: str
|
|
697
|
+
) -> type[T_SAE]:
|
|
699
698
|
"""Get the SAE class for a given architecture."""
|
|
700
699
|
sae_cls, _ = get_sae_class(architecture)
|
|
701
700
|
if not issubclass(sae_cls, cls):
|
|
@@ -1000,8 +999,8 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
1000
999
|
|
|
1001
1000
|
@classmethod
|
|
1002
1001
|
def get_sae_class_for_architecture(
|
|
1003
|
-
cls:
|
|
1004
|
-
) ->
|
|
1002
|
+
cls: type[T_TRAINING_SAE], architecture: str
|
|
1003
|
+
) -> type[T_TRAINING_SAE]:
|
|
1005
1004
|
"""Get the SAE class for a given architecture."""
|
|
1006
1005
|
sae_cls, _ = get_sae_training_class(architecture)
|
|
1007
1006
|
if not issubclass(sae_cls, cls):
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import Callable
|
|
|
6
6
|
import torch
|
|
7
7
|
from jaxtyping import Float
|
|
8
8
|
from torch import nn
|
|
9
|
+
from transformer_lens.hook_points import HookPoint
|
|
9
10
|
from typing_extensions import override
|
|
10
11
|
|
|
11
12
|
from sae_lens.saes.sae import (
|
|
@@ -15,34 +16,102 @@ from sae_lens.saes.sae import (
|
|
|
15
16
|
TrainingSAE,
|
|
16
17
|
TrainingSAEConfig,
|
|
17
18
|
TrainStepInput,
|
|
19
|
+
_disable_hooks,
|
|
18
20
|
)
|
|
19
21
|
|
|
20
22
|
|
|
23
|
+
class SparseHookPoint(HookPoint):
|
|
24
|
+
"""
|
|
25
|
+
A HookPoint that takes in a sparse tensor.
|
|
26
|
+
Overrides TransformerLens's HookPoint.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, d_sae: int):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.d_sae = d_sae
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
using_hooks = (
|
|
36
|
+
self._forward_hooks is not None and len(self._forward_hooks) > 0
|
|
37
|
+
) or (self._backward_hooks is not None and len(self._backward_hooks) > 0)
|
|
38
|
+
if using_hooks and x.is_sparse:
|
|
39
|
+
return x.to_dense()
|
|
40
|
+
return x # if no hooks are being used, use passthrough
|
|
41
|
+
|
|
42
|
+
|
|
21
43
|
class TopK(nn.Module):
|
|
22
44
|
"""
|
|
23
45
|
A simple TopK activation that zeroes out all but the top K elements along the last dimension,
|
|
24
46
|
and applies ReLU to the top K elements.
|
|
25
47
|
"""
|
|
26
48
|
|
|
27
|
-
|
|
49
|
+
use_sparse_activations: bool
|
|
28
50
|
|
|
29
51
|
def __init__(
|
|
30
52
|
self,
|
|
31
53
|
k: int,
|
|
54
|
+
use_sparse_activations: bool = False,
|
|
32
55
|
):
|
|
33
56
|
super().__init__()
|
|
34
57
|
self.k = k
|
|
58
|
+
self.use_sparse_activations = use_sparse_activations
|
|
35
59
|
|
|
36
|
-
def forward(
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
x: torch.Tensor,
|
|
63
|
+
) -> torch.Tensor:
|
|
37
64
|
"""
|
|
38
65
|
1) Select top K elements along the last dimension.
|
|
39
66
|
2) Apply ReLU.
|
|
40
67
|
3) Zero out all other entries.
|
|
41
68
|
"""
|
|
42
|
-
|
|
43
|
-
values =
|
|
69
|
+
topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False)
|
|
70
|
+
values = topk_values.relu()
|
|
71
|
+
if self.use_sparse_activations:
|
|
72
|
+
# Produce a COO sparse tensor (use sparse matrix multiply in decode)
|
|
73
|
+
original_shape = x.shape
|
|
74
|
+
|
|
75
|
+
# Create indices for all dimensions
|
|
76
|
+
# For each element in topk_indices, we need to map it back to the original tensor coordinates
|
|
77
|
+
batch_dims = original_shape[:-1] # All dimensions except the last one
|
|
78
|
+
num_batch_elements = torch.prod(torch.tensor(batch_dims)).item()
|
|
79
|
+
|
|
80
|
+
# Create batch indices - each batch element repeated k times
|
|
81
|
+
batch_indices_flat = torch.arange(
|
|
82
|
+
num_batch_elements, device=x.device
|
|
83
|
+
).repeat_interleave(self.k)
|
|
84
|
+
|
|
85
|
+
# Convert flat batch indices back to multi-dimensional indices
|
|
86
|
+
if len(batch_dims) == 1:
|
|
87
|
+
# 2D case: [batch, features]
|
|
88
|
+
sparse_indices = torch.stack(
|
|
89
|
+
[
|
|
90
|
+
batch_indices_flat,
|
|
91
|
+
topk_indices.flatten(),
|
|
92
|
+
]
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
# 3D+ case: need to unravel the batch indices
|
|
96
|
+
batch_indices_multi = []
|
|
97
|
+
remaining = batch_indices_flat
|
|
98
|
+
for dim_size in reversed(batch_dims):
|
|
99
|
+
batch_indices_multi.append(remaining % dim_size)
|
|
100
|
+
remaining = remaining // dim_size
|
|
101
|
+
batch_indices_multi.reverse()
|
|
102
|
+
|
|
103
|
+
sparse_indices = torch.stack(
|
|
104
|
+
[
|
|
105
|
+
*batch_indices_multi,
|
|
106
|
+
topk_indices.flatten(),
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return torch.sparse_coo_tensor(
|
|
111
|
+
sparse_indices, values.flatten(), original_shape
|
|
112
|
+
)
|
|
44
113
|
result = torch.zeros_like(x)
|
|
45
|
-
result.scatter_(-1,
|
|
114
|
+
result.scatter_(-1, topk_indices, values)
|
|
46
115
|
return result
|
|
47
116
|
|
|
48
117
|
|
|
@@ -60,6 +129,63 @@ class TopKSAEConfig(SAEConfig):
|
|
|
60
129
|
return "topk"
|
|
61
130
|
|
|
62
131
|
|
|
132
|
+
def _sparse_matmul_nd(
|
|
133
|
+
sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor
|
|
134
|
+
) -> torch.Tensor:
|
|
135
|
+
"""
|
|
136
|
+
Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out]
|
|
137
|
+
to get a result of shape [..., d_out].
|
|
138
|
+
|
|
139
|
+
This function handles sparse tensors with arbitrary batch dimensions by flattening
|
|
140
|
+
the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back.
|
|
141
|
+
"""
|
|
142
|
+
original_shape = sparse_tensor.shape
|
|
143
|
+
batch_dims = original_shape[:-1]
|
|
144
|
+
d_sae = original_shape[-1]
|
|
145
|
+
d_out = dense_matrix.shape[-1]
|
|
146
|
+
|
|
147
|
+
if sparse_tensor.ndim == 2:
|
|
148
|
+
# Simple 2D case - use torch.sparse.mm directly
|
|
149
|
+
# sparse.mm errors with bfloat16 :(
|
|
150
|
+
with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
|
|
151
|
+
return torch.sparse.mm(sparse_tensor, dense_matrix)
|
|
152
|
+
|
|
153
|
+
# For 3D+ case, reshape to 2D, multiply, then reshape back
|
|
154
|
+
batch_size = int(torch.prod(torch.tensor(batch_dims)).item())
|
|
155
|
+
|
|
156
|
+
# Ensure tensor is coalesced for efficient access to indices/values
|
|
157
|
+
if not sparse_tensor.is_coalesced():
|
|
158
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
159
|
+
|
|
160
|
+
# Get indices and values
|
|
161
|
+
indices = sparse_tensor.indices() # [ndim, nnz]
|
|
162
|
+
values = sparse_tensor.values() # [nnz]
|
|
163
|
+
|
|
164
|
+
# Convert multi-dimensional batch indices to flat indices
|
|
165
|
+
flat_batch_indices = torch.zeros_like(indices[0])
|
|
166
|
+
multiplier = 1
|
|
167
|
+
for i in reversed(range(len(batch_dims))):
|
|
168
|
+
flat_batch_indices += indices[i] * multiplier
|
|
169
|
+
multiplier *= batch_dims[i]
|
|
170
|
+
|
|
171
|
+
# Create 2D sparse tensor indices [batch_flat, feature]
|
|
172
|
+
sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]])
|
|
173
|
+
|
|
174
|
+
# Create 2D sparse tensor
|
|
175
|
+
sparse_2d = torch.sparse_coo_tensor(
|
|
176
|
+
sparse_2d_indices, values, (batch_size, d_sae)
|
|
177
|
+
).coalesce()
|
|
178
|
+
|
|
179
|
+
# sparse.mm errors with bfloat16 :(
|
|
180
|
+
with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
|
|
181
|
+
# Do the matrix multiplication
|
|
182
|
+
result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out]
|
|
183
|
+
|
|
184
|
+
# Reshape back to original batch dimensions
|
|
185
|
+
result_shape = tuple(batch_dims) + (d_out,)
|
|
186
|
+
return result_2d.view(result_shape)
|
|
187
|
+
|
|
188
|
+
|
|
63
189
|
class TopKSAE(SAE[TopKSAEConfig]):
|
|
64
190
|
"""
|
|
65
191
|
An inference-only sparse autoencoder using a "topk" activation function.
|
|
@@ -96,21 +222,26 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
96
222
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
97
223
|
|
|
98
224
|
def decode(
|
|
99
|
-
self,
|
|
225
|
+
self,
|
|
226
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
100
227
|
) -> Float[torch.Tensor, "... d_in"]:
|
|
101
228
|
"""
|
|
102
229
|
Reconstructs the input from topk feature activations.
|
|
103
230
|
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
104
231
|
and optional head reshaping.
|
|
105
232
|
"""
|
|
106
|
-
|
|
233
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
234
|
+
if feature_acts.is_sparse:
|
|
235
|
+
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
|
|
236
|
+
else:
|
|
237
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
107
238
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
108
239
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
109
240
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
110
241
|
|
|
111
242
|
@override
|
|
112
243
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
113
|
-
return TopK(self.cfg.k)
|
|
244
|
+
return TopK(self.cfg.k, use_sparse_activations=False)
|
|
114
245
|
|
|
115
246
|
@override
|
|
116
247
|
@torch.no_grad()
|
|
@@ -124,9 +255,43 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
124
255
|
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
125
256
|
"""
|
|
126
257
|
Configuration class for training a TopKTrainingSAE.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
k (int): Number of top features to keep active. Only the top k features
|
|
261
|
+
with the highest pre-activations will be non-zero. Defaults to 100.
|
|
262
|
+
use_sparse_activations (bool): Whether to use sparse tensor representations
|
|
263
|
+
for activations during training. This can reduce memory usage and improve
|
|
264
|
+
performance when k is small relative to d_sae, but is only worthwhile if
|
|
265
|
+
using float32 and not using autocast. Defaults to False.
|
|
266
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
267
|
+
dead neurons to learn useful features. This loss helps prevent neuron death
|
|
268
|
+
in TopK SAEs by having dead neurons reconstruct the residual error from
|
|
269
|
+
live neurons. Defaults to 1.0.
|
|
270
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
271
|
+
0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
|
|
272
|
+
Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
273
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
274
|
+
Inherited from SAEConfig.
|
|
275
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
276
|
+
Inherited from SAEConfig.
|
|
277
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
278
|
+
Defaults to "float32".
|
|
279
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
280
|
+
Defaults to "cpu".
|
|
281
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
282
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
283
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
284
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
285
|
+
Defaults to "none".
|
|
286
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
287
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
288
|
+
Defaults to "none".
|
|
289
|
+
metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
|
|
290
|
+
Inherited from SAEConfig.
|
|
127
291
|
"""
|
|
128
292
|
|
|
129
293
|
k: int = 100
|
|
294
|
+
use_sparse_activations: bool = False
|
|
130
295
|
aux_loss_coefficient: float = 1.0
|
|
131
296
|
|
|
132
297
|
@override
|
|
@@ -144,6 +309,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
144
309
|
|
|
145
310
|
def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
|
|
146
311
|
super().__init__(cfg, use_error_term)
|
|
312
|
+
self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae)
|
|
313
|
+
self.setup()
|
|
147
314
|
|
|
148
315
|
@override
|
|
149
316
|
def initialize_weights(self) -> None:
|
|
@@ -163,6 +330,41 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
163
330
|
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
164
331
|
return feature_acts, hidden_pre
|
|
165
332
|
|
|
333
|
+
@override
|
|
334
|
+
def decode(
|
|
335
|
+
self,
|
|
336
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
337
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
338
|
+
"""
|
|
339
|
+
Decodes feature activations back into input space,
|
|
340
|
+
applying optional finetuning scale, hooking, out normalization, etc.
|
|
341
|
+
"""
|
|
342
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
343
|
+
if feature_acts.is_sparse:
|
|
344
|
+
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
|
|
345
|
+
else:
|
|
346
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
347
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
348
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
349
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
350
|
+
|
|
351
|
+
@override
|
|
352
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
353
|
+
"""Forward pass through the SAE."""
|
|
354
|
+
feature_acts = self.encode(x)
|
|
355
|
+
sae_out = self.decode(feature_acts)
|
|
356
|
+
|
|
357
|
+
if self.use_error_term:
|
|
358
|
+
with torch.no_grad():
|
|
359
|
+
# Recompute without hooks for true error term
|
|
360
|
+
with _disable_hooks(self):
|
|
361
|
+
feature_acts_clean = self.encode(x)
|
|
362
|
+
x_reconstruct_clean = self.decode(feature_acts_clean)
|
|
363
|
+
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
|
|
364
|
+
sae_out = sae_out + sae_error
|
|
365
|
+
|
|
366
|
+
return self.hook_sae_output(sae_out)
|
|
367
|
+
|
|
166
368
|
@override
|
|
167
369
|
def calculate_aux_loss(
|
|
168
370
|
self,
|
|
@@ -189,7 +391,7 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
189
391
|
|
|
190
392
|
@override
|
|
191
393
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
192
|
-
return TopK(self.cfg.k)
|
|
394
|
+
return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations)
|
|
193
395
|
|
|
194
396
|
@override
|
|
195
397
|
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
|
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -253,12 +253,14 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
253
253
|
)
|
|
254
254
|
|
|
255
255
|
with torch.no_grad():
|
|
256
|
-
|
|
256
|
+
# calling .bool() should be equivalent to .abs() > 0, and work with coo tensors
|
|
257
|
+
firing_feats = train_step_output.feature_acts.bool().float()
|
|
258
|
+
did_fire = firing_feats.sum(-2).bool()
|
|
259
|
+
if did_fire.is_sparse:
|
|
260
|
+
did_fire = did_fire.to_dense()
|
|
257
261
|
self.n_forward_passes_since_fired += 1
|
|
258
262
|
self.n_forward_passes_since_fired[did_fire] = 0
|
|
259
|
-
self.act_freq_scores += (
|
|
260
|
-
(train_step_output.feature_acts.abs() > 0).float().sum(0)
|
|
261
|
-
)
|
|
263
|
+
self.act_freq_scores += firing_feats.sum(0)
|
|
262
264
|
self.n_frac_active_samples += self.cfg.train_batch_size_samples
|
|
263
265
|
|
|
264
266
|
# Grad scaler will rescale gradients if autocast is enabled
|
|
@@ -310,7 +312,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
310
312
|
loss = output.loss.item()
|
|
311
313
|
|
|
312
314
|
# metrics for currents acts
|
|
313
|
-
l0 = (
|
|
315
|
+
l0 = feature_acts.bool().float().sum(-1).to_dense().mean()
|
|
314
316
|
current_learning_rate = self.optimizer.param_groups[0]["lr"]
|
|
315
317
|
|
|
316
318
|
per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze()
|
sae_lens/training/types.py
CHANGED
|
@@ -1,39 +1,39 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=6cL-2l4CIzZJfgyRP5I90zu2Tty196wOgFg1JGlQd1c,3589
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=cNeAtp2JQ_vKbeddZVM-tcPLYyyfTWL8NDna5KQpkLI,12583
|
|
6
6
|
sae_lens/config.py,sha256=IdRXSKPfYY3hwUovj-u83eep8z52gkJHII0mY0KseYY,28739
|
|
7
7
|
sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
|
|
8
|
-
sae_lens/evals.py,sha256=
|
|
8
|
+
sae_lens/evals.py,sha256=p4AOueeemhJXyfLx2TxOva8LXxXj63JSKe9Lnib3mHs,39623
|
|
9
9
|
sae_lens/llm_sae_training_runner.py,sha256=sJTcDX1bUJJ_jZLUT88-8KUYIAPeUGoXktX68PsBqw0,15137
|
|
10
10
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=SM4aT8NM6ezYix5c2u7p72Fz2RfvTtf7gw5RdOSKXhc,49846
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
14
|
-
sae_lens/pretokenize_runner.py,sha256=
|
|
14
|
+
sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r4,7085
|
|
15
15
|
sae_lens/pretrained_saes.yaml,sha256=6ca3geEB6NyhULUrmdtPDK8ea0YdpLp8_au78vIFC5w,602553
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
17
|
sae_lens/saes/__init__.py,sha256=jVwazK8Q6dW5J6_zFXPoNAuBvSxgziQ8eMOjGM3t-X8,1475
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=GX_J0vH4vzeLqYxl0mkfsZQpFEoCEHMR4dIG8fz8N8w,3449
|
|
19
19
|
sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
|
|
20
20
|
sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
|
|
21
|
-
sae_lens/saes/sae.py,sha256=
|
|
21
|
+
sae_lens/saes/sae.py,sha256=nuII6ZmaVtJWhPjyhasHQyiv_Wj-zdAtRQqJRYbVBQs,38274
|
|
22
22
|
sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
|
|
23
|
-
sae_lens/saes/topk_sae.py,sha256=
|
|
23
|
+
sae_lens/saes/topk_sae.py,sha256=pM26I9uDeh_ZWx0HXUyPVFfEV2pfuRJmAPNWR5pmRhY,17615
|
|
24
24
|
sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,12156
|
|
25
|
-
sae_lens/tokenization_and_batching.py,sha256=
|
|
25
|
+
sae_lens/tokenization_and_batching.py,sha256=jV7Rx5wHHcYMmexFhvbSk2q5R0gYBjtKoJKpowAgMEo,4665
|
|
26
26
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
27
|
sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
|
|
28
28
|
sae_lens/training/activations_store.py,sha256=2EUY2abqpT5El3T95sypM_JRDgiKL3VeT73U9SQIFGY,32903
|
|
29
29
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
30
30
|
sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
|
|
31
|
-
sae_lens/training/sae_trainer.py,sha256=
|
|
32
|
-
sae_lens/training/types.py,sha256=
|
|
31
|
+
sae_lens/training/sae_trainer.py,sha256=il4Evf-c4F3Uf2n_v-AOItCasX-uPxYTzn_sZLvLkl0,15633
|
|
32
|
+
sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
33
33
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
34
34
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
35
35
|
sae_lens/util.py,sha256=lW7fBn_b8quvRYlen9PUmB7km60YhKyjmuelB1f6KzQ,2253
|
|
36
|
-
sae_lens-6.
|
|
37
|
-
sae_lens-6.
|
|
38
|
-
sae_lens-6.
|
|
39
|
-
sae_lens-6.
|
|
36
|
+
sae_lens-6.13.0.dist-info/METADATA,sha256=rqSlR_xjf3fqZga4OHpNtrhKzaA4tIrobj-e6yq8sbA,5318
|
|
37
|
+
sae_lens-6.13.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
38
|
+
sae_lens-6.13.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
39
|
+
sae_lens-6.13.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|