sae-lens 6.22.1__py3-none-any.whl → 6.25.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/saes/__init__.py CHANGED
@@ -33,6 +33,8 @@ from .topk_sae import (
33
33
  TopKTrainingSAEConfig,
34
34
  )
35
35
  from .transcoder import (
36
+ JumpReLUSkipTranscoder,
37
+ JumpReLUSkipTranscoderConfig,
36
38
  JumpReLUTranscoder,
37
39
  JumpReLUTranscoderConfig,
38
40
  SkipTranscoder,
@@ -70,6 +72,8 @@ __all__ = [
70
72
  "SkipTranscoderConfig",
71
73
  "JumpReLUTranscoder",
72
74
  "JumpReLUTranscoderConfig",
75
+ "JumpReLUSkipTranscoder",
76
+ "JumpReLUSkipTranscoderConfig",
73
77
  "MatryoshkaBatchTopKTrainingSAE",
74
78
  "MatryoshkaBatchTopKTrainingSAEConfig",
75
79
  "TemporalSAE",
@@ -89,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
89
89
  @torch.no_grad()
90
90
  def fold_W_dec_norm(self):
91
91
  """Override to handle gated-specific parameters."""
92
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
92
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
93
93
  self.W_dec.data = self.W_dec.data / W_dec_norms
94
94
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
95
95
 
@@ -217,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
217
217
  @torch.no_grad()
218
218
  def fold_W_dec_norm(self):
219
219
  """Override to handle gated-specific parameters."""
220
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
220
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
221
221
  self.W_dec.data = self.W_dec.data / W_dec_norms
222
222
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
223
223
 
@@ -167,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
167
167
  # Save the current threshold before calling parent method
168
168
  current_thresh = self.threshold.clone()
169
169
 
170
- # Get W_dec norms that will be used for scaling
171
- W_dec_norms = self.W_dec.norm(dim=-1)
170
+ # Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
171
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
172
172
 
173
173
  # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
174
174
  super().fold_W_dec_norm()
@@ -325,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
325
325
  # Save the current threshold before we call the parent method
326
326
  current_thresh = self.threshold.clone()
327
327
 
328
- # Get W_dec norms
329
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
328
+ # Get W_dec norms (clamped to avoid division by zero)
329
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
330
330
 
331
331
  # Call parent implementation to handle W_enc and W_dec adjustment
332
332
  super().fold_W_dec_norm()
sae_lens/saes/sae.py CHANGED
@@ -27,11 +27,10 @@ from typing_extensions import deprecated, overload, override
27
27
 
28
28
  from sae_lens import __version__
29
29
  from sae_lens.constants import (
30
- DTYPE_MAP,
31
30
  SAE_CFG_FILENAME,
32
31
  SAE_WEIGHTS_FILENAME,
33
32
  )
34
- from sae_lens.util import filter_valid_dataclass_fields
33
+ from sae_lens.util import dtype_to_str, filter_valid_dataclass_fields, str_to_dtype
35
34
 
36
35
  if TYPE_CHECKING:
37
36
  from sae_lens.config import LanguageModelSAERunnerConfig
@@ -253,7 +252,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
253
252
  stacklevel=1,
254
253
  )
255
254
 
256
- self.dtype = DTYPE_MAP[cfg.dtype]
255
+ self.dtype = str_to_dtype(cfg.dtype)
257
256
  self.device = torch.device(cfg.device)
258
257
  self.use_error_term = use_error_term
259
258
 
@@ -437,8 +436,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
437
436
 
438
437
  # Update dtype in config if provided
439
438
  if dtype_arg is not None:
440
- # Update the cfg.dtype
441
- self.cfg.dtype = str(dtype_arg)
439
+ # Update the cfg.dtype (use canonical short form like "float32")
440
+ self.cfg.dtype = dtype_to_str(dtype_arg)
442
441
 
443
442
  # Update the dtype property
444
443
  self.dtype = dtype_arg
@@ -484,7 +483,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
484
483
  @torch.no_grad()
485
484
  def fold_W_dec_norm(self):
486
485
  """Fold decoder norms into encoder."""
487
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
486
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
488
487
  self.W_dec.data = self.W_dec.data / W_dec_norms
489
488
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
490
489
 
@@ -534,6 +533,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
534
533
  dtype: str | None = None,
535
534
  converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
536
535
  ) -> T_SAE:
