sae-lens 6.28.0__tar.gz → 6.29.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {sae_lens-6.28.0 → sae_lens-6.29.0}/PKG-INFO +11 -1
  2. {sae_lens-6.28.0 → sae_lens-6.29.0}/README.md +10 -0
  3. {sae_lens-6.28.0 → sae_lens-6.29.0}/pyproject.toml +1 -1
  4. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/__init__.py +1 -1
  5. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/loading/pretrained_saes_directory.py +18 -0
  6. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/pretrained_saes.yaml +1 -1
  7. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/sae.py +13 -0
  8. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/synthetic/__init__.py +6 -0
  9. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/synthetic/activation_generator.py +105 -6
  10. sae_lens-6.29.0/sae_lens/synthetic/correlation.py +351 -0
  11. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/synthetic/feature_dictionary.py +54 -16
  12. sae_lens-6.29.0/sae_lens/synthetic/hierarchy.py +596 -0
  13. sae_lens-6.28.0/sae_lens/synthetic/correlation.py +0 -170
  14. sae_lens-6.28.0/sae_lens/synthetic/hierarchy.py +0 -335
  15. {sae_lens-6.28.0 → sae_lens-6.29.0}/LICENSE +0 -0
  16. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/analysis/__init__.py +0 -0
  17. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  18. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  19. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/cache_activations_runner.py +0 -0
  20. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/config.py +0 -0
  21. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/constants.py +0 -0
  22. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/evals.py +0 -0
  23. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/llm_sae_training_runner.py +0 -0
  24. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/load_model.py +0 -0
  25. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/loading/__init__.py +0 -0
  26. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  27. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/pretokenize_runner.py +0 -0
  28. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/registry.py +0 -0
  29. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/__init__.py +0 -0
  30. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  31. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/gated_sae.py +0 -0
  32. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  33. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
  34. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  35. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/standard_sae.py +0 -0
  36. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/temporal_sae.py +0 -0
  37. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/topk_sae.py +0 -0
  38. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/saes/transcoder.py +0 -0
  39. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/synthetic/evals.py +0 -0
  40. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/synthetic/firing_probabilities.py +0 -0
  41. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/synthetic/initialization.py +0 -0
  42. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/synthetic/plotting.py +0 -0
  43. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/synthetic/training.py +0 -0
  44. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/tokenization_and_batching.py +0 -0
  45. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/training/__init__.py +0 -0
  46. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/training/activation_scaler.py +0 -0
  47. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/training/activations_store.py +0 -0
  48. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/training/mixing_buffer.py +0 -0
  49. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/training/optim.py +0 -0
  50. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/training/sae_trainer.py +0 -0
  51. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/training/types.py +0 -0
  52. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  53. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/tutorial/tsea.py +0 -0
  54. {sae_lens-6.28.0 → sae_lens-6.29.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.28.0
3
+ Version: 6.29.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -50,6 +50,8 @@ SAELens exists to help researchers:
50
50
  - Analyse sparse autoencoders / research mechanistic interpretability.
51
51
  - Generate insights which make it easier to create safe and aligned AI systems.
52
52
 
53
+ SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
54
+
53
55
  Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
54
56
 
55
57
  - Download and Analyse pre-trained sparse autoencoders.
@@ -84,6 +86,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
84
86
 
85
87
  Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
86
88
 
89
+ ## Other SAE Projects
90
+
91
+ - [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
92
+ - [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
93
+ - [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
94
+ - [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
95
+ - [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
96
+
87
97
  ## Citation
88
98
 
89
99
  Please cite the package as follows:
@@ -14,6 +14,8 @@ SAELens exists to help researchers:
14
14
  - Analyse sparse autoencoders / research mechanistic interpretability.
15
15
  - Generate insights which make it easier to create safe and aligned AI systems.
16
16
 
17
+ SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
18
+
17
19
  Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
18
20
 
19
21
  - Download and Analyse pre-trained sparse autoencoders.
@@ -48,6 +50,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
48
50
 
49
51
  Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
50
52
 
53
+ ## Other SAE Projects
54
+
55
+ - [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
56
+ - [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
57
+ - [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
58
+ - [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
59
+ - [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
60
+
51
61
  ## Citation
52
62
 
53
63
  Please cite the package as follows:
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.28.0"
3
+ version = "6.29.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.28.0"
2
+ __version__ = "6.29.0"
3
3
 
4
4
  import logging
5
5
 
@@ -103,3 +103,21 @@ def get_config_overrides(release: str, sae_id: str) -> dict[str, Any]:
103
103
  if sae_info.neuronpedia_id is not None and sae_id in sae_info.neuronpedia_id:
104
104
  config_overrides["neuronpedia_id"] = sae_info.neuronpedia_id[sae_id]
105
105
  return config_overrides
106
+
107
+
108
+ def get_releases_for_repo_id(repo_id: str) -> list[str]:
109
+ """
110
+ Find all release names that use the given HuggingFace repo_id.
111
+
112
+ Args:
113
+ repo_id: The HuggingFace repository ID to search for.
114
+
115
+ Returns:
116
+ A list of release names that use this repo_id.
117
+ """
118
+ saes_directory = get_pretrained_saes_directory()
119
+ return [
120
+ release
121
+ for release, lookup in saes_directory.items()
122
+ if lookup.repo_id == repo_id
123
+ ]
@@ -40631,7 +40631,7 @@ gemma-3-1b-res-matryoshka-dc:
40631
40631
  conversion_func: null
40632
40632
  links:
40633
40633
  model: https://huggingface.co/google/gemma-3-1b-pt
40634
- model: gemma-3-1b
40634
+ model: google/gemma-3-1b-pt
40635
40635
  repo_id: chanind/gemma-3-1b-batch-topk-matryoshka-saes-w-32k-l0-40
40636
40636
  saes:
40637
40637
  - id: blocks.0.hook_resid_post
@@ -47,6 +47,7 @@ from sae_lens.loading.pretrained_saes_directory import (
47
47
  get_config_overrides,
48
48
  get_norm_scaling_factor,
49
49
  get_pretrained_saes_directory,
50
+ get_releases_for_repo_id,
50
51
  get_repo_id_and_folder_name,
51
52
  )
52
53
  from sae_lens.registry import get_sae_class, get_sae_training_class
@@ -624,6 +625,18 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
624
625
  raise ValueError(
625
626
  f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
626
627
  )
628
+ # Check if the user passed a repo_id that's in the pretrained SAEs list
629
+ matching_releases = get_releases_for_repo_id(release)
630
+ if matching_releases:
631
+ warnings.warn(
632
+ f"You are loading an SAE using the HuggingFace repo_id '{release}' directly. "
633
+ f"This repo is registered in the official pretrained SAEs list with release name(s): {matching_releases}. "
634
+ f"For better compatibility and to access additional metadata, consider loading with: "
635
+ f"SAE.from_pretrained(release='{matching_releases[0]}', sae_id='<sae_id>'). "
636
+ f"See the full list of pretrained SAEs at: https://decoderesearch.github.io/SAELens/latest/pretrained_saes/",
637
+ UserWarning,
638
+ stacklevel=2,
639
+ )
627
640
  elif sae_id not in sae_directory[release].saes_map:
628
641
  # Handle special cases like Gemma Scope
629
642
  if (
@@ -17,11 +17,14 @@ from sae_lens.synthetic.activation_generator import (
17
17
  ActivationGenerator,
18
18
  ActivationsModifier,
19
19
  ActivationsModifierInput,
20
+ CorrelationMatrixInput,
20
21
  )
21
22
  from sae_lens.synthetic.correlation import (
23
+ LowRankCorrelationMatrix,
22
24
  create_correlation_matrix_from_correlations,
23
25
  generate_random_correlation_matrix,
24
26
  generate_random_correlations,
27
+ generate_random_low_rank_correlation_matrix,
25
28
  )
26
29
  from sae_lens.synthetic.evals import (
27
30
  SyntheticDataEvalResult,
@@ -66,6 +69,9 @@ __all__ = [
66
69
  "create_correlation_matrix_from_correlations",
67
70
  "generate_random_correlations",
68
71
  "generate_random_correlation_matrix",
72
+ "generate_random_low_rank_correlation_matrix",
73
+ "LowRankCorrelationMatrix",
74
+ "CorrelationMatrixInput",
69
75
  # Feature modifiers
70
76
  "ActivationsModifier",
71
77
  "ActivationsModifierInput",
@@ -7,12 +7,16 @@ from collections.abc import Callable, Sequence
7
7
  import torch
8
8
  from scipy.stats import norm
9
9
  from torch import nn
10
- from torch.distributions import MultivariateNormal
10
+ from torch.distributions import LowRankMultivariateNormal, MultivariateNormal
11
11
 
12
+ from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
12
13
  from sae_lens.util import str_to_dtype
13
14
 
14
15
  ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
15
16
  ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
17
+ CorrelationMatrixInput = (
18
+ torch.Tensor | LowRankCorrelationMatrix | tuple[torch.Tensor, torch.Tensor]
19
+ )
16
20
 
17
21
 
18
22
  class ActivationGenerator(nn.Module):
@@ -28,6 +32,7 @@ class ActivationGenerator(nn.Module):
28
32
  mean_firing_magnitudes: torch.Tensor
29
33
  modify_activations: ActivationsModifier | None
30
34
  correlation_matrix: torch.Tensor | None
35
+ low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
31
36
  correlation_thresholds: torch.Tensor | None
32
37
 
33
38
  def __init__(
@@ -37,7 +42,7 @@ class ActivationGenerator(nn.Module):
37
42
  std_firing_magnitudes: torch.Tensor | float = 0.0,
38
43
  mean_firing_magnitudes: torch.Tensor | float = 1.0,
39
44
  modify_activations: ActivationsModifierInput = None,
40
- correlation_matrix: torch.Tensor | None = None,
45
+ correlation_matrix: CorrelationMatrixInput | None = None,
41
46
  device: torch.device | str = "cpu",
42
47
  dtype: torch.dtype | str = "float32",
43
48
  ):
@@ -54,15 +59,32 @@ class ActivationGenerator(nn.Module):
54
59
  )
55
60
  self.modify_activations = _normalize_modifiers(modify_activations)
56
61
  self.correlation_thresholds = None
62
+ self.correlation_matrix = None
63
+ self.low_rank_correlation = None
64
+
57
65
  if correlation_matrix is not None:
58
- _validate_correlation_matrix(correlation_matrix, num_features)
66
+ if isinstance(correlation_matrix, torch.Tensor):
67
+ # Full correlation matrix
68
+ _validate_correlation_matrix(correlation_matrix, num_features)
69
+ self.correlation_matrix = correlation_matrix
70
+ else:
71
+ # Low-rank correlation (tuple or LowRankCorrelationMatrix)
72
+ correlation_factor, correlation_diag = (
73
+ correlation_matrix[0],
74
+ correlation_matrix[1],
75
+ )
76
+ _validate_low_rank_correlation(
77
+ correlation_factor, correlation_diag, num_features
78
+ )
79
+ self.low_rank_correlation = (correlation_factor, correlation_diag)
80
+
59
81
  self.correlation_thresholds = torch.tensor(
60
82
  [norm.ppf(1 - p.item()) for p in self.firing_probabilities],
61
83
  device=device,
62
84
  dtype=self.firing_probabilities.dtype,
63
85
  )
64
- self.correlation_matrix = correlation_matrix
65
86
 
87
+ @torch.no_grad()
66
88
  def sample(self, batch_size: int) -> torch.Tensor:
67
89
  """
68
90
  Generate a batch of feature activations with controlled properties.
@@ -89,6 +111,15 @@ class ActivationGenerator(nn.Module):
89
111
  self.correlation_thresholds,
90
112
  device,
91
113
  )
114
+ elif self.low_rank_correlation is not None:
115
+ assert self.correlation_thresholds is not None
116
+ firing_features = _generate_low_rank_correlated_features(
117
+ batch_size,
118
+ self.low_rank_correlation[0],
119
+ self.low_rank_correlation[1],
120
+ self.correlation_thresholds,
121
+ device,
122
+ )
92
123
  else:
93
124
  firing_features = torch.bernoulli(
94
125
  self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
@@ -132,7 +163,7 @@ def _generate_correlated_features(
132
163
  device: Device to generate samples on
133
164
 
134
165
  Returns:
135
- Binary feature matrix of shape [batch_size, num_features]
166
+ Binary feature matrix of shape (batch_size, num_features)
136
167
  """
137
168
  num_features = correlation_matrix.shape[0]
138
169
 
@@ -145,6 +176,41 @@ def _generate_correlated_features(
145
176
  return (gaussian_samples > thresholds.unsqueeze(0)).float()
146
177
 
147
178
 
179
+ def _generate_low_rank_correlated_features(
180
+ batch_size: int,
181
+ correlation_factor: torch.Tensor,
182
+ correlation_diag: torch.Tensor,
183
+ thresholds: torch.Tensor,
184
+ device: torch.device,
185
+ ) -> torch.Tensor:
186
+ """
187
+ Generate correlated binary features using low-rank multivariate Gaussian sampling.
188
+
189
+ Uses the Gaussian copula approach with a low-rank covariance structure for scalability.
190
+ The covariance is represented as: cov = factor @ factor.T + diag(diag_term)
191
+
192
+ Args:
193
+ batch_size: Number of samples to generate
194
+ correlation_factor: Factor matrix of shape (num_features, rank)
195
+ correlation_diag: Diagonal term of shape (num_features,)
196
+ thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
197
+ device: Device to generate samples on
198
+
199
+ Returns:
200
+ Binary feature matrix of shape (batch_size, num_features)
201
+ """
202
+ mvn = LowRankMultivariateNormal(
203
+ loc=torch.zeros(
204
+ correlation_factor.shape[0], device=device, dtype=thresholds.dtype
205
+ ),
206
+ cov_factor=correlation_factor.to(device=device, dtype=thresholds.dtype),
207
+ cov_diag=correlation_diag.to(device=device, dtype=thresholds.dtype),
208
+ )
209
+
210
+ gaussian_samples = mvn.sample((batch_size,))
211
+ return (gaussian_samples > thresholds.unsqueeze(0)).float()
212
+
213
+
148
214
  def _to_tensor(
149
215
  value: torch.Tensor | float,
150
216
  num_features: int,
@@ -193,7 +259,7 @@ def _validate_correlation_matrix(
193
259
 
194
260
  Args:
195
261
  correlation_matrix: The matrix to validate
196
- num_features: Expected number of features (matrix should be [num_features, num_features])
262
+ num_features: Expected number of features (matrix should be (num_features, num_features))
197
263
 
198
264
  Raises:
199
265
  ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
@@ -213,3 +279,36 @@ def _validate_correlation_matrix(
213
279
  torch.linalg.cholesky(correlation_matrix)
214
280
  except RuntimeError as e:
215
281
  raise ValueError("Correlation matrix must be positive definite") from e
282
+
283
+
284
+ def _validate_low_rank_correlation(
285
+ correlation_factor: torch.Tensor,
286
+ correlation_diag: torch.Tensor,
287
+ num_features: int,
288
+ ) -> None:
289
+ """Validate that low-rank correlation parameters have correct properties.
290
+
291
+ Args:
292
+ correlation_factor: Factor matrix of shape (num_features, rank)
293
+ correlation_diag: Diagonal term of shape (num_features,)
294
+ num_features: Expected number of features
295
+
296
+ Raises:
297
+ ValueError: If shapes are incorrect or diagonal terms are not positive
298
+ """
299
+ if correlation_factor.ndim != 2:
300
+ raise ValueError(
301
+ f"correlation_factor must be 2D, got {correlation_factor.ndim}D"
302
+ )
303
+ if correlation_factor.shape[0] != num_features:
304
+ raise ValueError(
305
+ f"correlation_factor must have shape ({num_features}, rank), "
306
+ f"got {tuple(correlation_factor.shape)}"
307
+ )
308
+ if correlation_diag.shape != (num_features,):
309
+ raise ValueError(
310
+ f"correlation_diag must have shape ({num_features},), "
311
+ f"got {tuple(correlation_diag.shape)}"
312
+ )
313
+ if torch.any(correlation_diag <= 0):
314
+ raise ValueError("correlation_diag must have all positive values")