sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +55 -18
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +105 -235
- sae_lens/constants.py +20 -0
- sae_lens/evals.py +34 -31
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +103 -70
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +36 -10
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +248 -273
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +105 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +134 -158
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +47 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc3.dist-info/RECORD +38 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
sae_lens/training/optim.py
CHANGED
|
@@ -101,61 +101,85 @@ def _get_main_lr_scheduler(
|
|
|
101
101
|
raise ValueError(f"Unsupported scheduler: {scheduler_name}")
|
|
102
102
|
|
|
103
103
|
|
|
104
|
-
class
|
|
104
|
+
class CoefficientScheduler:
|
|
105
|
+
"""Linearly warms up a scalar value from 0.0 to a final value."""
|
|
106
|
+
|
|
105
107
|
def __init__(
|
|
106
108
|
self,
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
final_l1_coefficient: float,
|
|
109
|
+
warm_up_steps: float,
|
|
110
|
+
final_value: float,
|
|
110
111
|
):
|
|
111
|
-
self.
|
|
112
|
-
|
|
113
|
-
if self.l1_warmup_steps != 0:
|
|
114
|
-
self.current_l1_coefficient = 0.0
|
|
115
|
-
else:
|
|
116
|
-
self.current_l1_coefficient = final_l1_coefficient
|
|
117
|
-
|
|
118
|
-
self.final_l1_coefficient = final_l1_coefficient
|
|
119
|
-
|
|
112
|
+
self.warm_up_steps = warm_up_steps
|
|
113
|
+
self.final_value = final_value
|
|
120
114
|
self.current_step = 0
|
|
121
|
-
|
|
122
|
-
if not isinstance(self.
|
|
115
|
+
|
|
116
|
+
if not isinstance(self.final_value, (float, int)):
|
|
123
117
|
raise TypeError(
|
|
124
|
-
f"
|
|
118
|
+
f"final_value must be float or int, got {type(self.final_value)}."
|
|
125
119
|
)
|
|
126
120
|
|
|
121
|
+
# Initialize current_value based on whether warm-up is used
|
|
122
|
+
if self.warm_up_steps > 0:
|
|
123
|
+
self.current_value = 0.0
|
|
124
|
+
else:
|
|
125
|
+
self.current_value = self.final_value
|
|
126
|
+
|
|
127
127
|
def __repr__(self) -> str:
|
|
128
128
|
return (
|
|
129
|
-
f"
|
|
130
|
-
f"
|
|
131
|
-
f"total_steps={self.total_steps})"
|
|
129
|
+
f"{self.__class__.__name__}(final_value={self.final_value}, "
|
|
130
|
+
f"warm_up_steps={self.warm_up_steps})"
|
|
132
131
|
)
|
|
133
132
|
|
|
134
|
-
def step(self):
|
|
133
|
+
def step(self) -> float:
|
|
135
134
|
"""
|
|
136
|
-
Updates the
|
|
135
|
+
Updates the scalar value based on the current step.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The current scalar value after the step.
|
|
137
139
|
"""
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
) # type: ignore
|
|
140
|
+
if self.current_step < self.warm_up_steps:
|
|
141
|
+
self.current_value = self.final_value * (
|
|
142
|
+
(self.current_step + 1) / self.warm_up_steps
|
|
143
|
+
)
|
|
143
144
|
else:
|
|
144
|
-
|
|
145
|
+
# Ensure the value stays at final_value after warm-up
|
|
146
|
+
self.current_value = self.final_value
|
|
145
147
|
|
|
146
148
|
self.current_step += 1
|
|
149
|
+
return self.current_value
|
|
147
150
|
|
|
148
|
-
|
|
149
|
-
|
|
151
|
+
@property
|
|
152
|
+
def value(self) -> float:
|
|
153
|
+
"""Returns the current scalar value."""
|
|
154
|
+
return self.current_value
|
|
155
|
+
|
|
156
|
+
def state_dict(self) -> dict[str, Any]:
|
|
157
|
+
"""State dict for serialization."""
|
|
150
158
|
return {
|
|
151
|
-
"
|
|
152
|
-
"
|
|
153
|
-
"current_l1_coefficient": self.current_l1_coefficient,
|
|
154
|
-
"final_l1_coefficient": self.final_l1_coefficient,
|
|
159
|
+
"warm_up_steps": self.warm_up_steps,
|
|
160
|
+
"final_value": self.final_value,
|
|
155
161
|
"current_step": self.current_step,
|
|
162
|
+
"current_value": self.current_value,
|
|
156
163
|
}
|
|
157
164
|
|
|
158
165
|
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
159
|
-
"""Loads
|
|
160
|
-
|
|
161
|
-
|
|
166
|
+
"""Loads the scheduler state."""
|
|
167
|
+
self.warm_up_steps = state_dict["warm_up_steps"]
|
|
168
|
+
self.final_value = state_dict["final_value"]
|
|
169
|
+
self.current_step = state_dict["current_step"]
|
|
170
|
+
# Maintain consistency: re-calculate current_value based on loaded step
|
|
171
|
+
# This handles resuming correctly if stopped mid-warmup.
|
|
172
|
+
if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
|
|
173
|
+
# Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
|
|
174
|
+
step_for_calc = max(0, self.current_step)
|
|
175
|
+
# Recalculate based on the step *before* the one about to be taken
|
|
176
|
+
# Or simply use the saved current_value if available and consistent
|
|
177
|
+
if "current_value" in state_dict:
|
|
178
|
+
self.current_value = state_dict["current_value"]
|
|
179
|
+
else: # Legacy state dicts might not have current_value
|
|
180
|
+
self.current_value = self.final_value * (
|
|
181
|
+
step_for_calc / self.warm_up_steps
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
else:
|
|
185
|
+
self.current_value = self.final_value
|
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -1,26 +1,28 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Generic, Protocol
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
import wandb
|
|
8
|
+
from safetensors.torch import save_file
|
|
7
9
|
from torch.optim import Adam
|
|
8
10
|
from tqdm import tqdm
|
|
9
|
-
from transformer_lens.hook_points import HookedRootModule
|
|
10
11
|
|
|
11
12
|
from sae_lens import __version__
|
|
12
|
-
from sae_lens.config import
|
|
13
|
-
from sae_lens.
|
|
14
|
-
from sae_lens.saes.sae import
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
13
|
+
from sae_lens.config import SAETrainerConfig
|
|
14
|
+
from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
|
|
15
|
+
from sae_lens.saes.sae import (
|
|
16
|
+
T_TRAINING_SAE,
|
|
17
|
+
T_TRAINING_SAE_CONFIG,
|
|
18
|
+
TrainCoefficientConfig,
|
|
19
|
+
TrainingSAE,
|
|
20
|
+
TrainStepInput,
|
|
21
|
+
TrainStepOutput,
|
|
22
|
+
)
|
|
23
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
24
|
+
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
|
|
25
|
+
from sae_lens.training.types import DataProvider
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def _log_feature_sparsity(
|
|
@@ -29,7 +31,7 @@ def _log_feature_sparsity(
|
|
|
29
31
|
return torch.log10(feature_sparsity + eps).detach().cpu()
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
|
|
34
|
+
def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
|
|
33
35
|
"""
|
|
34
36
|
Make sure we record the version of SAELens used for the training run
|
|
35
37
|
"""
|
|
@@ -38,7 +40,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
|
|
|
38
40
|
|
|
39
41
|
@dataclass
|
|
40
42
|
class TrainSAEOutput:
|
|
41
|
-
sae: TrainingSAE
|
|
43
|
+
sae: TrainingSAE[Any]
|
|
42
44
|
checkpoint_path: str
|
|
43
45
|
log_feature_sparsities: torch.Tensor
|
|
44
46
|
|
|
@@ -46,33 +48,39 @@ class TrainSAEOutput:
|
|
|
46
48
|
class SaveCheckpointFn(Protocol):
|
|
47
49
|
def __call__(
|
|
48
50
|
self,
|
|
49
|
-
|
|
50
|
-
checkpoint_name: str,
|
|
51
|
-
wandb_aliases: list[str] | None = None,
|
|
51
|
+
checkpoint_path: Path,
|
|
52
52
|
) -> None: ...
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
|
|
55
|
+
Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str, Any]]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
56
59
|
"""
|
|
57
60
|
Core SAE class used for inference. For training, see TrainingSAE.
|
|
58
61
|
"""
|
|
59
62
|
|
|
63
|
+
data_provider: DataProvider
|
|
64
|
+
activation_scaler: ActivationScaler
|
|
65
|
+
evaluator: Evaluator[T_TRAINING_SAE] | None
|
|
66
|
+
|
|
60
67
|
def __init__(
|
|
61
68
|
self,
|
|
62
|
-
|
|
63
|
-
sae:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
69
|
+
cfg: SAETrainerConfig,
|
|
70
|
+
sae: T_TRAINING_SAE,
|
|
71
|
+
data_provider: DataProvider,
|
|
72
|
+
evaluator: Evaluator[T_TRAINING_SAE] | None = None,
|
|
73
|
+
save_checkpoint_fn: SaveCheckpointFn | None = None,
|
|
67
74
|
) -> None:
|
|
68
|
-
self.model = model
|
|
69
75
|
self.sae = sae
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
76
|
+
self.data_provider = data_provider
|
|
77
|
+
self.evaluator = evaluator
|
|
78
|
+
self.activation_scaler = ActivationScaler()
|
|
79
|
+
self.save_checkpoint_fn = save_checkpoint_fn
|
|
72
80
|
self.cfg = cfg
|
|
73
81
|
|
|
74
82
|
self.n_training_steps: int = 0
|
|
75
|
-
self.
|
|
83
|
+
self.n_training_samples: int = 0
|
|
76
84
|
self.started_fine_tuning: bool = False
|
|
77
85
|
|
|
78
86
|
_update_sae_lens_training_version(self.sae)
|
|
@@ -82,20 +90,16 @@ class SAETrainer:
|
|
|
82
90
|
self.checkpoint_thresholds = list(
|
|
83
91
|
range(
|
|
84
92
|
0,
|
|
85
|
-
cfg.
|
|
86
|
-
cfg.
|
|
93
|
+
cfg.total_training_samples,
|
|
94
|
+
cfg.total_training_samples // self.cfg.n_checkpoints,
|
|
87
95
|
)
|
|
88
96
|
)[1:]
|
|
89
97
|
|
|
90
|
-
self.act_freq_scores = torch.zeros(
|
|
91
|
-
cast(int, cfg.d_sae),
|
|
92
|
-
device=cfg.device,
|
|
93
|
-
)
|
|
98
|
+
self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
|
|
94
99
|
self.n_forward_passes_since_fired = torch.zeros(
|
|
95
|
-
|
|
96
|
-
device=cfg.device,
|
|
100
|
+
sae.cfg.d_sae, device=cfg.device
|
|
97
101
|
)
|
|
98
|
-
self.
|
|
102
|
+
self.n_frac_active_samples = 0
|
|
99
103
|
# we don't train the scaling factor (initially)
|
|
100
104
|
# set requires grad to false for the scaling factor
|
|
101
105
|
for name, param in self.sae.named_parameters():
|
|
@@ -121,14 +125,17 @@ class SAETrainer:
|
|
|
121
125
|
lr_end=cfg.lr_end,
|
|
122
126
|
num_cycles=cfg.n_restart_cycles,
|
|
123
127
|
)
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
128
|
+
self.coefficient_schedulers = {}
|
|
129
|
+
for name, coeff_cfg in self.sae.get_coefficients().items():
|
|
130
|
+
if not isinstance(coeff_cfg, TrainCoefficientConfig):
|
|
131
|
+
coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
|
|
132
|
+
self.coefficient_schedulers[name] = CoefficientScheduler(
|
|
133
|
+
warm_up_steps=coeff_cfg.warm_up_steps,
|
|
134
|
+
final_value=coeff_cfg.value,
|
|
135
|
+
)
|
|
129
136
|
|
|
130
137
|
# Setup autocast if using
|
|
131
|
-
self.
|
|
138
|
+
self.grad_scaler = torch.amp.GradScaler(
|
|
132
139
|
device=self.cfg.device, enabled=self.cfg.autocast
|
|
133
140
|
)
|
|
134
141
|
|
|
@@ -141,50 +148,36 @@ class SAETrainer:
|
|
|
141
148
|
else:
|
|
142
149
|
self.autocast_if_enabled = contextlib.nullcontext()
|
|
143
150
|
|
|
144
|
-
# Set up eval config
|
|
145
|
-
|
|
146
|
-
self.trainer_eval_config = EvalConfig(
|
|
147
|
-
batch_size_prompts=self.cfg.eval_batch_size_prompts,
|
|
148
|
-
n_eval_reconstruction_batches=self.cfg.n_eval_batches,
|
|
149
|
-
n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
|
|
150
|
-
compute_ce_loss=True,
|
|
151
|
-
compute_l2_norms=True,
|
|
152
|
-
compute_sparsity_metrics=True,
|
|
153
|
-
compute_variance_metrics=True,
|
|
154
|
-
compute_kl=False,
|
|
155
|
-
compute_featurewise_weight_based_metrics=False,
|
|
156
|
-
)
|
|
157
|
-
|
|
158
151
|
@property
|
|
159
152
|
def feature_sparsity(self) -> torch.Tensor:
|
|
160
|
-
return self.act_freq_scores / self.
|
|
153
|
+
return self.act_freq_scores / self.n_frac_active_samples
|
|
161
154
|
|
|
162
155
|
@property
|
|
163
156
|
def log_feature_sparsity(self) -> torch.Tensor:
|
|
164
157
|
return _log_feature_sparsity(self.feature_sparsity)
|
|
165
158
|
|
|
166
|
-
@property
|
|
167
|
-
def current_l1_coefficient(self) -> float:
|
|
168
|
-
return self.l1_scheduler.current_l1_coefficient
|
|
169
|
-
|
|
170
159
|
@property
|
|
171
160
|
def dead_neurons(self) -> torch.Tensor:
|
|
172
161
|
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
|
|
173
162
|
|
|
174
|
-
def fit(self) ->
|
|
175
|
-
pbar = tqdm(total=self.cfg.
|
|
163
|
+
def fit(self) -> T_TRAINING_SAE:
|
|
164
|
+
pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
|
|
176
165
|
|
|
177
|
-
self.
|
|
166
|
+
if self.sae.cfg.normalize_activations == "expected_average_only_in":
|
|
167
|
+
self.activation_scaler.estimate_scaling_factor(
|
|
168
|
+
d_in=self.sae.cfg.d_in,
|
|
169
|
+
data_provider=self.data_provider,
|
|
170
|
+
n_batches_for_norm_estimate=int(1e3),
|
|
171
|
+
)
|
|
178
172
|
|
|
179
173
|
# Train loop
|
|
180
|
-
while self.
|
|
174
|
+
while self.n_training_samples < self.cfg.total_training_samples:
|
|
181
175
|
# Do a training step.
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
)
|
|
185
|
-
self.n_training_tokens += self.cfg.train_batch_size_tokens
|
|
176
|
+
batch = next(self.data_provider).to(self.sae.device)
|
|
177
|
+
self.n_training_samples += batch.shape[0]
|
|
178
|
+
scaled_batch = self.activation_scaler(batch)
|
|
186
179
|
|
|
187
|
-
step_output = self._train_step(sae=self.sae, sae_in=
|
|
180
|
+
step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)
|
|
188
181
|
|
|
189
182
|
if self.cfg.logger.log_to_wandb:
|
|
190
183
|
self._log_train_step(step_output)
|
|
@@ -194,35 +187,56 @@ class SAETrainer:
|
|
|
194
187
|
self.n_training_steps += 1
|
|
195
188
|
self._update_pbar(step_output, pbar)
|
|
196
189
|
|
|
197
|
-
### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
|
|
198
|
-
self._begin_finetuning_if_needed()
|
|
199
|
-
|
|
200
190
|
# fold the estimated norm scaling factor into the sae weights
|
|
201
|
-
if self.
|
|
191
|
+
if self.activation_scaler.scaling_factor is not None:
|
|
202
192
|
self.sae.fold_activation_norm_scaling_factor(
|
|
203
|
-
self.
|
|
193
|
+
self.activation_scaler.scaling_factor
|
|
204
194
|
)
|
|
205
|
-
self.
|
|
195
|
+
self.activation_scaler.scaling_factor = None
|
|
206
196
|
|
|
207
197
|
# save final sae group to checkpoints folder
|
|
208
198
|
self.save_checkpoint(
|
|
209
|
-
|
|
210
|
-
checkpoint_name=f"final_{self.n_training_tokens}",
|
|
199
|
+
checkpoint_name=f"final_{self.n_training_samples}",
|
|
211
200
|
wandb_aliases=["final_model"],
|
|
212
201
|
)
|
|
213
202
|
|
|
214
203
|
pbar.close()
|
|
215
204
|
return self.sae
|
|
216
205
|
|
|
206
|
+
def save_checkpoint(
|
|
207
|
+
self,
|
|
208
|
+
checkpoint_name: str,
|
|
209
|
+
wandb_aliases: list[str] | None = None,
|
|
210
|
+
) -> None:
|
|
211
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
|
|
212
|
+
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
213
|
+
|
|
214
|
+
weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))
|
|
215
|
+
|
|
216
|
+
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
217
|
+
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
218
|
+
|
|
219
|
+
activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
|
|
220
|
+
self.activation_scaler.save(str(activation_scaler_path))
|
|
221
|
+
|
|
222
|
+
if self.cfg.logger.log_to_wandb:
|
|
223
|
+
self.cfg.logger.log(
|
|
224
|
+
self,
|
|
225
|
+
weights_path,
|
|
226
|
+
cfg_path,
|
|
227
|
+
sparsity_path=sparsity_path,
|
|
228
|
+
wandb_aliases=wandb_aliases,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if self.save_checkpoint_fn is not None:
|
|
232
|
+
self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
|
|
233
|
+
|
|
217
234
|
def _train_step(
|
|
218
235
|
self,
|
|
219
|
-
sae:
|
|
236
|
+
sae: T_TRAINING_SAE,
|
|
220
237
|
sae_in: torch.Tensor,
|
|
221
238
|
) -> TrainStepOutput:
|
|
222
239
|
sae.train()
|
|
223
|
-
# Make sure the W_dec is still zero-norm
|
|
224
|
-
if self.cfg.normalize_sae_decoder:
|
|
225
|
-
sae.set_decoder_norm_to_unit_norm()
|
|
226
240
|
|
|
227
241
|
# log and then reset the feature sparsity every feature_sampling_window steps
|
|
228
242
|
if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
|
|
@@ -238,7 +252,7 @@ class SAETrainer:
|
|
|
238
252
|
step_input=TrainStepInput(
|
|
239
253
|
sae_in=sae_in,
|
|
240
254
|
dead_neuron_mask=self.dead_neurons,
|
|
241
|
-
|
|
255
|
+
coefficients=self.get_coefficients(),
|
|
242
256
|
),
|
|
243
257
|
)
|
|
244
258
|
|
|
@@ -249,24 +263,24 @@ class SAETrainer:
|
|
|
249
263
|
self.act_freq_scores += (
|
|
250
264
|
(train_step_output.feature_acts.abs() > 0).float().sum(0)
|
|
251
265
|
)
|
|
252
|
-
self.
|
|
266
|
+
self.n_frac_active_samples += self.cfg.train_batch_size_samples
|
|
253
267
|
|
|
254
|
-
#
|
|
255
|
-
self.
|
|
268
|
+
# Grad scaler will rescale gradients if autocast is enabled
|
|
269
|
+
self.grad_scaler.scale(
|
|
256
270
|
train_step_output.loss
|
|
257
271
|
).backward() # loss.backward() if not autocasting
|
|
258
|
-
self.
|
|
272
|
+
self.grad_scaler.unscale_(self.optimizer) # needed to clip correctly
|
|
259
273
|
# TODO: Work out if grad norm clipping should be in config / how to test it.
|
|
260
274
|
torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
|
|
261
|
-
self.
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
sae.remove_gradient_parallel_to_decoder_directions()
|
|
275
|
+
self.grad_scaler.step(
|
|
276
|
+
self.optimizer
|
|
277
|
+
) # just ctx.optimizer.step() if not autocasting
|
|
278
|
+
self.grad_scaler.update()
|
|
266
279
|
|
|
267
280
|
self.optimizer.zero_grad()
|
|
268
281
|
self.lr_scheduler.step()
|
|
269
|
-
self.
|
|
282
|
+
for scheduler in self.coefficient_schedulers.values():
|
|
283
|
+
scheduler.step()
|
|
270
284
|
|
|
271
285
|
return train_step_output
|
|
272
286
|
|
|
@@ -276,16 +290,23 @@ class SAETrainer:
|
|
|
276
290
|
wandb.log(
|
|
277
291
|
self._build_train_step_log_dict(
|
|
278
292
|
output=step_output,
|
|
279
|
-
|
|
293
|
+
n_training_samples=self.n_training_samples,
|
|
280
294
|
),
|
|
281
295
|
step=self.n_training_steps,
|
|
282
296
|
)
|
|
283
297
|
|
|
298
|
+
@torch.no_grad()
|
|
299
|
+
def get_coefficients(self) -> dict[str, float]:
|
|
300
|
+
return {
|
|
301
|
+
name: scheduler.value
|
|
302
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
303
|
+
}
|
|
304
|
+
|
|
284
305
|
@torch.no_grad()
|
|
285
306
|
def _build_train_step_log_dict(
|
|
286
307
|
self,
|
|
287
308
|
output: TrainStepOutput,
|
|
288
|
-
|
|
309
|
+
n_training_samples: int,
|
|
289
310
|
) -> dict[str, Any]:
|
|
290
311
|
sae_in = output.sae_in
|
|
291
312
|
sae_out = output.sae_out
|
|
@@ -313,19 +334,15 @@ class SAETrainer:
|
|
|
313
334
|
"sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
|
|
314
335
|
"sparsity/dead_features": self.dead_neurons.sum().item(),
|
|
315
336
|
"details/current_learning_rate": current_learning_rate,
|
|
316
|
-
"details/
|
|
317
|
-
|
|
337
|
+
"details/n_training_samples": n_training_samples,
|
|
338
|
+
**{
|
|
339
|
+
f"details/{name}_coefficient": scheduler.value
|
|
340
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
341
|
+
},
|
|
318
342
|
}
|
|
319
343
|
for loss_name, loss_value in output.losses.items():
|
|
320
344
|
loss_item = _unwrap_item(loss_value)
|
|
321
|
-
|
|
322
|
-
if loss_name == "l1_loss":
|
|
323
|
-
log_dict[f"losses/{loss_name}"] = (
|
|
324
|
-
loss_item / self.current_l1_coefficient
|
|
325
|
-
)
|
|
326
|
-
log_dict[f"losses/raw_{loss_name}"] = loss_item
|
|
327
|
-
else:
|
|
328
|
-
log_dict[f"losses/{loss_name}"] = loss_item
|
|
345
|
+
log_dict[f"losses/{loss_name}"] = loss_item
|
|
329
346
|
|
|
330
347
|
return log_dict
|
|
331
348
|
|
|
@@ -337,30 +354,11 @@ class SAETrainer:
|
|
|
337
354
|
* self.cfg.logger.eval_every_n_wandb_logs
|
|
338
355
|
) == 0:
|
|
339
356
|
self.sae.eval()
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
eval_metrics, _ = run_evals(
|
|
346
|
-
sae=self.sae,
|
|
347
|
-
activation_store=self.activations_store,
|
|
348
|
-
model=self.model,
|
|
349
|
-
eval_config=self.trainer_eval_config,
|
|
350
|
-
ignore_tokens=ignore_tokens,
|
|
351
|
-
model_kwargs=self.cfg.model_kwargs,
|
|
352
|
-
) # not calculating featurwise metrics here.
|
|
353
|
-
|
|
354
|
-
# Remove eval metrics that are already logged during training
|
|
355
|
-
eval_metrics.pop("metrics/explained_variance", None)
|
|
356
|
-
eval_metrics.pop("metrics/explained_variance_std", None)
|
|
357
|
-
eval_metrics.pop("metrics/l0", None)
|
|
358
|
-
eval_metrics.pop("metrics/l1", None)
|
|
359
|
-
eval_metrics.pop("metrics/mse", None)
|
|
360
|
-
|
|
361
|
-
# Remove metrics that are not useful for wandb logging
|
|
362
|
-
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
363
|
-
|
|
357
|
+
eval_metrics = (
|
|
358
|
+
self.evaluator(self.sae, self.data_provider, self.activation_scaler)
|
|
359
|
+
if self.evaluator is not None
|
|
360
|
+
else {}
|
|
361
|
+
)
|
|
364
362
|
for key, value in self.sae.log_histograms().items():
|
|
365
363
|
eval_metrics[key] = wandb.Histogram(value) # type: ignore
|
|
366
364
|
|
|
@@ -384,21 +382,18 @@ class SAETrainer:
|
|
|
384
382
|
@torch.no_grad()
|
|
385
383
|
def _reset_running_sparsity_stats(self) -> None:
|
|
386
384
|
self.act_freq_scores = torch.zeros(
|
|
387
|
-
self.cfg.d_sae, # type: ignore
|
|
385
|
+
self.sae.cfg.d_sae, # type: ignore
|
|
388
386
|
device=self.cfg.device,
|
|
389
387
|
)
|
|
390
|
-
self.
|
|
388
|
+
self.n_frac_active_samples = 0
|
|
391
389
|
|
|
392
390
|
@torch.no_grad()
|
|
393
391
|
def _checkpoint_if_needed(self):
|
|
394
392
|
if (
|
|
395
393
|
self.checkpoint_thresholds
|
|
396
|
-
and self.
|
|
394
|
+
and self.n_training_samples > self.checkpoint_thresholds[0]
|
|
397
395
|
):
|
|
398
|
-
self.save_checkpoint(
|
|
399
|
-
trainer=self,
|
|
400
|
-
checkpoint_name=str(self.n_training_tokens),
|
|
401
|
-
)
|
|
396
|
+
self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
|
|
402
397
|
self.checkpoint_thresholds.pop(0)
|
|
403
398
|
|
|
404
399
|
@torch.no_grad()
|
|
@@ -414,26 +409,7 @@ class SAETrainer:
|
|
|
414
409
|
for loss_name, loss_value in step_output.losses.items()
|
|
415
410
|
)
|
|
416
411
|
pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
|
|
417
|
-
pbar.update(update_interval * self.cfg.
|
|
418
|
-
|
|
419
|
-
def _begin_finetuning_if_needed(self):
|
|
420
|
-
if (not self.started_fine_tuning) and (
|
|
421
|
-
self.n_training_tokens > self.cfg.training_tokens
|
|
422
|
-
):
|
|
423
|
-
self.started_fine_tuning = True
|
|
424
|
-
|
|
425
|
-
# finetuning method should be set in the config
|
|
426
|
-
# if not, then we don't finetune
|
|
427
|
-
if not isinstance(self.cfg.finetuning_method, str):
|
|
428
|
-
return
|
|
429
|
-
|
|
430
|
-
for name, param in self.sae.named_parameters():
|
|
431
|
-
if name in FINETUNING_PARAMETERS[self.cfg.finetuning_method]:
|
|
432
|
-
param.requires_grad = True
|
|
433
|
-
else:
|
|
434
|
-
param.requires_grad = False
|
|
435
|
-
|
|
436
|
-
self.finetuning = True
|
|
412
|
+
pbar.update(update_interval * self.cfg.train_batch_size_samples)
|
|
437
413
|
|
|
438
414
|
|
|
439
415
|
def _unwrap_item(item: float | torch.Tensor) -> float:
|
|
@@ -2,14 +2,15 @@ import io
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from tempfile import TemporaryDirectory
|
|
4
4
|
from textwrap import dedent
|
|
5
|
-
from typing import Iterable
|
|
5
|
+
from typing import Any, Iterable
|
|
6
6
|
|
|
7
7
|
from huggingface_hub import HfApi, create_repo, get_hf_file_metadata, hf_hub_url
|
|
8
8
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
|
9
9
|
from tqdm.autonotebook import tqdm
|
|
10
10
|
|
|
11
11
|
from sae_lens import logger
|
|
12
|
-
from sae_lens.
|
|
12
|
+
from sae_lens.constants import (
|
|
13
|
+
RUNNER_CFG_FILENAME,
|
|
13
14
|
SAE_CFG_FILENAME,
|
|
14
15
|
SAE_WEIGHTS_FILENAME,
|
|
15
16
|
SPARSITY_FILENAME,
|
|
@@ -18,7 +19,7 @@ from sae_lens.saes.sae import SAE
|
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def upload_saes_to_huggingface(
|
|
21
|
-
saes_dict: dict[str, SAE | Path | str],
|
|
22
|
+
saes_dict: dict[str, SAE[Any] | Path | str],
|
|
22
23
|
hf_repo_id: str,
|
|
23
24
|
hf_revision: str = "main",
|
|
24
25
|
show_progress: bool = True,
|
|
@@ -119,11 +120,16 @@ def _upload_sae(api: HfApi, sae_path: Path, repo_id: str, sae_id: str, revision:
|
|
|
119
120
|
revision=revision,
|
|
120
121
|
repo_type="model",
|
|
121
122
|
commit_message=f"Upload SAE {sae_id}",
|
|
122
|
-
allow_patterns=[
|
|
123
|
+
allow_patterns=[
|
|
124
|
+
SAE_CFG_FILENAME,
|
|
125
|
+
SAE_WEIGHTS_FILENAME,
|
|
126
|
+
SPARSITY_FILENAME,
|
|
127
|
+
RUNNER_CFG_FILENAME,
|
|
128
|
+
],
|
|
123
129
|
)
|
|
124
130
|
|
|
125
131
|
|
|
126
|
-
def _build_sae_path(sae_ref: SAE | Path | str, tmp_dir: str) -> Path:
|
|
132
|
+
def _build_sae_path(sae_ref: SAE[Any] | Path | str, tmp_dir: str) -> Path:
|
|
127
133
|
if isinstance(sae_ref, SAE):
|
|
128
134
|
sae_ref.save_model(tmp_dir)
|
|
129
135
|
return Path(tmp_dir)
|