sae-lens 6.20.1__tar.gz → 6.22.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {sae_lens-6.20.1 → sae_lens-6.22.2}/PKG-INFO +1 -1
  2. {sae_lens-6.20.1 → sae_lens-6.22.2}/pyproject.toml +1 -1
  3. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/analysis/hooked_sae_transformer.py +4 -13
  5. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/cache_activations_runner.py +2 -3
  6. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/pretrained_saes.yaml +26 -0
  7. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/gated_sae.py +4 -9
  8. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/jumprelu_sae.py +4 -9
  9. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/matryoshka_batchtopk_sae.py +2 -3
  10. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/sae.py +7 -18
  11. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/standard_sae.py +4 -9
  12. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/temporal_sae.py +5 -12
  13. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/topk_sae.py +7 -10
  14. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/training/activations_store.py +6 -7
  15. {sae_lens-6.20.1 → sae_lens-6.22.2}/LICENSE +0 -0
  16. {sae_lens-6.20.1 → sae_lens-6.22.2}/README.md +0 -0
  17. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/analysis/__init__.py +0 -0
  18. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  19. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/config.py +0 -0
  20. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/constants.py +0 -0
  21. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/evals.py +0 -0
  22. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/llm_sae_training_runner.py +0 -0
  23. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/load_model.py +0 -0
  24. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/loading/__init__.py +0 -0
  25. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  26. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  27. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/pretokenize_runner.py +0 -0
  28. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/registry.py +0 -0
  29. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/__init__.py +0 -0
  30. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/batchtopk_sae.py +0 -0
  31. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/saes/transcoder.py +0 -0
  32. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/tokenization_and_batching.py +0 -0
  33. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/training/__init__.py +0 -0
  34. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/training/activation_scaler.py +0 -0
  35. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/training/mixing_buffer.py +0 -0
  36. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/training/optim.py +0 -0
  37. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/training/sae_trainer.py +0 -0
  38. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/training/types.py +0 -0
  39. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  40. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/tutorial/tsea.py +0 -0
  41. {sae_lens-6.20.1 → sae_lens-6.22.2}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.20.1
3
+ Version: 6.22.2
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.20.1"
3
+ version = "6.22.2"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.20.1"
2
+ __version__ = "6.22.2"
3
3
 
4
4
  import logging
5
5
 
@@ -3,7 +3,6 @@ from contextlib import contextmanager
3
3
  from typing import Any, Callable
4
4
 
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from transformer_lens.ActivationCache import ActivationCache
8
7
  from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
9
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
@@ -11,8 +10,8 @@ from transformer_lens.HookedTransformer import HookedTransformer
11
10
 
12
11
  from sae_lens.saes.sae import SAE
13
12
 
14
- SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
15
- LossPerToken = Float[torch.Tensor, "batch pos-1"]
13
+ SingleLoss = torch.Tensor # Type alias for a single element tensor
14
+ LossPerToken = torch.Tensor
16
15
  Loss = SingleLoss | LossPerToken
17
16
 
18
17
 
@@ -171,12 +170,7 @@ class HookedSAETransformer(HookedTransformer):
171
170
  reset_saes_end: bool = True,
172
171
  use_error_term: bool | None = None,
173
172
  **model_kwargs: Any,
174
- ) -> (
175
- None
176
- | Float[torch.Tensor, "batch pos d_vocab"]
177
- | Loss
178
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
179
- ):
173
+ ) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
180
174
  """Wrapper around HookedTransformer forward pass.
181
175
 
182
176
  Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
@@ -203,10 +197,7 @@ class HookedSAETransformer(HookedTransformer):
203
197
  remove_batch_dim: bool = False,
204
198
  **kwargs: Any,
205
199
  ) -> tuple[
206
- None
207
- | Float[torch.Tensor, "batch pos d_vocab"]
208
- | Loss
209
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
200
+ None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
210
201
  ActivationCache | dict[str, torch.Tensor],
211
202
  ]:
212
203
  """Wrapper around 'run_with_cache' in HookedTransformer.
@@ -9,7 +9,6 @@ import torch
9
9
  from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
- from jaxtyping import Float, Int
13
12
  from tqdm.auto import tqdm
14
13
  from transformer_lens.HookedTransformer import HookedRootModule
