sae-lens 6.3.0__tar.gz → 6.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {sae_lens-6.3.0 → sae_lens-6.3.1}/PKG-INFO +1 -1
  2. {sae_lens-6.3.0 → sae_lens-6.3.1}/pyproject.toml +1 -1
  3. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/llm_sae_training_runner.py +3 -1
  5. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/sae_trainer.py +0 -8
  6. {sae_lens-6.3.0 → sae_lens-6.3.1}/LICENSE +0 -0
  7. {sae_lens-6.3.0 → sae_lens-6.3.1}/README.md +0 -0
  8. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/analysis/__init__.py +0 -0
  9. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  10. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  11. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/cache_activations_runner.py +0 -0
  12. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/config.py +0 -0
  13. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/constants.py +0 -0
  14. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/evals.py +0 -0
  15. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/load_model.py +0 -0
  16. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/loading/__init__.py +0 -0
  17. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  18. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  19. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/pretokenize_runner.py +0 -0
  20. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/pretrained_saes.yaml +0 -0
  21. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/registry.py +0 -0
  22. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/__init__.py +0 -0
  23. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/batchtopk_sae.py +0 -0
  24. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/gated_sae.py +0 -0
  25. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/jumprelu_sae.py +0 -0
  26. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/sae.py +0 -0
  27. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/standard_sae.py +0 -0
  28. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/topk_sae.py +0 -0
  29. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/tokenization_and_batching.py +0 -0
  30. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/__init__.py +0 -0
  31. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/activation_scaler.py +0 -0
  32. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/activations_store.py +0 -0
  33. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/mixing_buffer.py +0 -0
  34. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/optim.py +0 -0
  35. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/types.py +0 -0
  36. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  37. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/tutorial/tsea.py +0 -0
  38. {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.3.0
3
+ Version: 6.3.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.3.0"
3
+ version = "6.3.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.3.0"
2
+ __version__ = "6.3.1"
3
3
 
4
4
  import logging
5
5
 
@@ -17,6 +17,7 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
17
17
  from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
18
  from sae_lens.evals import EvalConfig, run_evals
19
19
  from sae_lens.load_model import load_model
20
+ from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
20
21
  from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
21
22
  from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
22
23
  from sae_lens.saes.sae import (
@@ -291,7 +292,7 @@ def _parse_cfg_args(
291
292
  architecture_parser.add_argument(
292
293
  "--architecture",
293
294
  type=str,
294
- choices=["standard", "gated", "jumprelu", "topk"],
295
+ choices=["standard", "gated", "jumprelu", "topk", "batchtopk"],
295
296
  default="standard",
296
297
  help="SAE architecture to use",
297
298
  )
@@ -352,6 +353,7 @@ def _parse_cfg_args(
352
353
  "gated": GatedTrainingSAEConfig,
353
354
  "jumprelu": JumpReLUTrainingSAEConfig,
354
355
  "topk": TopKTrainingSAEConfig,
356
+ "batchtopk": BatchTopKTrainingSAEConfig,
355
357
  }
356
358
 
357
359
  sae_config_type = sae_config_map[architecture]
@@ -1,5 +1,4 @@
1
1
  import contextlib
2
- from dataclasses import dataclass
3
2
  from pathlib import Path
4
3
  from typing import Any, Callable, Generic, Protocol
5
4
 
@@ -38,13 +37,6 @@ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
38
37
  sae.cfg.sae_lens_training_version = str(__version__)
39
38
 
40
39
 
41
- @dataclass
42
- class TrainSAEOutput:
43
- sae: TrainingSAE[Any]
44
- checkpoint_path: str
45
- log_feature_sparsities: torch.Tensor
46
-
47
-
48
40
  class SaveCheckpointFn(Protocol):
49
41
  def __call__(
50
42
  self,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes