sae-lens 6.2.0__py3-none-any.whl → 6.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.2.0"
2
+ __version__ = "6.3.1"
3
3
 
4
4
  import logging
5
5
 
sae_lens/config.py CHANGED
@@ -46,6 +46,29 @@ def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: i
46
46
  return simple_parsing.helpers.dict_field(default, type=json_dict, **kwargs)
47
47
 
48
48
 
49
+ def special_token(s: str) -> Any:
50
+ """Parse special token value from string."""
51
+ if s.lower() == "none":
52
+ return None
53
+ if s in ["bos", "eos", "sep"]:
54
+ return s
55
+ try:
56
+ return int(s)
57
+ except ValueError:
58
+ raise ValueError(
59
+ f"Expected 'bos', 'eos', 'sep', an integer, or 'none', got {s}"
60
+ )
61
+
62
+
63
+ def special_token_field(
64
+ default: int | Literal["bos", "eos", "sep"] | None, **kwargs: Any
65
+ ) -> Any: # type: ignore
66
+ """
67
+ Helper to wrap simple_parsing.helpers.field so we can load special token fields from the command line.
68
+ """
69
+ return simple_parsing.helpers.field(default=default, type=special_token, **kwargs)
70
+
71
+
49
72
  @dataclass
50
73
  class LoggingConfig:
51
74
  # WANDB
@@ -116,6 +139,8 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
116
139
  training_tokens (int): The number of training tokens.
117
140
  store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
118
141
  seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
142
+ disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
143
+ sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
119
144
  device (str): The device to use. Usually "cuda".
120
145
  act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
121
146
  seed (int): The seed to use.
@@ -178,6 +203,10 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
178
203
  training_tokens: int = 2_000_000
179
204
  store_batch_size_prompts: int = 32
180
205
  seqpos_slice: tuple[int | None, ...] = (None,)
206
+ disable_concat_sequences: bool = False
207
+ sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
208
+ special_token_field(default="bos")
209
+ )
181
210
 
182
211
  # Misc
183
212
  device: str = "cpu"
@@ -564,6 +593,9 @@ class PretokenizeRunnerConfig:
564
593
  begin_sequence_token: int | Literal["bos", "eos", "sep"] | None = None
565
594
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos"
566
595
 
596
+ # sequence processing
597
+ disable_concat_sequences: bool = False
598
+
567
599
  # if saving locally, set save_path
568
600
  save_path: str | None = None
569
601
 
@@ -17,6 +17,7 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
17
17
  from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
18
  from sae_lens.evals import EvalConfig, run_evals
19
19
  from sae_lens.load_model import load_model
20
+ from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
20
21
  from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
21
22
  from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
22
23
  from sae_lens.saes.sae import (
@@ -202,6 +203,12 @@ class LanguageModelSAETrainingRunner:
202
203
  )
203
204
  self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
204
205
  self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
206
+ self.sae.cfg.metadata.sequence_separator_token = (
207
+ self.cfg.sequence_separator_token
208
+ )
209
+ self.sae.cfg.metadata.disable_concat_sequences = (
210
+ self.cfg.disable_concat_sequences
211
+ )
205
212
 
206
213
  def _compile_if_needed(self):
207
214
  # Compile model and SAE
@@ -285,7 +292,7 @@ def _parse_cfg_args(
285
292
  architecture_parser.add_argument(
286
293
  "--architecture",
287
294
  type=str,
288
- choices=["standard", "gated", "jumprelu", "topk"],
295
+ choices=["standard", "gated", "jumprelu", "topk", "batchtopk"],
289
296
  default="standard",
290
297
  help="SAE architecture to use",
291
298
  )
@@ -346,6 +353,7 @@ def _parse_cfg_args(
346
353
  "gated": GatedTrainingSAEConfig,
347
354
  "jumprelu": JumpReLUTrainingSAEConfig,
348
355
  "topk": TopKTrainingSAEConfig,
356
+ "batchtopk": BatchTopKTrainingSAEConfig,
349
357
  }
350
358
 
351
359
  sae_config_type = sae_config_map[architecture]
@@ -35,6 +35,7 @@ class PretokenizedDatasetMetadata:
35
35
  begin_batch_token: int | Literal["bos", "eos", "sep"] | None
36
36
  begin_sequence_token: int | Literal["bos", "eos", "sep"] | None
37
37
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None
38
+ disable_concat_sequences: bool
38
39
 
39
40
 
40
41
  def metadata_from_config(cfg: PretokenizeRunnerConfig) -> PretokenizedDatasetMetadata:
@@ -52,6 +53,7 @@ def metadata_from_config(cfg: PretokenizeRunnerConfig) -> PretokenizedDatasetMet
52
53
  begin_batch_token=cfg.begin_batch_token,
53
54
  begin_sequence_token=cfg.begin_sequence_token,
54
55
  sequence_separator_token=cfg.sequence_separator_token,
56
+ disable_concat_sequences=cfg.disable_concat_sequences,
55
57
  )