536
+ """
537
+ Load a SAE from disk.
538
+
539
+ Args:
540
+ path: The path to the SAE weights and config.
541
+ device: The device to load the SAE on, defaults to "cpu".
542
+ dtype: The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.
543
+ converter: The converter to use to load the SAE, defaults to sae_lens_disk_loader.
544
+ """
537
545
  overrides = {"dtype": dtype} if dtype is not None else None
538
546
  cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
539
547
  cfg_dict = handle_config_defaulting(cfg_dict)
@@ -542,10 +550,17 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
542
550
  )
543
551
  sae_cfg = sae_config_cls.from_dict(cfg_dict)
544
552
  sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
553
+ # hack to avoid using double memory when loading the SAE.
554
+ # first put the SAE on the meta device, then load the weights.
555
+ device = sae_cfg.device
556
+ sae_cfg.device = "meta"
545
557
  sae = sae_cls(sae_cfg)
558
+ sae.cfg.device = device
546
559
  sae.process_state_dict_for_loading(state_dict)
547
- sae.load_state_dict(state_dict)
548
- return sae
560
+ sae.load_state_dict(state_dict, assign=True)
561
+ # the loaders should already handle the dtype / device conversion
562
+ # but this is a fallback to guarantee the SAE is on the correct device and dtype
563
+ return sae.to(dtype=str_to_dtype(sae_cfg.dtype), device=device)
549
564
 
550
565
  @classmethod
551
566
  def from_pretrained(
@@ -553,6 +568,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
553
568
  release: str,
554
569
  sae_id: str,
555
570
  device: str = "cpu",
571
+ dtype: str = "float32",
556
572
  force_download: bool = False,
557
573
  converter: PretrainedSaeHuggingfaceLoader | None = None,
558
574
  ) -> T_SAE:
@@ -562,10 +578,18 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
562
578
  Args:
563
579
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
564
580
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
565
- device: The device to load the SAE on.
581
+ device: The device to load the SAE on, defaults to "cpu".
582
+ dtype: The dtype to load the SAE on, defaults to "float32".
583
+ force_download: Whether to force download the SAE weights and config, defaults to False.
584
+ converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
566
585
  """
567
586
  return cls.from_pretrained_with_cfg_and_sparsity(
568
- release, sae_id, device, force_download, converter=converter
587
+ release,
588
+ sae_id,
589
+ device,
590
+ force_download=force_download,
591
+ dtype=dtype,
592
+ converter=converter,
569
593
  )[0]
570
594
 
571
595
  @classmethod
@@ -574,6 +598,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
574
598
  release: str,
575
599
  sae_id: str,
576
600
  device: str = "cpu",
601
+ dtype: str = "float32",
577
602
  force_download: bool = False,
578
603
  converter: PretrainedSaeHuggingfaceLoader | None = None,
579
604
  ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
@@ -584,7 +609,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
584
609
  Args:
585
610
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
586
611
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
587
- device: The device to load the SAE on.
612
+ device: The device to load the SAE on, defaults to "cpu".
613
+ dtype: The dtype to load the SAE on, defaults to "float32".
614
+ force_download: Whether to force download the SAE weights and config, defaults to False.
615
+ converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
588
616
  """
589
617
 
590
618
  # get sae directory
@@ -634,6 +662,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
634
662
  repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
635
663
  config_overrides = get_config_overrides(release, sae_id)
636
664
  config_overrides["device"] = device
665
+ config_overrides["dtype"] = dtype
637
666
 
638
667
  # Load config and weights
639
668
  cfg_dict, state_dict, log_sparsities = conversion_loader(
@@ -651,9 +680,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
651
680
  )
652
681
  sae_cfg = sae_config_cls.from_dict(cfg_dict)
653
682
  sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
