sae-lens 6.3.0__tar.gz → 6.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {sae_lens-6.3.0 → sae_lens-6.4.0}/PKG-INFO +1 -1
  2. {sae_lens-6.3.0 → sae_lens-6.4.0}/pyproject.toml +1 -1
  3. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/config.py +10 -1
  5. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/llm_sae_training_runner.py +3 -1
  6. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/sae_trainer.py +0 -8
  7. {sae_lens-6.3.0 → sae_lens-6.4.0}/LICENSE +0 -0
  8. {sae_lens-6.3.0 → sae_lens-6.4.0}/README.md +0 -0
  9. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/analysis/__init__.py +0 -0
  10. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  11. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  12. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/cache_activations_runner.py +0 -0
  13. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/constants.py +0 -0
  14. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/evals.py +0 -0
  15. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/load_model.py +0 -0
  16. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/loading/__init__.py +0 -0
  17. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  18. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  19. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/pretokenize_runner.py +0 -0
  20. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/pretrained_saes.yaml +0 -0
  21. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/registry.py +0 -0
  22. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/__init__.py +0 -0
  23. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  24. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/gated_sae.py +0 -0
  25. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  26. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/sae.py +0 -0
  27. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/standard_sae.py +0 -0
  28. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/topk_sae.py +0 -0
  29. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/tokenization_and_batching.py +0 -0
  30. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/__init__.py +0 -0
  31. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/activation_scaler.py +0 -0
  32. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/activations_store.py +0 -0
  33. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/mixing_buffer.py +0 -0
  34. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/optim.py +0 -0
  35. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/types.py +0 -0
  36. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  37. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/tutorial/tsea.py +0 -0
  38. {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.3.0
3
+ Version: 6.4.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.3.0"
3
+ version = "6.4.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.3.0"
2
+ __version__ = "6.4.0"
3
3
 
4
4
  import logging
5
5
 
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import math
3
+ import warnings
3
4
  from dataclasses import asdict, dataclass, field
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
@@ -125,7 +126,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
125
126
  model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
126
127
  model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
127
128
  hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
128
- hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
129
+ hook_eval (str): DEPRECATED: Will be removed in v7.0.0. NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
129
130
  hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
130
131
  dataset_path (str): A Hugging Face dataset path.
131
132
  dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
@@ -264,6 +265,14 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
264
265
  exclude_special_tokens: bool | list[int] = False
265
266
 
266
267
  def __post_init__(self):
268
+ if self.hook_eval != "NOT_IN_USE":
269
+ warnings.warn(
270
+ "The 'hook_eval' field is deprecated and will be removed in v7.0.0. "
271
+ "It is not currently used and can be safely removed from your config.",
272
+ DeprecationWarning,
273
+ stacklevel=2,
274
+ )
275
+
267
276
  if self.use_cached_activations and self.cached_activations_path is None:
268
277
  self.cached_activations_path = _default_cached_activations_path(
269
278
  self.dataset_path,
@@ -17,6 +17,7 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
17
17
  from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
18
  from sae_lens.evals import EvalConfig, run_evals
19
19
  from sae_lens.load_model import load_model
20
+ from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
20
21
  from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
21
22
  from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
22
23
  from sae_lens.saes.sae import (
@@ -291,7 +292,7 @@ def _parse_cfg_args(
291
292
  architecture_parser.add_argument(
292
293
  "--architecture",
293
294
  type=str,
294
- choices=["standard", "gated", "jumprelu", "topk"],
295
+ choices=["standard", "gated", "jumprelu", "topk", "batchtopk"],
295
296
  default="standard",
296
297
  help="SAE architecture to use",
297
298
  )
@@ -352,6 +353,7 @@ def _parse_cfg_args(
352
353
  "gated": GatedTrainingSAEConfig,
353
354
  "jumprelu": JumpReLUTrainingSAEConfig,
354
355
  "topk": TopKTrainingSAEConfig,
356
+ "batchtopk": BatchTopKTrainingSAEConfig,
355
357
  }
356
358
 
357
359
  sae_config_type = sae_config_map[architecture]
@@ -1,5 +1,4 @@
1
1
  import contextlib
2
- from dataclasses import dataclass
3
2
  from pathlib import Path
4
3
  from typing import Any, Callable, Generic, Protocol
5
4
 
@@ -38,13 +37,6 @@ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
38
37
  sae.cfg.sae_lens_training_version = str(__version__)
39
38
 
40
39
 
41
- @dataclass
42
- class TrainSAEOutput:
43
- sae: TrainingSAE[Any]
44
- checkpoint_path: str
45
- log_feature_sparsities: torch.Tensor
46
-
47
-
48
40
  class SaveCheckpointFn(Protocol):
49
41
  def __call__(
50
42
  self,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes