sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
|
@@ -2,23 +2,24 @@ import io
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from tempfile import TemporaryDirectory
|
|
4
4
|
from textwrap import dedent
|
|
5
|
-
from typing import Iterable
|
|
5
|
+
from typing import Any, Iterable
|
|
6
6
|
|
|
7
7
|
from huggingface_hub import HfApi, create_repo, get_hf_file_metadata, hf_hub_url
|
|
8
8
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
|
9
9
|
from tqdm.autonotebook import tqdm
|
|
10
10
|
|
|
11
11
|
from sae_lens import logger
|
|
12
|
-
from sae_lens.
|
|
12
|
+
from sae_lens.constants import (
|
|
13
|
+
RUNNER_CFG_FILENAME,
|
|
13
14
|
SAE_CFG_FILENAME,
|
|
14
15
|
SAE_WEIGHTS_FILENAME,
|
|
15
16
|
SPARSITY_FILENAME,
|
|
16
17
|
)
|
|
17
|
-
from sae_lens.sae import SAE
|
|
18
|
+
from sae_lens.saes.sae import SAE
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def upload_saes_to_huggingface(
|
|
21
|
-
saes_dict: dict[str, SAE | Path | str],
|
|
22
|
+
saes_dict: dict[str, SAE[Any] | Path | str],
|
|
22
23
|
hf_repo_id: str,
|
|
23
24
|
hf_revision: str = "main",
|
|
24
25
|
show_progress: bool = True,
|
|
@@ -87,7 +88,7 @@ def _create_default_readme(repo_id: str, sae_ids: Iterable[str]) -> str:
|
|
|
87
88
|
```python
|
|
88
89
|
from sae_lens import SAE
|
|
89
90
|
|
|
90
|
-
sae
|
|
91
|
+
sae = SAE.from_pretrained("{repo_id}", "<sae_id>")
|
|
91
92
|
```
|
|
92
93
|
"""
|
|
93
94
|
)
|
|
@@ -119,11 +120,16 @@ def _upload_sae(api: HfApi, sae_path: Path, repo_id: str, sae_id: str, revision:
|
|
|
119
120
|
revision=revision,
|
|
120
121
|
repo_type="model",
|
|
121
122
|
commit_message=f"Upload SAE {sae_id}",
|
|
122
|
-
allow_patterns=[
|
|
123
|
+
allow_patterns=[
|
|
124
|
+
SAE_CFG_FILENAME,
|
|
125
|
+
SAE_WEIGHTS_FILENAME,
|
|
126
|
+
SPARSITY_FILENAME,
|
|
127
|
+
RUNNER_CFG_FILENAME,
|
|
128
|
+
],
|
|
123
129
|
)
|
|
124
130
|
|
|
125
131
|
|
|
126
|
-
def _build_sae_path(sae_ref: SAE | Path | str, tmp_dir: str) -> Path:
|
|
132
|
+
def _build_sae_path(sae_ref: SAE[Any] | Path | str, tmp_dir: str) -> Path:
|
|
127
133
|
if isinstance(sae_ref, SAE):
|
|
128
134
|
sae_ref.save_model(tmp_dir)
|
|
129
135
|
return Path(tmp_dir)
|
sae_lens/util.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import asdict, fields, is_dataclass
|
|
3
|
+
from typing import Sequence, TypeVar
|
|
4
|
+
|
|
5
|
+
K = TypeVar("K")
|
|
6
|
+
V = TypeVar("V")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def filter_valid_dataclass_fields(
|
|
10
|
+
source: dict[str, V] | object,
|
|
11
|
+
destination: object | type,
|
|
12
|
+
whitelist_fields: Sequence[str] | None = None,
|
|
13
|
+
) -> dict[str, V]:
|
|
14
|
+
"""Filter a source dict or dataclass instance to only include fields that are present in the destination dataclass."""
|
|
15
|
+
|
|
16
|
+
if not is_dataclass(destination):
|
|
17
|
+
raise ValueError(f"{destination} is not a dataclass")
|
|
18
|
+
|
|
19
|
+
if is_dataclass(source) and not isinstance(source, type):
|
|
20
|
+
source_dict = asdict(source)
|
|
21
|
+
elif isinstance(source, dict):
|
|
22
|
+
source_dict = source
|
|
23
|
+
else:
|
|
24
|
+
raise ValueError(f"{source} is not a dict or dataclass")
|
|
25
|
+
|
|
26
|
+
valid_field_names = {field.name for field in fields(destination)}
|
|
27
|
+
if whitelist_fields is not None:
|
|
28
|
+
valid_field_names = valid_field_names.union(whitelist_fields)
|
|
29
|
+
return {key: val for key, val in source_dict.items() if key in valid_field_names}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def extract_stop_at_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
33
|
+
"""Extract the stop_at layer from a HookedTransformer hook name.
|
|
34
|
+
|
|
35
|
+
Returns None if the hook name is not a valid HookedTransformer hook name.
|
|
36
|
+
"""
|
|
37
|
+
layer = extract_layer_from_tlens_hook_name(hook_name)
|
|
38
|
+
return None if layer is None else layer + 1
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
42
|
+
"""Extract the layer from a HookedTransformer hook name.
|
|
43
|
+
|
|
44
|
+
Returns None if the hook name is not a valid HookedTransformer hook name.
|
|
45
|
+
"""
|
|
46
|
+
hook_match = re.search(r"\.(\d+)\.", hook_name)
|
|
47
|
+
return None if hook_match is None else int(hook_match.group(1))
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
sae_lens/__init__.py,sha256=kAAUdxuMPOM3qLV1vK245fOEruEdsKd_GgAEKJsvxgY,2856
|
|
2
|
+
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
|
+
sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
|
|
6
|
+
sae_lens/config.py,sha256=9Lg4HkQvj1t9QZJdmC071lyJMc_iqNQknosT7zOYfwM,27278
|
|
7
|
+
sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
|
|
8
|
+
sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
|
|
10
|
+
sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
|
|
11
|
+
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=5XEU4uFFeGCePwqDwhlE7CrFGRSI0U9Cu-UQVa33Y1E,36432
|
|
13
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
14
|
+
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
|
|
16
|
+
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
|
|
18
|
+
sae_lens/saes/gated_sae.py,sha256=0zd66bH04nsaGk3bxHk10hsZofa2GrFbMo15LOsuqgU,9233
|
|
19
|
+
sae_lens/saes/jumprelu_sae.py,sha256=iwmPQJ4XpIxzgosty680u8Zj7x1uVZhM75kPOT3obi0,12060
|
|
20
|
+
sae_lens/saes/sae.py,sha256=ZEXEXFVtrtFrzuOV3nyweTBleNCV4EDGh1ImaF32uqg,39618
|
|
21
|
+
sae_lens/saes/standard_sae.py,sha256=PfkGLsw_6La3PXHOQL0u7qQsaZsXCJqYCeCcRDj5n64,6274
|
|
22
|
+
sae_lens/saes/topk_sae.py,sha256=kmry1FE1H06OvCfn84V-j2JfWGKcU5b2urwAq_Oq5j4,9893
|
|
23
|
+
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
24
|
+
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
+
sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
|
|
26
|
+
sae_lens/training/activations_store.py,sha256=z8erbiB6ODbsqlu-bwEWbyj4XZvgsVgjCRBuQovqp2Q,32612
|
|
27
|
+
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
28
|
+
sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
|
|
29
|
+
sae_lens/training/sae_trainer.py,sha256=9K0VudwSTJp9OlCVzaU_ngZ0WlYNrN6-ozTCCAxR9_k,15421
|
|
30
|
+
sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
|
|
31
|
+
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
32
|
+
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
33
|
+
sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
|
|
34
|
+
sae_lens-6.0.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
35
|
+
sae_lens-6.0.0.dist-info/METADATA,sha256=LsDx7IEn0Mll0pUSKkBw6jlAXM9NwOVwseomhTb0AiE,5323
|
|
36
|
+
sae_lens-6.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
37
|
+
sae_lens-6.0.0.dist-info/RECORD,,
|