683
+ # hack to avoid using double memory when loading the SAE.
684
+ # first put the SAE on the meta device, then load the weights.
685
+ device = sae_cfg.device
686
+ sae_cfg.device = "meta"
654
687
  sae = sae_cls(sae_cfg)
688
+ sae.cfg.device = device
655
689
  sae.process_state_dict_for_loading(state_dict)
656
- sae.load_state_dict(state_dict)
690
+ sae.load_state_dict(state_dict, assign=True)
657
691
 
658
692
  # Apply normalization if needed
659
693
  if cfg_dict.get("normalize_activations") == "expected_average_only_in":
@@ -666,7 +700,13 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
666
700
  f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
667
701
  )
668
702
 
669
- return sae, cfg_dict, log_sparsities
703
+ # the loaders should already handle the dtype / device conversion
704
+ # but this is a fallback to guarantee the SAE is on the correct device and dtype
705
+ return (
706
+ sae.to(dtype=str_to_dtype(dtype), device=device),
707
+ cfg_dict,
708
+ log_sparsities,
709
+ )
670
710
 
671
711
  @classmethod
672
712
  def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
sae_lens/saes/topk_sae.py CHANGED
@@ -531,7 +531,7 @@ def _fold_norm_topk(
531
531
  b_enc: torch.Tensor,
532
532
  W_dec: torch.Tensor,
533
533
  ) -> None:
534
- W_dec_norm = W_dec.norm(dim=-1)
534
+ W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
535
535
  b_enc.data = b_enc.data * W_dec_norm
536
536
  W_dec_norms = W_dec_norm.unsqueeze(1)
537
537
  W_dec.data = W_dec.data / W_dec_norms
@@ -368,3 +368,44 @@ class JumpReLUTranscoder(Transcoder):
368
368
  def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoder":
369
369
  cfg = JumpReLUTranscoderConfig.from_dict(config_dict)
370
370
  return cls(cfg)
371
+
372
+
373
+ @dataclass
374
+ class JumpReLUSkipTranscoderConfig(JumpReLUTranscoderConfig):
375
+ """Configuration for JumpReLU transcoder."""
376
+
377
+ @classmethod
378
+ def architecture(cls) -> str:
379
+ """Return the architecture name for this config."""
380
+ return "jumprelu_skip_transcoder"
381
+
382
+ @classmethod
383
+ def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoderConfig":
384
+ """Create a JumpReLUSkipTranscoderConfig from a dictionary."""
385
+ # Filter to only include valid dataclass fields
386
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
387
+
388
+ # Create the config instance
389
+ res = cls(**filtered_config_dict)
390
+
391
+ # Handle metadata if present
392
+ if "metadata" in config_dict:
393
+ res.metadata = SAEMetadata(**config_dict["metadata"])
394
+
395
+ return res
396
+
397
+
398
+ class JumpReLUSkipTranscoder(JumpReLUTranscoder, SkipTranscoder):
399
+ """
400
+ A transcoder with a learnable skip connection and JumpReLU activation function.
401
+ """
402
+
403
+ cfg: JumpReLUSkipTranscoderConfig # type: ignore[assignment]
404
+
405
+ def __init__(self, cfg: JumpReLUSkipTranscoderConfig):
406
+ super().__init__(cfg)
407
+
408
+ @classmethod
409
+ def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoder":
410
+ cfg = JumpReLUSkipTranscoderConfig.from_dict(config_dict)
411
+ return cls(cfg)
@@ -24,7 +24,7 @@ from sae_lens.config import (
24
24
  HfDataset,
25
25
  LanguageModelSAERunnerConfig,
26
26
  )
27
- from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
27
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
28
28
  from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
@@ -32,6 +32,7 @@ from sae_lens.training.mixing_buffer import mixing_buffer
32
32
  from sae_lens.util import (
33
33
  extract_stop_at_layer_from_tlens_hook_name,
34
34
  get_special_token_ids,
35
+ str_to_dtype,
35
36
  )
36
37
 
37
38
 
