sae-lens 6.9.0__tar.gz → 6.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {sae_lens-6.9.0 → sae_lens-6.9.1}/PKG-INFO +1 -1
  2. {sae_lens-6.9.0 → sae_lens-6.9.1}/pyproject.toml +1 -1
  3. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/loading/pretrained_sae_loaders.py +91 -11
  5. {sae_lens-6.9.0 → sae_lens-6.9.1}/LICENSE +0 -0
  6. {sae_lens-6.9.0 → sae_lens-6.9.1}/README.md +0 -0
  7. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/analysis/__init__.py +0 -0
  8. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  9. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  10. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/cache_activations_runner.py +0 -0
  11. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/config.py +0 -0
  12. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/constants.py +0 -0
  13. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/evals.py +0 -0
  14. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/llm_sae_training_runner.py +0 -0
  15. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/load_model.py +0 -0
  16. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/loading/__init__.py +0 -0
  17. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  18. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/pretokenize_runner.py +0 -0
  19. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/pretrained_saes.yaml +0 -0
  20. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/registry.py +0 -0
  21. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/saes/__init__.py +0 -0
  22. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/saes/batchtopk_sae.py +0 -0
  23. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/saes/gated_sae.py +0 -0
  24. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/saes/jumprelu_sae.py +0 -0
  25. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/saes/sae.py +0 -0
  26. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/saes/standard_sae.py +0 -0
  27. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/saes/topk_sae.py +0 -0
  28. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/saes/transcoder.py +0 -0
  29. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/tokenization_and_batching.py +0 -0
  30. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/training/__init__.py +0 -0
  31. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/training/activation_scaler.py +0 -0
  32. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/training/activations_store.py +0 -0
  33. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/training/mixing_buffer.py +0 -0
  34. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/training/optim.py +0 -0
  35. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/training/sae_trainer.py +0 -0
  36. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/training/types.py +0 -0
  37. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  38. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/tutorial/tsea.py +0 -0
  39. {sae_lens-6.9.0 → sae_lens-6.9.1}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.9.0
3
+ Version: 6.9.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.9.0"
3
+ version = "6.9.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.9.0"
2
+ __version__ = "6.9.1"
3
3
 
4
4
  import logging
5
5
 
@@ -1,12 +1,14 @@
1
1
  import json
2
2
  import re
3
+ import warnings
3
4
  from pathlib import Path
4
5
  from typing import Any, Protocol
5
6
 
6
7
  import numpy as np
8
+ import requests
7
9
  import torch
8
10
  import yaml
9
- from huggingface_hub import hf_hub_download
11
+ from huggingface_hub import hf_hub_download, hf_hub_url
10
12
  from huggingface_hub.utils import EntryNotFoundError
11
13
  from packaging.version import Version
12
14
  from safetensors import safe_open