15
14
 
@@ -318,8 +317,8 @@ class CacheActivationsRunner:
318
317
  def _create_shard(
319
318
  self,
320
319
  buffer: tuple[
321
- Float[torch.Tensor, "(bs context_size) d_in"],
322
- Int[torch.Tensor, "(bs context_size)"] | None,
320
+ torch.Tensor, # shape: (bs context_size) d_in
321
+ torch.Tensor | None, # shape: (bs context_size) or None
323
322
  ],
324
323
  ) -> Dataset:
325
324
  hook_names = [self.cfg.hook_name]
@@ -14916,6 +14916,30 @@ qwen2.5-7b-instruct-andyrdt:
14916
14916
  path: resid_post_layer_27/trainer_1
14917
14917
  neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14918
14918
 
14919
+ gpt-oss-20b-andyrdt:
14920
+ conversion_func: dictionary_learning_1
14921
+ model: openai/gpt-oss-20b
14922
+ repo_id: andyrdt/saes-gpt-oss-20b
14923
+ saes:
14924
+ - id: resid_post_layer_3_trainer_0
14925
+ path: resid_post_layer_3/trainer_0
14926
+ neuronpedia: gpt-oss-20b/3-resid-post-aa
14927
+ - id: resid_post_layer_7_trainer_0
14928
+ path: resid_post_layer_7/trainer_0
14929
+ neuronpedia: gpt-oss-20b/7-resid-post-aa
14930
+ - id: resid_post_layer_11_trainer_0
14931
+ path: resid_post_layer_11/trainer_0
14932
+ neuronpedia: gpt-oss-20b/11-resid-post-aa
14933
+ - id: resid_post_layer_15_trainer_0
14934
+ path: resid_post_layer_15/trainer_0
14935
+ neuronpedia: gpt-oss-20b/15-resid-post-aa
14936
+ - id: resid_post_layer_19_trainer_0
14937
+ path: resid_post_layer_19/trainer_0
14938
+ neuronpedia: gpt-oss-20b/19-resid-post-aa
14939
+ - id: resid_post_layer_23_trainer_0
14940
+ path: resid_post_layer_23/trainer_0
14941
+ neuronpedia: gpt-oss-20b/23-resid-post-aa
14942
+
14919
14943
  goodfire-llama-3.3-70b-instruct:
14920
14944
  conversion_func: goodfire
14921
14945
  model: meta-llama/Llama-3.3-70B-Instruct
@@ -14924,6 +14948,7 @@ goodfire-llama-3.3-70b-instruct:
14924
14948
  - id: layer_50
14925
14949
  path: Llama-3.3-70B-Instruct-SAE-l50.pt
14926
14950
  l0: 121
14951
+ neuronpedia: llama3.3-70b-it/50-resid-post-gf
14927
14952
 
14928
14953
  goodfire-llama-3.1-8b-instruct:
14929
14954
  conversion_func: goodfire
@@ -14933,3 +14958,4 @@ goodfire-llama-3.1-8b-instruct:
14933
14958
  - id: layer_19
14934
14959
  path: Llama-3.1-8B-Instruct-SAE-l19.pth
14935
14960
  l0: 91
14961
+ neuronpedia: llama3.1-8b-it/19-resid-post-gf
@@ -2,7 +2,6 @@ from dataclasses import dataclass
2
2
  from typing import Any
3
3
 
4
4
  import torch
5
- from jaxtyping import Float
6
5
  from numpy.typing import NDArray
7
6
  from torch import nn
8
7
  from typing_extensions import override
@@ -49,9 +48,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
49
48
  super().initialize_weights()
50
49
  _init_weights_gated(self)
51
50
 
52
- def encode(
53
- self, x: Float[torch.Tensor, "... d_in"]
54
- ) -> Float[torch.Tensor, "... d_sae"]:
51
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
55
52
  """
56
53
  Encode the input tensor into the feature space using a gated encoder.
57
54
  This must match the original encode_gated implementation from SAE class.
@@ -72,9 +69,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
72
69
  # Combine gating and magnitudes
73
70
  return self.hook_sae_acts_post(active_features * feature_magnitudes)
74
71
 
75
- def decode(
76
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
77
- ) -> Float[torch.Tensor, "... d_in"]:
72
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
78
73
  """
