sae-lens 6.22.0__tar.gz → 6.22.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.22.0 → sae_lens-6.22.3}/PKG-INFO +1 -1
- {sae_lens-6.22.0 → sae_lens-6.22.3}/pyproject.toml +1 -1
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/__init__.py +1 -1
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/analysis/hooked_sae_transformer.py +4 -13
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/cache_activations_runner.py +2 -3
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/gated_sae.py +6 -11
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/jumprelu_sae.py +8 -13
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/matryoshka_batchtopk_sae.py +2 -3
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/sae.py +8 -19
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/standard_sae.py +4 -9
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/temporal_sae.py +5 -12
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/topk_sae.py +8 -11
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/training/activations_store.py +6 -7
- {sae_lens-6.22.0 → sae_lens-6.22.3}/LICENSE +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/README.md +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/config.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/constants.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/evals.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/load_model.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/registry.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/training/types.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.22.0 → sae_lens-6.22.3}/sae_lens/util.py +0 -0
|
@@ -3,7 +3,6 @@ from contextlib import contextmanager
|
|
|
3
3
|
from typing import Any, Callable
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from jaxtyping import Float
|
|
7
6
|
from transformer_lens.ActivationCache import ActivationCache
|
|
8
7
|
from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
|
|
9
8
|
from transformer_lens.hook_points import HookPoint # Hooking utilities
|
|
@@ -11,8 +10,8 @@ from transformer_lens.HookedTransformer import HookedTransformer
|
|
|
11
10
|
|
|
12
11
|
from sae_lens.saes.sae import SAE
|
|
13
12
|
|
|
14
|
-
SingleLoss =
|
|
15
|
-
LossPerToken =
|
|
13
|
+
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
14
|
+
LossPerToken = torch.Tensor
|
|
16
15
|
Loss = SingleLoss | LossPerToken
|
|
17
16
|
|
|
18
17
|
|
|
@@ -171,12 +170,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
171
170
|
reset_saes_end: bool = True,
|
|
172
171
|
use_error_term: bool | None = None,
|
|
173
172
|
**model_kwargs: Any,
|
|
174
|
-
) ->
|
|
175
|
-
None
|
|
176
|
-
| Float[torch.Tensor, "batch pos d_vocab"]
|
|
177
|
-
| Loss
|
|
178
|
-
| tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
|
|
179
|
-
):
|
|
173
|
+
) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
|
|
180
174
|
"""Wrapper around HookedTransformer forward pass.
|
|
181
175
|
|
|
182
176
|
Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
|
|
@@ -203,10 +197,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
203
197
|
remove_batch_dim: bool = False,
|
|
204
198
|
**kwargs: Any,
|
|
205
199
|
) -> tuple[
|
|
206
|
-
None
|
|
207
|
-
| Float[torch.Tensor, "batch pos d_vocab"]
|
|
208
|
-
| Loss
|
|
209
|
-
| tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
|
|
200
|
+
None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
|
|
210
201
|
ActivationCache | dict[str, torch.Tensor],
|
|
211
202
|
]:
|
|
212
203
|
"""Wrapper around 'run_with_cache' in HookedTransformer.
|
|
@@ -9,7 +9,6 @@ import torch
|
|
|
9
9
|
from datasets import Array2D, Dataset, Features, Sequence, Value
|
|
10
10
|
from datasets.fingerprint import generate_fingerprint
|
|
11
11
|
from huggingface_hub import HfApi
|
|
12
|
-
from jaxtyping import Float, Int
|
|
13
12
|
from tqdm.auto import tqdm
|
|
14
13
|
from transformer_lens.HookedTransformer import HookedRootModule
|
|
15
14
|
|
|
@@ -318,8 +317,8 @@ class CacheActivationsRunner:
|
|
|
318
317
|
def _create_shard(
|
|
319
318
|
self,
|
|
320
319
|
buffer: tuple[
|
|
321
|
-
|
|
322
|
-
|
|
320
|
+
torch.Tensor, # shape: (bs context_size) d_in
|
|
321
|
+
torch.Tensor | None, # shape: (bs context_size) or None
|
|
323
322
|
],
|
|
324
323
|
) -> Dataset:
|
|
325
324
|
hook_names = [self.cfg.hook_name]
|
|
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from jaxtyping import Float
|
|
6
5
|
from numpy.typing import NDArray
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
@@ -49,9 +48,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
49
48
|
super().initialize_weights()
|
|
50
49
|
_init_weights_gated(self)
|
|
51
50
|
|
|
52
|
-
def encode(
|
|
53
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
54
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
51
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
52
|
"""
|
|
56
53
|
Encode the input tensor into the feature space using a gated encoder.
|
|
57
54
|
This must match the original encode_gated implementation from SAE class.
|
|
@@ -72,9 +69,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
72
69
|
# Combine gating and magnitudes
|
|
73
70
|
return self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
74
71
|
|
|
75
|
-
def decode(
|
|
76
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
77
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
72
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
78
73
|
"""
|
|
79
74
|
Decode the feature activations back into the input space:
|
|
80
75
|
1) Apply optional finetuning scaling.
|
|
@@ -94,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
94
89
|
@torch.no_grad()
|
|
95
90
|
def fold_W_dec_norm(self):
|
|
96
91
|
"""Override to handle gated-specific parameters."""
|
|
97
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
92
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
98
93
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
99
94
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
100
95
|
|
|
@@ -147,8 +142,8 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
147
142
|
_init_weights_gated(self)
|
|
148
143
|
|
|
149
144
|
def encode_with_hidden_pre(
|
|
150
|
-
self, x:
|
|
151
|
-
) -> tuple[
|
|
145
|
+
self, x: torch.Tensor
|
|
146
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
152
147
|
"""
|
|
153
148
|
Gated forward pass with pre-activation (for training).
|
|
154
149
|
"""
|
|
@@ -222,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
222
217
|
@torch.no_grad()
|
|
223
218
|
def fold_W_dec_norm(self):
|
|
224
219
|
"""Override to handle gated-specific parameters."""
|
|
225
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
220
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
226
221
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
227
222
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
228
223
|
|
|
@@ -3,7 +3,6 @@ from typing import Any, Literal
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
|
-
from jaxtyping import Float
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
9
8
|
|
|
@@ -130,9 +129,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
130
129
|
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
131
130
|
)
|
|
132
131
|
|
|
133
|
-
def encode(
|
|
134
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
135
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
132
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
136
133
|
"""
|
|
137
134
|
Encode the input tensor into the feature space using JumpReLU.
|
|
138
135
|
The threshold parameter determines which units remain active.
|
|
@@ -150,9 +147,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
150
147
|
# 3) Multiply the normally activated units by that mask.
|
|
151
148
|
return self.hook_sae_acts_post(base_acts * jump_relu_mask)
|
|
152
149
|
|
|
153
|
-
def decode(
|
|
154
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
155
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
150
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
156
151
|
"""
|
|
157
152
|
Decode the feature activations back to the input space.
|
|
158
153
|
Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
|
|
@@ -172,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
172
167
|
# Save the current threshold before calling parent method
|
|
173
168
|
current_thresh = self.threshold.clone()
|
|
174
169
|
|
|
175
|
-
# Get W_dec norms that will be used for scaling
|
|
176
|
-
W_dec_norms = self.W_dec.norm(dim=-1)
|
|
170
|
+
# Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
|
|
171
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
177
172
|
|
|
178
173
|
# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
|
|
179
174
|
super().fold_W_dec_norm()
|
|
@@ -265,8 +260,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
265
260
|
return torch.exp(self.log_threshold)
|
|
266
261
|
|
|
267
262
|
def encode_with_hidden_pre(
|
|
268
|
-
self, x:
|
|
269
|
-
) -> tuple[
|
|
263
|
+
self, x: torch.Tensor
|
|
264
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
270
265
|
sae_in = self.process_sae_in(x)
|
|
271
266
|
|
|
272
267
|
hidden_pre = sae_in @ self.W_enc + self.b_enc
|
|
@@ -330,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
330
325
|
# Save the current threshold before we call the parent method
|
|
331
326
|
current_thresh = self.threshold.clone()
|
|
332
327
|
|
|
333
|
-
# Get W_dec norms
|
|
334
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
328
|
+
# Get W_dec norms (clamped to avoid division by zero)
|
|
329
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
335
330
|
|
|
336
331
|
# Call parent implementation to handle W_enc and W_dec adjustment
|
|
337
332
|
super().fold_W_dec_norm()
|
|
@@ -2,7 +2,6 @@ import warnings
|
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from jaxtyping import Float
|
|
6
5
|
from typing_extensions import override
|
|
7
6
|
|
|
8
7
|
from sae_lens.saes.batchtopk_sae import (
|
|
@@ -95,10 +94,10 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
|
95
94
|
|
|
96
95
|
def _decode_matryoshka_level(
|
|
97
96
|
self,
|
|
98
|
-
feature_acts:
|
|
97
|
+
feature_acts: torch.Tensor,
|
|
99
98
|
width: int,
|
|
100
99
|
inv_W_dec_norm: torch.Tensor,
|
|
101
|
-
) ->
|
|
100
|
+
) -> torch.Tensor:
|
|
102
101
|
"""
|
|
103
102
|
Decodes feature activations back into input space for a matryoshka level
|
|
104
103
|
"""
|
|
@@ -19,7 +19,6 @@ from typing import (
|
|
|
19
19
|
|
|
20
20
|
import einops
|
|
21
21
|
import torch
|
|
22
|
-
from jaxtyping import Float
|
|
23
22
|
from numpy.typing import NDArray
|
|
24
23
|
from safetensors.torch import load_file, save_file
|
|
25
24
|
from torch import nn
|
|
@@ -351,16 +350,12 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
351
350
|
self.W_enc = nn.Parameter(w_enc_data)
|
|
352
351
|
|
|
353
352
|
@abstractmethod
|
|
354
|
-
def encode(
|
|
355
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
356
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
353
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
357
354
|
"""Encode input tensor to feature space."""
|
|
358
355
|
pass
|
|
359
356
|
|
|
360
357
|
@abstractmethod
|
|
361
|
-
def decode(
|
|
362
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
363
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
358
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
364
359
|
"""Decode feature activations back to input space."""
|
|
365
360
|
pass
|
|
366
361
|
|
|
@@ -450,9 +445,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
450
445
|
|
|
451
446
|
return super().to(*args, **kwargs)
|
|
452
447
|
|
|
453
|
-
def process_sae_in(
|
|
454
|
-
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
455
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
448
|
+
def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
|
|
456
449
|
sae_in = sae_in.to(self.dtype)
|
|
457
450
|
sae_in = self.reshape_fn_in(sae_in)
|
|
458
451
|
|
|
@@ -491,7 +484,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
491
484
|
@torch.no_grad()
|
|
492
485
|
def fold_W_dec_norm(self):
|
|
493
486
|
"""Fold decoder norms into encoder."""
|
|
494
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
487
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
495
488
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
496
489
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
497
490
|
|
|
@@ -859,14 +852,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
859
852
|
|
|
860
853
|
@abstractmethod
|
|
861
854
|
def encode_with_hidden_pre(
|
|
862
|
-
self, x:
|
|
863
|
-
) -> tuple[
|
|
855
|
+
self, x: torch.Tensor
|
|
856
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
864
857
|
"""Encode with access to pre-activation values for training."""
|
|
865
858
|
...
|
|
866
859
|
|
|
867
|
-
def encode(
|
|
868
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
869
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
860
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
870
861
|
"""
|
|
871
862
|
For inference, just encode without returning hidden_pre.
|
|
872
863
|
(training_forward_pass calls encode_with_hidden_pre).
|
|
@@ -874,9 +865,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
874
865
|
feature_acts, _ = self.encode_with_hidden_pre(x)
|
|
875
866
|
return feature_acts
|
|
876
867
|
|
|
877
|
-
def decode(
|
|
878
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
879
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
868
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
880
869
|
"""
|
|
881
870
|
Decodes feature activations back into input space,
|
|
882
871
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
|
-
from jaxtyping import Float
|
|
6
5
|
from numpy.typing import NDArray
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
@@ -54,9 +53,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
54
53
|
super().initialize_weights()
|
|
55
54
|
_init_weights_standard(self)
|
|
56
55
|
|
|
57
|
-
def encode(
|
|
58
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
59
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
56
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
57
|
"""
|
|
61
58
|
Encode the input tensor into the feature space.
|
|
62
59
|
"""
|
|
@@ -67,9 +64,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
67
64
|
# Apply the activation function (e.g., ReLU, depending on config)
|
|
68
65
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
69
66
|
|
|
70
|
-
def decode(
|
|
71
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
72
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
67
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
73
68
|
"""
|
|
74
69
|
Decode the feature activations back to the input space.
|
|
75
70
|
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
@@ -127,8 +122,8 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
127
122
|
}
|
|
128
123
|
|
|
129
124
|
def encode_with_hidden_pre(
|
|
130
|
-
self, x:
|
|
131
|
-
) -> tuple[
|
|
125
|
+
self, x: torch.Tensor
|
|
126
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
132
127
|
# Process the input (including dtype conversion, hook call, and any activation normalization)
|
|
133
128
|
sae_in = self.process_sae_in(x)
|
|
134
129
|
# Compute the pre-activation (and allow for a hook if desired)
|
|
@@ -13,7 +13,6 @@ from typing import Literal
|
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
import torch.nn.functional as F
|
|
16
|
-
from jaxtyping import Float
|
|
17
16
|
from torch import nn
|
|
18
17
|
from typing_extensions import override
|
|
19
18
|
|
|
@@ -250,8 +249,8 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
|
|
|
250
249
|
)
|
|
251
250
|
|
|
252
251
|
def encode_with_predictions(
|
|
253
|
-
self, x:
|
|
254
|
-
) -> tuple[
|
|
252
|
+
self, x: torch.Tensor
|
|
253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
255
254
|
"""Encode input to novel codes only.
|
|
256
255
|
|
|
257
256
|
Returns only the sparse novel codes (not predicted codes).
|
|
@@ -312,14 +311,10 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
|
|
|
312
311
|
# Return only novel codes (these are the interpretable features)
|
|
313
312
|
return z_novel, z_pred
|
|
314
313
|
|
|
315
|
-
def encode(
|
|
316
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
317
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
314
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
318
315
|
return self.encode_with_predictions(x)[0]
|
|
319
316
|
|
|
320
|
-
def decode(
|
|
321
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
322
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
317
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
323
318
|
"""Decode novel codes to reconstruction.
|
|
324
319
|
|
|
325
320
|
Note: This only decodes the novel codes. For full reconstruction,
|
|
@@ -342,9 +337,7 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
|
|
|
342
337
|
return sae_out
|
|
343
338
|
|
|
344
339
|
@override
|
|
345
|
-
def forward(
|
|
346
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
347
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
340
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
348
341
|
"""Full forward pass through TemporalSAE.
|
|
349
342
|
|
|
350
343
|
Returns complete reconstruction (predicted + novel).
|
|
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
|
4
4
|
from typing import Any, Callable
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from jaxtyping import Float
|
|
8
7
|
from torch import nn
|
|
9
8
|
from transformer_lens.hook_points import HookPoint
|
|
10
9
|
from typing_extensions import override
|
|
@@ -235,9 +234,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
235
234
|
super().initialize_weights()
|
|
236
235
|
_init_weights_topk(self)
|
|
237
236
|
|
|
238
|
-
def encode(
|
|
239
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
240
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
237
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
241
238
|
"""
|
|
242
239
|
Converts input x into feature activations.
|
|
243
240
|
Uses topk activation under the hood.
|
|
@@ -251,8 +248,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
251
248
|
|
|
252
249
|
def decode(
|
|
253
250
|
self,
|
|
254
|
-
feature_acts:
|
|
255
|
-
) ->
|
|
251
|
+
feature_acts: torch.Tensor,
|
|
252
|
+
) -> torch.Tensor:
|
|
256
253
|
"""
|
|
257
254
|
Reconstructs the input from topk feature activations.
|
|
258
255
|
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
@@ -354,8 +351,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
354
351
|
_init_weights_topk(self)
|
|
355
352
|
|
|
356
353
|
def encode_with_hidden_pre(
|
|
357
|
-
self, x:
|
|
358
|
-
) -> tuple[
|
|
354
|
+
self, x: torch.Tensor
|
|
355
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
359
356
|
"""
|
|
360
357
|
Similar to the base training method: calculate pre-activations, then apply TopK.
|
|
361
358
|
"""
|
|
@@ -372,8 +369,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
372
369
|
@override
|
|
373
370
|
def decode(
|
|
374
371
|
self,
|
|
375
|
-
feature_acts:
|
|
376
|
-
) ->
|
|
372
|
+
feature_acts: torch.Tensor,
|
|
373
|
+
) -> torch.Tensor:
|
|
377
374
|
"""
|
|
378
375
|
Decodes feature activations back into input space,
|
|
379
376
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
@@ -534,7 +531,7 @@ def _fold_norm_topk(
|
|
|
534
531
|
b_enc: torch.Tensor,
|
|
535
532
|
W_dec: torch.Tensor,
|
|
536
533
|
) -> None:
|
|
537
|
-
W_dec_norm = W_dec.norm(dim=-1)
|
|
534
|
+
W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
538
535
|
b_enc.data = b_enc.data * W_dec_norm
|
|
539
536
|
W_dec_norms = W_dec_norm.unsqueeze(1)
|
|
540
537
|
W_dec.data = W_dec.data / W_dec_norms
|
|
@@ -12,7 +12,6 @@ import torch
|
|
|
12
12
|
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
|
|
13
13
|
from huggingface_hub import hf_hub_download
|
|
14
14
|
from huggingface_hub.utils import HfHubHTTPError
|
|
15
|
-
from jaxtyping import Float, Int
|
|
16
15
|
from requests import HTTPError
|
|
17
16
|
from safetensors.torch import load_file, save_file
|
|
18
17
|
from tqdm.auto import tqdm
|
|
@@ -167,9 +166,11 @@ class ActivationsStore:
|
|
|
167
166
|
disable_concat_sequences: bool = False,
|
|
168
167
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
|
|
169
168
|
) -> ActivationsStore:
|
|
169
|
+
if context_size is None:
|
|
170
|
+
context_size = sae.cfg.metadata.context_size
|
|
170
171
|
if sae.cfg.metadata.hook_name is None:
|
|
171
172
|
raise ValueError("hook_name is required")
|
|
172
|
-
if
|
|
173
|
+
if context_size is None:
|
|
173
174
|
raise ValueError("context_size is required")
|
|
174
175
|
if sae.cfg.metadata.prepend_bos is None:
|
|
175
176
|
raise ValueError("prepend_bos is required")
|
|
@@ -179,9 +180,7 @@ class ActivationsStore:
|
|
|
179
180
|
d_in=sae.cfg.d_in,
|
|
180
181
|
hook_name=sae.cfg.metadata.hook_name,
|
|
181
182
|
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
182
|
-
context_size=
|
|
183
|
-
if context_size is None
|
|
184
|
-
else context_size,
|
|
183
|
+
context_size=context_size,
|
|
185
184
|
prepend_bos=sae.cfg.metadata.prepend_bos,
|
|
186
185
|
streaming=streaming,
|
|
187
186
|
store_batch_size_prompts=store_batch_size_prompts,
|
|
@@ -542,8 +541,8 @@ class ActivationsStore:
|
|
|
542
541
|
d_in: int,
|
|
543
542
|
raise_on_epoch_end: bool,
|
|
544
543
|
) -> tuple[
|
|
545
|
-
|
|
546
|
-
|
|
544
|
+
torch.Tensor,
|
|
545
|
+
torch.Tensor | None,
|
|
547
546
|
]:
|
|
548
547
|
"""
|
|
549
548
|
Loads `total_size` activations from `cached_activation_dataset`
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|