@@ -1330,6 +1332,48 @@ def mwhanna_transcoder_huggingface_loader(
1330
1332
  return cfg_dict, state_dict, None
1331
1333
 
1332
1334
 
1335
+ def get_safetensors_tensor_shapes(url: str) -> dict[str, list[int]]:
1336
+ """
1337
+ Get tensor shapes from a safetensors file using HTTP range requests
1338
+ without downloading the entire file.
1339
+
1340
+ Args:
1341
+ url: Direct URL to the safetensors file
1342
+
1343
+ Returns:
1344
+ Dictionary mapping tensor names to their shapes
1345
+ """
1346
+ # Check if server supports range requests
1347
+ response = requests.head(url, timeout=10)
1348
+ response.raise_for_status()
1349
+
1350
+ accept_ranges = response.headers.get("Accept-Ranges", "")
1351
+ if "bytes" not in accept_ranges:
1352
+ raise ValueError("Server does not support range requests")
1353
+
1354
+ # Fetch first 8 bytes to get metadata size
1355
+ headers = {"Range": "bytes=0-7"}
1356
+ response = requests.get(url, headers=headers, timeout=10)
1357
+ if response.status_code != 206:
1358
+ raise ValueError("Failed to fetch initial bytes for metadata size")
1359
+
1360
+ meta_size = int.from_bytes(response.content, byteorder="little")
1361
+
1362
+ # Fetch the metadata header
1363
+ headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
1364
+ response = requests.get(url, headers=headers, timeout=10)
1365
+ if response.status_code != 206:
1366
+ raise ValueError("Failed to fetch metadata header")
1367
+
1368
+ metadata_json = response.content.decode("utf-8").strip()
1369
+ metadata = json.loads(metadata_json)
1370
+
1371
+ # Extract tensor shapes, excluding the __metadata__ key
1372
+ return {
1373
+ name: info["shape"] for name, info in metadata.items() if name != "__metadata__"
1374
+ }
1375
+
1376
+
1333
1377
  def mntss_clt_layer_huggingface_loader(
1334
1378
  repo_id: str,
1335
1379
  folder_name: str,
@@ -1341,11 +1385,20 @@ def mntss_clt_layer_huggingface_loader(
1341
1385
  Load a MNTSS CLT layer as a single layer transcoder.
1342
1386
  The assumption is that the `folder_name` is the layer to load as an int
1343
1387
  """
1344
- base_config_path = hf_hub_download(
1345
- repo_id, "config.yaml", force_download=force_download
1388
+
1389
+ # warn that this sums the decoders together, so should only be used to find feature activations, not for reconstruction
1390
+ warnings.warn(
1391
+ "This loads the CLT layer as a single layer transcoder by summing all decoders together. This should only be used to find feature activations, not for reconstruction",
1392
+ UserWarning,
1393
+ )
1394
+
1395
+ cfg_dict = get_mntss_clt_layer_config_from_hf(
1396
+ repo_id,
1397
+ folder_name,
1398
+ device,
1399
+ force_download,
1400
+ cfg_overrides,
1346
1401
  )
1347
- with open(base_config_path) as f:
1348
- cfg_info: dict[str, Any] = yaml.safe_load(f)
1349
1402
 
1350
1403
  # We need to actually load the weights, since the config is missing most information
1351
1404
  encoder_path = hf_hub_download(
@@ -1370,11 +1423,39 @@ def mntss_clt_layer_huggingface_loader(
1370
1423
  "W_dec": decoder_state_dict[f"W_dec_{folder_name}"].sum(dim=1), # type: ignore
1371
1424
  }
1372
1425
 
1373
- cfg_dict = {
1426
+ return cfg_dict, state_dict, None
1427
+
1428
+
1429
+ def get_mntss_clt_layer_config_from_hf(
1430
+ repo_id: str,
1431
+ folder_name: str,
1432
+ device: str,
1433
+ force_download: bool = False, # noqa: ARG001
1434
+ cfg_overrides: dict[str, Any] | None = None,
1435
+ ) -> dict[str, Any]:
1436
+ """
1437
+ Load a MNTSS CLT layer as a single layer transcoder.
1438
+ The assumption is that the `folder_name` is the layer to load as an int
1439
+ """
1440
+ base_config_path = hf_hub_download(
1441
+ repo_id, "config.yaml", force_download=force_download
1442
+ )
1443
+ with open(base_config_path) as f:
1444
+ cfg_info: dict[str, Any] = yaml.safe_load(f)
1445
+
1446
+ # Get tensor shapes without downloading full files using HTTP range requests
1447
+ encoder_url = hf_hub_url(repo_id, f"W_enc_{folder_name}.safetensors")
1448
+ encoder_shapes = get_safetensors_tensor_shapes(encoder_url)
1449
+
1450
+ # Extract shapes for the required tensors
1451
+ b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
1452
+ b_enc_shape = encoder_shapes[f"b_enc_{folder_name}"]
1453
+
1454
+ return {
1374
1455
  "architecture": "transcoder",
1375
- "d_in": state_dict["b_dec"].shape[0],
1376
- "d_out": state_dict["b_dec"].shape[0],
1377
- "d_sae": state_dict["b_enc"].shape[0],
1456
+ "d_in": b_dec_shape[0],
1457
+ "d_out": b_dec_shape[0],
1458
+ "d_sae": b_enc_shape[0],
1378
1459
  "dtype": "float32",
1379
1460
  "device": device if device is not None else "cpu",
1380
1461
  "activation_fn": "relu",
@@ -1387,8 +1468,6 @@ def mntss_clt_layer_huggingface_loader(
1387
1468
  **(cfg_overrides or {}),
1388
1469
  }
1389
1470
 
1390
- return cfg_dict, state_dict, None
1391
-
1392
1471
 
1393
1472
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1394
1473
  "sae_lens": sae_lens_huggingface_loader,
@@ -1416,4 +1495,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1416
1495
  "sparsify": get_sparsify_config_from_hf,
1417
1496
  "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1418
1497
  "mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
1498
+ "mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1419
1499
  }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes