sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -258
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +52 -4
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.11.0.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
@@ -2,23 +2,24 @@ import io
2
2
  from pathlib import Path
3
3
  from tempfile import TemporaryDirectory
4
4
  from textwrap import dedent
5
- from typing import Iterable
5
+ from typing import Any, Iterable
6
6
 
7
7
  from huggingface_hub import HfApi, create_repo, get_hf_file_metadata, hf_hub_url
8
8
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
9
9
  from tqdm.autonotebook import tqdm
10
10
 
11
11
  from sae_lens import logger
12
- from sae_lens.config import (
12
+ from sae_lens.constants import (
13
+ RUNNER_CFG_FILENAME,
13
14
  SAE_CFG_FILENAME,
14
15
  SAE_WEIGHTS_FILENAME,
15
16
  SPARSITY_FILENAME,
16
17
  )
17
- from sae_lens.sae import SAE
18
+ from sae_lens.saes.sae import SAE
18
19
 
19
20
 
20
21
  def upload_saes_to_huggingface(
21
- saes_dict: dict[str, SAE | Path | str],
22
+ saes_dict: dict[str, SAE[Any] | Path | str],
22
23
  hf_repo_id: str,
23
24
  hf_revision: str = "main",
24
25
  show_progress: bool = True,
@@ -87,7 +88,7 @@ def _create_default_readme(repo_id: str, sae_ids: Iterable[str]) -> str:
87
88
  ```python
88
89
  from sae_lens import SAE
89
90
 
90
- sae, cfg_dict, sparsity = SAE.from_pretrained("{repo_id}", "<sae_id>")
91
+ sae = SAE.from_pretrained("{repo_id}", "<sae_id>")
91
92
  ```
92
93
  """
93
94
  )
@@ -119,11 +120,16 @@ def _upload_sae(api: HfApi, sae_path: Path, repo_id: str, sae_id: str, revision:
119
120
  revision=revision,
120
121
  repo_type="model",
121
122
  commit_message=f"Upload SAE {sae_id}",
122
- allow_patterns=[SAE_CFG_FILENAME, SAE_WEIGHTS_FILENAME, SPARSITY_FILENAME],
123
+ allow_patterns=[
124
+ SAE_CFG_FILENAME,
125
+ SAE_WEIGHTS_FILENAME,
126
+ SPARSITY_FILENAME,
127
+ RUNNER_CFG_FILENAME,
128
+ ],
123
129
  )
124
130
 
125
131
 
126
- def _build_sae_path(sae_ref: SAE | Path | str, tmp_dir: str) -> Path:
132
+ def _build_sae_path(sae_ref: SAE[Any] | Path | str, tmp_dir: str) -> Path:
127
133
  if isinstance(sae_ref, SAE):
128
134
  sae_ref.save_model(tmp_dir)
129
135
  return Path(tmp_dir)
sae_lens/util.py ADDED
@@ -0,0 +1,47 @@
1
+ import re
2
+ from dataclasses import asdict, fields, is_dataclass
3
+ from typing import Sequence, TypeVar
4
+
5
+ K = TypeVar("K")
6
+ V = TypeVar("V")
7
+
8
+
9
+ def filter_valid_dataclass_fields(
10
+ source: dict[str, V] | object,
11
+ destination: object | type,
12
+ whitelist_fields: Sequence[str] | None = None,
13
+ ) -> dict[str, V]:
14
+ """Filter a source dict or dataclass instance to only include fields that are present in the destination dataclass."""
15
+
16
+ if not is_dataclass(destination):
17
+ raise ValueError(f"{destination} is not a dataclass")
18
+
19
+ if is_dataclass(source) and not isinstance(source, type):
20
+ source_dict = asdict(source)
21
+ elif isinstance(source, dict):
22
+ source_dict = source
23
+ else:
24
+ raise ValueError(f"{source} is not a dict or dataclass")
25
+
26
+ valid_field_names = {field.name for field in fields(destination)}
27
+ if whitelist_fields is not None:
28
+ valid_field_names = valid_field_names.union(whitelist_fields)
29
+ return {key: val for key, val in source_dict.items() if key in valid_field_names}
30
+
31
+
32
+ def extract_stop_at_layer_from_tlens_hook_name(hook_name: str) -> int | None:
33
+ """Extract the stop_at layer from a HookedTransformer hook name.
34
+
35
+ Returns None if the hook name is not a valid HookedTransformer hook name.
36
+ """
37
+ layer = extract_layer_from_tlens_hook_name(hook_name)
38
+ return None if layer is None else layer + 1
39
+
40
+
41
+ def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
42
+ """Extract the layer from a HookedTransformer hook name.
43
+
44
+ Returns None if the hook name is not a valid HookedTransformer hook name.
45
+ """
46
+ hook_match = re.search(r"\.(\d+)\.", hook_name)
47
+ return None if hook_match is None else int(hook_match.group(1))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 5.11.0
3
+ Version: 6.0.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -0,0 +1,37 @@
1
+ sae_lens/__init__.py,sha256=kAAUdxuMPOM3qLV1vK245fOEruEdsKd_GgAEKJsvxgY,2856
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
5
+ sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
+ sae_lens/config.py,sha256=9Lg4HkQvj1t9QZJdmC071lyJMc_iqNQknosT7zOYfwM,27278
7
+ sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
+ sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
9
+ sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
10
+ sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
11
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=5XEU4uFFeGCePwqDwhlE7CrFGRSI0U9Cu-UQVa33Y1E,36432
13
+ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
+ sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
15
+ sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
16
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
+ sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
+ sae_lens/saes/gated_sae.py,sha256=0zd66bH04nsaGk3bxHk10hsZofa2GrFbMo15LOsuqgU,9233
19
+ sae_lens/saes/jumprelu_sae.py,sha256=iwmPQJ4XpIxzgosty680u8Zj7x1uVZhM75kPOT3obi0,12060
20
+ sae_lens/saes/sae.py,sha256=ZEXEXFVtrtFrzuOV3nyweTBleNCV4EDGh1ImaF32uqg,39618
21
+ sae_lens/saes/standard_sae.py,sha256=PfkGLsw_6La3PXHOQL0u7qQsaZsXCJqYCeCcRDj5n64,6274
22
+ sae_lens/saes/topk_sae.py,sha256=kmry1FE1H06OvCfn84V-j2JfWGKcU5b2urwAq_Oq5j4,9893
23
+ sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
26
+ sae_lens/training/activations_store.py,sha256=z8erbiB6ODbsqlu-bwEWbyj4XZvgsVgjCRBuQovqp2Q,32612
27
+ sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
28
+ sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
29
+ sae_lens/training/sae_trainer.py,sha256=9K0VudwSTJp9OlCVzaU_ngZ0WlYNrN6-ozTCCAxR9_k,15421
30
+ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
31
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
32
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
33
+ sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
34
+ sae_lens-6.0.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
35
+ sae_lens-6.0.0.dist-info/METADATA,sha256=LsDx7IEn0Mll0pUSKkBw6jlAXM9NwOVwseomhTb0AiE,5323
36
+ sae_lens-6.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
37
+ sae_lens-6.0.0.dist-info/RECORD,,