@@ -166,9 +167,11 @@ class ActivationsStore:
166
167
  disable_concat_sequences: bool = False,
167
168
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
168
169
  ) -> ActivationsStore:
170
+ if context_size is None:
171
+ context_size = sae.cfg.metadata.context_size
169
172
  if sae.cfg.metadata.hook_name is None:
170
173
  raise ValueError("hook_name is required")
171
- if sae.cfg.metadata.context_size is None:
174
+ if context_size is None:
172
175
  raise ValueError("context_size is required")
173
176
  if sae.cfg.metadata.prepend_bos is None:
174
177
  raise ValueError("prepend_bos is required")
@@ -178,9 +181,7 @@ class ActivationsStore:
178
181
  d_in=sae.cfg.d_in,
179
182
  hook_name=sae.cfg.metadata.hook_name,
180
183
  hook_head_index=sae.cfg.metadata.hook_head_index,
181
- context_size=sae.cfg.metadata.context_size
182
- if context_size is None
183
- else context_size,
184
+ context_size=context_size,
184
185
  prepend_bos=sae.cfg.metadata.prepend_bos,
185
186
  streaming=streaming,
186
187
  store_batch_size_prompts=store_batch_size_prompts,
@@ -230,7 +231,7 @@ class ActivationsStore:
230
231
  load_dataset(
231
232
  dataset,
232
233
  split="train",
233
- streaming=streaming,
234
+ streaming=streaming, # type: ignore
234
235
  trust_remote_code=dataset_trust_remote_code, # type: ignore
235
236
  )
236
237
  if isinstance(dataset, str)
@@ -258,7 +259,7 @@ class ActivationsStore:
258
259
  self.prepend_bos = prepend_bos
259
260
  self.normalize_activations = normalize_activations
260
261
  self.device = torch.device(device)
261
- self.dtype = DTYPE_MAP[dtype]
262
+ self.dtype = str_to_dtype(dtype)
262
263
  self.cached_activations_path = cached_activations_path
263
264
  self.autocast_lm = autocast_lm
264
265
  self.seqpos_slice = seqpos_slice
sae_lens/util.py CHANGED
@@ -5,8 +5,11 @@ from dataclasses import asdict, fields, is_dataclass
5
5
  from pathlib import Path
6
6
  from typing import Sequence, TypeVar
7
7
 
8
+ import torch
8
9
  from transformers import PreTrainedTokenizerBase
9
10
 
11
+ from sae_lens.constants import DTYPE_MAP, DTYPE_TO_STR
12
+
10
13
  K = TypeVar("K")
11
14
  V = TypeVar("V")
12
15
 
@@ -90,3 +93,21 @@ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
90
93
  special_tokens.add(token_id)
91
94
 
92
95
  return list(special_tokens)
96
+
97
+
98
+ def str_to_dtype(dtype: str) -> torch.dtype:
99
+ """Convert a string to a torch.dtype."""
100
+ if dtype not in DTYPE_MAP:
101
+ raise ValueError(
102
+ f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_MAP.keys())}"
103
+ )
104
+ return DTYPE_MAP[dtype]
105
+
106
+
107
+ def dtype_to_str(dtype: torch.dtype) -> str:
108
+ """Convert a torch.dtype to a string."""
109
+ if dtype not in DTYPE_TO_STR:
110
+ raise ValueError(
111
+ f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_TO_STR.keys())}"
112
+ )
113
+ return DTYPE_TO_STR[dtype]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.22.1
3
+ Version: 6.25.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -62,7 +62,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
62
62
 
63
63
  ## Loading Pre-trained SAEs.
64
64
 
65
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
65
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
66
66
 
67
67
  ## Migrating to SAELens v6
68
68
 
