sae-lens 6.14.1__py3-none-any.whl → 6.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,35 @@
1
+ temporal-sae-gemma-2-2b:
2
+ conversion_func: temporal
3
+ model: gemma-2-2b
4
+ repo_id: canrager/temporalSAEs
5
+ config_overrides:
6
+ model_name: gemma-2-2b
7
+ hook_name: blocks.12.hook_resid_post
8
+ dataset_path: monology/pile-uncopyrighted
9
+ saes:
10
+ - id: blocks.12.hook_resid_post
11
+ l0: 192
12
+ norm_scaling_factor: 0.00666666667
13
+ path: gemma-2-2B/layer_12/temporal
14
+ neuronpedia: gemma-2-2b/12-temporal-res
15
+ temporal-sae-llama-3.1-8b:
16
+ conversion_func: temporal
17
+ model: meta-llama/Llama-3.1-8B
18
+ repo_id: canrager/temporalSAEs
19
+ config_overrides:
20
+ model_name: meta-llama/Llama-3.1-8B
21
+ dataset_path: monology/pile-uncopyrighted
22
+ saes:
23
+ - id: blocks.15.hook_resid_post
24
+ l0: 256
25
+ norm_scaling_factor: 0.029
26
+ path: llama-3.1-8B/layer_15/temporal
27
+ neuronpedia: llama3.1-8b/15-temporal-res
28
+ - id: blocks.26.hook_resid_post
29
+ l0: 256
30
+ norm_scaling_factor: 0.029
31
+ path: llama-3.1-8B/layer_26/temporal
32
+ neuronpedia: llama3.1-8b/26-temporal-res
1
33
  deepseek-r1-distill-llama-8b-qresearch:
2
34
  conversion_func: deepseek_r1
3
35
  model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
@@ -14882,4 +14914,48 @@ qwen2.5-7b-instruct-andyrdt:
14882
14914
  neuronpedia: qwen2.5-7b-it/23-resid-post-aa
14883
14915
  - id: resid_post_layer_27_trainer_1
14884
14916
  path: resid_post_layer_27/trainer_1
14885
- neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14917
+ neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14918
+
14919
+ gpt-oss-20b-andyrdt:
14920
+ conversion_func: dictionary_learning_1
14921
+ model: openai/gpt-oss-20b
14922
+ repo_id: andyrdt/saes-gpt-oss-20b
14923
+ saes:
14924
+ - id: resid_post_layer_3_trainer_0
14925
+ path: resid_post_layer_3/trainer_0
14926
+ neuronpedia: gpt-oss-20b/3-resid-post-aa
14927
+ - id: resid_post_layer_7_trainer_0
14928
+ path: resid_post_layer_7/trainer_0
14929
+ neuronpedia: gpt-oss-20b/7-resid-post-aa
14930
+ - id: resid_post_layer_11_trainer_0
14931
+ path: resid_post_layer_11/trainer_0
14932
+ neuronpedia: gpt-oss-20b/11-resid-post-aa
14933
+ - id: resid_post_layer_15_trainer_0
14934
+ path: resid_post_layer_15/trainer_0
14935
+ neuronpedia: gpt-oss-20b/15-resid-post-aa
14936
+ - id: resid_post_layer_19_trainer_0
14937
+ path: resid_post_layer_19/trainer_0
14938
+ neuronpedia: gpt-oss-20b/19-resid-post-aa
14939
+ - id: resid_post_layer_23_trainer_0
14940
+ path: resid_post_layer_23/trainer_0
14941
+ neuronpedia: gpt-oss-20b/23-resid-post-aa
14942
+
14943
+ goodfire-llama-3.3-70b-instruct:
14944
+ conversion_func: goodfire
14945
+ model: meta-llama/Llama-3.3-70B-Instruct
14946
+ repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
14947
+ saes:
14948
+ - id: layer_50
14949
+ path: Llama-3.3-70B-Instruct-SAE-l50.pt
14950
+ l0: 121
14951
+ neuronpedia: llama3.3-70b-it/50-resid-post-gf
14952
+
14953
+ goodfire-llama-3.1-8b-instruct:
14954
+ conversion_func: goodfire
14955
+ model: meta-llama/Llama-3.1-8B-Instruct
14956
+ repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
14957
+ saes:
14958
+ - id: layer_19
14959
+ path: Llama-3.1-8B-Instruct-SAE-l19.pth
14960
+ l0: 91
14961
+ neuronpedia: llama3.1-8b-it/19-resid-post-gf
sae_lens/saes/__init__.py CHANGED
@@ -14,6 +14,10 @@ from .jumprelu_sae import (
14
14
  JumpReLUTrainingSAE,
15
15
  JumpReLUTrainingSAEConfig,
16
16
  )
