sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +55 -18
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +105 -235
- sae_lens/constants.py +20 -0
- sae_lens/evals.py +34 -31
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +103 -70
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +36 -10
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +248 -273
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +105 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +134 -158
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +47 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc3.dist-info/RECORD +38 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
sae_lens/util.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import asdict, fields, is_dataclass
|
|
3
|
+
from typing import Sequence, TypeVar
|
|
4
|
+
|
|
5
|
+
K = TypeVar("K")
|
|
6
|
+
V = TypeVar("V")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def filter_valid_dataclass_fields(
|
|
10
|
+
source: dict[str, V] | object,
|
|
11
|
+
destination: object | type,
|
|
12
|
+
whitelist_fields: Sequence[str] | None = None,
|
|
13
|
+
) -> dict[str, V]:
|
|
14
|
+
"""Filter a source dict or dataclass instance to only include fields that are present in the destination dataclass."""
|
|
15
|
+
|
|
16
|
+
if not is_dataclass(destination):
|
|
17
|
+
raise ValueError(f"{destination} is not a dataclass")
|
|
18
|
+
|
|
19
|
+
if is_dataclass(source) and not isinstance(source, type):
|
|
20
|
+
source_dict = asdict(source)
|
|
21
|
+
elif isinstance(source, dict):
|
|
22
|
+
source_dict = source
|
|
23
|
+
else:
|
|
24
|
+
raise ValueError(f"{source} is not a dict or dataclass")
|
|
25
|
+
|
|
26
|
+
valid_field_names = {field.name for field in fields(destination)}
|
|
27
|
+
if whitelist_fields is not None:
|
|
28
|
+
valid_field_names = valid_field_names.union(whitelist_fields)
|
|
29
|
+
return {key: val for key, val in source_dict.items() if key in valid_field_names}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def extract_stop_at_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
33
|
+
"""Extract the stop_at layer from a HookedTransformer hook name.
|
|
34
|
+
|
|
35
|
+
Returns None if the hook name is not a valid HookedTransformer hook name.
|
|
36
|
+
"""
|
|
37
|
+
layer = extract_layer_from_tlens_hook_name(hook_name)
|
|
38
|
+
return None if layer is None else layer + 1
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
42
|
+
"""Extract the layer from a HookedTransformer hook name.
|
|
43
|
+
|
|
44
|
+
Returns None if the hook name is not a valid HookedTransformer hook name.
|
|
45
|
+
"""
|
|
46
|
+
hook_match = re.search(r"\.(\d+)\.", hook_name)
|
|
47
|
+
return None if hook_match is None else int(hook_match.group(1))
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
sae_lens/__init__.py,sha256=881mDkwEifeN32NsH78_CaeH11sKYK4YnqCW502qHE4,2861
|
|
2
|
+
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
|
+
sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
|
|
6
|
+
sae_lens/config.py,sha256=5Wgr8SsUvYWU2Xmet1JyJ0upAZArMDpYfr3jaK8TvRY,27234
|
|
7
|
+
sae_lens/constants.py,sha256=RJlzWx7wLNMNmrdI63naF7-M3enb55vYRN4x1hXx6vI,593
|
|
8
|
+
sae_lens/evals.py,sha256=WRdHlVeZxXCi33gef7rQE90PSUBF6pjrHnPP6av_Urg,38747
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=-FPXaHvDfSw5twSaDO8O80aGIzX6T0HywgdpEFFoO-8,9098
|
|
10
|
+
sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
|
|
11
|
+
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=FSAz9Je-8Xl7ccdEyp8-WRn-KFtaJ74zgKMefnfaj3A,30877
|
|
13
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
14
|
+
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
|
|
16
|
+
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
|
|
18
|
+
sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
|
|
19
|
+
sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
|
|
20
|
+
sae_lens/saes/sae.py,sha256=u4kmsUVxa2rnFt8A5jLfj7T6h6qqBK6CkecHslebQgE,34938
|
|
21
|
+
sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
|
|
22
|
+
sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
|
|
23
|
+
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
24
|
+
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
+
sae_lens/training/activation_scaler.py,sha256=1P-vva3wJhs2NH65YONli4Rw4auvgZkxe_KKwTNMCR0,1714
|
|
26
|
+
sae_lens/training/activations_store.py,sha256=Xvnz7l2aw3XWtOQsQDj4G4bt-XT6egbumGBwrAM1mtA,32722
|
|
27
|
+
sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
|
|
28
|
+
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
29
|
+
sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
|
|
30
|
+
sae_lens/training/sae_trainer.py,sha256=rFuMdnBDe82nd7YV_QKVE18V5jCWmohbzkIGL0Z2kIM,15153
|
|
31
|
+
sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
|
|
32
|
+
sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
|
|
33
|
+
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
34
|
+
sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
|
|
35
|
+
sae_lens-6.0.0rc3.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
36
|
+
sae_lens-6.0.0rc3.dist-info/METADATA,sha256=irWiVHtJUXiACNPxZ0fNIVwq1n7n0wxg87c0WSYUkMw,5326
|
|
37
|
+
sae_lens-6.0.0rc3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
38
|
+
sae_lens-6.0.0rc3.dist-info/RECORD,,
|
sae_lens/regsitry.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING
|
|
2
|
-
|
|
3
|
-
# avoid circular imports
|
|
4
|
-
if TYPE_CHECKING:
|
|
5
|
-
from sae_lens.saes.sae import SAE, TrainingSAE
|
|
6
|
-
|
|
7
|
-
SAE_CLASS_REGISTRY: dict[str, "type[SAE]"] = {}
|
|
8
|
-
SAE_TRAINING_CLASS_REGISTRY: dict[str, "type[TrainingSAE]"] = {}
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def register_sae_class(architecture: str, sae_class: "type[SAE]") -> None:
|
|
12
|
-
if architecture in SAE_CLASS_REGISTRY:
|
|
13
|
-
raise ValueError(
|
|
14
|
-
f"SAE class for architecture {architecture} already registered."
|
|
15
|
-
)
|
|
16
|
-
SAE_CLASS_REGISTRY[architecture] = sae_class
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def register_sae_training_class(
|
|
20
|
-
architecture: str, sae_training_class: "type[TrainingSAE]"
|
|
21
|
-
) -> None:
|
|
22
|
-
if architecture in SAE_TRAINING_CLASS_REGISTRY:
|
|
23
|
-
raise ValueError(
|
|
24
|
-
f"SAE training class for architecture {architecture} already registered."
|
|
25
|
-
)
|
|
26
|
-
SAE_TRAINING_CLASS_REGISTRY[architecture] = sae_training_class
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def get_sae_class(architecture: str) -> "type[SAE]":
|
|
30
|
-
return SAE_CLASS_REGISTRY[architecture]
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def get_sae_training_class(architecture: str) -> "type[TrainingSAE]":
|
|
34
|
-
return SAE_TRAINING_CLASS_REGISTRY[architecture]
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=ofQyurU7LtxIsg89QFCZe13QsdYpxErRI0x0tiCpB04,2074
|
|
2
|
-
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
sae_lens/analysis/hooked_sae_transformer.py,sha256=RK0mcLhymXdJInXHcagQggxW9Qf4ptePnH7sKXvGGaU,13727
|
|
4
|
-
sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
|
|
5
|
-
sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
|
|
6
|
-
sae_lens/config.py,sha256=SPjziXrTyOBjObSi-3s0_mza3Z7WH8gd9NT9pVUfosg,34375
|
|
7
|
-
sae_lens/evals.py,sha256=tjDKmkUM4fBbP9LHZuBLCx37ux8Px9CliTMme3Wjt1A,38898
|
|
8
|
-
sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
|
|
9
|
-
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=NcqyH2KDL8Dg66-hjXsBAq1-IwdLEpYfKwbkHxSQbrg,29961
|
|
11
|
-
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
12
|
-
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
13
|
-
sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
|
|
14
|
-
sae_lens/regsitry.py,sha256=yCse5NmVH-ZaPET3jW8r7C_py2DL3yoox40GxGzJ0TI,1098
|
|
15
|
-
sae_lens/sae_training_runner.py,sha256=VRNSAIsZLfcQMfZB8qdnK45PUXwoNvJ-rKt9BVYjMMY,8244
|
|
16
|
-
sae_lens/saes/gated_sae.py,sha256=l5ucq7AZHya6ZClWNNE7CionGSf1ms5m1Ah3IoN6SH4,9916
|
|
17
|
-
sae_lens/saes/jumprelu_sae.py,sha256=DRWgY58894cNh_sYAlefObI4rr0Eb6KHu1WuhTCcvB4,13468
|
|
18
|
-
sae_lens/saes/sae.py,sha256=fd7OEsSXbmVii6QoYI_TRti6dwaxAQyrBcKyX7PxERw,36779
|
|
19
|
-
sae_lens/saes/standard_sae.py,sha256=m2eNL_w6ave-_g7F1eQiwI4qbjMwwjzvxp96RN_WVAw,7110
|
|
20
|
-
sae_lens/saes/topk_sae.py,sha256=aBET4F55A4xMIvZ8AazPtyl3oL-9S7krKx78li0uKGk,11370
|
|
21
|
-
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
22
|
-
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
|
-
sae_lens/training/activations_store.py,sha256=ilJdcnZWfTDus1bdoqIb1wF_7H8_HWLmf8OCGrybmlA,35998
|
|
24
|
-
sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
|
|
25
|
-
sae_lens/training/optim.py,sha256=AImcc-MAaGDLOBP2hJ4alDFCtaqqgm4cc2eBxIxiQAo,5784
|
|
26
|
-
sae_lens/training/sae_trainer.py,sha256=6TkqbzA0fYluRM8ouI_nU9sz-FaP63axxcnDrVfw37E,16279
|
|
27
|
-
sae_lens/training/upload_saes_to_huggingface.py,sha256=tVC-2Txw7-9XttGlKzM0OSqU8CK7HDO9vIzDMqEwAYU,4366
|
|
28
|
-
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
29
|
-
sae_lens-6.0.0rc1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
30
|
-
sae_lens-6.0.0rc1.dist-info/METADATA,sha256=wHH-VRtquu-FjZEOHdPJi3zYW3ns7MCT1fVerbPEylc,5326
|
|
31
|
-
sae_lens-6.0.0rc1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
|
32
|
-
sae_lens-6.0.0rc1.dist-info/RECORD,,
|
|
File without changes
|