sae-lens 6.6.0__tar.gz → 6.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {sae_lens-6.6.0 → sae_lens-6.6.1}/PKG-INFO +7 -12
  2. {sae_lens-6.6.0 → sae_lens-6.6.1}/pyproject.toml +7 -12
  3. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/cache_activations_runner.py +1 -1
  5. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/evals.py +6 -4
  6. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/training/activations_store.py +1 -1
  7. {sae_lens-6.6.0 → sae_lens-6.6.1}/LICENSE +0 -0
  8. {sae_lens-6.6.0 → sae_lens-6.6.1}/README.md +0 -0
  9. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/analysis/__init__.py +0 -0
  10. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  11. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  12. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/config.py +0 -0
  13. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/constants.py +0 -0
  14. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/llm_sae_training_runner.py +0 -0
  15. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/load_model.py +0 -0
  16. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/loading/__init__.py +0 -0
  17. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  18. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  19. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/pretokenize_runner.py +0 -0
  20. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/pretrained_saes.yaml +0 -0
  21. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/registry.py +0 -0
  22. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/saes/__init__.py +0 -0
  23. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/saes/batchtopk_sae.py +0 -0
  24. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/saes/gated_sae.py +0 -0
  25. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/saes/jumprelu_sae.py +0 -0
  26. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/saes/sae.py +0 -0
  27. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/saes/standard_sae.py +0 -0
  28. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/saes/topk_sae.py +0 -0
  29. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/saes/transcoder.py +0 -0
  30. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/tokenization_and_batching.py +0 -0
  31. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/training/__init__.py +0 -0
  32. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/training/activation_scaler.py +0 -0
  33. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/training/mixing_buffer.py +0 -0
  34. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/training/optim.py +0 -0
  35. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/training/sae_trainer.py +0 -0
  36. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/training/types.py +0 -0
  37. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  38. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/tutorial/tsea.py +0 -0
  39. {sae_lens-6.6.0 → sae_lens-6.6.1}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.6.0
3
+ Version: 6.6.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -16,24 +16,19 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
16
  Provides-Extra: mamba
17
17
  Requires-Dist: automated-interpretability (>=0.0.5,<1.0.0)
18
18
  Requires-Dist: babe (>=0.0.7,<0.0.8)
19
- Requires-Dist: datasets (>=3.1.0,<4.0.0)
19
+ Requires-Dist: datasets (>=3.1.0)
20
20
  Requires-Dist: mamba-lens (>=0.0.4,<0.0.5) ; extra == "mamba"
21
- Requires-Dist: matplotlib (>=3.8.3,<4.0.0)
22
- Requires-Dist: matplotlib-inline (>=0.1.6,<0.2.0)
23
21
  Requires-Dist: nltk (>=3.8.1,<4.0.0)
24
- Requires-Dist: plotly (>=5.19.0,<6.0.0)
25
- Requires-Dist: plotly-express (>=0.4.1,<0.5.0)
26
- Requires-Dist: pytest-profiling (>=1.7.0,<2.0.0)
27
- Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
22
+ Requires-Dist: plotly (>=5.19.0)
23
+ Requires-Dist: plotly-express (>=0.4.1)
24
+ Requires-Dist: python-dotenv (>=1.0.1)
28
25
  Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
29
- Requires-Dist: pyzmq (==26.0.0)
30
- Requires-Dist: safetensors (>=0.4.2,<0.5.0)
26
+ Requires-Dist: safetensors (>=0.4.2,<1.0.0)
31
27
  Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
28
+ Requires-Dist: tenacity (>=9.0.0)
32
29
  Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
33
30
  Requires-Dist: transformers (>=4.38.1,<5.0.0)
34
- Requires-Dist: typer (>=0.12.3,<0.13.0)
35
31
  Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
36
- Requires-Dist: zstandard (>=0.22.0,<0.23.0)
37
32
  Project-URL: Homepage, https://jbloomaus.github.io/SAELens
38
33
  Project-URL: Repository, https://github.com/jbloomAus/SAELens
39
34
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.6.0"
3
+ version = "6.6.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -21,24 +21,19 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
21
21
  python = "^3.10"
22
22
  transformer-lens = "^2.16.1"
23
23
  transformers = "^4.38.1"
24
- plotly = "^5.19.0"
25
- plotly-express = "^0.4.1"
26
- matplotlib = "^3.8.3"
27
- matplotlib-inline = "^0.1.6"
28
- datasets = "^3.1.0"
24
+ plotly = ">=5.19.0"
25
+ plotly-express = ">=0.4.1"
26
+ datasets = ">=3.1.0"
29
27
  babe = "^0.0.7"
30
28
  nltk = "^3.8.1"
31
- safetensors = "^0.4.2"
32
- typer = "^0.12.3"
29
+ safetensors = ">=0.4.2,<1.0.0"
33
30
  mamba-lens = { version = "^0.0.4", optional = true }
34
- pyzmq = "26.0.0"
35
31
  automated-interpretability = ">=0.0.5,<1.0.0"
36
- python-dotenv = "^1.0.1"
32
+ python-dotenv = ">=1.0.1"
37
33
  pyyaml = "^6.0.1"
38
- pytest-profiling = "^1.7.0"
39
- zstandard = "^0.22.0"
40
34
  typing-extensions = "^4.10.0"
41
35
  simple-parsing = "^0.1.6"
36
+ tenacity = ">=9.0.0"
42
37
 
43
38
  [tool.poetry.group.dev.dependencies]
44
39
  pytest = "^8.0.2"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.6.0"
2
+ __version__ = "6.6.1"
3
3
 
4
4
  import logging
5
5
 
@@ -82,7 +82,7 @@ class CacheActivationsRunner:
82
82
  )
83
83
  for hook_name in [self.cfg.hook_name]
84
84
  }
85
- features_dict["token_ids"] = Sequence(
85
+ features_dict["token_ids"] = Sequence( # type: ignore
86
86
  Value(dtype="int32"), length=self.context_size
87
87
  )
88
88
  self.features = Features(features_dict)
@@ -459,14 +459,16 @@ def get_sparsity_and_variance_metrics(
459
459
  original_act = cache[hook_name]
460
460
 
461
461
  # normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
462
- original_act = activation_scaler.scale(original_act)
462
+ original_act_scaled = activation_scaler.scale(original_act)
463
463
 
464
464
  # send the (maybe normalised) activations into the SAE
465
- sae_feature_activations = sae.encode(original_act.to(sae.device))
466
- sae_out = sae.decode(sae_feature_activations).to(original_act.device)
465
+ sae_feature_activations = sae.encode(original_act_scaled.to(sae.device))
466
+ sae_out_scaled = sae.decode(sae_feature_activations).to(
467
+ original_act_scaled.device
468
+ )
467
469
  del cache
468
470
 
469
- sae_out = activation_scaler.unscale(sae_out)
471
+ sae_out = activation_scaler.unscale(sae_out_scaled)
470
472
 
471
473
  flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
472
474
  flattened_sae_feature_acts = einops.rearrange(
@@ -289,7 +289,7 @@ class ActivationsStore:
289
289
  "Dataset must have a 'tokens', 'input_ids', 'text', or 'problem' column."
290
290
  )
291
291
  if self.is_dataset_tokenized:
292
- ds_context_size = len(dataset_sample[self.tokens_column])
292
+ ds_context_size = len(dataset_sample[self.tokens_column]) # type: ignore
293
293
  if ds_context_size < self.context_size:
294
294
  raise ValueError(
295
295
  f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}.
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes