sae-lens 6.22.1__py3-none-any.whl → 6.25.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +8 -1
- sae_lens/cache_activations_runner.py +2 -2
- sae_lens/config.py +2 -2
- sae_lens/constants.py +8 -0
- sae_lens/loading/pretrained_sae_loaders.py +298 -80
- sae_lens/pretokenize_runner.py +3 -3
- sae_lens/pretrained_saes.yaml +26949 -97
- sae_lens/saes/__init__.py +4 -0
- sae_lens/saes/gated_sae.py +2 -2
- sae_lens/saes/jumprelu_sae.py +4 -4
- sae_lens/saes/sae.py +53 -13
- sae_lens/saes/topk_sae.py +1 -1
- sae_lens/saes/transcoder.py +41 -0
- sae_lens/training/activations_store.py +8 -7
- sae_lens/util.py +21 -0
- {sae_lens-6.22.1.dist-info → sae_lens-6.25.1.dist-info}/METADATA +2 -2
- {sae_lens-6.22.1.dist-info → sae_lens-6.25.1.dist-info}/RECORD +19 -19
- {sae_lens-6.22.1.dist-info → sae_lens-6.25.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.22.1.dist-info → sae_lens-6.25.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/saes/__init__.py
CHANGED
|
@@ -33,6 +33,8 @@ from .topk_sae import (
|
|
|
33
33
|
TopKTrainingSAEConfig,
|
|
34
34
|
)
|
|
35
35
|
from .transcoder import (
|
|
36
|
+
JumpReLUSkipTranscoder,
|
|
37
|
+
JumpReLUSkipTranscoderConfig,
|
|
36
38
|
JumpReLUTranscoder,
|
|
37
39
|
JumpReLUTranscoderConfig,
|
|
38
40
|
SkipTranscoder,
|
|
@@ -70,6 +72,8 @@ __all__ = [
|
|
|
70
72
|
"SkipTranscoderConfig",
|
|
71
73
|
"JumpReLUTranscoder",
|
|
72
74
|
"JumpReLUTranscoderConfig",
|
|
75
|
+
"JumpReLUSkipTranscoder",
|
|
76
|
+
"JumpReLUSkipTranscoderConfig",
|
|
73
77
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
74
78
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
75
79
|
"TemporalSAE",
|
sae_lens/saes/gated_sae.py
CHANGED
|
@@ -89,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
89
89
|
@torch.no_grad()
|
|
90
90
|
def fold_W_dec_norm(self):
|
|
91
91
|
"""Override to handle gated-specific parameters."""
|
|
92
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
92
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
93
93
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
94
94
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
95
95
|
|
|
@@ -217,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
217
217
|
@torch.no_grad()
|
|
218
218
|
def fold_W_dec_norm(self):
|
|
219
219
|
"""Override to handle gated-specific parameters."""
|
|
220
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
220
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
221
221
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
222
222
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
223
223
|
|
sae_lens/saes/jumprelu_sae.py
CHANGED
|
@@ -167,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
167
167
|
# Save the current threshold before calling parent method
|
|
168
168
|
current_thresh = self.threshold.clone()
|
|
169
169
|
|
|
170
|
-
# Get W_dec norms that will be used for scaling
|
|
171
|
-
W_dec_norms = self.W_dec.norm(dim=-1)
|
|
170
|
+
# Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
|
|
171
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
172
172
|
|
|
173
173
|
# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
|
|
174
174
|
super().fold_W_dec_norm()
|
|
@@ -325,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
325
325
|
# Save the current threshold before we call the parent method
|
|
326
326
|
current_thresh = self.threshold.clone()
|
|
327
327
|
|
|
328
|
-
# Get W_dec norms
|
|
329
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
328
|
+
# Get W_dec norms (clamped to avoid division by zero)
|
|
329
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
330
330
|
|
|
331
331
|
# Call parent implementation to handle W_enc and W_dec adjustment
|
|
332
332
|
super().fold_W_dec_norm()
|
sae_lens/saes/sae.py
CHANGED
|
@@ -27,11 +27,10 @@ from typing_extensions import deprecated, overload, override
|
|
|
27
27
|
|
|
28
28
|
from sae_lens import __version__
|
|
29
29
|
from sae_lens.constants import (
|
|
30
|
-
DTYPE_MAP,
|
|
31
30
|
SAE_CFG_FILENAME,
|
|
32
31
|
SAE_WEIGHTS_FILENAME,
|
|
33
32
|
)
|
|
34
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
33
|
+
from sae_lens.util import dtype_to_str, filter_valid_dataclass_fields, str_to_dtype
|
|
35
34
|
|
|
36
35
|
if TYPE_CHECKING:
|
|
37
36
|
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
@@ -253,7 +252,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
253
252
|
stacklevel=1,
|
|
254
253
|
)
|
|
255
254
|
|
|
256
|
-
self.dtype =
|
|
255
|
+
self.dtype = str_to_dtype(cfg.dtype)
|
|
257
256
|
self.device = torch.device(cfg.device)
|
|
258
257
|
self.use_error_term = use_error_term
|
|
259
258
|
|
|
@@ -437,8 +436,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
437
436
|
|
|
438
437
|
# Update dtype in config if provided
|
|
439
438
|
if dtype_arg is not None:
|
|
440
|
-
# Update the cfg.dtype
|
|
441
|
-
self.cfg.dtype =
|
|
439
|
+
# Update the cfg.dtype (use canonical short form like "float32")
|
|
440
|
+
self.cfg.dtype = dtype_to_str(dtype_arg)
|
|
442
441
|
|
|
443
442
|
# Update the dtype property
|
|
444
443
|
self.dtype = dtype_arg
|
|
@@ -484,7 +483,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
484
483
|
@torch.no_grad()
|
|
485
484
|
def fold_W_dec_norm(self):
|
|
486
485
|
"""Fold decoder norms into encoder."""
|
|
487
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
486
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
488
487
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
489
488
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
490
489
|
|
|
@@ -534,6 +533,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
534
533
|
dtype: str | None = None,
|
|
535
534
|
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
536
535
|
) -> T_SAE:
|
|
536
|
+
"""
|
|
537
|
+
Load a SAE from disk.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
path: The path to the SAE weights and config.
|
|
541
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
542
|
+
dtype: The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.
|
|
543
|
+
converter: The converter to use to load the SAE, defaults to sae_lens_disk_loader.
|
|
544
|
+
"""
|
|
537
545
|
overrides = {"dtype": dtype} if dtype is not None else None
|
|
538
546
|
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
539
547
|
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
@@ -542,10 +550,17 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
542
550
|
)
|
|
543
551
|
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
544
552
|
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
553
|
+
# hack to avoid using double memory when loading the SAE.
|
|
554
|
+
# first put the SAE on the meta device, then load the weights.
|
|
555
|
+
device = sae_cfg.device
|
|
556
|
+
sae_cfg.device = "meta"
|
|
545
557
|
sae = sae_cls(sae_cfg)
|
|
558
|
+
sae.cfg.device = device
|
|
546
559
|
sae.process_state_dict_for_loading(state_dict)
|
|
547
|
-
sae.load_state_dict(state_dict)
|
|
548
|
-
|
|
560
|
+
sae.load_state_dict(state_dict, assign=True)
|
|
561
|
+
# the loaders should already handle the dtype / device conversion
|
|
562
|
+
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
563
|
+
return sae.to(dtype=str_to_dtype(sae_cfg.dtype), device=device)
|
|
549
564
|
|
|
550
565
|
@classmethod
|
|
551
566
|
def from_pretrained(
|
|
@@ -553,6 +568,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
553
568
|
release: str,
|
|
554
569
|
sae_id: str,
|
|
555
570
|
device: str = "cpu",
|
|
571
|
+
dtype: str = "float32",
|
|
556
572
|
force_download: bool = False,
|
|
557
573
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
558
574
|
) -> T_SAE:
|
|
@@ -562,10 +578,18 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
562
578
|
Args:
|
|
563
579
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
564
580
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
565
|
-
device: The device to load the SAE on.
|
|
581
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
582
|
+
dtype: The dtype to load the SAE on, defaults to "float32".
|
|
583
|
+
force_download: Whether to force download the SAE weights and config, defaults to False.
|
|
584
|
+
converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
|
|
566
585
|
"""
|
|
567
586
|
return cls.from_pretrained_with_cfg_and_sparsity(
|
|
568
|
-
release,
|
|
587
|
+
release,
|
|
588
|
+
sae_id,
|
|
589
|
+
device,
|
|
590
|
+
force_download=force_download,
|
|
591
|
+
dtype=dtype,
|
|
592
|
+
converter=converter,
|
|
569
593
|
)[0]
|
|
570
594
|
|
|
571
595
|
@classmethod
|
|
@@ -574,6 +598,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
574
598
|
release: str,
|
|
575
599
|
sae_id: str,
|
|
576
600
|
device: str = "cpu",
|
|
601
|
+
dtype: str = "float32",
|
|
577
602
|
force_download: bool = False,
|
|
578
603
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
579
604
|
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
@@ -584,7 +609,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
584
609
|
Args:
|
|
585
610
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
586
611
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
587
|
-
device: The device to load the SAE on.
|
|
612
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
613
|
+
dtype: The dtype to load the SAE on, defaults to "float32".
|
|
614
|
+
force_download: Whether to force download the SAE weights and config, defaults to False.
|
|
615
|
+
converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
|
|
588
616
|
"""
|
|
589
617
|
|
|
590
618
|
# get sae directory
|
|
@@ -634,6 +662,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
634
662
|
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
|
|
635
663
|
config_overrides = get_config_overrides(release, sae_id)
|
|
636
664
|
config_overrides["device"] = device
|
|
665
|
+
config_overrides["dtype"] = dtype
|
|
637
666
|
|
|
638
667
|
# Load config and weights
|
|
639
668
|
cfg_dict, state_dict, log_sparsities = conversion_loader(
|
|
@@ -651,9 +680,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
651
680
|
)
|
|
652
681
|
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
653
682
|
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
683
|
+
# hack to avoid using double memory when loading the SAE.
|
|
684
|
+
# first put the SAE on the meta device, then load the weights.
|
|
685
|
+
device = sae_cfg.device
|
|
686
|
+
sae_cfg.device = "meta"
|
|
654
687
|
sae = sae_cls(sae_cfg)
|
|
688
|
+
sae.cfg.device = device
|
|
655
689
|
sae.process_state_dict_for_loading(state_dict)
|
|
656
|
-
sae.load_state_dict(state_dict)
|
|
690
|
+
sae.load_state_dict(state_dict, assign=True)
|
|
657
691
|
|
|
658
692
|
# Apply normalization if needed
|
|
659
693
|
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
@@ -666,7 +700,13 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
666
700
|
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
667
701
|
)
|
|
668
702
|
|
|
669
|
-
|
|
703
|
+
# the loaders should already handle the dtype / device conversion
|
|
704
|
+
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
705
|
+
return (
|
|
706
|
+
sae.to(dtype=str_to_dtype(dtype), device=device),
|
|
707
|
+
cfg_dict,
|
|
708
|
+
log_sparsities,
|
|
709
|
+
)
|
|
670
710
|
|
|
671
711
|
@classmethod
|
|
672
712
|
def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -531,7 +531,7 @@ def _fold_norm_topk(
|
|
|
531
531
|
b_enc: torch.Tensor,
|
|
532
532
|
W_dec: torch.Tensor,
|
|
533
533
|
) -> None:
|
|
534
|
-
W_dec_norm = W_dec.norm(dim=-1)
|
|
534
|
+
W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
535
535
|
b_enc.data = b_enc.data * W_dec_norm
|
|
536
536
|
W_dec_norms = W_dec_norm.unsqueeze(1)
|
|
537
537
|
W_dec.data = W_dec.data / W_dec_norms
|
sae_lens/saes/transcoder.py
CHANGED
|
@@ -368,3 +368,44 @@ class JumpReLUTranscoder(Transcoder):
|
|
|
368
368
|
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoder":
|
|
369
369
|
cfg = JumpReLUTranscoderConfig.from_dict(config_dict)
|
|
370
370
|
return cls(cfg)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@dataclass
|
|
374
|
+
class JumpReLUSkipTranscoderConfig(JumpReLUTranscoderConfig):
|
|
375
|
+
"""Configuration for JumpReLU transcoder."""
|
|
376
|
+
|
|
377
|
+
@classmethod
|
|
378
|
+
def architecture(cls) -> str:
|
|
379
|
+
"""Return the architecture name for this config."""
|
|
380
|
+
return "jumprelu_skip_transcoder"
|
|
381
|
+
|
|
382
|
+
@classmethod
|
|
383
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoderConfig":
|
|
384
|
+
"""Create a JumpReLUSkipTranscoderConfig from a dictionary."""
|
|
385
|
+
# Filter to only include valid dataclass fields
|
|
386
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
387
|
+
|
|
388
|
+
# Create the config instance
|
|
389
|
+
res = cls(**filtered_config_dict)
|
|
390
|
+
|
|
391
|
+
# Handle metadata if present
|
|
392
|
+
if "metadata" in config_dict:
|
|
393
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
394
|
+
|
|
395
|
+
return res
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class JumpReLUSkipTranscoder(JumpReLUTranscoder, SkipTranscoder):
|
|
399
|
+
"""
|
|
400
|
+
A transcoder with a learnable skip connection and JumpReLU activation function.
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
cfg: JumpReLUSkipTranscoderConfig # type: ignore[assignment]
|
|
404
|
+
|
|
405
|
+
def __init__(self, cfg: JumpReLUSkipTranscoderConfig):
|
|
406
|
+
super().__init__(cfg)
|
|
407
|
+
|
|
408
|
+
@classmethod
|
|
409
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoder":
|
|
410
|
+
cfg = JumpReLUSkipTranscoderConfig.from_dict(config_dict)
|
|
411
|
+
return cls(cfg)
|
|
@@ -24,7 +24,7 @@ from sae_lens.config import (
|
|
|
24
24
|
HfDataset,
|
|
25
25
|
LanguageModelSAERunnerConfig,
|
|
26
26
|
)
|
|
27
|
-
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
|
|
27
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
|
|
28
28
|
from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
29
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
30
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
@@ -32,6 +32,7 @@ from sae_lens.training.mixing_buffer import mixing_buffer
|
|
|
32
32
|
from sae_lens.util import (
|
|
33
33
|
extract_stop_at_layer_from_tlens_hook_name,
|
|
34
34
|
get_special_token_ids,
|
|
35
|
+
str_to_dtype,
|
|
35
36
|
)
|
|
36
37
|
|
|
37
38
|
|
|
@@ -166,9 +167,11 @@ class ActivationsStore:
|
|
|
166
167
|
disable_concat_sequences: bool = False,
|
|
167
168
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
|
|
168
169
|
) -> ActivationsStore:
|
|
170
|
+
if context_size is None:
|
|
171
|
+
context_size = sae.cfg.metadata.context_size
|
|
169
172
|
if sae.cfg.metadata.hook_name is None:
|
|
170
173
|
raise ValueError("hook_name is required")
|
|
171
|
-
if
|
|
174
|
+
if context_size is None:
|
|
172
175
|
raise ValueError("context_size is required")
|
|
173
176
|
if sae.cfg.metadata.prepend_bos is None:
|
|
174
177
|
raise ValueError("prepend_bos is required")
|
|
@@ -178,9 +181,7 @@ class ActivationsStore:
|
|
|
178
181
|
d_in=sae.cfg.d_in,
|
|
179
182
|
hook_name=sae.cfg.metadata.hook_name,
|
|
180
183
|
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
181
|
-
context_size=
|
|
182
|
-
if context_size is None
|
|
183
|
-
else context_size,
|
|
184
|
+
context_size=context_size,
|
|
184
185
|
prepend_bos=sae.cfg.metadata.prepend_bos,
|
|
185
186
|
streaming=streaming,
|
|
186
187
|
store_batch_size_prompts=store_batch_size_prompts,
|
|
@@ -230,7 +231,7 @@ class ActivationsStore:
|
|
|
230
231
|
load_dataset(
|
|
231
232
|
dataset,
|
|
232
233
|
split="train",
|
|
233
|
-
streaming=streaming,
|
|
234
|
+
streaming=streaming, # type: ignore
|
|
234
235
|
trust_remote_code=dataset_trust_remote_code, # type: ignore
|
|
235
236
|
)
|
|
236
237
|
if isinstance(dataset, str)
|
|
@@ -258,7 +259,7 @@ class ActivationsStore:
|
|
|
258
259
|
self.prepend_bos = prepend_bos
|
|
259
260
|
self.normalize_activations = normalize_activations
|
|
260
261
|
self.device = torch.device(device)
|
|
261
|
-
self.dtype =
|
|
262
|
+
self.dtype = str_to_dtype(dtype)
|
|
262
263
|
self.cached_activations_path = cached_activations_path
|
|
263
264
|
self.autocast_lm = autocast_lm
|
|
264
265
|
self.seqpos_slice = seqpos_slice
|
sae_lens/util.py
CHANGED
|
@@ -5,8 +5,11 @@ from dataclasses import asdict, fields, is_dataclass
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Sequence, TypeVar
|
|
7
7
|
|
|
8
|
+
import torch
|
|
8
9
|
from transformers import PreTrainedTokenizerBase
|
|
9
10
|
|
|
11
|
+
from sae_lens.constants import DTYPE_MAP, DTYPE_TO_STR
|
|
12
|
+
|
|
10
13
|
K = TypeVar("K")
|
|
11
14
|
V = TypeVar("V")
|
|
12
15
|
|
|
@@ -90,3 +93,21 @@ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
|
|
|
90
93
|
special_tokens.add(token_id)
|
|
91
94
|
|
|
92
95
|
return list(special_tokens)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def str_to_dtype(dtype: str) -> torch.dtype:
|
|
99
|
+
"""Convert a string to a torch.dtype."""
|
|
100
|
+
if dtype not in DTYPE_MAP:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_MAP.keys())}"
|
|
103
|
+
)
|
|
104
|
+
return DTYPE_MAP[dtype]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def dtype_to_str(dtype: torch.dtype) -> str:
|
|
108
|
+
"""Convert a torch.dtype to a string."""
|
|
109
|
+
if dtype not in DTYPE_TO_STR:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_TO_STR.keys())}"
|
|
112
|
+
)
|
|
113
|
+
return DTYPE_TO_STR[dtype]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.25.1
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -62,7 +62,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
|
|
|
62
62
|
|
|
63
63
|
## Loading Pre-trained SAEs.
|
|
64
64
|
|
|
65
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/
|
|
65
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
66
66
|
|
|
67
67
|
## Migrating to SAELens v6
|
|
68
68
|
|
|
@@ -1,41 +1,41 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=vWuA8EbynIJadj666RoFNCTIvoH9-HFpUxuHwoYt8Ks,4268
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
5
|
-
sae_lens/cache_activations_runner.py,sha256=
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
7
|
-
sae_lens/constants.py,sha256=
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
|
|
6
|
+
sae_lens/config.py,sha256=JmcrXT4orJV2OulbEZAciz8RQmYv7DrtUtRbOLsNQ2Y,30330
|
|
7
|
+
sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
|
|
8
8
|
sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
|
|
9
9
|
sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
|
|
10
10
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=hq-dhxsEdUmlAnZEiZBqX7lNyQQwZ6KXmXZWpzAc5FY,63638
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
|
|
14
|
-
sae_lens/pretokenize_runner.py,sha256=
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
14
|
+
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=Hy9mk4Liy50B0CIBD4ER1ETcho2drFFiIy-bPVCN_lc,1510210
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
-
sae_lens/saes/__init__.py,sha256=
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=fYVujOzNnUgpzLL0MBLBt_DNX2CPcTaheukzCd2bEPo,1906
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
19
|
-
sae_lens/saes/gated_sae.py,sha256=
|
|
20
|
-
sae_lens/saes/jumprelu_sae.py,sha256=
|
|
19
|
+
sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
|
|
20
|
+
sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
|
|
21
21
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
|
|
22
|
-
sae_lens/saes/sae.py,sha256=
|
|
22
|
+
sae_lens/saes/sae.py,sha256=fzXv8lwHskSxsf8hm_wlKPkpq50iafmBjBNQzwZ6a00,40050
|
|
23
23
|
sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
|
|
24
24
|
sae_lens/saes/temporal_sae.py,sha256=DsecivcHWId-MTuJpQbz8OhqtmGhZACxJauYZGHo0Ok,13272
|
|
25
|
-
sae_lens/saes/topk_sae.py,sha256=
|
|
26
|
-
sae_lens/saes/transcoder.py,sha256=
|
|
25
|
+
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
26
|
+
sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
|
|
27
27
|
sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
|
|
28
28
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
29
|
sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
|
|
30
|
-
sae_lens/training/activations_store.py,sha256=
|
|
30
|
+
sae_lens/training/activations_store.py,sha256=rQadexm2BiwK7_MZIPlRkcKSqabi3iuOTC-R8aJchS8,33778
|
|
31
31
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
32
32
|
sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
|
|
33
33
|
sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
|
|
34
34
|
sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
35
35
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
36
36
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
37
|
-
sae_lens/util.py,sha256=
|
|
38
|
-
sae_lens-6.
|
|
39
|
-
sae_lens-6.
|
|
40
|
-
sae_lens-6.
|
|
41
|
-
sae_lens-6.
|
|
37
|
+
sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
|
|
38
|
+
sae_lens-6.25.1.dist-info/METADATA,sha256=gClFVWzEWNNjrXsGqvCY6ry6ehXIFwp8PB0jIOhmQvc,5361
|
|
39
|
+
sae_lens-6.25.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
40
|
+
sae_lens-6.25.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
41
|
+
sae_lens-6.25.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|