sae-lens 6.2.0__py3-none-any.whl → 6.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/config.py +32 -0
- sae_lens/llm_sae_training_runner.py +6 -0
- sae_lens/pretokenize_runner.py +3 -0
- sae_lens/tokenization_and_batching.py +8 -0
- sae_lens/training/activations_store.py +20 -3
- {sae_lens-6.2.0.dist-info → sae_lens-6.3.0.dist-info}/METADATA +1 -1
- {sae_lens-6.2.0.dist-info → sae_lens-6.3.0.dist-info}/RECORD +10 -10
- {sae_lens-6.2.0.dist-info → sae_lens-6.3.0.dist-info}/LICENSE +0 -0
- {sae_lens-6.2.0.dist-info → sae_lens-6.3.0.dist-info}/WHEEL +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/config.py
CHANGED
|
@@ -46,6 +46,29 @@ def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: i
|
|
|
46
46
|
return simple_parsing.helpers.dict_field(default, type=json_dict, **kwargs)
|
|
47
47
|
|
|
48
48
|
|
|
49
|
+
def special_token(s: str) -> Any:
|
|
50
|
+
"""Parse special token value from string."""
|
|
51
|
+
if s.lower() == "none":
|
|
52
|
+
return None
|
|
53
|
+
if s in ["bos", "eos", "sep"]:
|
|
54
|
+
return s
|
|
55
|
+
try:
|
|
56
|
+
return int(s)
|
|
57
|
+
except ValueError:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Expected 'bos', 'eos', 'sep', an integer, or 'none', got {s}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def special_token_field(
|
|
64
|
+
default: int | Literal["bos", "eos", "sep"] | None, **kwargs: Any
|
|
65
|
+
) -> Any: # type: ignore
|
|
66
|
+
"""
|
|
67
|
+
Helper to wrap simple_parsing.helpers.field so we can load special token fields from the command line.
|
|
68
|
+
"""
|
|
69
|
+
return simple_parsing.helpers.field(default=default, type=special_token, **kwargs)
|
|
70
|
+
|
|
71
|
+
|
|
49
72
|
@dataclass
|
|
50
73
|
class LoggingConfig:
|
|
51
74
|
# WANDB
|
|
@@ -116,6 +139,8 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
116
139
|
training_tokens (int): The number of training tokens.
|
|
117
140
|
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
|
|
118
141
|
seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
142
|
+
disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
|
|
143
|
+
sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
|
|
119
144
|
device (str): The device to use. Usually "cuda".
|
|
120
145
|
act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
|
|
121
146
|
seed (int): The seed to use.
|
|
@@ -178,6 +203,10 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
178
203
|
training_tokens: int = 2_000_000
|
|
179
204
|
store_batch_size_prompts: int = 32
|
|
180
205
|
seqpos_slice: tuple[int | None, ...] = (None,)
|
|
206
|
+
disable_concat_sequences: bool = False
|
|
207
|
+
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
|
|
208
|
+
special_token_field(default="bos")
|
|
209
|
+
)
|
|
181
210
|
|
|
182
211
|
# Misc
|
|
183
212
|
device: str = "cpu"
|
|
@@ -564,6 +593,9 @@ class PretokenizeRunnerConfig:
|
|
|
564
593
|
begin_sequence_token: int | Literal["bos", "eos", "sep"] | None = None
|
|
565
594
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos"
|
|
566
595
|
|
|
596
|
+
# sequence processing
|
|
597
|
+
disable_concat_sequences: bool = False
|
|
598
|
+
|
|
567
599
|
# if saving locally, set save_path
|
|
568
600
|
save_path: str | None = None
|
|
569
601
|
|
|
@@ -202,6 +202,12 @@ class LanguageModelSAETrainingRunner:
|
|
|
202
202
|
)
|
|
203
203
|
self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
|
|
204
204
|
self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
|
|
205
|
+
self.sae.cfg.metadata.sequence_separator_token = (
|
|
206
|
+
self.cfg.sequence_separator_token
|
|
207
|
+
)
|
|
208
|
+
self.sae.cfg.metadata.disable_concat_sequences = (
|
|
209
|
+
self.cfg.disable_concat_sequences
|
|
210
|
+
)
|
|
205
211
|
|
|
206
212
|
def _compile_if_needed(self):
|
|
207
213
|
# Compile model and SAE
|
sae_lens/pretokenize_runner.py
CHANGED
|
@@ -35,6 +35,7 @@ class PretokenizedDatasetMetadata:
|
|
|
35
35
|
begin_batch_token: int | Literal["bos", "eos", "sep"] | None
|
|
36
36
|
begin_sequence_token: int | Literal["bos", "eos", "sep"] | None
|
|
37
37
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None
|
|
38
|
+
disable_concat_sequences: bool
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
def metadata_from_config(cfg: PretokenizeRunnerConfig) -> PretokenizedDatasetMetadata:
|
|
@@ -52,6 +53,7 @@ def metadata_from_config(cfg: PretokenizeRunnerConfig) -> PretokenizedDatasetMet
|
|
|
52
53
|
begin_batch_token=cfg.begin_batch_token,
|
|
53
54
|
begin_sequence_token=cfg.begin_sequence_token,
|
|
54
55
|
sequence_separator_token=cfg.sequence_separator_token,
|
|
56
|
+
disable_concat_sequences=cfg.disable_concat_sequences,
|
|
55
57
|
)
|
|
56
58
|
|
|
57
59
|
|
|
@@ -99,6 +101,7 @@ def pretokenize_dataset(
|
|
|
99
101
|
sequence_separator_token_id=get_special_token_from_cfg(
|
|
100
102
|
cfg.sequence_separator_token, tokenizer
|
|
101
103
|
),
|
|
104
|
+
disable_concat_sequences=cfg.disable_concat_sequences,
|
|
102
105
|
)
|
|
103
106
|
)
|
|
104
107
|
}
|
|
@@ -64,6 +64,7 @@ def concat_and_batch_sequences(
|
|
|
64
64
|
begin_batch_token_id: int | None = None,
|
|
65
65
|
begin_sequence_token_id: int | None = None,
|
|
66
66
|
sequence_separator_token_id: int | None = None,
|
|
67
|
+
disable_concat_sequences: bool = False,
|
|
67
68
|
) -> Generator[torch.Tensor, None, None]:
|
|
68
69
|
"""
|
|
69
70
|
Generator to concat token sequences together from the tokens_interator, yielding
|
|
@@ -75,8 +76,15 @@ def concat_and_batch_sequences(
|
|
|
75
76
|
begin_batch_token_id: If provided, this token will be at position 0 of each batch
|
|
76
77
|
begin_sequence_token_id: If provided, this token will be the first token of each sequence
|
|
77
78
|
sequence_separator_token_id: If provided, this token will be inserted between concatenated sequences
|
|
79
|
+
disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size
|
|
78
80
|
max_batches: If not provided, the iterator will be run to completion.
|
|
79
81
|
"""
|
|
82
|
+
if disable_concat_sequences:
|
|
83
|
+
for tokens in tokens_iterator:
|
|
84
|
+
if len(tokens) >= context_size:
|
|
85
|
+
yield tokens[:context_size]
|
|
86
|
+
return
|
|
87
|
+
|
|
80
88
|
batch: torch.Tensor | None = None
|
|
81
89
|
for tokens in tokens_iterator:
|
|
82
90
|
if len(tokens.shape) != 1:
|
|
@@ -25,6 +25,7 @@ from sae_lens.config import (
|
|
|
25
25
|
LanguageModelSAERunnerConfig,
|
|
26
26
|
)
|
|
27
27
|
from sae_lens.constants import DTYPE_MAP
|
|
28
|
+
from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
28
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
29
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
30
31
|
from sae_lens.training.mixing_buffer import mixing_buffer
|
|
@@ -141,6 +142,8 @@ class ActivationsStore:
|
|
|
141
142
|
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
|
|
142
143
|
seqpos_slice=cfg.seqpos_slice,
|
|
143
144
|
exclude_special_tokens=exclude_special_tokens,
|
|
145
|
+
disable_concat_sequences=cfg.disable_concat_sequences,
|
|
146
|
+
sequence_separator_token=cfg.sequence_separator_token,
|
|
144
147
|
)
|
|
145
148
|
|
|
146
149
|
@classmethod
|
|
@@ -157,6 +160,8 @@ class ActivationsStore:
|
|
|
157
160
|
train_batch_size_tokens: int = 4096,
|
|
158
161
|
total_tokens: int = 10**9,
|
|
159
162
|
device: str = "cpu",
|
|
163
|
+
disable_concat_sequences: bool = False,
|
|
164
|
+
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
|
|
160
165
|
) -> ActivationsStore:
|
|
161
166
|
if sae.cfg.metadata.hook_name is None:
|
|
162
167
|
raise ValueError("hook_name is required")
|
|
@@ -184,6 +189,8 @@ class ActivationsStore:
|
|
|
184
189
|
dtype=sae.cfg.dtype,
|
|
185
190
|
device=torch.device(device),
|
|
186
191
|
seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
|
|
192
|
+
disable_concat_sequences=disable_concat_sequences,
|
|
193
|
+
sequence_separator_token=sequence_separator_token,
|
|
187
194
|
)
|
|
188
195
|
|
|
189
196
|
def __init__(
|
|
@@ -209,6 +216,8 @@ class ActivationsStore:
|
|
|
209
216
|
dataset_trust_remote_code: bool | None = None,
|
|
210
217
|
seqpos_slice: tuple[int | None, ...] = (None,),
|
|
211
218
|
exclude_special_tokens: torch.Tensor | None = None,
|
|
219
|
+
disable_concat_sequences: bool = False,
|
|
220
|
+
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
|
|
212
221
|
):
|
|
213
222
|
self.model = model
|
|
214
223
|
if model_kwargs is None:
|
|
@@ -252,6 +261,10 @@ class ActivationsStore:
|
|
|
252
261
|
self.seqpos_slice = seqpos_slice
|
|
253
262
|
self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
|
|
254
263
|
self.exclude_special_tokens = exclude_special_tokens
|
|
264
|
+
self.disable_concat_sequences = disable_concat_sequences
|
|
265
|
+
self.sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
|
|
266
|
+
sequence_separator_token
|
|
267
|
+
)
|
|
255
268
|
|
|
256
269
|
self.n_dataset_processed = 0
|
|
257
270
|
|
|
@@ -361,14 +374,18 @@ class ActivationsStore:
|
|
|
361
374
|
else:
|
|
362
375
|
tokenizer = getattr(self.model, "tokenizer", None)
|
|
363
376
|
bos_token_id = None if tokenizer is None else tokenizer.bos_token_id
|
|
377
|
+
|
|
364
378
|
yield from concat_and_batch_sequences(
|
|
365
379
|
tokens_iterator=self._iterate_raw_dataset_tokens(),
|
|
366
380
|
context_size=self.context_size,
|
|
367
381
|
begin_batch_token_id=(bos_token_id if self.prepend_bos else None),
|
|
368
382
|
begin_sequence_token_id=None,
|
|
369
|
-
sequence_separator_token_id=(
|
|
370
|
-
|
|
371
|
-
)
|
|
383
|
+
sequence_separator_token_id=get_special_token_from_cfg(
|
|
384
|
+
self.sequence_separator_token, tokenizer
|
|
385
|
+
)
|
|
386
|
+
if tokenizer is not None
|
|
387
|
+
else None,
|
|
388
|
+
disable_concat_sequences=self.disable_concat_sequences,
|
|
372
389
|
)
|
|
373
390
|
|
|
374
391
|
def load_cached_activation_dataset(self) -> Dataset | None:
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=Fu85qhIdVyFWNMOI2q9UMgfXux7kGY_AkNdzuxsO8C0,3073
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
6
|
+
sae_lens/config.py,sha256=6xATsLdg80mXnEsW12x-cvCbAu6SjnONqbRz2eEbqAU,27796
|
|
7
7
|
sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
|
|
8
8
|
sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
|
|
9
|
-
sae_lens/llm_sae_training_runner.py,sha256=
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=nSzNI4zZkh8hn8Z5eWQpql1zi718WINW8bPxI3-c_dI,13584
|
|
10
10
|
sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
sae_lens/loading/pretrained_sae_loaders.py,sha256=5XEU4uFFeGCePwqDwhlE7CrFGRSI0U9Cu-UQVa33Y1E,36432
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
14
|
-
sae_lens/pretokenize_runner.py,sha256=
|
|
14
|
+
sae_lens/pretokenize_runner.py,sha256=w0f6SfZLAxbp5eAAKnet8RqUB_DKofZ9RGsoJwFnYbA,7058
|
|
15
15
|
sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
17
|
sae_lens/saes/__init__.py,sha256=RYqE1qkMws-kwQLmBZFhA_VCa69zVtBjGPIy_UAk2pw,1159
|
|
@@ -21,10 +21,10 @@ sae_lens/saes/jumprelu_sae.py,sha256=3xkhBcCol2mEpIBLceymCpudocm2ypOjTeTXbpiXoA4
|
|
|
21
21
|
sae_lens/saes/sae.py,sha256=McpF4pTh70r6SQUbHFm0YQ9X2c2qPULBUSd_YmnEk4Y,38284
|
|
22
22
|
sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
|
|
23
23
|
sae_lens/saes/topk_sae.py,sha256=CXMBI6CFvI5829bOhoQ350VXR9d8uFHUDlULTIWHXoU,8686
|
|
24
|
-
sae_lens/tokenization_and_batching.py,sha256=
|
|
24
|
+
sae_lens/tokenization_and_batching.py,sha256=now7caLbU3p-iGokNwmqZDyIvxYoXgnG1uklhgiLZN4,4656
|
|
25
25
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
26
|
sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
|
|
27
|
-
sae_lens/training/activations_store.py,sha256=
|
|
27
|
+
sae_lens/training/activations_store.py,sha256=x2Fwt5QY7M83v6Vf1CSa821j2_WKMw9oPu1cdlLblvg,32887
|
|
28
28
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
29
29
|
sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
|
|
30
30
|
sae_lens/training/sae_trainer.py,sha256=2xcO-02OozFunob5vwoHud-hVMhVl9d28_F9gDCiL6o,15529
|
|
@@ -32,7 +32,7 @@ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
|
|
|
32
32
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
33
33
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
34
34
|
sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
|
|
35
|
-
sae_lens-6.
|
|
36
|
-
sae_lens-6.
|
|
37
|
-
sae_lens-6.
|
|
38
|
-
sae_lens-6.
|
|
35
|
+
sae_lens-6.3.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
36
|
+
sae_lens-6.3.0.dist-info/METADATA,sha256=ImDu3LLHXp0eR4EA0mPQD37xdQpURh5fc5bLm5-3nWM,5555
|
|
37
|
+
sae_lens-6.3.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
38
|
+
sae_lens-6.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|