sae-lens 6.5.1__py3-none-any.whl → 6.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.5.1"
2
+ __version__ = "6.5.3"
3
3
 
4
4
  import logging
5
5
 
sae_lens/evals.py CHANGED
@@ -776,6 +776,7 @@ def multiple_evals(
776
776
  n_eval_sparsity_variance_batches: int,
777
777
  eval_batch_size_prompts: int = 8,
778
778
  datasets: list[str] = ["Skylion007/openwebtext", "lighteval/MATH"],
779
+ dataset_trust_remote_code: bool = False,
779
780
  ctx_lens: list[int] = [128],
780
781
  output_dir: str = "eval_results",
781
782
  verbose: bool = False,
@@ -822,7 +823,11 @@ def multiple_evals(
822
823
  for ctx_len in ctx_lens:
823
824
  for dataset in datasets:
824
825
  activation_store = ActivationsStore.from_sae(
825
- current_model, sae, context_size=ctx_len, dataset=dataset
826
+ current_model,
827
+ sae,
828
+ context_size=ctx_len,
829
+ dataset=dataset,
830
+ dataset_trust_remote_code=dataset_trust_remote_code,
826
831
  )
827
832
  activation_store.shuffle_input_dataset(seed=42)
828
833
 
@@ -882,6 +887,7 @@ def run_evaluations(args: argparse.Namespace) -> list[dict[str, Any]]:
882
887
  n_eval_sparsity_variance_batches=args.n_eval_sparsity_variance_batches,
883
888
  eval_batch_size_prompts=args.batch_size_prompts,
884
889
  datasets=args.datasets,
890
+ dataset_trust_remote_code=args.dataset_trust_remote_code,
885
891
  ctx_lens=args.ctx_lens,
886
892
  output_dir=args.output_dir,
887
893
  verbose=args.verbose,
@@ -1004,6 +1010,11 @@ def process_args(args: list[str]) -> argparse.Namespace:
1004
1010
  default=["Skylion007/openwebtext"],
1005
1011
  help="Datasets to evaluate on, such as 'Skylion007/openwebtext' or 'lighteval/MATH'.",
1006
1012
  )
1013
+ arg_parser.add_argument(
1014
+ "--dataset_trust_remote_code",
1015
+ action="store_true",
1016
+ help="Allow execution of remote code when loading datasets for evaluation.",
1017
+ )
1007
1018
  arg_parser.add_argument(
1008
1019
  "--ctx_lens",
1009
1020
  nargs="+",
sae_lens/load_model.py CHANGED
@@ -52,7 +52,7 @@ def load_model(
52
52
  if model_class_name == "AutoModelForCausalLM":
53
53
  hf_model = AutoModelForCausalLM.from_pretrained(
54
54
  model_name, **model_from_pretrained_kwargs
55
- ).to(device)
55
+ ).to(device) # type: ignore
56
56
  tokenizer = AutoTokenizer.from_pretrained(model_name)
57
57
  return HookedProxyLM(hf_model, tokenizer)
58
58
 
@@ -1177,6 +1177,7 @@ def get_gemma_2_transcoder_config_from_hf(
1177
1177
  "prepend_bos": True,
1178
1178
  "dataset_path": "monology/pile-uncopyrighted",
1179
1179
  "context_size": 1024,
1180
+ "apply_b_dec_to_input": False,
1180
1181
  **(cfg_overrides or {}),
1181
1182
  }
1182
1183
 
@@ -52,6 +52,11 @@ class TranscoderConfig(SAEConfig):
52
52
 
53
53
  return res
54
54
 
55
+ def __post_init__(self):
56
+ if self.apply_b_dec_to_input:
57
+ raise ValueError("apply_b_dec_to_input is not supported for transcoders")
58
+ return super().__post_init__()
59
+
55
60
 
56
61
  class Transcoder(SAE[TranscoderConfig]):
57
62
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.5.1
3
+ Version: 6.5.3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -16,7 +16,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
16
  Provides-Extra: mamba
17
17
  Requires-Dist: automated-interpretability (>=0.0.5,<1.0.0)
18
18
  Requires-Dist: babe (>=0.0.7,<0.0.8)
19
- Requires-Dist: datasets (>=2.17.1,<3.0.0)
19
+ Requires-Dist: datasets (>=3.1.0,<4.0.0)
20
20
  Requires-Dist: mamba-lens (>=0.0.4,<0.0.5) ; extra == "mamba"
21
21
  Requires-Dist: matplotlib (>=3.8.3,<4.0.0)
22
22
  Requires-Dist: matplotlib-inline (>=0.1.6,<0.2.0)
@@ -1,15 +1,15 @@
1
- sae_lens/__init__.py,sha256=v3iyvnzUIs3XkzrtiW_D0FH9tDRqcxpPEzlBdGkhiRo,3588
1
+ sae_lens/__init__.py,sha256=HBSxvQf7TM7R4n1Al0RbtS0bVFgfb3eDBMuLXZs2cKw,3588
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
5
5
  sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
6
  sae_lens/config.py,sha256=IrjbsKBbaZoFXYrsPJ5xBwIqi9uZJIIFXjV_uoErJaE,28176
7
7
  sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
- sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
8
+ sae_lens/evals.py,sha256=rVmKqdSa4xlZw-78qpgdhDleJIO1EI1QHRiPsObW1Tc,39129
9
9
  sae_lens/llm_sae_training_runner.py,sha256=exxNX_OEhdiUrlgmBP9bjX9DOf0HUcNQGO4unKeDjKM,13713
10
- sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
10
+ sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
11
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=RfZhE8l2y5_ZW5-fW-iN-SJGJfHRaZo4d6AiooK1Xuc,40890
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=0nb3PG26XG1EEvVIf2Jp-4_g2E30XulA5u5ok4Ah8no,40929
13
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
14
  sae_lens/pretokenize_runner.py,sha256=w0f6SfZLAxbp5eAAKnet8RqUB_DKofZ9RGsoJwFnYbA,7058
15
15
  sae_lens/pretrained_saes.yaml,sha256=BncZpkf-NhVOWV1FhrDPr_TPxAINP4-eiPZBw-HO4N4,578197
@@ -21,7 +21,7 @@ sae_lens/saes/jumprelu_sae.py,sha256=3xkhBcCol2mEpIBLceymCpudocm2ypOjTeTXbpiXoA4
21
21
  sae_lens/saes/sae.py,sha256=gdUZuLaOHQrPjbDj-nZI813B6-_mNAnV9i9z4qTnpHk,38255
22
22
  sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
23
23
  sae_lens/saes/topk_sae.py,sha256=CXMBI6CFvI5829bOhoQ350VXR9d8uFHUDlULTIWHXoU,8686
24
- sae_lens/saes/transcoder.py,sha256=qbsvIeRy7M1nuHapDut0URNg9fR-UtfBvwW4_hgsAac,11963
24
+ sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,12156
25
25
  sae_lens/tokenization_and_batching.py,sha256=now7caLbU3p-iGokNwmqZDyIvxYoXgnG1uklhgiLZN4,4656
26
26
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
@@ -33,7 +33,7 @@ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
33
33
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
34
34
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
35
35
  sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
36
- sae_lens-6.5.1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
37
- sae_lens-6.5.1.dist-info/METADATA,sha256=wpZeKz6pPucYSC_9HOM3FR_aq8rxPSKxn2WW4rnz7Vo,5555
38
- sae_lens-6.5.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- sae_lens-6.5.1.dist-info/RECORD,,
36
+ sae_lens-6.5.3.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
37
+ sae_lens-6.5.3.dist-info/METADATA,sha256=LP8y6l6lH8Ibsm9_HxTX-iem1uKUsailabbkBii5YIk,5554
38
+ sae_lens-6.5.3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ sae_lens-6.5.3.dist-info/RECORD,,