56
58
 
57
59
 
@@ -99,6 +101,7 @@ def pretokenize_dataset(
99
101
  sequence_separator_token_id=get_special_token_from_cfg(
100
102
  cfg.sequence_separator_token, tokenizer
101
103
  ),
104
+ disable_concat_sequences=cfg.disable_concat_sequences,
102
105
  )
103
106
  )
104
107
  }
@@ -64,6 +64,7 @@ def concat_and_batch_sequences(
64
64
  begin_batch_token_id: int | None = None,
65
65
  begin_sequence_token_id: int | None = None,
66
66
  sequence_separator_token_id: int | None = None,
67
+ disable_concat_sequences: bool = False,
67
68
  ) -> Generator[torch.Tensor, None, None]:
68
69
  """
69
70
  Generator to concat token sequences together from the tokens_interator, yielding
@@ -75,8 +76,15 @@ def concat_and_batch_sequences(
75
76
  begin_batch_token_id: If provided, this token will be at position 0 of each batch
76
77
  begin_sequence_token_id: If provided, this token will be the first token of each sequence
77
78
  sequence_separator_token_id: If provided, this token will be inserted between concatenated sequences
79
+ disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size
78
80
  max_batches: If not provided, the iterator will be run to completion.
79
81
  """
82
+ if disable_concat_sequences:
83
+ for tokens in tokens_iterator:
84
+ if len(tokens) >= context_size:
85
+ yield tokens[:context_size]
86
+ return
87
+
80
88
  batch: torch.Tensor | None = None
81
89
  for tokens in tokens_iterator:
82
90
  if len(tokens.shape) != 1:
@@ -25,6 +25,7 @@ from sae_lens.config import (
25
25
  LanguageModelSAERunnerConfig,
26
26
  )
27
27
  from sae_lens.constants import DTYPE_MAP
28
+ from sae_lens.pretokenize_runner import get_special_token_from_cfg
28
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
29
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
30
31
  from sae_lens.training.mixing_buffer import mixing_buffer
@@ -141,6 +142,8 @@ class ActivationsStore:
141
142
  dataset_trust_remote_code=cfg.dataset_trust_remote_code,
142
143
  seqpos_slice=cfg.seqpos_slice,
143
144
  exclude_special_tokens=exclude_special_tokens,
145
+ disable_concat_sequences=cfg.disable_concat_sequences,
146
+ sequence_separator_token=cfg.sequence_separator_token,
144
147
  )
145
148
 
146
149
  @classmethod
@@ -157,6 +160,8 @@ class ActivationsStore:
157
160
  train_batch_size_tokens: int = 4096,
158
161
  total_tokens: int = 10**9,
159
162
  device: str = "cpu",
163
+ disable_concat_sequences: bool = False,
164
+ sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
160
165
  ) -> ActivationsStore:
161
166
  if sae.cfg.metadata.hook_name is None:
162
167
  raise ValueError("hook_name is required")
@@ -184,6 +189,8 @@ class ActivationsStore:
184
189
  dtype=sae.cfg.dtype,
185
190
  device=torch.device(device),
186
191
  seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
192
+ disable_concat_sequences=disable_concat_sequences,
193
+ sequence_separator_token=sequence_separator_token,
187
194
  )
188
195
 
189
196
  def __init__(
@@ -209,6 +216,8 @@ class ActivationsStore:
209
216
  dataset_trust_remote_code: bool | None = None,
210
217
  seqpos_slice: tuple[int | None, ...] = (None,),
211
218
  exclude_special_tokens: torch.Tensor | None = None,
219
+ disable_concat_sequences: bool = False,
220
+ sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
212
221
  ):
213
222
  self.model = model
214
223
  if model_kwargs is None:
@@ -252,6 +261,10 @@ class ActivationsStore:
252
261
  self.seqpos_slice = seqpos_slice
253
262
  self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
254
263
  self.exclude_special_tokens = exclude_special_tokens
264
+ self.disable_concat_sequences = disable_concat_sequences
265
+ self.sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
266
+ sequence_separator_token
267
+ )
255
268
 
256
269
  self.n_dataset_processed = 0
257
270
 
@@ -361,14 +374,18 @@ class ActivationsStore:
361
374
  else:
362
375
  tokenizer = getattr(self.model, "tokenizer", None)
363
376
  bos_token_id = None if tokenizer is None else tokenizer.bos_token_id
