sae-lens 6.0.0rc2__tar.gz → 6.0.0rc4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/PKG-INFO +1 -1
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/pyproject.toml +2 -1
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/__init__.py +6 -3
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/analysis/neuronpedia_integration.py +3 -3
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/cache_activations_runner.py +7 -6
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/config.py +50 -6
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/constants.py +2 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/evals.py +39 -28
- sae_lens-6.0.0rc4/sae_lens/llm_sae_training_runner.py +377 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/load_model.py +53 -5
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/loading/pretrained_sae_loaders.py +24 -12
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/saes/gated_sae.py +0 -4
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/saes/jumprelu_sae.py +4 -10
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/saes/sae.py +121 -51
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/saes/standard_sae.py +4 -11
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/saes/topk_sae.py +18 -12
- sae_lens-6.0.0rc4/sae_lens/training/activation_scaler.py +53 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/training/activations_store.py +77 -174
- sae_lens-6.0.0rc4/sae_lens/training/mixing_buffer.py +56 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/training/sae_trainer.py +107 -98
- sae_lens-6.0.0rc4/sae_lens/training/types.py +5 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/training/upload_saes_to_huggingface.py +1 -1
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/util.py +19 -0
- sae_lens-6.0.0rc2/sae_lens/sae_training_runner.py +0 -237
- sae_lens-6.0.0rc2/sae_lens/training/geometric_median.py +0 -101
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/LICENSE +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/README.md +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/registry.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.0.0rc2 → sae_lens-6.0.0rc4}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.0.0-rc.
|
|
3
|
+
version = "6.0.0-rc.4"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -60,6 +60,7 @@ tabulate = "^0.9.0"
|
|
|
60
60
|
ruff = "^0.7.4"
|
|
61
61
|
eai-sparsify = "^1.1.1"
|
|
62
62
|
mike = "^2.0.0"
|
|
63
|
+
trio = "^0.30.0"
|
|
63
64
|
|
|
64
65
|
[tool.poetry.extras]
|
|
65
66
|
mamba = ["mamba-lens"]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.0.0-rc.
|
|
2
|
+
__version__ = "6.0.0-rc.4"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -33,16 +33,17 @@ from .cache_activations_runner import CacheActivationsRunner
|
|
|
33
33
|
from .config import (
|
|
34
34
|
CacheActivationsRunnerConfig,
|
|
35
35
|
LanguageModelSAERunnerConfig,
|
|
36
|
+
LoggingConfig,
|
|
36
37
|
PretokenizeRunnerConfig,
|
|
37
38
|
)
|
|
38
39
|
from .evals import run_evals
|
|
40
|
+
from .llm_sae_training_runner import LanguageModelSAETrainingRunner, SAETrainingRunner
|
|
39
41
|
from .loading.pretrained_sae_loaders import (
|
|
40
42
|
PretrainedSaeDiskLoader,
|
|
41
43
|
PretrainedSaeHuggingfaceLoader,
|
|
42
44
|
)
|
|
43
45
|
from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
|
|
44
46
|
from .registry import register_sae_class, register_sae_training_class
|
|
45
|
-
from .sae_training_runner import SAETrainingRunner
|
|
46
47
|
from .training.activations_store import ActivationsStore
|
|
47
48
|
from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
|
|
48
49
|
|
|
@@ -54,7 +55,7 @@ __all__ = [
|
|
|
54
55
|
"HookedSAETransformer",
|
|
55
56
|
"ActivationsStore",
|
|
56
57
|
"LanguageModelSAERunnerConfig",
|
|
57
|
-
"
|
|
58
|
+
"LanguageModelSAETrainingRunner",
|
|
58
59
|
"CacheActivationsRunnerConfig",
|
|
59
60
|
"CacheActivationsRunner",
|
|
60
61
|
"PretokenizeRunnerConfig",
|
|
@@ -82,6 +83,8 @@ __all__ = [
|
|
|
82
83
|
"JumpReLUSAEConfig",
|
|
83
84
|
"JumpReLUTrainingSAE",
|
|
84
85
|
"JumpReLUTrainingSAEConfig",
|
|
86
|
+
"SAETrainingRunner",
|
|
87
|
+
"LoggingConfig",
|
|
85
88
|
]
|
|
86
89
|
|
|
87
90
|
|
|
@@ -59,7 +59,7 @@ def NanAndInfReplacer(value: str):
|
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
|
|
62
|
-
sae_id = sae.cfg.neuronpedia_id
|
|
62
|
+
sae_id = sae.cfg.metadata.neuronpedia_id
|
|
63
63
|
if sae_id is None:
|
|
64
64
|
logger.warning(
|
|
65
65
|
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
|
|
@@ -74,7 +74,7 @@ def get_neuronpedia_quick_list(
|
|
|
74
74
|
features: list[int],
|
|
75
75
|
name: str = "temporary_list",
|
|
76
76
|
):
|
|
77
|
-
sae_id = sae.cfg.neuronpedia_id
|
|
77
|
+
sae_id = sae.cfg.metadata.neuronpedia_id
|
|
78
78
|
if sae_id is None:
|
|
79
79
|
logger.warning(
|
|
80
80
|
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
|
|
@@ -86,7 +86,7 @@ def get_neuronpedia_quick_list(
|
|
|
86
86
|
url = url + "?name=" + name
|
|
87
87
|
list_feature = [
|
|
88
88
|
{
|
|
89
|
-
"modelId": sae.cfg.model_name,
|
|
89
|
+
"modelId": sae.cfg.metadata.model_name,
|
|
90
90
|
"layer": sae_id.split("/")[1],
|
|
91
91
|
"index": str(feature),
|
|
92
92
|
}
|
|
@@ -34,7 +34,6 @@ def _mk_activations_store(
|
|
|
34
34
|
dataset=override_dataset or cfg.dataset_path,
|
|
35
35
|
streaming=cfg.streaming,
|
|
36
36
|
hook_name=cfg.hook_name,
|
|
37
|
-
hook_layer=cfg.hook_layer,
|
|
38
37
|
hook_head_index=None,
|
|
39
38
|
context_size=cfg.context_size,
|
|
40
39
|
d_in=cfg.d_in,
|
|
@@ -265,7 +264,7 @@ class CacheActivationsRunner:
|
|
|
265
264
|
|
|
266
265
|
for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
|
|
267
266
|
try:
|
|
268
|
-
buffer = self.activations_store.
|
|
267
|
+
buffer = self.activations_store.get_raw_buffer(
|
|
269
268
|
self.cfg.n_batches_in_buffer, shuffle=False
|
|
270
269
|
)
|
|
271
270
|
shard = self._create_shard(buffer)
|
|
@@ -319,7 +318,7 @@ class CacheActivationsRunner:
|
|
|
319
318
|
def _create_shard(
|
|
320
319
|
self,
|
|
321
320
|
buffer: tuple[
|
|
322
|
-
Float[torch.Tensor, "(bs context_size)
|
|
321
|
+
Float[torch.Tensor, "(bs context_size) d_in"],
|
|
323
322
|
Int[torch.Tensor, "(bs context_size)"] | None,
|
|
324
323
|
],
|
|
325
324
|
) -> Dataset:
|
|
@@ -327,13 +326,15 @@ class CacheActivationsRunner:
|
|
|
327
326
|
acts, token_ids = buffer
|
|
328
327
|
acts = einops.rearrange(
|
|
329
328
|
acts,
|
|
330
|
-
"(bs context_size)
|
|
329
|
+
"(bs context_size) d_in -> bs context_size d_in",
|
|
331
330
|
bs=self.cfg.n_seq_in_buffer,
|
|
332
331
|
context_size=self.context_size,
|
|
333
332
|
d_in=self.cfg.d_in,
|
|
334
|
-
num_layers=len(hook_names),
|
|
335
333
|
)
|
|
336
|
-
shard_dict
|
|
334
|
+
shard_dict: dict[str, object] = {
|
|
335
|
+
hook_name: act_batch
|
|
336
|
+
for hook_name, act_batch in zip(hook_names, [acts], strict=True)
|
|
337
|
+
}
|
|
337
338
|
|
|
338
339
|
if token_ids is not None:
|
|
339
340
|
token_ids = einops.rearrange(
|
|
@@ -23,7 +23,9 @@ from sae_lens.saes.sae import TrainingSAEConfig
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
24
|
pass
|
|
25
25
|
|
|
26
|
-
T_TRAINING_SAE_CONFIG = TypeVar(
|
|
26
|
+
T_TRAINING_SAE_CONFIG = TypeVar(
|
|
27
|
+
"T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
|
|
28
|
+
)
|
|
27
29
|
|
|
28
30
|
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
29
31
|
|
|
@@ -102,7 +104,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
102
104
|
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
|
|
103
105
|
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
|
|
104
106
|
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
105
|
-
hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
|
|
106
107
|
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
|
|
107
108
|
dataset_path (str): A Hugging Face dataset path.
|
|
108
109
|
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
|
|
@@ -159,7 +160,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
159
160
|
model_class_name: str = "HookedTransformer"
|
|
160
161
|
hook_name: str = "blocks.0.hook_mlp_out"
|
|
161
162
|
hook_eval: str = "NOT_IN_USE"
|
|
162
|
-
hook_layer: int = 0
|
|
163
163
|
hook_head_index: int | None = None
|
|
164
164
|
dataset_path: str = ""
|
|
165
165
|
dataset_trust_remote_code: bool = True
|
|
@@ -201,7 +201,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
201
201
|
train_batch_size_tokens: int = 4096
|
|
202
202
|
|
|
203
203
|
## Adam
|
|
204
|
-
adam_beta1: float = 0.
|
|
204
|
+
adam_beta1: float = 0.9
|
|
205
205
|
adam_beta2: float = 0.999
|
|
206
206
|
|
|
207
207
|
## Learning Rate Schedule
|
|
@@ -375,6 +375,27 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
375
375
|
|
|
376
376
|
return cls(**cfg)
|
|
377
377
|
|
|
378
|
+
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
379
|
+
return SAETrainerConfig(
|
|
380
|
+
n_checkpoints=self.n_checkpoints,
|
|
381
|
+
checkpoint_path=self.checkpoint_path,
|
|
382
|
+
total_training_samples=self.total_training_tokens,
|
|
383
|
+
device=self.device,
|
|
384
|
+
autocast=self.autocast,
|
|
385
|
+
lr=self.lr,
|
|
386
|
+
lr_end=self.lr_end,
|
|
387
|
+
lr_scheduler_name=self.lr_scheduler_name,
|
|
388
|
+
lr_warm_up_steps=self.lr_warm_up_steps,
|
|
389
|
+
adam_beta1=self.adam_beta1,
|
|
390
|
+
adam_beta2=self.adam_beta2,
|
|
391
|
+
lr_decay_steps=self.lr_decay_steps,
|
|
392
|
+
n_restart_cycles=self.n_restart_cycles,
|
|
393
|
+
train_batch_size_samples=self.train_batch_size_tokens,
|
|
394
|
+
dead_feature_window=self.dead_feature_window,
|
|
395
|
+
feature_sampling_window=self.feature_sampling_window,
|
|
396
|
+
logger=self.logger,
|
|
397
|
+
)
|
|
398
|
+
|
|
378
399
|
|
|
379
400
|
@dataclass
|
|
380
401
|
class CacheActivationsRunnerConfig:
|
|
@@ -386,7 +407,6 @@ class CacheActivationsRunnerConfig:
|
|
|
386
407
|
model_name (str): The name of the model to use.
|
|
387
408
|
model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
|
|
388
409
|
hook_name (str): The name of the hook to use.
|
|
389
|
-
hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
|
|
390
410
|
d_in (int): Dimension of the model.
|
|
391
411
|
total_training_tokens (int): Total number of tokens to process.
|
|
392
412
|
context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
|
|
@@ -416,7 +436,6 @@ class CacheActivationsRunnerConfig:
|
|
|
416
436
|
model_name: str
|
|
417
437
|
model_batch_size: int
|
|
418
438
|
hook_name: str
|
|
419
|
-
hook_layer: int
|
|
420
439
|
d_in: int
|
|
421
440
|
training_tokens: int
|
|
422
441
|
|
|
@@ -576,3 +595,28 @@ class PretokenizeRunnerConfig:
|
|
|
576
595
|
hf_num_shards: int = 64
|
|
577
596
|
hf_revision: str = "main"
|
|
578
597
|
hf_is_private_repo: bool = False
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
@dataclass
|
|
601
|
+
class SAETrainerConfig:
|
|
602
|
+
n_checkpoints: int
|
|
603
|
+
checkpoint_path: str
|
|
604
|
+
total_training_samples: int
|
|
605
|
+
device: str
|
|
606
|
+
autocast: bool
|
|
607
|
+
lr: float
|
|
608
|
+
lr_end: float | None
|
|
609
|
+
lr_scheduler_name: str
|
|
610
|
+
lr_warm_up_steps: int
|
|
611
|
+
adam_beta1: float
|
|
612
|
+
adam_beta2: float
|
|
613
|
+
lr_decay_steps: int
|
|
614
|
+
n_restart_cycles: int
|
|
615
|
+
train_batch_size_samples: int
|
|
616
|
+
dead_feature_window: int
|
|
617
|
+
feature_sampling_window: int
|
|
618
|
+
logger: LoggingConfig
|
|
619
|
+
|
|
620
|
+
@property
|
|
621
|
+
def total_training_steps(self) -> int:
|
|
622
|
+
return self.total_training_samples // self.train_batch_size_samples
|
|
@@ -16,3 +16,5 @@ SPARSITY_FILENAME = "sparsity.safetensors"
|
|
|
16
16
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
|
+
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
20
|
+
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import math
|
|
5
5
|
import re
|
|
6
6
|
import subprocess
|
|
7
|
+
import sys
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from collections.abc import Mapping
|
|
9
10
|
from dataclasses import dataclass, field
|
|
@@ -15,13 +16,15 @@ from typing import Any
|
|
|
15
16
|
import einops
|
|
16
17
|
import pandas as pd
|
|
17
18
|
import torch
|
|
18
|
-
from tqdm import tqdm
|
|
19
|
+
from tqdm.auto import tqdm
|
|
19
20
|
from transformer_lens import HookedTransformer
|
|
20
21
|
from transformer_lens.hook_points import HookedRootModule
|
|
21
22
|
|
|
22
23
|
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
|
|
23
24
|
from sae_lens.saes.sae import SAE, SAEConfig
|
|
25
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
24
26
|
from sae_lens.training.activations_store import ActivationsStore
|
|
27
|
+
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
def get_library_version() -> str:
|
|
@@ -103,6 +106,7 @@ def run_evals(
|
|
|
103
106
|
sae: SAE[Any],
|
|
104
107
|
activation_store: ActivationsStore,
|
|
105
108
|
model: HookedRootModule,
|
|
109
|
+
activation_scaler: ActivationScaler,
|
|
106
110
|
eval_config: EvalConfig = EvalConfig(),
|
|
107
111
|
model_kwargs: Mapping[str, Any] = {},
|
|
108
112
|
ignore_tokens: set[int | None] = set(),
|
|
@@ -140,6 +144,7 @@ def run_evals(
|
|
|
140
144
|
sae,
|
|
141
145
|
model,
|
|
142
146
|
activation_store,
|
|
147
|
+
activation_scaler,
|
|
143
148
|
compute_kl=eval_config.compute_kl,
|
|
144
149
|
compute_ce_loss=eval_config.compute_ce_loss,
|
|
145
150
|
n_batches=eval_config.n_eval_reconstruction_batches,
|
|
@@ -189,6 +194,7 @@ def run_evals(
|
|
|
189
194
|
sae,
|
|
190
195
|
model,
|
|
191
196
|
activation_store,
|
|
197
|
+
activation_scaler,
|
|
192
198
|
compute_l2_norms=eval_config.compute_l2_norms,
|
|
193
199
|
compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
|
|
194
200
|
compute_variance_metrics=eval_config.compute_variance_metrics,
|
|
@@ -301,6 +307,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
301
307
|
sae: SAE[Any],
|
|
302
308
|
model: HookedRootModule,
|
|
303
309
|
activation_store: ActivationsStore,
|
|
310
|
+
activation_scaler: ActivationScaler,
|
|
304
311
|
compute_kl: bool,
|
|
305
312
|
compute_ce_loss: bool,
|
|
306
313
|
n_batches: int,
|
|
@@ -326,8 +333,8 @@ def get_downstream_reconstruction_metrics(
|
|
|
326
333
|
for metric_name, metric_value in get_recons_loss(
|
|
327
334
|
sae,
|
|
328
335
|
model,
|
|
336
|
+
activation_scaler,
|
|
329
337
|
batch_tokens,
|
|
330
|
-
activation_store,
|
|
331
338
|
compute_kl=compute_kl,
|
|
332
339
|
compute_ce_loss=compute_ce_loss,
|
|
333
340
|
ignore_tokens=ignore_tokens,
|
|
@@ -369,6 +376,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
369
376
|
sae: SAE[Any],
|
|
370
377
|
model: HookedRootModule,
|
|
371
378
|
activation_store: ActivationsStore,
|
|
379
|
+
activation_scaler: ActivationScaler,
|
|
372
380
|
n_batches: int,
|
|
373
381
|
compute_l2_norms: bool,
|
|
374
382
|
compute_sparsity_metrics: bool,
|
|
@@ -436,7 +444,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
436
444
|
batch_tokens,
|
|
437
445
|
prepend_bos=False,
|
|
438
446
|
names_filter=[hook_name],
|
|
439
|
-
stop_at_layer=
|
|
447
|
+
stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(hook_name),
|
|
440
448
|
**model_kwargs,
|
|
441
449
|
)
|
|
442
450
|
|
|
@@ -451,16 +459,14 @@ def get_sparsity_and_variance_metrics(
|
|
|
451
459
|
original_act = cache[hook_name]
|
|
452
460
|
|
|
453
461
|
# normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
|
|
454
|
-
|
|
455
|
-
original_act = activation_store.apply_norm_scaling_factor(original_act)
|
|
462
|
+
original_act = activation_scaler.scale(original_act)
|
|
456
463
|
|
|
457
464
|
# send the (maybe normalised) activations into the SAE
|
|
458
465
|
sae_feature_activations = sae.encode(original_act.to(sae.device))
|
|
459
466
|
sae_out = sae.decode(sae_feature_activations).to(original_act.device)
|
|
460
467
|
del cache
|
|
461
468
|
|
|
462
|
-
|
|
463
|
-
sae_out = activation_store.unscale(sae_out)
|
|
469
|
+
sae_out = activation_scaler.unscale(sae_out)
|
|
464
470
|
|
|
465
471
|
flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
|
|
466
472
|
flattened_sae_feature_acts = einops.rearrange(
|
|
@@ -582,8 +588,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
582
588
|
def get_recons_loss(
|
|
583
589
|
sae: SAE[SAEConfig],
|
|
584
590
|
model: HookedRootModule,
|
|
591
|
+
activation_scaler: ActivationScaler,
|
|
585
592
|
batch_tokens: torch.Tensor,
|
|
586
|
-
activation_store: ActivationsStore,
|
|
587
593
|
compute_kl: bool,
|
|
588
594
|
compute_ce_loss: bool,
|
|
589
595
|
ignore_tokens: set[int | None] = set(),
|
|
@@ -618,15 +624,13 @@ def get_recons_loss(
|
|
|
618
624
|
activations = activations.to(sae.device)
|
|
619
625
|
|
|
620
626
|
# Handle rescaling if SAE expects it
|
|
621
|
-
|
|
622
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
627
|
+
activations = activation_scaler.scale(activations)
|
|
623
628
|
|
|
624
629
|
# SAE class agnost forward forward pass.
|
|
625
630
|
new_activations = sae.decode(sae.encode(activations)).to(activations.dtype)
|
|
626
631
|
|
|
627
632
|
# Unscale if activations were scaled prior to going into the SAE
|
|
628
|
-
|
|
629
|
-
new_activations = activation_store.unscale(new_activations)
|
|
633
|
+
new_activations = activation_scaler.unscale(new_activations)
|
|
630
634
|
|
|
631
635
|
new_activations = torch.where(mask[..., None], new_activations, activations)
|
|
632
636
|
|
|
@@ -637,8 +641,7 @@ def get_recons_loss(
|
|
|
637
641
|
activations = activations.to(sae.device)
|
|
638
642
|
|
|
639
643
|
# Handle rescaling if SAE expects it
|
|
640
|
-
|
|
641
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
644
|
+
activations = activation_scaler.scale(activations)
|
|
642
645
|
|
|
643
646
|
# SAE class agnost forward forward pass.
|
|
644
647
|
new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
|
|
@@ -650,8 +653,7 @@ def get_recons_loss(
|
|
|
650
653
|
) # reshape to match original shape
|
|
651
654
|
|
|
652
655
|
# Unscale if activations were scaled prior to going into the SAE
|
|
653
|
-
|
|
654
|
-
new_activations = activation_store.unscale(new_activations)
|
|
656
|
+
new_activations = activation_scaler.unscale(new_activations)
|
|
655
657
|
|
|
656
658
|
return new_activations.to(original_device)
|
|
657
659
|
|
|
@@ -660,8 +662,7 @@ def get_recons_loss(
|
|
|
660
662
|
activations = activations.to(sae.device)
|
|
661
663
|
|
|
662
664
|
# Handle rescaling if SAE expects it
|
|
663
|
-
|
|
664
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
665
|
+
activations = activation_scaler.scale(activations)
|
|
665
666
|
|
|
666
667
|
new_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
|
|
667
668
|
activations.dtype
|
|
@@ -669,8 +670,7 @@ def get_recons_loss(
|
|
|
669
670
|
activations[:, :, head_index] = new_activations
|
|
670
671
|
|
|
671
672
|
# Unscale if activations were scaled prior to going into the SAE
|
|
672
|
-
|
|
673
|
-
activations = activation_store.unscale(activations)
|
|
673
|
+
activations = activation_scaler.unscale(activations)
|
|
674
674
|
|
|
675
675
|
return activations.to(original_device)
|
|
676
676
|
|
|
@@ -815,16 +815,18 @@ def multiple_evals(
|
|
|
815
815
|
release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
|
|
816
816
|
sae_id=sae_id, # won't always be a hook point
|
|
817
817
|
device=device,
|
|
818
|
-
)
|
|
818
|
+
)
|
|
819
819
|
|
|
820
820
|
# move SAE to device if not there already
|
|
821
821
|
sae.to(device)
|
|
822
822
|
|
|
823
|
-
if current_model_str != sae.cfg.model_name:
|
|
823
|
+
if current_model_str != sae.cfg.metadata.model_name:
|
|
824
824
|
del current_model # potentially saves GPU memory
|
|
825
|
-
current_model_str = sae.cfg.model_name
|
|
825
|
+
current_model_str = sae.cfg.metadata.model_name
|
|
826
826
|
current_model = HookedTransformer.from_pretrained_no_processing(
|
|
827
|
-
current_model_str,
|
|
827
|
+
current_model_str,
|
|
828
|
+
device=device,
|
|
829
|
+
**sae.cfg.metadata.model_from_pretrained_kwargs,
|
|
828
830
|
)
|
|
829
831
|
assert current_model is not None
|
|
830
832
|
|
|
@@ -849,6 +851,7 @@ def multiple_evals(
|
|
|
849
851
|
scalar_metrics, feature_metrics = run_evals(
|
|
850
852
|
sae=sae,
|
|
851
853
|
activation_store=activation_store,
|
|
854
|
+
activation_scaler=ActivationScaler(),
|
|
852
855
|
model=current_model,
|
|
853
856
|
eval_config=eval_config,
|
|
854
857
|
ignore_tokens={
|
|
@@ -941,7 +944,7 @@ def process_results(
|
|
|
941
944
|
}
|
|
942
945
|
|
|
943
946
|
|
|
944
|
-
|
|
947
|
+
def process_args(args: list[str]) -> argparse.Namespace:
|
|
945
948
|
arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs")
|
|
946
949
|
arg_parser.add_argument(
|
|
947
950
|
"sae_regex_pattern",
|
|
@@ -1031,11 +1034,19 @@ if __name__ == "__main__":
|
|
|
1031
1034
|
help="Enable verbose output with tqdm loaders.",
|
|
1032
1035
|
)
|
|
1033
1036
|
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
+
return arg_parser.parse_args(args)
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
def run_evals_cli(args: list[str]) -> None:
|
|
1041
|
+
opts = process_args(args)
|
|
1042
|
+
eval_results = run_evaluations(opts)
|
|
1043
|
+
output_files = process_results(eval_results, opts.output_dir)
|
|
1037
1044
|
|
|
1038
1045
|
print("Evaluation complete. Output files:")
|
|
1039
1046
|
print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
|
|
1040
1047
|
print(f"Combined JSON: {output_files['combined_json']}")
|
|
1041
1048
|
print(f"CSV: {output_files['csv']}")
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
if __name__ == "__main__":
|
|
1052
|
+
run_evals_cli(sys.argv[1:])
|