sae-lens 6.30.1__tar.gz → 6.31.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.30.1 → sae_lens-6.31.0}/PKG-INFO +1 -1
- {sae_lens-6.30.1 → sae_lens-6.31.0}/pyproject.toml +1 -1
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/config.py +9 -1
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/evals.py +2 -2
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/temporal_sae.py +1 -1
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/training/activation_scaler.py +3 -1
- {sae_lens-6.30.1 → sae_lens-6.31.0}/LICENSE +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/README.md +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/__init__.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/activation_generator.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/correlation.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/evals.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/feature_dictionary.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/firing_probabilities.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/hierarchy.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/initialization.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/plotting.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/synthetic/training.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.30.1 → sae_lens-6.31.0}/sae_lens/util.py +0 -0
|
@@ -82,6 +82,7 @@ class LoggingConfig:
|
|
|
82
82
|
log_to_wandb: bool = True
|
|
83
83
|
log_activations_store_to_wandb: bool = False
|
|
84
84
|
log_optimizer_state_to_wandb: bool = False
|
|
85
|
+
log_weights_to_wandb: bool = True
|
|
85
86
|
wandb_project: str = "sae_lens_training"
|
|
86
87
|
wandb_id: str | None = None
|
|
87
88
|
run_name: str | None = None
|
|
@@ -107,7 +108,8 @@ class LoggingConfig:
|
|
|
107
108
|
type="model",
|
|
108
109
|
metadata=dict(trainer.cfg.__dict__),
|
|
109
110
|
)
|
|
110
|
-
|
|
111
|
+
if self.log_weights_to_wandb:
|
|
112
|
+
model_artifact.add_file(str(weights_path))
|
|
111
113
|
model_artifact.add_file(str(cfg_path))
|
|
112
114
|
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
113
115
|
|
|
@@ -557,6 +559,12 @@ class CacheActivationsRunnerConfig:
|
|
|
557
559
|
context_size=self.context_size,
|
|
558
560
|
)
|
|
559
561
|
|
|
562
|
+
if self.context_size > self.training_tokens:
|
|
563
|
+
raise ValueError(
|
|
564
|
+
f"context_size ({self.context_size}) is greater than training_tokens "
|
|
565
|
+
f"({self.training_tokens}). Please reduce context_size or increase training_tokens."
|
|
566
|
+
)
|
|
567
|
+
|
|
560
568
|
if self.new_cached_activations_path is None:
|
|
561
569
|
self.new_cached_activations_path = _default_cached_activations_path( # type: ignore
|
|
562
570
|
self.dataset_path, self.model_name, self.hook_name, None
|
|
@@ -335,7 +335,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
335
335
|
|
|
336
336
|
batch_iter = range(n_batches)
|
|
337
337
|
if verbose:
|
|
338
|
-
batch_iter = tqdm(batch_iter, desc="Reconstruction Batches")
|
|
338
|
+
batch_iter = tqdm(batch_iter, desc="Reconstruction Batches", leave=False)
|
|
339
339
|
|
|
340
340
|
for _ in batch_iter:
|
|
341
341
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
@@ -430,7 +430,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
430
430
|
|
|
431
431
|
batch_iter = range(n_batches)
|
|
432
432
|
if verbose:
|
|
433
|
-
batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches")
|
|
433
|
+
batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches", leave=False)
|
|
434
434
|
|
|
435
435
|
for _ in batch_iter:
|
|
436
436
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
@@ -28,7 +28,9 @@ class ActivationScaler:
|
|
|
28
28
|
) -> float:
|
|
29
29
|
norms_per_batch: list[float] = []
|
|
30
30
|
for _ in tqdm(
|
|
31
|
-
range(n_batches_for_norm_estimate),
|
|
31
|
+
range(n_batches_for_norm_estimate),
|
|
32
|
+
desc="Estimating norm scaling factor",
|
|
33
|
+
leave=False,
|
|
32
34
|
):
|
|
33
35
|
acts = next(data_provider)
|
|
34
36
|
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|