sae-lens 6.24.1__py3-none-any.whl → 6.25.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/cache_activations_runner.py +2 -2
- sae_lens/config.py +2 -2
- sae_lens/constants.py +8 -0
- sae_lens/loading/pretrained_sae_loaders.py +66 -66
- sae_lens/pretrained_saes.yaml +16 -0
- sae_lens/saes/sae.py +52 -12
- sae_lens/training/activations_store.py +3 -2
- sae_lens/util.py +21 -0
- {sae_lens-6.24.1.dist-info → sae_lens-6.25.1.dist-info}/METADATA +1 -1
- {sae_lens-6.24.1.dist-info → sae_lens-6.25.1.dist-info}/RECORD +13 -13
- {sae_lens-6.24.1.dist-info → sae_lens-6.25.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.24.1.dist-info → sae_lens-6.25.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
|
|
|
14
14
|
|
|
15
15
|
from sae_lens import logger
|
|
16
16
|
from sae_lens.config import CacheActivationsRunnerConfig
|
|
17
|
-
from sae_lens.constants import DTYPE_MAP
|
|
18
17
|
from sae_lens.load_model import load_model
|
|
19
18
|
from sae_lens.training.activations_store import ActivationsStore
|
|
19
|
+
from sae_lens.util import str_to_dtype
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _mk_activations_store(
|
|
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
|
|
|
97
97
|
bytes_per_token = (
|
|
98
98
|
self.cfg.d_in * self.cfg.dtype.itemsize
|
|
99
99
|
if isinstance(self.cfg.dtype, torch.dtype)
|
|
100
|
-
else
|
|
100
|
+
else str_to_dtype(self.cfg.dtype).itemsize
|
|
101
101
|
)
|
|
102
102
|
total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
|
|
103
103
|
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
|
sae_lens/config.py
CHANGED
|
@@ -17,9 +17,9 @@ from datasets import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
|
-
from sae_lens.constants import DTYPE_MAP
|
|
21
20
|
from sae_lens.registry import get_sae_training_class
|
|
22
21
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
22
|
+
from sae_lens.util import str_to_dtype
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
25
|
pass
|
|
@@ -563,7 +563,7 @@ class CacheActivationsRunnerConfig:
|
|
|
563
563
|
|
|
564
564
|
@property
|
|
565
565
|
def bytes_per_token(self) -> int:
|
|
566
|
-
return self.d_in *
|
|
566
|
+
return self.d_in * str_to_dtype(self.dtype).itemsize
|
|
567
567
|
|
|
568
568
|
@property
|
|
569
569
|
def n_tokens_in_buffer(self) -> int:
|
sae_lens/constants.py
CHANGED
|
@@ -11,6 +11,14 @@ DTYPE_MAP = {
|
|
|
11
11
|
"torch.bfloat16": torch.bfloat16,
|
|
12
12
|
}
|
|
13
13
|
|
|
14
|
+
# Reverse mapping from torch.dtype to canonical string format
|
|
15
|
+
DTYPE_TO_STR = {
|
|
16
|
+
torch.float32: "float32",
|
|
17
|
+
torch.float64: "float64",
|
|
18
|
+
torch.float16: "float16",
|
|
19
|
+
torch.bfloat16: "bfloat16",
|
|
20
|
+
}
|
|
21
|
+
|
|
14
22
|
|
|
15
23
|
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
24
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
@@ -12,11 +12,9 @@ from huggingface_hub import hf_hub_download, hf_hub_url
|
|
|
12
12
|
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
|
-
from safetensors.torch import load_file
|
|
16
15
|
|
|
17
16
|
from sae_lens import logger
|
|
18
17
|
from sae_lens.constants import (
|
|
19
|
-
DTYPE_MAP,
|
|
20
18
|
SAE_CFG_FILENAME,
|
|
21
19
|
SAE_WEIGHTS_FILENAME,
|
|
22
20
|
SPARSIFY_WEIGHTS_FILENAME,
|
|
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
28
26
|
get_repo_id_and_folder_name,
|
|
29
27
|
)
|
|
30
28
|
from sae_lens.registry import get_sae_class
|
|
31
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
29
|
+
from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
|
|
32
30
|
|
|
33
31
|
LLM_METADATA_KEYS = {
|
|
34
32
|
"model_name",
|
|
@@ -51,6 +49,21 @@ LLM_METADATA_KEYS = {
|
|
|
51
49
|
}
|
|
52
50
|
|
|
53
51
|
|
|
52
|
+
def load_safetensors_weights(
|
|
53
|
+
path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
|
|
54
|
+
) -> dict[str, torch.Tensor]:
|
|
55
|
+
"""Load safetensors weights and optionally convert to a different dtype"""
|
|
56
|
+
loaded_weights = {}
|
|
57
|
+
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
|
58
|
+
with safe_open(path, framework="pt", device=device) as f:
|
|
59
|
+
for k in f.keys(): # noqa: SIM118
|
|
60
|
+
weight = f.get_tensor(k)
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
weight = weight.to(dtype=dtype)
|
|
63
|
+
loaded_weights[k] = weight
|
|
64
|
+
return loaded_weights
|
|
65
|
+
|
|
66
|
+
|
|
54
67
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
55
68
|
class PretrainedSaeHuggingfaceLoader(Protocol):
|
|
56
69
|
def __call__(
|
|
@@ -341,7 +354,7 @@ def read_sae_components_from_disk(
|
|
|
341
354
|
Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
|
|
342
355
|
"""
|
|
343
356
|
if dtype is None:
|
|
344
|
-
dtype =
|
|
357
|
+
dtype = str_to_dtype(cfg_dict["dtype"])
|
|
345
358
|
|
|
346
359
|
state_dict = {}
|
|
347
360
|
with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
|
|
@@ -682,15 +695,6 @@ def gemma_3_sae_huggingface_loader(
|
|
|
682
695
|
cfg_overrides,
|
|
683
696
|
)
|
|
684
697
|
|
|
685
|
-
# replace folder name of 65k with 64k
|
|
686
|
-
# TODO: remove this workaround once weights are fixed
|
|
687
|
-
if "270m-pt" in repo_id:
|
|
688
|
-
if "65k" in folder_name:
|
|
689
|
-
folder_name = folder_name.replace("65k", "64k")
|
|
690
|
-
# replace folder name of 262k with 250k
|
|
691
|
-
if "262k" in folder_name:
|
|
692
|
-
folder_name = folder_name.replace("262k", "250k")
|
|
693
|
-
|
|
694
698
|
params_file = "params.safetensors"
|
|
695
699
|
if "clt" in folder_name:
|
|
696
700
|
params_file = folder_name.split("/")[-1] + ".safetensors"
|
|
@@ -704,7 +708,9 @@ def gemma_3_sae_huggingface_loader(
|
|
|
704
708
|
force_download=force_download,
|
|
705
709
|
)
|
|
706
710
|
|
|
707
|
-
raw_state_dict =
|
|
711
|
+
raw_state_dict = load_safetensors_weights(
|
|
712
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
713
|
+
)
|
|
708
714
|
|
|
709
715
|
with torch.no_grad():
|
|
710
716
|
w_dec = raw_state_dict["w_dec"]
|
|
@@ -791,11 +797,13 @@ def get_goodfire_huggingface_loader(
|
|
|
791
797
|
)
|
|
792
798
|
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
793
799
|
|
|
800
|
+
target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
|
|
801
|
+
|
|
794
802
|
state_dict = {
|
|
795
|
-
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
796
|
-
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
797
|
-
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
798
|
-
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
803
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
|
|
804
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
|
|
805
|
+
"b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
|
|
806
|
+
"b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
|
|
799
807
|
}
|
|
800
808
|
|
|
801
809
|
return cfg_dict, state_dict, None
|
|
@@ -898,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
|
|
|
898
906
|
force_download=force_download,
|
|
899
907
|
)
|
|
900
908
|
|
|
901
|
-
|
|
902
|
-
|
|
909
|
+
state_dict_loaded = load_safetensors_weights(
|
|
910
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
911
|
+
)
|
|
903
912
|
|
|
904
913
|
# Convert and organize the weights
|
|
905
914
|
state_dict = {
|
|
906
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
907
|
-
|
|
908
|
-
.
|
|
909
|
-
"
|
|
910
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
911
|
-
.T,
|
|
912
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
913
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
914
|
-
),
|
|
915
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
916
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
917
|
-
),
|
|
915
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
916
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
917
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
918
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
918
919
|
"threshold": torch.ones(
|
|
919
920
|
cfg_dict["d_sae"],
|
|
920
|
-
dtype=
|
|
921
|
+
dtype=str_to_dtype(cfg_dict["dtype"]),
|
|
921
922
|
device=cfg_dict["device"],
|
|
922
923
|
)
|
|
923
924
|
* cfg_dict["jump_relu_threshold"],
|
|
@@ -1228,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
1228
1229
|
force_download=force_download,
|
|
1229
1230
|
)
|
|
1230
1231
|
|
|
1231
|
-
|
|
1232
|
-
|
|
1232
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1233
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1234
|
+
)
|
|
1233
1235
|
|
|
1234
1236
|
# Convert and organize the weights
|
|
1235
1237
|
state_dict = {
|
|
1236
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
1237
|
-
|
|
1238
|
-
.
|
|
1239
|
-
"
|
|
1240
|
-
|
|
1241
|
-
.T,
|
|
1242
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
1243
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1244
|
-
),
|
|
1245
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
1246
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1247
|
-
),
|
|
1248
|
-
"threshold": state_dict_loaded["log_jumprelu_threshold"]
|
|
1249
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
1250
|
-
.exp(),
|
|
1238
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
1239
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
1240
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
1241
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
1242
|
+
"threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
|
|
1251
1243
|
}
|
|
1252
1244
|
|
|
1253
1245
|
# No sparsity tensor for Llama Scope SAEs
|
|
@@ -1367,34 +1359,34 @@ def sparsify_disk_loader(
|
|
|
1367
1359
|
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1368
1360
|
|
|
1369
1361
|
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1370
|
-
state_dict_loaded =
|
|
1371
|
-
|
|
1372
|
-
|
|
1362
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1363
|
+
weight_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1364
|
+
)
|
|
1373
1365
|
|
|
1374
1366
|
W_enc = (
|
|
1375
1367
|
state_dict_loaded["W_enc"]
|
|
1376
1368
|
if "W_enc" in state_dict_loaded
|
|
1377
1369
|
else state_dict_loaded["encoder.weight"].T
|
|
1378
|
-
)
|
|
1370
|
+
)
|
|
1379
1371
|
|
|
1380
1372
|
if "W_dec" in state_dict_loaded:
|
|
1381
|
-
W_dec = state_dict_loaded["W_dec"].T
|
|
1373
|
+
W_dec = state_dict_loaded["W_dec"].T
|
|
1382
1374
|
else:
|
|
1383
|
-
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1375
|
+
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1384
1376
|
|
|
1385
1377
|
if "b_enc" in state_dict_loaded:
|
|
1386
|
-
b_enc = state_dict_loaded["b_enc"]
|
|
1378
|
+
b_enc = state_dict_loaded["b_enc"]
|
|
1387
1379
|
elif "encoder.bias" in state_dict_loaded:
|
|
1388
|
-
b_enc = state_dict_loaded["encoder.bias"]
|
|
1380
|
+
b_enc = state_dict_loaded["encoder.bias"]
|
|
1389
1381
|
else:
|
|
1390
|
-
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1382
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
|
|
1391
1383
|
|
|
1392
1384
|
if "b_dec" in state_dict_loaded:
|
|
1393
|
-
b_dec = state_dict_loaded["b_dec"]
|
|
1385
|
+
b_dec = state_dict_loaded["b_dec"]
|
|
1394
1386
|
elif "decoder.bias" in state_dict_loaded:
|
|
1395
|
-
b_dec = state_dict_loaded["decoder.bias"]
|
|
1387
|
+
b_dec = state_dict_loaded["decoder.bias"]
|
|
1396
1388
|
else:
|
|
1397
|
-
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1389
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
|
|
1398
1390
|
|
|
1399
1391
|
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1400
1392
|
return cfg_dict, state_dict
|
|
@@ -1625,7 +1617,9 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1625
1617
|
)
|
|
1626
1618
|
|
|
1627
1619
|
# Load weights from safetensors
|
|
1628
|
-
state_dict =
|
|
1620
|
+
state_dict = load_safetensors_weights(
|
|
1621
|
+
file_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1622
|
+
)
|
|
1629
1623
|
state_dict["W_enc"] = state_dict["W_enc"].T
|
|
1630
1624
|
|
|
1631
1625
|
return cfg_dict, state_dict, None
|
|
@@ -1709,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1709
1703
|
force_download=force_download,
|
|
1710
1704
|
)
|
|
1711
1705
|
|
|
1712
|
-
encoder_state_dict =
|
|
1713
|
-
|
|
1706
|
+
encoder_state_dict = load_safetensors_weights(
|
|
1707
|
+
encoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1708
|
+
)
|
|
1709
|
+
decoder_state_dict = load_safetensors_weights(
|
|
1710
|
+
decoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1711
|
+
)
|
|
1714
1712
|
|
|
1715
1713
|
with torch.no_grad():
|
|
1716
1714
|
state_dict = {
|
|
@@ -1853,7 +1851,9 @@ def temporal_sae_huggingface_loader(
|
|
|
1853
1851
|
)
|
|
1854
1852
|
|
|
1855
1853
|
# Load checkpoint from safetensors
|
|
1856
|
-
state_dict_raw =
|
|
1854
|
+
state_dict_raw = load_safetensors_weights(
|
|
1855
|
+
ckpt_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1856
|
+
)
|
|
1857
1857
|
|
|
1858
1858
|
# Convert to SAELens naming convention
|
|
1859
1859
|
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -10197,6 +10197,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10197
10197
|
- id: layer_16_width_16k_l0_medium
|
|
10198
10198
|
path: resid_post/layer_16_width_16k_l0_medium
|
|
10199
10199
|
l0: 53
|
|
10200
|
+
neuronpedia: gemma-3-27b-it/16-gemmascope-2-res-16k
|
|
10200
10201
|
- id: layer_16_width_16k_l0_small
|
|
10201
10202
|
path: resid_post/layer_16_width_16k_l0_small
|
|
10202
10203
|
l0: 17
|
|
@@ -10206,6 +10207,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10206
10207
|
- id: layer_16_width_1m_l0_medium
|
|
10207
10208
|
path: resid_post/layer_16_width_1m_l0_medium
|
|
10208
10209
|
l0: 53
|
|
10210
|
+
neuronpedia: gemma-3-27b-it/16-gemmascope-2-res-1m
|
|
10209
10211
|
- id: layer_16_width_1m_l0_small
|
|
10210
10212
|
path: resid_post/layer_16_width_1m_l0_small
|
|
10211
10213
|
l0: 17
|
|
@@ -10215,6 +10217,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10215
10217
|
- id: layer_16_width_262k_l0_medium
|
|
10216
10218
|
path: resid_post/layer_16_width_262k_l0_medium
|
|
10217
10219
|
l0: 53
|
|
10220
|
+
neuronpedia: gemma-3-27b-it/16-gemmascope-2-res-262k
|
|
10218
10221
|
- id: layer_16_width_262k_l0_medium_seed_1
|
|
10219
10222
|
path: resid_post/layer_16_width_262k_l0_medium_seed_1
|
|
10220
10223
|
l0: 53
|
|
@@ -10227,6 +10230,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10227
10230
|
- id: layer_16_width_65k_l0_medium
|
|
10228
10231
|
path: resid_post/layer_16_width_65k_l0_medium
|
|
10229
10232
|
l0: 53
|
|
10233
|
+
neuronpedia: gemma-3-27b-it/16-gemmascope-2-res-65k
|
|
10230
10234
|
- id: layer_16_width_65k_l0_small
|
|
10231
10235
|
path: resid_post/layer_16_width_65k_l0_small
|
|
10232
10236
|
l0: 17
|
|
@@ -10236,6 +10240,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10236
10240
|
- id: layer_31_width_16k_l0_medium
|
|
10237
10241
|
path: resid_post/layer_31_width_16k_l0_medium
|
|
10238
10242
|
l0: 60
|
|
10243
|
+
neuronpedia: gemma-3-27b-it/31-gemmascope-2-res-16k
|
|
10239
10244
|
- id: layer_31_width_16k_l0_small
|
|
10240
10245
|
path: resid_post/layer_31_width_16k_l0_small
|
|
10241
10246
|
l0: 20
|
|
@@ -10245,6 +10250,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10245
10250
|
- id: layer_31_width_1m_l0_medium
|
|
10246
10251
|
path: resid_post/layer_31_width_1m_l0_medium
|
|
10247
10252
|
l0: 60
|
|
10253
|
+
neuronpedia: gemma-3-27b-it/31-gemmascope-2-res-1m
|
|
10248
10254
|
- id: layer_31_width_1m_l0_small
|
|
10249
10255
|
path: resid_post/layer_31_width_1m_l0_small
|
|
10250
10256
|
l0: 20
|
|
@@ -10254,6 +10260,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10254
10260
|
- id: layer_31_width_262k_l0_medium
|
|
10255
10261
|
path: resid_post/layer_31_width_262k_l0_medium
|
|
10256
10262
|
l0: 60
|
|
10263
|
+
neuronpedia: gemma-3-27b-it/31-gemmascope-2-res-262k
|
|
10257
10264
|
- id: layer_31_width_262k_l0_medium_seed_1
|
|
10258
10265
|
path: resid_post/layer_31_width_262k_l0_medium_seed_1
|
|
10259
10266
|
l0: 60
|
|
@@ -10266,6 +10273,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10266
10273
|
- id: layer_31_width_65k_l0_medium
|
|
10267
10274
|
path: resid_post/layer_31_width_65k_l0_medium
|
|
10268
10275
|
l0: 60
|
|
10276
|
+
neuronpedia: gemma-3-27b-it/31-gemmascope-2-res-65k
|
|
10269
10277
|
- id: layer_31_width_65k_l0_small
|
|
10270
10278
|
path: resid_post/layer_31_width_65k_l0_small
|
|
10271
10279
|
l0: 20
|
|
@@ -10275,6 +10283,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10275
10283
|
- id: layer_40_width_16k_l0_medium
|
|
10276
10284
|
path: resid_post/layer_40_width_16k_l0_medium
|
|
10277
10285
|
l0: 60
|
|
10286
|
+
neuronpedia: gemma-3-27b-it/40-gemmascope-2-res-16k
|
|
10278
10287
|
- id: layer_40_width_16k_l0_small
|
|
10279
10288
|
path: resid_post/layer_40_width_16k_l0_small
|
|
10280
10289
|
l0: 20
|
|
@@ -10284,6 +10293,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10284
10293
|
- id: layer_40_width_1m_l0_medium
|
|
10285
10294
|
path: resid_post/layer_40_width_1m_l0_medium
|
|
10286
10295
|
l0: 60
|
|
10296
|
+
neuronpedia: gemma-3-27b-it/40-gemmascope-2-res-1m
|
|
10287
10297
|
- id: layer_40_width_1m_l0_small
|
|
10288
10298
|
path: resid_post/layer_40_width_1m_l0_small
|
|
10289
10299
|
l0: 20
|
|
@@ -10293,6 +10303,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10293
10303
|
- id: layer_40_width_262k_l0_medium
|
|
10294
10304
|
path: resid_post/layer_40_width_262k_l0_medium
|
|
10295
10305
|
l0: 60
|
|
10306
|
+
neuronpedia: gemma-3-27b-it/40-gemmascope-2-res-262k
|
|
10296
10307
|
- id: layer_40_width_262k_l0_medium_seed_1
|
|
10297
10308
|
path: resid_post/layer_40_width_262k_l0_medium_seed_1
|
|
10298
10309
|
l0: 60
|
|
@@ -10305,6 +10316,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10305
10316
|
- id: layer_40_width_65k_l0_medium
|
|
10306
10317
|
path: resid_post/layer_40_width_65k_l0_medium
|
|
10307
10318
|
l0: 60
|
|
10319
|
+
neuronpedia: gemma-3-27b-it/40-gemmascope-2-res-65k
|
|
10308
10320
|
- id: layer_40_width_65k_l0_small
|
|
10309
10321
|
path: resid_post/layer_40_width_65k_l0_small
|
|
10310
10322
|
l0: 20
|
|
@@ -10314,6 +10326,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10314
10326
|
- id: layer_53_width_16k_l0_medium
|
|
10315
10327
|
path: resid_post/layer_53_width_16k_l0_medium
|
|
10316
10328
|
l0: 60
|
|
10329
|
+
neuronpedia: gemma-3-27b-it/53-gemmascope-2-res-16k
|
|
10317
10330
|
- id: layer_53_width_16k_l0_small
|
|
10318
10331
|
path: resid_post/layer_53_width_16k_l0_small
|
|
10319
10332
|
l0: 20
|
|
@@ -10323,6 +10336,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10323
10336
|
- id: layer_53_width_1m_l0_medium
|
|
10324
10337
|
path: resid_post/layer_53_width_1m_l0_medium
|
|
10325
10338
|
l0: 60
|
|
10339
|
+
neuronpedia: gemma-3-27b-it/53-gemmascope-2-res-1m
|
|
10326
10340
|
- id: layer_53_width_1m_l0_small
|
|
10327
10341
|
path: resid_post/layer_53_width_1m_l0_small
|
|
10328
10342
|
l0: 20
|
|
@@ -10332,6 +10346,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10332
10346
|
- id: layer_53_width_262k_l0_medium
|
|
10333
10347
|
path: resid_post/layer_53_width_262k_l0_medium
|
|
10334
10348
|
l0: 60
|
|
10349
|
+
neuronpedia: gemma-3-27b-it/53-gemmascope-2-res-262k
|
|
10335
10350
|
- id: layer_53_width_262k_l0_medium_seed_1
|
|
10336
10351
|
path: resid_post/layer_53_width_262k_l0_medium_seed_1
|
|
10337
10352
|
l0: 60
|
|
@@ -10344,6 +10359,7 @@ gemma-scope-2-27b-it-res:
|
|
|
10344
10359
|
- id: layer_53_width_65k_l0_medium
|
|
10345
10360
|
path: resid_post/layer_53_width_65k_l0_medium
|
|
10346
10361
|
l0: 60
|
|
10362
|
+
neuronpedia: gemma-3-27b-it/53-gemmascope-2-res-65k
|
|
10347
10363
|
- id: layer_53_width_65k_l0_small
|
|
10348
10364
|
path: resid_post/layer_53_width_65k_l0_small
|
|
10349
10365
|
l0: 20
|
sae_lens/saes/sae.py
CHANGED
|
@@ -27,11 +27,10 @@ from typing_extensions import deprecated, overload, override
|
|
|
27
27
|
|
|
28
28
|
from sae_lens import __version__
|
|
29
29
|
from sae_lens.constants import (
|
|
30
|
-
DTYPE_MAP,
|
|
31
30
|
SAE_CFG_FILENAME,
|
|
32
31
|
SAE_WEIGHTS_FILENAME,
|
|
33
32
|
)
|
|
34
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
33
|
+
from sae_lens.util import dtype_to_str, filter_valid_dataclass_fields, str_to_dtype
|
|
35
34
|
|
|
36
35
|
if TYPE_CHECKING:
|
|
37
36
|
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
@@ -253,7 +252,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
253
252
|
stacklevel=1,
|
|
254
253
|
)
|
|
255
254
|
|
|
256
|
-
self.dtype =
|
|
255
|
+
self.dtype = str_to_dtype(cfg.dtype)
|
|
257
256
|
self.device = torch.device(cfg.device)
|
|
258
257
|
self.use_error_term = use_error_term
|
|
259
258
|
|
|
@@ -437,8 +436,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
437
436
|
|
|
438
437
|
# Update dtype in config if provided
|
|
439
438
|
if dtype_arg is not None:
|
|
440
|
-
# Update the cfg.dtype
|
|
441
|
-
self.cfg.dtype =
|
|
439
|
+
# Update the cfg.dtype (use canonical short form like "float32")
|
|
440
|
+
self.cfg.dtype = dtype_to_str(dtype_arg)
|
|
442
441
|
|
|
443
442
|
# Update the dtype property
|
|
444
443
|
self.dtype = dtype_arg
|
|
@@ -534,6 +533,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
534
533
|
dtype: str | None = None,
|
|
535
534
|
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
536
535
|
) -> T_SAE:
|
|
536
|
+
"""
|
|
537
|
+
Load a SAE from disk.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
path: The path to the SAE weights and config.
|
|
541
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
542
|
+
dtype: The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.
|
|
543
|
+
converter: The converter to use to load the SAE, defaults to sae_lens_disk_loader.
|
|
544
|
+
"""
|
|
537
545
|
overrides = {"dtype": dtype} if dtype is not None else None
|
|
538
546
|
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
539
547
|
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
@@ -542,10 +550,17 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
542
550
|
)
|
|
543
551
|
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
544
552
|
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
553
|
+
# hack to avoid using double memory when loading the SAE.
|
|
554
|
+
# first put the SAE on the meta device, then load the weights.
|
|
555
|
+
device = sae_cfg.device
|
|
556
|
+
sae_cfg.device = "meta"
|
|
545
557
|
sae = sae_cls(sae_cfg)
|
|
558
|
+
sae.cfg.device = device
|
|
546
559
|
sae.process_state_dict_for_loading(state_dict)
|
|
547
|
-
sae.load_state_dict(state_dict)
|
|
548
|
-
|
|
560
|
+
sae.load_state_dict(state_dict, assign=True)
|
|
561
|
+
# the loaders should already handle the dtype / device conversion
|
|
562
|
+
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
563
|
+
return sae.to(dtype=str_to_dtype(sae_cfg.dtype), device=device)
|
|
549
564
|
|
|
550
565
|
@classmethod
|
|
551
566
|
def from_pretrained(
|
|
@@ -553,6 +568,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
553
568
|
release: str,
|
|
554
569
|
sae_id: str,
|
|
555
570
|
device: str = "cpu",
|
|
571
|
+
dtype: str = "float32",
|
|
556
572
|
force_download: bool = False,
|
|
557
573
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
558
574
|
) -> T_SAE:
|
|
@@ -562,10 +578,18 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
562
578
|
Args:
|
|
563
579
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
564
580
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
565
|
-
device: The device to load the SAE on.
|
|
581
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
582
|
+
dtype: The dtype to load the SAE on, defaults to "float32".
|
|
583
|
+
force_download: Whether to force download the SAE weights and config, defaults to False.
|
|
584
|
+
converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
|
|
566
585
|
"""
|
|
567
586
|
return cls.from_pretrained_with_cfg_and_sparsity(
|
|
568
|
-
release,
|
|
587
|
+
release,
|
|
588
|
+
sae_id,
|
|
589
|
+
device,
|
|
590
|
+
force_download=force_download,
|
|
591
|
+
dtype=dtype,
|
|
592
|
+
converter=converter,
|
|
569
593
|
)[0]
|
|
570
594
|
|
|
571
595
|
@classmethod
|
|
@@ -574,6 +598,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
574
598
|
release: str,
|
|
575
599
|
sae_id: str,
|
|
576
600
|
device: str = "cpu",
|
|
601
|
+
dtype: str = "float32",
|
|
577
602
|
force_download: bool = False,
|
|
578
603
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
579
604
|
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
@@ -584,7 +609,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
584
609
|
Args:
|
|
585
610
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
586
611
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
587
|
-
device: The device to load the SAE on.
|
|
612
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
613
|
+
dtype: The dtype to load the SAE on, defaults to "float32".
|
|
614
|
+
force_download: Whether to force download the SAE weights and config, defaults to False.
|
|
615
|
+
converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
|
|
588
616
|
"""
|
|
589
617
|
|
|
590
618
|
# get sae directory
|
|
@@ -634,6 +662,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
634
662
|
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
|
|
635
663
|
config_overrides = get_config_overrides(release, sae_id)
|
|
636
664
|
config_overrides["device"] = device
|
|
665
|
+
config_overrides["dtype"] = dtype
|
|
637
666
|
|
|
638
667
|
# Load config and weights
|
|
639
668
|
cfg_dict, state_dict, log_sparsities = conversion_loader(
|
|
@@ -651,9 +680,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
651
680
|
)
|
|
652
681
|
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
653
682
|
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
683
|
+
# hack to avoid using double memory when loading the SAE.
|
|
684
|
+
# first put the SAE on the meta device, then load the weights.
|
|
685
|
+
device = sae_cfg.device
|
|
686
|
+
sae_cfg.device = "meta"
|
|
654
687
|
sae = sae_cls(sae_cfg)
|
|
688
|
+
sae.cfg.device = device
|
|
655
689
|
sae.process_state_dict_for_loading(state_dict)
|
|
656
|
-
sae.load_state_dict(state_dict)
|
|
690
|
+
sae.load_state_dict(state_dict, assign=True)
|
|
657
691
|
|
|
658
692
|
# Apply normalization if needed
|
|
659
693
|
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
@@ -666,7 +700,13 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
666
700
|
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
667
701
|
)
|
|
668
702
|
|
|
669
|
-
|
|
703
|
+
# the loaders should already handle the dtype / device conversion
|
|
704
|
+
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
705
|
+
return (
|
|
706
|
+
sae.to(dtype=str_to_dtype(dtype), device=device),
|
|
707
|
+
cfg_dict,
|
|
708
|
+
log_sparsities,
|
|
709
|
+
)
|
|
670
710
|
|
|
671
711
|
@classmethod
|
|
672
712
|
def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
|
@@ -24,7 +24,7 @@ from sae_lens.config import (
|
|
|
24
24
|
HfDataset,
|
|
25
25
|
LanguageModelSAERunnerConfig,
|
|
26
26
|
)
|
|
27
|
-
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
|
|
27
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
|
|
28
28
|
from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
29
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
30
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
@@ -32,6 +32,7 @@ from sae_lens.training.mixing_buffer import mixing_buffer
|
|
|
32
32
|
from sae_lens.util import (
|
|
33
33
|
extract_stop_at_layer_from_tlens_hook_name,
|
|
34
34
|
get_special_token_ids,
|
|
35
|
+
str_to_dtype,
|
|
35
36
|
)
|
|
36
37
|
|
|
37
38
|
|
|
@@ -258,7 +259,7 @@ class ActivationsStore:
|
|
|
258
259
|
self.prepend_bos = prepend_bos
|
|
259
260
|
self.normalize_activations = normalize_activations
|
|
260
261
|
self.device = torch.device(device)
|
|
261
|
-
self.dtype =
|
|
262
|
+
self.dtype = str_to_dtype(dtype)
|
|
262
263
|
self.cached_activations_path = cached_activations_path
|
|
263
264
|
self.autocast_lm = autocast_lm
|
|
264
265
|
self.seqpos_slice = seqpos_slice
|
sae_lens/util.py
CHANGED
|
@@ -5,8 +5,11 @@ from dataclasses import asdict, fields, is_dataclass
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Sequence, TypeVar
|
|
7
7
|
|
|
8
|
+
import torch
|
|
8
9
|
from transformers import PreTrainedTokenizerBase
|
|
9
10
|
|
|
11
|
+
from sae_lens.constants import DTYPE_MAP, DTYPE_TO_STR
|
|
12
|
+
|
|
10
13
|
K = TypeVar("K")
|
|
11
14
|
V = TypeVar("V")
|
|
12
15
|
|
|
@@ -90,3 +93,21 @@ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
|
|
|
90
93
|
special_tokens.add(token_id)
|
|
91
94
|
|
|
92
95
|
return list(special_tokens)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def str_to_dtype(dtype: str) -> torch.dtype:
|
|
99
|
+
"""Convert a string to a torch.dtype."""
|
|
100
|
+
if dtype not in DTYPE_MAP:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_MAP.keys())}"
|
|
103
|
+
)
|
|
104
|
+
return DTYPE_MAP[dtype]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def dtype_to_str(dtype: torch.dtype) -> str:
|
|
108
|
+
"""Convert a torch.dtype to a string."""
|
|
109
|
+
if dtype not in DTYPE_TO_STR:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_TO_STR.keys())}"
|
|
112
|
+
)
|
|
113
|
+
return DTYPE_TO_STR[dtype]
|
|
@@ -1,25 +1,25 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=vWuA8EbynIJadj666RoFNCTIvoH9-HFpUxuHwoYt8Ks,4268
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
5
|
-
sae_lens/cache_activations_runner.py,sha256=
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
7
|
-
sae_lens/constants.py,sha256=
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
|
|
6
|
+
sae_lens/config.py,sha256=JmcrXT4orJV2OulbEZAciz8RQmYv7DrtUtRbOLsNQ2Y,30330
|
|
7
|
+
sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
|
|
8
8
|
sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
|
|
9
9
|
sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
|
|
10
10
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=hq-dhxsEdUmlAnZEiZBqX7lNyQQwZ6KXmXZWpzAc5FY,63638
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
|
|
14
14
|
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=Hy9mk4Liy50B0CIBD4ER1ETcho2drFFiIy-bPVCN_lc,1510210
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
17
|
sae_lens/saes/__init__.py,sha256=fYVujOzNnUgpzLL0MBLBt_DNX2CPcTaheukzCd2bEPo,1906
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
19
19
|
sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
|
|
20
20
|
sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
|
|
21
21
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
|
|
22
|
-
sae_lens/saes/sae.py,sha256=
|
|
22
|
+
sae_lens/saes/sae.py,sha256=fzXv8lwHskSxsf8hm_wlKPkpq50iafmBjBNQzwZ6a00,40050
|
|
23
23
|
sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
|
|
24
24
|
sae_lens/saes/temporal_sae.py,sha256=DsecivcHWId-MTuJpQbz8OhqtmGhZACxJauYZGHo0Ok,13272
|
|
25
25
|
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
@@ -27,15 +27,15 @@ sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,1
|
|
|
27
27
|
sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
|
|
28
28
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
29
|
sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
|
|
30
|
-
sae_lens/training/activations_store.py,sha256=
|
|
30
|
+
sae_lens/training/activations_store.py,sha256=rQadexm2BiwK7_MZIPlRkcKSqabi3iuOTC-R8aJchS8,33778
|
|
31
31
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
32
32
|
sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
|
|
33
33
|
sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
|
|
34
34
|
sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
35
35
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
36
36
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
37
|
-
sae_lens/util.py,sha256=
|
|
38
|
-
sae_lens-6.
|
|
39
|
-
sae_lens-6.
|
|
40
|
-
sae_lens-6.
|
|
41
|
-
sae_lens-6.
|
|
37
|
+
sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
|
|
38
|
+
sae_lens-6.25.1.dist-info/METADATA,sha256=gClFVWzEWNNjrXsGqvCY6ry6ehXIFwp8PB0jIOhmQvc,5361
|
|
39
|
+
sae_lens-6.25.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
40
|
+
sae_lens-6.25.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
41
|
+
sae_lens-6.25.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|