79
74
  Decode the feature activations back into the input space:
80
75
  1) Apply optional finetuning scaling.
@@ -147,8 +142,8 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
147
142
  _init_weights_gated(self)
148
143
 
149
144
  def encode_with_hidden_pre(
150
- self, x: Float[torch.Tensor, "... d_in"]
151
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
145
+ self, x: torch.Tensor
146
+ ) -> tuple[torch.Tensor, torch.Tensor]:
152
147
  """
153
148
  Gated forward pass with pre-activation (for training).
154
149
  """
@@ -3,7 +3,6 @@ from typing import Any, Literal
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from torch import nn
8
7
  from typing_extensions import override
9
8
 
@@ -130,9 +129,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
130
129
  torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
131
130
  )
132
131
 
133
- def encode(
134
- self, x: Float[torch.Tensor, "... d_in"]
135
- ) -> Float[torch.Tensor, "... d_sae"]:
132
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
136
133
  """
137
134
  Encode the input tensor into the feature space using JumpReLU.
138
135
  The threshold parameter determines which units remain active.
@@ -150,9 +147,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
150
147
  # 3) Multiply the normally activated units by that mask.
151
148
  return self.hook_sae_acts_post(base_acts * jump_relu_mask)
152
149
 
153
- def decode(
154
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
155
- ) -> Float[torch.Tensor, "... d_in"]:
150
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
156
151
  """
157
152
  Decode the feature activations back to the input space.
158
153
  Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
@@ -265,8 +260,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
265
260
  return torch.exp(self.log_threshold)
266
261
 
267
262
  def encode_with_hidden_pre(
268
- self, x: Float[torch.Tensor, "... d_in"]
269
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
263
+ self, x: torch.Tensor
264
+ ) -> tuple[torch.Tensor, torch.Tensor]:
270
265
  sae_in = self.process_sae_in(x)
271
266
 
272
267
  hidden_pre = sae_in @ self.W_enc + self.b_enc
@@ -2,7 +2,6 @@ import warnings
2
2
  from dataclasses import dataclass, field
3
3
 
4
4
  import torch
5
- from jaxtyping import Float
6
5
  from typing_extensions import override
7
6
 
8
7
  from sae_lens.saes.batchtopk_sae import (
@@ -95,10 +94,10 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
95
94
 
96
95
  def _decode_matryoshka_level(
97
96
  self,
98
- feature_acts: Float[torch.Tensor, "... d_sae"],
97
+ feature_acts: torch.Tensor,
99
98
  width: int,
100
99
  inv_W_dec_norm: torch.Tensor,
101
- ) -> Float[torch.Tensor, "... d_in"]:
100
+ ) -> torch.Tensor:
102
101
  """
103
102
  Decodes feature activations back into input space for a matryoshka level
104
103
  """
