sae-lens 6.3.0__tar.gz → 6.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.3.0 → sae_lens-6.3.1}/PKG-INFO +1 -1
- {sae_lens-6.3.0 → sae_lens-6.3.1}/pyproject.toml +1 -1
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/__init__.py +1 -1
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/llm_sae_training_runner.py +3 -1
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/sae_trainer.py +0 -8
- {sae_lens-6.3.0 → sae_lens-6.3.1}/LICENSE +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/README.md +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/config.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/constants.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/evals.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/load_model.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/types.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.3.1}/sae_lens/util.py +0 -0
|
@@ -17,6 +17,7 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
|
17
17
|
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
|
|
18
18
|
from sae_lens.evals import EvalConfig, run_evals
|
|
19
19
|
from sae_lens.load_model import load_model
|
|
20
|
+
from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
|
|
20
21
|
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
21
22
|
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
22
23
|
from sae_lens.saes.sae import (
|
|
@@ -291,7 +292,7 @@ def _parse_cfg_args(
|
|
|
291
292
|
architecture_parser.add_argument(
|
|
292
293
|
"--architecture",
|
|
293
294
|
type=str,
|
|
294
|
-
choices=["standard", "gated", "jumprelu", "topk"],
|
|
295
|
+
choices=["standard", "gated", "jumprelu", "topk", "batchtopk"],
|
|
295
296
|
default="standard",
|
|
296
297
|
help="SAE architecture to use",
|
|
297
298
|
)
|
|
@@ -352,6 +353,7 @@ def _parse_cfg_args(
|
|
|
352
353
|
"gated": GatedTrainingSAEConfig,
|
|
353
354
|
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
354
355
|
"topk": TopKTrainingSAEConfig,
|
|
356
|
+
"batchtopk": BatchTopKTrainingSAEConfig,
|
|
355
357
|
}
|
|
356
358
|
|
|
357
359
|
sae_config_type = sae_config_map[architecture]
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
from dataclasses import dataclass
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
from typing import Any, Callable, Generic, Protocol
|
|
5
4
|
|
|
@@ -38,13 +37,6 @@ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
|
|
|
38
37
|
sae.cfg.sae_lens_training_version = str(__version__)
|
|
39
38
|
|
|
40
39
|
|
|
41
|
-
@dataclass
|
|
42
|
-
class TrainSAEOutput:
|
|
43
|
-
sae: TrainingSAE[Any]
|
|
44
|
-
checkpoint_path: str
|
|
45
|
-
log_feature_sparsities: torch.Tensor
|
|
46
|
-
|
|
47
|
-
|
|
48
40
|
class SaveCheckpointFn(Protocol):
|
|
49
41
|
def __call__(
|
|
50
42
|
self,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|