17
+ from .matryoshka_batchtopk_sae import (
18
+ MatryoshkaBatchTopKTrainingSAE,
19
+ MatryoshkaBatchTopKTrainingSAEConfig,
20
+ )
17
21
  from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
18
22
  from .standard_sae import (
19
23
  StandardSAE,
@@ -21,6 +25,7 @@ from .standard_sae import (
21
25
  StandardTrainingSAE,
22
26
  StandardTrainingSAEConfig,
23
27
  )
28
+ from .temporal_sae import TemporalSAE, TemporalSAEConfig
24
29
  from .topk_sae import (
25
30
  TopKSAE,
26
31
  TopKSAEConfig,
@@ -65,4 +70,8 @@ __all__ = [
65
70
  "SkipTranscoderConfig",
66
71
  "JumpReLUTranscoder",
67
72
  "JumpReLUTranscoderConfig",
73
+ "MatryoshkaBatchTopKTrainingSAE",
74
+ "MatryoshkaBatchTopKTrainingSAEConfig",
75
+ "TemporalSAE",
76
+ "TemporalSAEConfig",
68
77
  ]
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
23
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
24
  acts = x.relu()
25
25
  flat_acts = acts.flatten()
26
- acts_topk_flat = torch.topk(flat_acts, int(self.k * acts.shape[0]), dim=-1)
26
+ # Calculate total number of samples across all non-feature dimensions
27
+ num_samples = acts.shape[:-1].numel()
28
+ acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
27
29
  return (
28
30
  torch.zeros_like(flat_acts)
29
31
  .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
@@ -2,7 +2,6 @@ from dataclasses import dataclass
2
2
  from typing import Any
3
3
 
4
4
  import torch
5
- from jaxtyping import Float
6
5
  from numpy.typing import NDArray
7
6
  from torch import nn
8
7
  from typing_extensions import override
@@ -49,9 +48,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
49
48
  super().initialize_weights()
50
49
  _init_weights_gated(self)
51
50
 
52
- def encode(
53
- self, x: Float[torch.Tensor, "... d_in"]
54
- ) -> Float[torch.Tensor, "... d_sae"]:
51
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
55
52
  """
56
53
  Encode the input tensor into the feature space using a gated encoder.
57
54
  This must match the original encode_gated implementation from SAE class.
@@ -72,9 +69,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
72
69
  # Combine gating and magnitudes
73
70
  return self.hook_sae_acts_post(active_features * feature_magnitudes)
74
71
 
75
- def decode(
76
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
77
- ) -> Float[torch.Tensor, "... d_in"]:
72
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
78
73
  """
79
74
  Decode the feature activations back into the input space:
80
75
  1) Apply optional finetuning scaling.
@@ -147,8 +142,8 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
147
142
  _init_weights_gated(self)
148
143
 
149
144
  def encode_with_hidden_pre(
150
- self, x: Float[torch.Tensor, "... d_in"]
151
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
145
+ self, x: torch.Tensor
146
+ ) -> tuple[torch.Tensor, torch.Tensor]:
152
147
  """
153
148
  Gated forward pass with pre-activation (for training).
154
149
  """
@@ -3,7 +3,6 @@ from typing import Any, Literal
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from torch import nn
8
7
  from typing_extensions import override
9
8
 
@@ -130,9 +129,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
130
129
  torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
131
130
  )
132
131
 
133
- def encode(
134
- self, x: Float[torch.Tensor, "... d_in"]
135
- ) -> Float[torch.Tensor, "... d_sae"]:
132
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
136
133
  """
137
134
  Encode the input tensor into the feature space using JumpReLU.
138
135
  The threshold parameter determines which units remain active.
@@ -150,9 +147,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
150
147
  # 3) Multiply the normally activated units by that mask.
151
148
  return self.hook_sae_acts_post(base_acts * jump_relu_mask)
152
149
 
153
- def decode(
154
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
155
- ) -> Float[torch.Tensor, "... d_in"]:
150
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
156
151
  """
157
152
  Decode the feature activations back to the input space.
158
153
  Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
@@ -265,8 +260,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
265
260
  return torch.exp(self.log_threshold)
266
261
 
267
262
  def encode_with_hidden_pre(
268
- self, x: Float[torch.Tensor, "... d_in"]
269
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
263
+ self, x: torch.Tensor
264
+ ) -> tuple[torch.Tensor, torch.Tensor]:
270
265
  sae_in = self.process_sae_in(x)
271
266
 
272
267
  hidden_pre = sae_in @ self.W_enc + self.b_enc
@@ -0,0 +1,136 @@
1
+ import warnings
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ from typing_extensions import override
6
+
7
+ from sae_lens.saes.batchtopk_sae import (
8
+ BatchTopKTrainingSAE,
9
+ BatchTopKTrainingSAEConfig,
10
+ )
11
+ from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
12
+ from sae_lens.saes.topk_sae import _sparse_matmul_nd
13
+
14
+
15
+ @dataclass
16
+ class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
17
+ """
18
+ Configuration class for training a MatryoshkaBatchTopKTrainingSAE.
19
+
20
+ [Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
21
+ losses of different widths during training to avoid feature absorption. This also has a
22
+ nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
23
+ However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
24
+ longer to train due to requiring multiple forward passes per training step.
25
+
26
+ After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.
27
+
28
+ Args:
29
+ matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
30
+ k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
31
+ Defaults to 100.
32
+ topk_threshold_lr (float): Learning rate for updating the global topk threshold.
33
+ The threshold is updated using an exponential moving average of the minimum
34
+ positive activation value. Defaults to 0.01.
35
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
36
+ dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
37
+ Defaults to 1.0.
38
+ rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
39
+ Inherited from TopKTrainingSAEConfig. Defaults to True.
40
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
41
+ Inherited from TrainingSAEConfig. Defaults to 0.1.
42
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
43
+ Inherited from SAEConfig.
44
+ d_sae (int): SAE latent dimension (number of features in the SAE).
45
+ Inherited from SAEConfig.
46
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
47
+ Defaults to "float32".
48
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
49
+ Defaults to "cpu".
50
+ """
51
+
52
+ matryoshka_widths: list[int] = field(default_factory=list)
53
+
54
+ @override
55
+ @classmethod
56
+ def architecture(cls) -> str:
57
+ return "matryoshka_batchtopk"
58
+
59
+
60
+ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
61
+ """
62
+ Global Batch TopK Training SAE
63
+
64
+ This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
65
+
66
+ BatchTopK SAEs are saved as JumpReLU SAEs after training.
67
+ """
68
+
69
+ cfg: MatryoshkaBatchTopKTrainingSAEConfig # type: ignore[assignment]
70
+
71
+ def __init__(
72
+ self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
73
+ ):
74
+ super().__init__(cfg, use_error_term)
75
+ _validate_matryoshka_config(cfg)
76
+
77
+ @override
78
+ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
79
+ base_output = super().training_forward_pass(step_input)
80
+ inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
81
+ # the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
82
+ for width in self.cfg.matryoshka_widths[:-1]:
83
+ inner_reconstruction = self._decode_matryoshka_level(
84
+ base_output.feature_acts, width, inv_W_dec_norm
85
+ )
86
+ inner_mse_loss = (
87
+ self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
88
+ .sum(dim=-1)
89
+ .mean()
90
+ )
91
+ base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
92
+ base_output.loss = base_output.loss + inner_mse_loss
93
+ return base_output
94
+
95
+ def _decode_matryoshka_level(
96
+ self,
97
+ feature_acts: torch.Tensor,
98
+ width: int,
99
+ inv_W_dec_norm: torch.Tensor,
100
+ ) -> torch.Tensor:
101
+ """
102
+ Decodes feature activations back into input space for a matryoshka level
103
+ """
104
+ inner_feature_acts = feature_acts[:, :width]
105
+ # Handle sparse tensors using efficient sparse matrix multiplication
106
+ if self.cfg.rescale_acts_by_decoder_norm:
107
+ # need to multiply by the inverse of the norm because division is illegal with sparse tensors
108
+ inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
109
+ if inner_feature_acts.is_sparse:
110
+ sae_out_pre = (
111
+ _sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
112
+ )
113
+ else:
114
+ sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
115
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
116
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
117
+
118
+
119
+ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
120
+ if cfg.matryoshka_widths[-1] != cfg.d_sae:
121
+ # warn the users that we will add a final matryoshka level
122
+ warnings.warn(
123
+ "WARNING: The final matryoshka level width is not set to cfg.d_sae. "
124
+ "A final matryoshka level of width=cfg.d_sae will be added."
125
+ )
126
+ cfg.matryoshka_widths.append(cfg.d_sae)
127
+
128
+ for prev_width, curr_width in zip(
129
+ cfg.matryoshka_widths[:-1], cfg.matryoshka_widths[1:]
130
+ ):
131
+ if prev_width >= curr_width:
132
+ raise ValueError("cfg.matryoshka_widths must be strictly increasing.")
133
+ if len(cfg.matryoshka_widths) == 1:
134
+ warnings.warn(
135
+ "WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
136
+ )
sae_lens/saes/sae.py CHANGED
@@ -19,9 +19,8 @@ from typing import (
19
19
 
20
20
  import einops
21
21
  import torch
22
- from jaxtyping import Float
23
22
  from numpy.typing import NDArray
24
- from safetensors.torch import save_file
23
+ from safetensors.torch import load_file, save_file
25
24
  from torch import nn
26
25
  from transformer_lens.hook_points import HookedRootModule, HookPoint
27
26
  from typing_extensions import deprecated, overload, override
@@ -155,9 +154,9 @@ class SAEConfig(ABC):
155
154
  dtype: str = "float32"
156
155
  device: str = "cpu"
157
156
  apply_b_dec_to_input: bool = True
158
- normalize_activations: Literal[
159
- "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
160
- ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
157
+ normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
158
+ "none" # none, expected_average_only_in (Anthropic April Update)
159
+ )
161
160
  reshape_activations: Literal["none", "hook_z"] = "none"
162
161
  metadata: SAEMetadata = field(default_factory=SAEMetadata)
163
162
 
@@ -217,6 +216,7 @@ class TrainStepInput:
217
216
  sae_in: torch.Tensor
218
217
  coefficients: dict[str, float]
219
218
  dead_neuron_mask: torch.Tensor | None
219
+ n_training_steps: int
220
220
 
221
221
 
222
222
  class TrainCoefficientConfig(NamedTuple):
@@ -308,6 +308,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
308
308
 
309
309
  self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
310
310
  self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
311
+
311
312
  elif self.cfg.normalize_activations == "layer_norm":
312
313
  # we need to scale the norm of the input and store the scaling factor
313
314
  def run_time_activation_ln_in(
@@ -349,16 +350,12 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
349
350
  self.W_enc = nn.Parameter(w_enc_data)
350
351
 
351
352
  @abstractmethod
352
- def encode(
353
- self, x: Float[torch.Tensor, "... d_in"]
354
- ) -> Float[torch.Tensor, "... d_sae"]:
353
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
355
354
  """Encode input tensor to feature space."""
356
355
  pass
357
356
 
358
357
  @abstractmethod
359
- def decode(
360
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
361
- ) -> Float[torch.Tensor, "... d_in"]:
358
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
362
359
  """Decode feature activations back to input space."""
363
360
  pass
364
361
 
@@ -448,26 +445,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
448
445
 
449
446
  return super().to(*args, **kwargs)
450
447
 
451
- def process_sae_in(
452
- self, sae_in: Float[torch.Tensor, "... d_in"]
453
- ) -> Float[torch.Tensor, "... d_in"]:
454
- # print(f"Input shape to process_sae_in: {sae_in.shape}")
455
- # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
456
- # print(f"self.b_dec shape: {self.b_dec.shape}")
457
- # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
458
-
448
+ def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
459
449
  sae_in = sae_in.to(self.dtype)
460
-
461
- # print(f"Shape before reshape_fn_in: {sae_in.shape}")
462
450
  sae_in = self.reshape_fn_in(sae_in)
463
- # print(f"Shape after reshape_fn_in: {sae_in.shape}")
464
451
 
465
452
  sae_in = self.hook_sae_input(sae_in)
466
453
  sae_in = self.run_time_activation_norm_fn_in(sae_in)
467
454
 
468
455
  # Here's where the error happens
469
456
  bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
470
- # print(f"Bias term shape: {bias_term.shape}")
471
457
 
472
458
  return sae_in - bias_term
473
459
 
@@ -866,14 +852,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
866
852
 
867
853
  @abstractmethod
868
854
  def encode_with_hidden_pre(
869
- self, x: Float[torch.Tensor, "... d_in"]
870
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
855
+ self, x: torch.Tensor
856
+ ) -> tuple[torch.Tensor, torch.Tensor]:
871
857
  """Encode with access to pre-activation values for training."""
872
858
  ...
873
859
 
874
- def encode(
875
- self, x: Float[torch.Tensor, "... d_in"]
876
- ) -> Float[torch.Tensor, "... d_sae"]:
860
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
877
861
  """
878
862
  For inference, just encode without returning hidden_pre.
879
863
  (training_forward_pass calls encode_with_hidden_pre).
@@ -881,9 +865,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
881
865
  feature_acts, _ = self.encode_with_hidden_pre(x)
882
866
  return feature_acts
883
867
 
884
- def decode(
885
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
886
- ) -> Float[torch.Tensor, "... d_in"]:
868
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
887
869
  """
888
870
  Decodes feature activations back into input space,
889
871
  applying optional finetuning scale, hooking, out normalization, etc.
@@ -1017,6 +999,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1017
999
  ) -> type[TrainingSAEConfig]:
1018
1000
  return get_sae_training_class(architecture)[1]
1019
1001
 
1002
+ def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
1003
+ checkpoint_path = Path(checkpoint_path)
1004
+ state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
1005
+ self.process_state_dict_for_loading(state_dict)
1006
+ self.load_state_dict(state_dict)
1007
+
1020
1008
 
1021
1009
  _blank_hook = nn.Identity()
1022
1010
 
@@ -2,7 +2,6 @@ from dataclasses import dataclass
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
- from jaxtyping import Float
6
5
  from numpy.typing import NDArray
7
6
  from torch import nn
8
7
  from typing_extensions import override
@@ -54,9 +53,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
54
53
  super().initialize_weights()
55
54
  _init_weights_standard(self)
56
55
 
57
- def encode(
58
- self, x: Float[torch.Tensor, "... d_in"]
59
- ) -> Float[torch.Tensor, "... d_sae"]:
56
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
60
57
  """
61
58
  Encode the input tensor into the feature space.
62
59
  """
@@ -67,9 +64,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
67
64
  # Apply the activation function (e.g., ReLU, depending on config)
68
65
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
69
66
 
70
- def decode(
71
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
72
- ) -> Float[torch.Tensor, "... d_in"]:
67
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
73
68
  """
74
69
  Decode the feature activations back to the input space.
75
70
  Now, if hook_z reshaping is turned on, we reverse the flattening.
@@ -127,8 +122,8 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
127
122
  }
128
123
 
129
124
  def encode_with_hidden_pre(
130
- self, x: Float[torch.Tensor, "... d_in"]
131
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
125
+ self, x: torch.Tensor
126
+ ) -> tuple[torch.Tensor, torch.Tensor]:
132
127
  # Process the input (including dtype conversion, hook call, and any activation normalization)
133
128
  sae_in = self.process_sae_in(x)
134
129
  # Compute the pre-activation (and allow for a hook if desired)