377
+
364
378
  yield from concat_and_batch_sequences(
365
379
  tokens_iterator=self._iterate_raw_dataset_tokens(),
366
380
  context_size=self.context_size,
367
381
  begin_batch_token_id=(bos_token_id if self.prepend_bos else None),
368
382
  begin_sequence_token_id=None,
369
- sequence_separator_token_id=(
370
- bos_token_id if self.prepend_bos else None
371
- ),
383
+ sequence_separator_token_id=get_special_token_from_cfg(
384
+ self.sequence_separator_token, tokenizer
385
+ )
386
+ if tokenizer is not None
387
+ else None,
388
+ disable_concat_sequences=self.disable_concat_sequences,
372
389
  )
373
390
 
374
391
  def load_cached_activation_dataset(self) -> Dataset | None:
@@ -1,5 +1,4 @@
1
1
  import contextlib
2
- from dataclasses import dataclass
3
2
  from pathlib import Path
4
3
  from typing import Any, Callable, Generic, Protocol
5
4
 
@@ -38,13 +37,6 @@ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
38
37
  sae.cfg.sae_lens_training_version = str(__version__)
39
38
 
40
39
 
41
- @dataclass
42
- class TrainSAEOutput:
43
- sae: TrainingSAE[Any]
44
- checkpoint_path: str
45
- log_feature_sparsities: torch.Tensor
46
-
47
-
48
40
  class SaveCheckpointFn(Protocol):
49
41
  def __call__(
50
42
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.2.0
3
+ Version: 6.3.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,17 +1,17 @@
1
- sae_lens/__init__.py,sha256=ByxdNdLeg_pvK89IX1lHa6iHgs2ab-UulX55Y0hUhY4,3073
1
+ sae_lens/__init__.py,sha256=8vvwKdk-cv0-h2R1ah18VSmNjcBHt7X9gV3A1LtrroM,3073
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
5
5
  sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
- sae_lens/config.py,sha256=qMMx9KuiXTD5lG3g0VzaekWOnvdAzGFSq8j1n-GObEQ,26467
6
+ sae_lens/config.py,sha256=6xATsLdg80mXnEsW12x-cvCbAu6SjnONqbRz2eEbqAU,27796
7
7
  sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
8
  sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
9
- sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
9
+ sae_lens/llm_sae_training_runner.py,sha256=exxNX_OEhdiUrlgmBP9bjX9DOf0HUcNQGO4unKeDjKM,13713
10
10
  sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
11
11
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  sae_lens/loading/pretrained_sae_loaders.py,sha256=5XEU4uFFeGCePwqDwhlE7CrFGRSI0U9Cu-UQVa33Y1E,36432
13
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
- sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
14
+ sae_lens/pretokenize_runner.py,sha256=w0f6SfZLAxbp5eAAKnet8RqUB_DKofZ9RGsoJwFnYbA,7058
15
15
  sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=RYqE1qkMws-kwQLmBZFhA_VCa69zVtBjGPIy_UAk2pw,1159
@@ -21,18 +21,18 @@ sae_lens/saes/jumprelu_sae.py,sha256=3xkhBcCol2mEpIBLceymCpudocm2ypOjTeTXbpiXoA4
21
21
  sae_lens/saes/sae.py,sha256=McpF4pTh70r6SQUbHFm0YQ9X2c2qPULBUSd_YmnEk4Y,38284
22
22
  sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
23
23
  sae_lens/saes/topk_sae.py,sha256=CXMBI6CFvI5829bOhoQ350VXR9d8uFHUDlULTIWHXoU,8686
24
- sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
+ sae_lens/tokenization_and_batching.py,sha256=now7caLbU3p-iGokNwmqZDyIvxYoXgnG1uklhgiLZN4,4656
25
25
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
27
- sae_lens/training/activations_store.py,sha256=HBN3oEib3PlPUDJb_yVFabQp0JcN9rWbnUN1s2DBMAs,31933
27
+ sae_lens/training/activations_store.py,sha256=x2Fwt5QY7M83v6Vf1CSa821j2_WKMw9oPu1cdlLblvg,32887
28
28
  sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
29
29
  sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
30
- sae_lens/training/sae_trainer.py,sha256=2xcO-02OozFunob5vwoHud-hVMhVl9d28_F9gDCiL6o,15529
30
+ sae_lens/training/sae_trainer.py,sha256=6HPf5wtmY1wMUTkLFRg9DujNMMXJkVMPdAhB2svvlkk,15368
31
31
  sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
32
32
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
33
33
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
34
34
  sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
35
- sae_lens-6.2.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
36
- sae_lens-6.2.0.dist-info/METADATA,sha256=Fqsq0scF5Uia0YBmeZQwVi4m4DX16_Ck-cKokbuch7U,5555
37
- sae_lens-6.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- sae_lens-6.2.0.dist-info/RECORD,,
35
+ sae_lens-6.3.1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
36
+ sae_lens-6.3.1.dist-info/METADATA,sha256=d-dAwcr-WiSFkybEqtOdFxhnJJBX0xiFec8uvln3ztE,5555
37
+ sae_lens-6.3.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ sae_lens-6.3.1.dist-info/RECORD,,