sae-lens 6.3.0__tar.gz → 6.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.3.0 → sae_lens-6.4.0}/PKG-INFO +1 -1
- {sae_lens-6.3.0 → sae_lens-6.4.0}/pyproject.toml +1 -1
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/config.py +10 -1
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/llm_sae_training_runner.py +3 -1
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/sae_trainer.py +0 -8
- {sae_lens-6.3.0 → sae_lens-6.4.0}/LICENSE +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/README.md +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.4.0}/sae_lens/util.py +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
|
+
import warnings
|
|
3
4
|
from dataclasses import asdict, dataclass, field
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
|
@@ -125,7 +126,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
125
126
|
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
|
|
126
127
|
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
|
|
127
128
|
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
|
|
128
|
-
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
129
|
+
hook_eval (str): DEPRECATED: Will be removed in v7.0.0. NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
129
130
|
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
|
|
130
131
|
dataset_path (str): A Hugging Face dataset path.
|
|
131
132
|
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
|
|
@@ -264,6 +265,14 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
264
265
|
exclude_special_tokens: bool | list[int] = False
|
|
265
266
|
|
|
266
267
|
def __post_init__(self):
|
|
268
|
+
if self.hook_eval != "NOT_IN_USE":
|
|
269
|
+
warnings.warn(
|
|
270
|
+
"The 'hook_eval' field is deprecated and will be removed in v7.0.0. "
|
|
271
|
+
"It is not currently used and can be safely removed from your config.",
|
|
272
|
+
DeprecationWarning,
|
|
273
|
+
stacklevel=2,
|
|
274
|
+
)
|
|
275
|
+
|
|
267
276
|
if self.use_cached_activations and self.cached_activations_path is None:
|
|
268
277
|
self.cached_activations_path = _default_cached_activations_path(
|
|
269
278
|
self.dataset_path,
|
|
@@ -17,6 +17,7 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
|
17
17
|
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
|
|
18
18
|
from sae_lens.evals import EvalConfig, run_evals
|
|
19
19
|
from sae_lens.load_model import load_model
|
|
20
|
+
from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
|
|
20
21
|
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
21
22
|
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
22
23
|
from sae_lens.saes.sae import (
|
|
@@ -291,7 +292,7 @@ def _parse_cfg_args(
|
|
|
291
292
|
architecture_parser.add_argument(
|
|
292
293
|
"--architecture",
|
|
293
294
|
type=str,
|
|
294
|
-
choices=["standard", "gated", "jumprelu", "topk"],
|
|
295
|
+
choices=["standard", "gated", "jumprelu", "topk", "batchtopk"],
|
|
295
296
|
default="standard",
|
|
296
297
|
help="SAE architecture to use",
|
|
297
298
|
)
|
|
@@ -352,6 +353,7 @@ def _parse_cfg_args(
|
|
|
352
353
|
"gated": GatedTrainingSAEConfig,
|
|
353
354
|
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
354
355
|
"topk": TopKTrainingSAEConfig,
|
|
356
|
+
"batchtopk": BatchTopKTrainingSAEConfig,
|
|
355
357
|
}
|
|
356
358
|
|
|
357
359
|
sae_config_type = sae_config_map[architecture]
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
from dataclasses import dataclass
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
from typing import Any, Callable, Generic, Protocol
|
|
5
4
|
|
|
@@ -38,13 +37,6 @@ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
|
|
|
38
37
|
sae.cfg.sae_lens_training_version = str(__version__)
|
|
39
38
|
|
|
40
39
|
|
|
41
|
-
@dataclass
|
|
42
|
-
class TrainSAEOutput:
|
|
43
|
-
sae: TrainingSAE[Any]
|
|
44
|
-
checkpoint_path: str
|
|
45
|
-
log_feature_sparsities: torch.Tensor
|
|
46
|
-
|
|
47
|
-
|
|
48
40
|
class SaveCheckpointFn(Protocol):
|
|
49
41
|
def __call__(
|
|
50
42
|
self,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|