@@ -19,7 +19,6 @@ from typing import (
19
19
 
20
20
  import einops
21
21
  import torch
22
- from jaxtyping import Float
23
22
  from numpy.typing import NDArray
24
23
  from safetensors.torch import load_file, save_file
25
24
  from torch import nn
@@ -351,16 +350,12 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
351
350
  self.W_enc = nn.Parameter(w_enc_data)
352
351
 
353
352
  @abstractmethod
354
- def encode(
355
- self, x: Float[torch.Tensor, "... d_in"]
356
- ) -> Float[torch.Tensor, "... d_sae"]:
353
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
357
354
  """Encode input tensor to feature space."""
358
355
  pass
359
356
 
360
357
  @abstractmethod
361
- def decode(
362
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
363
- ) -> Float[torch.Tensor, "... d_in"]:
358
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
364
359
  """Decode feature activations back to input space."""
365
360
  pass
366
361
 
@@ -450,9 +445,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
450
445
 
451
446
  return super().to(*args, **kwargs)
452
447
 
453
- def process_sae_in(
454
- self, sae_in: Float[torch.Tensor, "... d_in"]
455
- ) -> Float[torch.Tensor, "... d_in"]:
448
+ def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
456
449
  sae_in = sae_in.to(self.dtype)
457
450
  sae_in = self.reshape_fn_in(sae_in)
458
451
 
@@ -859,14 +852,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
859
852
 
860
853
  @abstractmethod
861
854
  def encode_with_hidden_pre(
862
- self, x: Float[torch.Tensor, "... d_in"]
863
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
855
+ self, x: torch.Tensor
856
+ ) -> tuple[torch.Tensor, torch.Tensor]:
864
857
  """Encode with access to pre-activation values for training."""
865
858
  ...
866
859
 
867
- def encode(
868
- self, x: Float[torch.Tensor, "... d_in"]
869
- ) -> Float[torch.Tensor, "... d_sae"]:
860
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
870
861
  """
871
862
  For inference, just encode without returning hidden_pre.
872
863
  (training_forward_pass calls encode_with_hidden_pre).
@@ -874,9 +865,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
874
865
  feature_acts, _ = self.encode_with_hidden_pre(x)
875
866
  return feature_acts
876
867
 
877
- def decode(
878
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
879
- ) -> Float[torch.Tensor, "... d_in"]:
868
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
880
869
  """
881
870
  Decodes feature activations back into input space,
882
871
  applying optional finetuning scale, hooking, out normalization, etc.
@@ -2,7 +2,6 @@ from dataclasses import dataclass
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
- from jaxtyping import Float
6
5
  from numpy.typing import NDArray
7
6
  from torch import nn
8
7
  from typing_extensions import override
@@ -54,9 +53,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
54
53
  super().initialize_weights()
55
54
  _init_weights_standard(self)
56
55
 
57
- def encode(
58
- self, x: Float[torch.Tensor, "... d_in"]
59
- ) -> Float[torch.Tensor, "... d_sae"]:
56
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
60
57
  """
61
58
  Encode the input tensor into the feature space.
62
59
  """
@@ -67,9 +64,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
67
64
  # Apply the activation function (e.g., ReLU, depending on config)
68
65
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
69
66
 
70
- def decode(
71
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
72
- ) -> Float[torch.Tensor, "... d_in"]:
67
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
73
68
  """
74
69
  Decode the feature activations back to the input space.
75
70
  Now, if hook_z reshaping is turned on, we reverse the flattening.
@@ -127,8 +122,8 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
127
122
  }
128
123
 
129
124
  def encode_with_hidden_pre(
130
- self, x: Float[torch.Tensor, "... d_in"]
131
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
125
+ self, x: torch.Tensor
126
+ ) -> tuple[torch.Tensor, torch.Tensor]:
132
127
  # Process the input (including dtype conversion, hook call, and any activation normalization)
133
128
  sae_in = self.process_sae_in(x)
134
129
  # Compute the pre-activation (and allow for a hook if desired)
@@ -13,7 +13,6 @@ from typing import Literal
13
13
 
14
14
  import torch
15
15
  import torch.nn.functional as F
16
- from jaxtyping import Float
17
16
  from torch import nn
18
17
  from typing_extensions import override
19
18
 
@@ -250,8 +249,8 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
250
249
  )
251
250
 
252
251
  def encode_with_predictions(
253
- self, x: Float[torch.Tensor, "... d_in"]
254
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
252
+ self, x: torch.Tensor
253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
255
254
  """Encode input to novel codes only.
256
255
 
257
256
  Returns only the sparse novel codes (not predicted codes).
@@ -312,14 +311,10 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
312
311
  # Return only novel codes (these are the interpretable features)
313
312
  return z_novel, z_pred
314
313
 
315
- def encode(
316
- self, x: Float[torch.Tensor, "... d_in"]
317
- ) -> Float[torch.Tensor, "... d_sae"]:
314
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
318
315
  return self.encode_with_predictions(x)[0]
319
316
 
320
- def decode(
321
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
322
- ) -> Float[torch.Tensor, "... d_in"]:
317
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
323
318
  """Decode novel codes to reconstruction.
324
319
 
325
320
  Note: This only decodes the novel codes. For full reconstruction,
@@ -342,9 +337,7 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
342
337
  return sae_out
343
338
 
344
339
  @override
345
- def forward(
346
- self, x: Float[torch.Tensor, "... d_in"]
347
- ) -> Float[torch.Tensor, "... d_in"]:
340
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
348
341
  """Full forward pass through TemporalSAE.
349
342
 
350
343
  Returns complete reconstruction (predicted + novel).
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
4
  from typing import Any, Callable
5
5
 
6
6
  import torch
7
- from jaxtyping import Float
8
7
  from torch import nn
9
8
  from transformer_lens.hook_points import HookPoint
10
9
  from typing_extensions import override
@@ -235,9 +234,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
235
234
  super().initialize_weights()
236
235
  _init_weights_topk(self)
237
236
 
238
- def encode(
239
- self, x: Float[torch.Tensor, "... d_in"]
240
- ) -> Float[torch.Tensor, "... d_sae"]:
237
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
241
238
  """
242
239
  Converts input x into feature activations.
243
240
  Uses topk activation under the hood.
@@ -251,8 +248,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
251
248
 
252
249
  def decode(
253
250
  self,
254
- feature_acts: Float[torch.Tensor, "... d_sae"],
255
- ) -> Float[torch.Tensor, "... d_in"]:
251
+ feature_acts: torch.Tensor,
252
+ ) -> torch.Tensor:
256
253
  """
257
254
  Reconstructs the input from topk feature activations.
258
255
  Applies optional finetuning scaling, hooking to recons, out normalization,
@@ -354,8 +351,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
354
351
  _init_weights_topk(self)
355
352
 
356
353
  def encode_with_hidden_pre(
357
- self, x: Float[torch.Tensor, "... d_in"]
358
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
354
+ self, x: torch.Tensor
355
+ ) -> tuple[torch.Tensor, torch.Tensor]:
359
356
  """
360
357
  Similar to the base training method: calculate pre-activations, then apply TopK.
361
358
  """
@@ -372,8 +369,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
372
369
  @override
373
370
  def decode(
374
371
  self,
375
- feature_acts: Float[torch.Tensor, "... d_sae"],
376
- ) -> Float[torch.Tensor, "... d_in"]:
372
+ feature_acts: torch.Tensor,
373
+ ) -> torch.Tensor:
377
374
  """
378
375
  Decodes feature activations back into input space,
379
376
  applying optional finetuning scale, hooking, out normalization, etc.
@@ -12,7 +12,6 @@ import torch
12
12
  from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
13
13
  from huggingface_hub import hf_hub_download
14
14
  from huggingface_hub.utils import HfHubHTTPError
15
- from jaxtyping import Float, Int
16
15
  from requests import HTTPError
17
16
  from safetensors.torch import load_file, save_file
18
17
  from tqdm.auto import tqdm
@@ -167,9 +166,11 @@ class ActivationsStore:
167
166
  disable_concat_sequences: bool = False,
168
167
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
169
168
  ) -> ActivationsStore:
169
+ if context_size is None:
170
+ context_size = sae.cfg.metadata.context_size
170
171
  if sae.cfg.metadata.hook_name is None:
171
172
  raise ValueError("hook_name is required")
172
- if sae.cfg.metadata.context_size is None:
173
+ if context_size is None:
173
174
  raise ValueError("context_size is required")
174
175
  if sae.cfg.metadata.prepend_bos is None:
175
176
  raise ValueError("prepend_bos is required")
@@ -179,9 +180,7 @@ class ActivationsStore:
179
180
  d_in=sae.cfg.d_in,
180
181
  hook_name=sae.cfg.metadata.hook_name,
181
182
  hook_head_index=sae.cfg.metadata.hook_head_index,
182
- context_size=sae.cfg.metadata.context_size
183
- if context_size is None
184
- else context_size,
183
+ context_size=context_size,
185
184
  prepend_bos=sae.cfg.metadata.prepend_bos,
186
185
  streaming=streaming,
187
186
  store_batch_size_prompts=store_batch_size_prompts,
@@ -542,8 +541,8 @@ class ActivationsStore:
542
541
  d_in: int,
543
542
  raise_on_epoch_end: bool,
544
543
  ) -> tuple[
545
- Float[torch.Tensor, "(total_size context_size) num_layers d_in"],
546
- Int[torch.Tensor, "(total_size context_size)"] | None,
544
+ torch.Tensor,
545
+ torch.Tensor | None,
547
546
  ]:
548
547
  """
549
548
  Loads `total_size` activations from `cached_activation_dataset`
File without changes
File without changes
File without changes
File without changes
File without changes