sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/util.py ADDED
@@ -0,0 +1,47 @@
1
+ import re
2
+ from dataclasses import asdict, fields, is_dataclass
3
+ from typing import Sequence, TypeVar
4
+
5
+ K = TypeVar("K")
6
+ V = TypeVar("V")
7
+
8
+
9
+ def filter_valid_dataclass_fields(
10
+ source: dict[str, V] | object,
11
+ destination: object | type,
12
+ whitelist_fields: Sequence[str] | None = None,
13
+ ) -> dict[str, V]:
14
+ """Filter a source dict or dataclass instance to only include fields that are present in the destination dataclass."""
15
+
16
+ if not is_dataclass(destination):
17
+ raise ValueError(f"{destination} is not a dataclass")
18
+
19
+ if is_dataclass(source) and not isinstance(source, type):
20
+ source_dict = asdict(source)
21
+ elif isinstance(source, dict):
22
+ source_dict = source
23
+ else:
24
+ raise ValueError(f"{source} is not a dict or dataclass")
25
+
26
+ valid_field_names = {field.name for field in fields(destination)}
27
+ if whitelist_fields is not None:
28
+ valid_field_names = valid_field_names.union(whitelist_fields)
29
+ return {key: val for key, val in source_dict.items() if key in valid_field_names}
30
+
31
+
32
+ def extract_stop_at_layer_from_tlens_hook_name(hook_name: str) -> int | None:
33
+ """Extract the stop_at layer from a HookedTransformer hook name.
34
+
35
+ Returns None if the hook name is not a valid HookedTransformer hook name.
36
+ """
37
+ layer = extract_layer_from_tlens_hook_name(hook_name)
38
+ return None if layer is None else layer + 1
39
+
40
+
41
+ def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
42
+ """Extract the layer from a HookedTransformer hook name.
43
+
44
+ Returns None if the hook name is not a valid HookedTransformer hook name.
45
+ """
46
+ hook_match = re.search(r"\.(\d+)\.", hook_name)
47
+ return None if hook_match is None else int(hook_match.group(1))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc1
3
+ Version: 6.0.0rc3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -0,0 +1,38 @@
1
+ sae_lens/__init__.py,sha256=881mDkwEifeN32NsH78_CaeH11sKYK4YnqCW502qHE4,2861
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
5
+ sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
+ sae_lens/config.py,sha256=5Wgr8SsUvYWU2Xmet1JyJ0upAZArMDpYfr3jaK8TvRY,27234
7
+ sae_lens/constants.py,sha256=RJlzWx7wLNMNmrdI63naF7-M3enb55vYRN4x1hXx6vI,593
8
+ sae_lens/evals.py,sha256=WRdHlVeZxXCi33gef7rQE90PSUBF6pjrHnPP6av_Urg,38747
9
+ sae_lens/llm_sae_training_runner.py,sha256=-FPXaHvDfSw5twSaDO8O80aGIzX6T0HywgdpEFFoO-8,9098
10
+ sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
11
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=FSAz9Je-8Xl7ccdEyp8-WRn-KFtaJ74zgKMefnfaj3A,30877
13
+ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
+ sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
15
+ sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
16
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
+ sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
+ sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
19
+ sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
20
+ sae_lens/saes/sae.py,sha256=u4kmsUVxa2rnFt8A5jLfj7T6h6qqBK6CkecHslebQgE,34938
21
+ sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
22
+ sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
23
+ sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ sae_lens/training/activation_scaler.py,sha256=1P-vva3wJhs2NH65YONli4Rw4auvgZkxe_KKwTNMCR0,1714
26
+ sae_lens/training/activations_store.py,sha256=Xvnz7l2aw3XWtOQsQDj4G4bt-XT6egbumGBwrAM1mtA,32722
27
+ sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
28
+ sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
29
+ sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
30
+ sae_lens/training/sae_trainer.py,sha256=rFuMdnBDe82nd7YV_QKVE18V5jCWmohbzkIGL0Z2kIM,15153
31
+ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
32
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
33
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
34
+ sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
35
+ sae_lens-6.0.0rc3.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
36
+ sae_lens-6.0.0rc3.dist-info/METADATA,sha256=irWiVHtJUXiACNPxZ0fNIVwq1n7n0wxg87c0WSYUkMw,5326
37
+ sae_lens-6.0.0rc3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ sae_lens-6.0.0rc3.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.2
2
+ Generator: poetry-core 2.1.3
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
sae_lens/regsitry.py DELETED
@@ -1,34 +0,0 @@
1
- from typing import TYPE_CHECKING
2
-
3
- # avoid circular imports
4
- if TYPE_CHECKING:
5
- from sae_lens.saes.sae import SAE, TrainingSAE
6
-
7
- SAE_CLASS_REGISTRY: dict[str, "type[SAE]"] = {}
8
- SAE_TRAINING_CLASS_REGISTRY: dict[str, "type[TrainingSAE]"] = {}
9
-
10
-
11
- def register_sae_class(architecture: str, sae_class: "type[SAE]") -> None:
12
- if architecture in SAE_CLASS_REGISTRY:
13
- raise ValueError(
14
- f"SAE class for architecture {architecture} already registered."
15
- )
16
- SAE_CLASS_REGISTRY[architecture] = sae_class
17
-
18
-
19
- def register_sae_training_class(
20
- architecture: str, sae_training_class: "type[TrainingSAE]"
21
- ) -> None:
22
- if architecture in SAE_TRAINING_CLASS_REGISTRY:
23
- raise ValueError(
24
- f"SAE training class for architecture {architecture} already registered."
25
- )
26
- SAE_TRAINING_CLASS_REGISTRY[architecture] = sae_training_class
27
-
28
-
29
- def get_sae_class(architecture: str) -> "type[SAE]":
30
- return SAE_CLASS_REGISTRY[architecture]
31
-
32
-
33
- def get_sae_training_class(architecture: str) -> "type[TrainingSAE]":
34
- return SAE_TRAINING_CLASS_REGISTRY[architecture]
@@ -1,32 +0,0 @@
1
- sae_lens/__init__.py,sha256=ofQyurU7LtxIsg89QFCZe13QsdYpxErRI0x0tiCpB04,2074
2
- sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- sae_lens/analysis/hooked_sae_transformer.py,sha256=RK0mcLhymXdJInXHcagQggxW9Qf4ptePnH7sKXvGGaU,13727
4
- sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
5
- sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
6
- sae_lens/config.py,sha256=SPjziXrTyOBjObSi-3s0_mza3Z7WH8gd9NT9pVUfosg,34375
7
- sae_lens/evals.py,sha256=tjDKmkUM4fBbP9LHZuBLCx37ux8Px9CliTMme3Wjt1A,38898
8
- sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
9
- sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- sae_lens/loading/pretrained_sae_loaders.py,sha256=NcqyH2KDL8Dg66-hjXsBAq1-IwdLEpYfKwbkHxSQbrg,29961
11
- sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
12
- sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
13
- sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
14
- sae_lens/regsitry.py,sha256=yCse5NmVH-ZaPET3jW8r7C_py2DL3yoox40GxGzJ0TI,1098
15
- sae_lens/sae_training_runner.py,sha256=VRNSAIsZLfcQMfZB8qdnK45PUXwoNvJ-rKt9BVYjMMY,8244
16
- sae_lens/saes/gated_sae.py,sha256=l5ucq7AZHya6ZClWNNE7CionGSf1ms5m1Ah3IoN6SH4,9916
17
- sae_lens/saes/jumprelu_sae.py,sha256=DRWgY58894cNh_sYAlefObI4rr0Eb6KHu1WuhTCcvB4,13468
18
- sae_lens/saes/sae.py,sha256=fd7OEsSXbmVii6QoYI_TRti6dwaxAQyrBcKyX7PxERw,36779
19
- sae_lens/saes/standard_sae.py,sha256=m2eNL_w6ave-_g7F1eQiwI4qbjMwwjzvxp96RN_WVAw,7110
20
- sae_lens/saes/topk_sae.py,sha256=aBET4F55A4xMIvZ8AazPtyl3oL-9S7krKx78li0uKGk,11370
21
- sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
22
- sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- sae_lens/training/activations_store.py,sha256=ilJdcnZWfTDus1bdoqIb1wF_7H8_HWLmf8OCGrybmlA,35998
24
- sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
25
- sae_lens/training/optim.py,sha256=AImcc-MAaGDLOBP2hJ4alDFCtaqqgm4cc2eBxIxiQAo,5784
26
- sae_lens/training/sae_trainer.py,sha256=6TkqbzA0fYluRM8ouI_nU9sz-FaP63axxcnDrVfw37E,16279
27
- sae_lens/training/upload_saes_to_huggingface.py,sha256=tVC-2Txw7-9XttGlKzM0OSqU8CK7HDO9vIzDMqEwAYU,4366
28
- sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
29
- sae_lens-6.0.0rc1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
30
- sae_lens-6.0.0rc1.dist-info/METADATA,sha256=wHH-VRtquu-FjZEOHdPJi3zYW3ns7MCT1fVerbPEylc,5326
31
- sae_lens-6.0.0rc1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
32
- sae_lens-6.0.0rc1.dist-info/RECORD,,