sae-lens 6.9.0__py3-none-any.whl → 6.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/config.py +13 -5
- sae_lens/llm_sae_training_runner.py +53 -8
- sae_lens/loading/pretrained_sae_loaders.py +91 -11
- sae_lens/training/sae_trainer.py +25 -29
- sae_lens/util.py +18 -0
- {sae_lens-6.9.0.dist-info → sae_lens-6.10.0.dist-info}/METADATA +1 -1
- {sae_lens-6.9.0.dist-info → sae_lens-6.10.0.dist-info}/RECORD +10 -10
- {sae_lens-6.9.0.dist-info → sae_lens-6.10.0.dist-info}/LICENSE +0 -0
- {sae_lens-6.9.0.dist-info → sae_lens-6.10.0.dist-info}/WHEEL +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/config.py
CHANGED
|
@@ -169,8 +169,10 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
169
169
|
eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
|
|
170
170
|
logger (LoggingConfig): Configuration for logging (e.g. W&B).
|
|
171
171
|
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
172
|
-
checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
|
|
173
|
-
|
|
172
|
+
checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
|
|
173
|
+
save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
|
|
174
|
+
output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
|
|
175
|
+
verbose (bool): Whether to print verbose output. (default is True)
|
|
174
176
|
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
175
177
|
model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
176
178
|
sae_lens_version (str): The version of the sae_lens library.
|
|
@@ -254,9 +256,13 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
254
256
|
|
|
255
257
|
logger: LoggingConfig = field(default_factory=LoggingConfig)
|
|
256
258
|
|
|
257
|
-
#
|
|
259
|
+
# Outputs/Checkpoints
|
|
258
260
|
n_checkpoints: int = 0
|
|
259
|
-
checkpoint_path: str = "checkpoints"
|
|
261
|
+
checkpoint_path: str | None = "checkpoints"
|
|
262
|
+
save_final_checkpoint: bool = False
|
|
263
|
+
output_path: str | None = "output"
|
|
264
|
+
|
|
265
|
+
# Misc
|
|
260
266
|
verbose: bool = True
|
|
261
267
|
model_kwargs: dict[str, Any] = dict_field(default={})
|
|
262
268
|
model_from_pretrained_kwargs: dict[str, Any] | None = dict_field(default=None)
|
|
@@ -394,6 +400,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
394
400
|
return SAETrainerConfig(
|
|
395
401
|
n_checkpoints=self.n_checkpoints,
|
|
396
402
|
checkpoint_path=self.checkpoint_path,
|
|
403
|
+
save_final_checkpoint=self.save_final_checkpoint,
|
|
397
404
|
total_training_samples=self.total_training_tokens,
|
|
398
405
|
device=self.device,
|
|
399
406
|
autocast=self.autocast,
|
|
@@ -618,7 +625,8 @@ class PretokenizeRunnerConfig:
|
|
|
618
625
|
@dataclass
|
|
619
626
|
class SAETrainerConfig:
|
|
620
627
|
n_checkpoints: int
|
|
621
|
-
checkpoint_path: str
|
|
628
|
+
checkpoint_path: str | None
|
|
629
|
+
save_final_checkpoint: bool
|
|
622
630
|
total_training_samples: int
|
|
623
631
|
device: str
|
|
624
632
|
autocast: bool
|
|
@@ -8,13 +8,18 @@ from typing import Any, Generic
|
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import wandb
|
|
11
|
+
from safetensors.torch import save_file
|
|
11
12
|
from simple_parsing import ArgumentParser
|
|
12
13
|
from transformer_lens.hook_points import HookedRootModule
|
|
13
14
|
from typing_extensions import deprecated
|
|
14
15
|
|
|
15
16
|
from sae_lens import logger
|
|
16
17
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
17
|
-
from sae_lens.constants import
|
|
18
|
+
from sae_lens.constants import (
|
|
19
|
+
ACTIVATIONS_STORE_STATE_FILENAME,
|
|
20
|
+
RUNNER_CFG_FILENAME,
|
|
21
|
+
SPARSITY_FILENAME,
|
|
22
|
+
)
|
|
18
23
|
from sae_lens.evals import EvalConfig, run_evals
|
|
19
24
|
from sae_lens.load_model import load_model
|
|
20
25
|
from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
|
|
@@ -185,11 +190,47 @@ class LanguageModelSAETrainingRunner:
|
|
|
185
190
|
self._compile_if_needed()
|
|
186
191
|
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
187
192
|
|
|
193
|
+
if self.cfg.output_path is not None:
|
|
194
|
+
self.save_final_sae(
|
|
195
|
+
sae=sae,
|
|
196
|
+
output_path=self.cfg.output_path,
|
|
197
|
+
log_feature_sparsity=trainer.log_feature_sparsity,
|
|
198
|
+
)
|
|
199
|
+
|
|
188
200
|
if self.cfg.logger.log_to_wandb:
|
|
189
201
|
wandb.finish()
|
|
190
202
|
|
|
191
203
|
return sae
|
|
192
204
|
|
|
205
|
+
def save_final_sae(
|
|
206
|
+
self,
|
|
207
|
+
sae: TrainingSAE[Any],
|
|
208
|
+
output_path: str,
|
|
209
|
+
log_feature_sparsity: torch.Tensor | None = None,
|
|
210
|
+
):
|
|
211
|
+
base_output_path = Path(output_path)
|
|
212
|
+
base_output_path.mkdir(exist_ok=True, parents=True)
|
|
213
|
+
|
|
214
|
+
weights_path, cfg_path = sae.save_inference_model(str(base_output_path))
|
|
215
|
+
|
|
216
|
+
sparsity_path = None
|
|
217
|
+
if log_feature_sparsity is not None:
|
|
218
|
+
sparsity_path = base_output_path / SPARSITY_FILENAME
|
|
219
|
+
save_file({"sparsity": log_feature_sparsity}, sparsity_path)
|
|
220
|
+
|
|
221
|
+
runner_config = self.cfg.to_dict()
|
|
222
|
+
with open(base_output_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
223
|
+
json.dump(runner_config, f)
|
|
224
|
+
|
|
225
|
+
if self.cfg.logger.log_to_wandb:
|
|
226
|
+
self.cfg.logger.log(
|
|
227
|
+
self,
|
|
228
|
+
weights_path,
|
|
229
|
+
cfg_path,
|
|
230
|
+
sparsity_path=sparsity_path,
|
|
231
|
+
wandb_aliases=["final_model"],
|
|
232
|
+
)
|
|
233
|
+
|
|
193
234
|
def _set_sae_metadata(self):
|
|
194
235
|
self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
|
|
195
236
|
self.sae.cfg.metadata.hook_name = self.cfg.hook_name
|
|
@@ -247,20 +288,24 @@ class LanguageModelSAETrainingRunner:
|
|
|
247
288
|
sae = trainer.fit()
|
|
248
289
|
|
|
249
290
|
except (KeyboardInterrupt, InterruptedException):
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
291
|
+
if self.cfg.checkpoint_path is not None:
|
|
292
|
+
logger.warning("interrupted, saving progress")
|
|
293
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / str(
|
|
294
|
+
trainer.n_training_samples
|
|
295
|
+
)
|
|
296
|
+
self.save_checkpoint(checkpoint_path)
|
|
297
|
+
logger.info("done saving")
|
|
256
298
|
raise
|
|
257
299
|
|
|
258
300
|
return sae
|
|
259
301
|
|
|
260
302
|
def save_checkpoint(
|
|
261
303
|
self,
|
|
262
|
-
checkpoint_path: Path,
|
|
304
|
+
checkpoint_path: Path | None,
|
|
263
305
|
) -> None:
|
|
306
|
+
if checkpoint_path is None:
|
|
307
|
+
return
|
|
308
|
+
|
|
264
309
|
self.activations_store.save(
|
|
265
310
|
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
266
311
|
)
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
|
+
import warnings
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Any, Protocol
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
8
|
+
import requests
|
|
7
9
|
import torch
|
|
8
10
|
import yaml
|
|
9
|
-
from huggingface_hub import hf_hub_download
|
|
11
|
+
from huggingface_hub import hf_hub_download, hf_hub_url
|
|
10
12
|
from huggingface_hub.utils import EntryNotFoundError
|
|
11
13
|
from packaging.version import Version
|
|
12
14
|
from safetensors import safe_open
|
|
@@ -1330,6 +1332,48 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1330
1332
|
return cfg_dict, state_dict, None
|
|
1331
1333
|
|
|
1332
1334
|
|
|
1335
|
+
def get_safetensors_tensor_shapes(url: str) -> dict[str, list[int]]:
|
|
1336
|
+
"""
|
|
1337
|
+
Get tensor shapes from a safetensors file using HTTP range requests
|
|
1338
|
+
without downloading the entire file.
|
|
1339
|
+
|
|
1340
|
+
Args:
|
|
1341
|
+
url: Direct URL to the safetensors file
|
|
1342
|
+
|
|
1343
|
+
Returns:
|
|
1344
|
+
Dictionary mapping tensor names to their shapes
|
|
1345
|
+
"""
|
|
1346
|
+
# Check if server supports range requests
|
|
1347
|
+
response = requests.head(url, timeout=10)
|
|
1348
|
+
response.raise_for_status()
|
|
1349
|
+
|
|
1350
|
+
accept_ranges = response.headers.get("Accept-Ranges", "")
|
|
1351
|
+
if "bytes" not in accept_ranges:
|
|
1352
|
+
raise ValueError("Server does not support range requests")
|
|
1353
|
+
|
|
1354
|
+
# Fetch first 8 bytes to get metadata size
|
|
1355
|
+
headers = {"Range": "bytes=0-7"}
|
|
1356
|
+
response = requests.get(url, headers=headers, timeout=10)
|
|
1357
|
+
if response.status_code != 206:
|
|
1358
|
+
raise ValueError("Failed to fetch initial bytes for metadata size")
|
|
1359
|
+
|
|
1360
|
+
meta_size = int.from_bytes(response.content, byteorder="little")
|
|
1361
|
+
|
|
1362
|
+
# Fetch the metadata header
|
|
1363
|
+
headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1364
|
+
response = requests.get(url, headers=headers, timeout=10)
|
|
1365
|
+
if response.status_code != 206:
|
|
1366
|
+
raise ValueError("Failed to fetch metadata header")
|
|
1367
|
+
|
|
1368
|
+
metadata_json = response.content.decode("utf-8").strip()
|
|
1369
|
+
metadata = json.loads(metadata_json)
|
|
1370
|
+
|
|
1371
|
+
# Extract tensor shapes, excluding the __metadata__ key
|
|
1372
|
+
return {
|
|
1373
|
+
name: info["shape"] for name, info in metadata.items() if name != "__metadata__"
|
|
1374
|
+
}
|
|
1375
|
+
|
|
1376
|
+
|
|
1333
1377
|
def mntss_clt_layer_huggingface_loader(
|
|
1334
1378
|
repo_id: str,
|
|
1335
1379
|
folder_name: str,
|
|
@@ -1341,11 +1385,20 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1341
1385
|
Load a MNTSS CLT layer as a single layer transcoder.
|
|
1342
1386
|
The assumption is that the `folder_name` is the layer to load as an int
|
|
1343
1387
|
"""
|
|
1344
|
-
|
|
1345
|
-
|
|
1388
|
+
|
|
1389
|
+
# warn that this sums the decoders together, so should only be used to find feature activations, not for reconstruction
|
|
1390
|
+
warnings.warn(
|
|
1391
|
+
"This loads the CLT layer as a single layer transcoder by summing all decoders together. This should only be used to find feature activations, not for reconstruction",
|
|
1392
|
+
UserWarning,
|
|
1393
|
+
)
|
|
1394
|
+
|
|
1395
|
+
cfg_dict = get_mntss_clt_layer_config_from_hf(
|
|
1396
|
+
repo_id,
|
|
1397
|
+
folder_name,
|
|
1398
|
+
device,
|
|
1399
|
+
force_download,
|
|
1400
|
+
cfg_overrides,
|
|
1346
1401
|
)
|
|
1347
|
-
with open(base_config_path) as f:
|
|
1348
|
-
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1349
1402
|
|
|
1350
1403
|
# We need to actually load the weights, since the config is missing most information
|
|
1351
1404
|
encoder_path = hf_hub_download(
|
|
@@ -1370,11 +1423,39 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1370
1423
|
"W_dec": decoder_state_dict[f"W_dec_{folder_name}"].sum(dim=1), # type: ignore
|
|
1371
1424
|
}
|
|
1372
1425
|
|
|
1373
|
-
cfg_dict
|
|
1426
|
+
return cfg_dict, state_dict, None
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
def get_mntss_clt_layer_config_from_hf(
|
|
1430
|
+
repo_id: str,
|
|
1431
|
+
folder_name: str,
|
|
1432
|
+
device: str,
|
|
1433
|
+
force_download: bool = False, # noqa: ARG001
|
|
1434
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1435
|
+
) -> dict[str, Any]:
|
|
1436
|
+
"""
|
|
1437
|
+
Load a MNTSS CLT layer as a single layer transcoder.
|
|
1438
|
+
The assumption is that the `folder_name` is the layer to load as an int
|
|
1439
|
+
"""
|
|
1440
|
+
base_config_path = hf_hub_download(
|
|
1441
|
+
repo_id, "config.yaml", force_download=force_download
|
|
1442
|
+
)
|
|
1443
|
+
with open(base_config_path) as f:
|
|
1444
|
+
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1445
|
+
|
|
1446
|
+
# Get tensor shapes without downloading full files using HTTP range requests
|
|
1447
|
+
encoder_url = hf_hub_url(repo_id, f"W_enc_{folder_name}.safetensors")
|
|
1448
|
+
encoder_shapes = get_safetensors_tensor_shapes(encoder_url)
|
|
1449
|
+
|
|
1450
|
+
# Extract shapes for the required tensors
|
|
1451
|
+
b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
|
|
1452
|
+
b_enc_shape = encoder_shapes[f"b_enc_{folder_name}"]
|
|
1453
|
+
|
|
1454
|
+
return {
|
|
1374
1455
|
"architecture": "transcoder",
|
|
1375
|
-
"d_in":
|
|
1376
|
-
"d_out":
|
|
1377
|
-
"d_sae":
|
|
1456
|
+
"d_in": b_dec_shape[0],
|
|
1457
|
+
"d_out": b_dec_shape[0],
|
|
1458
|
+
"d_sae": b_enc_shape[0],
|
|
1378
1459
|
"dtype": "float32",
|
|
1379
1460
|
"device": device if device is not None else "cpu",
|
|
1380
1461
|
"activation_fn": "relu",
|
|
@@ -1387,8 +1468,6 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1387
1468
|
**(cfg_overrides or {}),
|
|
1388
1469
|
}
|
|
1389
1470
|
|
|
1390
|
-
return cfg_dict, state_dict, None
|
|
1391
|
-
|
|
1392
1471
|
|
|
1393
1472
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1394
1473
|
"sae_lens": sae_lens_huggingface_loader,
|
|
@@ -1416,4 +1495,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1416
1495
|
"sparsify": get_sparsify_config_from_hf,
|
|
1417
1496
|
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
|
|
1418
1497
|
"mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
|
|
1498
|
+
"mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
|
|
1419
1499
|
}
|
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -22,6 +22,7 @@ from sae_lens.saes.sae import (
|
|
|
22
22
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
23
23
|
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
|
|
24
24
|
from sae_lens.training.types import DataProvider
|
|
25
|
+
from sae_lens.util import path_or_tmp_dir
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def _log_feature_sparsity(
|
|
@@ -40,7 +41,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
|
|
|
40
41
|
class SaveCheckpointFn(Protocol):
|
|
41
42
|
def __call__(
|
|
42
43
|
self,
|
|
43
|
-
checkpoint_path: Path,
|
|
44
|
+
checkpoint_path: Path | None,
|
|
44
45
|
) -> None: ...
|
|
45
46
|
|
|
46
47
|
|
|
@@ -187,12 +188,8 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
187
188
|
)
|
|
188
189
|
self.activation_scaler.scaling_factor = None
|
|
189
190
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
checkpoint_name=f"final_{self.n_training_samples}",
|
|
193
|
-
wandb_aliases=["final_model"],
|
|
194
|
-
save_inference_model=True,
|
|
195
|
-
)
|
|
191
|
+
if self.cfg.save_final_checkpoint:
|
|
192
|
+
self.save_checkpoint(checkpoint_name=f"final_{self.n_training_samples}")
|
|
196
193
|
|
|
197
194
|
pbar.close()
|
|
198
195
|
return self.sae
|
|
@@ -201,32 +198,31 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
201
198
|
self,
|
|
202
199
|
checkpoint_name: str,
|
|
203
200
|
wandb_aliases: list[str] | None = None,
|
|
204
|
-
save_inference_model: bool = False,
|
|
205
201
|
) -> None:
|
|
206
|
-
checkpoint_path =
|
|
207
|
-
checkpoint_path.
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
if save_inference_model
|
|
212
|
-
else self.sae.save_model
|
|
213
|
-
)
|
|
214
|
-
weights_path, cfg_path = save_fn(str(checkpoint_path))
|
|
202
|
+
checkpoint_path = None
|
|
203
|
+
if self.cfg.checkpoint_path is not None or self.cfg.logger.log_to_wandb:
|
|
204
|
+
with path_or_tmp_dir(self.cfg.checkpoint_path) as base_checkpoint_path:
|
|
205
|
+
checkpoint_path = base_checkpoint_path / checkpoint_name
|
|
206
|
+
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
215
207
|
|
|
216
|
-
|
|
217
|
-
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
208
|
+
weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))
|
|
218
209
|
|
|
219
|
-
|
|
220
|
-
|
|
210
|
+
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
211
|
+
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
221
212
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
213
|
+
activation_scaler_path = (
|
|
214
|
+
checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
|
|
215
|
+
)
|
|
216
|
+
self.activation_scaler.save(str(activation_scaler_path))
|
|
217
|
+
|
|
218
|
+
if self.cfg.logger.log_to_wandb:
|
|
219
|
+
self.cfg.logger.log(
|
|
220
|
+
self,
|
|
221
|
+
weights_path,
|
|
222
|
+
cfg_path,
|
|
223
|
+
sparsity_path=sparsity_path,
|
|
224
|
+
wandb_aliases=wandb_aliases,
|
|
225
|
+
)
|
|
230
226
|
|
|
231
227
|
if self.save_checkpoint_fn is not None:
|
|
232
228
|
self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
|
sae_lens/util.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import re
|
|
2
|
+
import tempfile
|
|
3
|
+
from contextlib import contextmanager
|
|
2
4
|
from dataclasses import asdict, fields, is_dataclass
|
|
5
|
+
from pathlib import Path
|
|
3
6
|
from typing import Sequence, TypeVar
|
|
4
7
|
|
|
5
8
|
K = TypeVar("K")
|
|
@@ -45,3 +48,18 @@ def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
|
45
48
|
"""
|
|
46
49
|
hook_match = re.search(r"\.(\d+)\.", hook_name)
|
|
47
50
|
return None if hook_match is None else int(hook_match.group(1))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@contextmanager
|
|
54
|
+
def path_or_tmp_dir(path: str | Path | None):
|
|
55
|
+
"""Context manager that yields a concrete Path for path.
|
|
56
|
+
|
|
57
|
+
- If path is None, creates a TemporaryDirectory and yields its Path.
|
|
58
|
+
The directory is cleaned up on context exit.
|
|
59
|
+
- Otherwise, yields Path(path) without creating or cleaning.
|
|
60
|
+
"""
|
|
61
|
+
if path is None:
|
|
62
|
+
with tempfile.TemporaryDirectory() as td:
|
|
63
|
+
yield Path(td)
|
|
64
|
+
else:
|
|
65
|
+
yield Path(path)
|
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=k8M2SyKNE3KpipPxODICdLG8KJNVvf1Zab4KNJuGWMQ,3589
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=cNeAtp2JQ_vKbeddZVM-tcPLYyyfTWL8NDna5KQpkLI,12583
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
6
|
+
sae_lens/config.py,sha256=IdRXSKPfYY3hwUovj-u83eep8z52gkJHII0mY0KseYY,28739
|
|
7
7
|
sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
|
|
8
8
|
sae_lens/evals.py,sha256=4hanbyG8qZLItWqft94F4ZjUoytPVB7fw5s0P4Oi0VE,39504
|
|
9
|
-
sae_lens/llm_sae_training_runner.py,sha256=
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=sJTcDX1bUJJ_jZLUT88-8KUYIAPeUGoXktX68PsBqw0,15137
|
|
10
10
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=CVzHntSUKR1X3_gAqn8K_Ajq8D85qBrmrgEgU93IV4A,49609
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
14
14
|
sae_lens/pretokenize_runner.py,sha256=w0f6SfZLAxbp5eAAKnet8RqUB_DKofZ9RGsoJwFnYbA,7058
|
|
15
15
|
sae_lens/pretrained_saes.yaml,sha256=d6FYfWTdVAPlOCM55C1ICS6lF9nWPPVNwjlXCa9p7NU,600468
|
|
@@ -28,12 +28,12 @@ sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOw
|
|
|
28
28
|
sae_lens/training/activations_store.py,sha256=2EUY2abqpT5El3T95sypM_JRDgiKL3VeT73U9SQIFGY,32903
|
|
29
29
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
30
30
|
sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
|
|
31
|
-
sae_lens/training/sae_trainer.py,sha256=
|
|
31
|
+
sae_lens/training/sae_trainer.py,sha256=Jh5AyBGtfZjnprv9H3k0p_luWWnM7YFjlmHuO1W_J6U,15465
|
|
32
32
|
sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
|
|
33
33
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
34
34
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
35
|
-
sae_lens/util.py,sha256=
|
|
36
|
-
sae_lens-6.
|
|
37
|
-
sae_lens-6.
|
|
38
|
-
sae_lens-6.
|
|
39
|
-
sae_lens-6.
|
|
35
|
+
sae_lens/util.py,sha256=lW7fBn_b8quvRYlen9PUmB7km60YhKyjmuelB1f6KzQ,2253
|
|
36
|
+
sae_lens-6.10.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
37
|
+
sae_lens-6.10.0.dist-info/METADATA,sha256=7Yq4_hrZVc2CBB4nMvgy_BGFjT5FrF3SfOo8LnJ18Rg,5245
|
|
38
|
+
sae_lens-6.10.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
39
|
+
sae_lens-6.10.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|