@@ -1,41 +1,41 @@
1
- sae_lens/__init__.py,sha256=v-2uKiNW5UNVCRt7vyBrvI0olJsXIxaPp9TJvo-m9wg,4033
1
+ sae_lens/__init__.py,sha256=vWuA8EbynIJadj666RoFNCTIvoH9-HFpUxuHwoYt8Ks,4268
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
- sae_lens/cache_activations_runner.py,sha256=RNN_nDQkH0lqEIxTAIDx3g1cgAzRxQWBSBEXA6nbWh0,12565
6
- sae_lens/config.py,sha256=fxvpQxFfPOVUkryiHD19q9O1AJDSkIguWeYlbJuTxmY,30329
7
- sae_lens/constants.py,sha256=qX12uAE_xkha6hjss_0MGTbakI7gEkJzHABkZaHWQFU,683
5
+ sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
6
+ sae_lens/config.py,sha256=JmcrXT4orJV2OulbEZAciz8RQmYv7DrtUtRbOLsNQ2Y,30330
7
+ sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
8
  sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
9
  sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
10
10
  sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
11
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=X-gVZ4A74E85lSMFMsZ_rEQhHlR9AYFwhxvoA_vt2CQ,56051
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=hq-dhxsEdUmlAnZEiZBqX7lNyQQwZ6KXmXZWpzAc5FY,63638
13
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
14
- sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r4,7085
15
- sae_lens/pretrained_saes.yaml,sha256=VzgJ_t-IEWpO2MabgQY6CAcg8FFsqZWiOVXjqvqfgeE,604973
14
+ sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
+ sae_lens/pretrained_saes.yaml,sha256=Hy9mk4Liy50B0CIBD4ER1ETcho2drFFiIy-bPVCN_lc,1510210
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
- sae_lens/saes/__init__.py,sha256=nTNPnJ7edyfedo1MX96xwn9WOG8504yHbT9LFw9od_0,1778
17
+ sae_lens/saes/__init__.py,sha256=fYVujOzNnUgpzLL0MBLBt_DNX2CPcTaheukzCd2bEPo,1906
18
18
  sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
- sae_lens/saes/gated_sae.py,sha256=Jq74JGtqpO6tW3XdJGbURTTWN_fAoAMKu9T7O-MZTeE,8793
20
- sae_lens/saes/jumprelu_sae.py,sha256=zUGHWOFXbeDBS3mjkOE3ikxlEniq2EX9rCAizLMOpp4,13206
19
+ sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
20
+ sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
21
21
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
22
- sae_lens/saes/sae.py,sha256=q8ylAdqtkNAms7X-3y1QIBfHOZ-FvKHvCap7Tw_cnzE,37733
22
+ sae_lens/saes/sae.py,sha256=fzXv8lwHskSxsf8hm_wlKPkpq50iafmBjBNQzwZ6a00,40050
23
23
  sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
24
24
  sae_lens/saes/temporal_sae.py,sha256=DsecivcHWId-MTuJpQbz8OhqtmGhZACxJauYZGHo0Ok,13272
25
- sae_lens/saes/topk_sae.py,sha256=D1N4LHGOeV8dhHW0i3HqBT1cqA-E1Plq11uMJtVfNBo,21057
26
- sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,12156
25
+ sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
26
+ sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
27
27
  sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
28
28
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
30
- sae_lens/training/activations_store.py,sha256=1ozCANGXO8Vx9d_l-heb-MsSpUoYcHagcve5JLGwZYY,33762
30
+ sae_lens/training/activations_store.py,sha256=rQadexm2BiwK7_MZIPlRkcKSqabi3iuOTC-R8aJchS8,33778
31
31
  sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
32
32
  sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
33
33
  sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
34
34
  sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
35
35
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
36
36
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
37
- sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
38
- sae_lens-6.22.1.dist-info/METADATA,sha256=QoCu9iHTvA66XSkU2aR_4VxP7wGFr_NQPJUZwxvaOak,5369
39
- sae_lens-6.22.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
40
- sae_lens-6.22.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
41
- sae_lens-6.22.1.dist-info/RECORD,,
37
+ sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
38
+ sae_lens-6.25.1.dist-info/METADATA,sha256=gClFVWzEWNNjrXsGqvCY6ry6ehXIFwp8PB0jIOhmQvc,5361
39
+ sae_lens-6.25.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
40
+ sae_lens-6.25.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
41
+ sae_lens-6.25.1.dist-info/RECORD,,