sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +56 -6
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +2 -1
- sae_lens/config.py +121 -252
- sae_lens/constants.py +18 -0
- sae_lens/evals.py +32 -17
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +68 -36
- sae_lens/pretrained_saes.yaml +0 -12
- sae_lens/registry.py +49 -0
- sae_lens/sae_training_runner.py +40 -54
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +258 -0
- sae_lens/saes/jumprelu_sae.py +354 -0
- sae_lens/saes/sae.py +948 -0
- sae_lens/saes/standard_sae.py +185 -0
- sae_lens/saes/topk_sae.py +294 -0
- sae_lens/training/activations_store.py +32 -16
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +55 -86
- sae_lens/training/upload_saes_to_huggingface.py +12 -6
- sae_lens/util.py +28 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc2.dist-info/RECORD +35 -0
- sae_lens/sae.py +0 -747
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.10.3.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/WHEEL +0 -0
sae_lens/training/optim.py
CHANGED
|
@@ -101,61 +101,85 @@ def _get_main_lr_scheduler(
|
|
|
101
101
|
raise ValueError(f"Unsupported scheduler: {scheduler_name}")
|
|
102
102
|
|
|
103
103
|
|
|
104
|
-
class
|
|
104
|
+
class CoefficientScheduler:
|
|
105
|
+
"""Linearly warms up a scalar value from 0.0 to a final value."""
|
|
106
|
+
|
|
105
107
|
def __init__(
|
|
106
108
|
self,
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
final_l1_coefficient: float,
|
|
109
|
+
warm_up_steps: float,
|
|
110
|
+
final_value: float,
|
|
110
111
|
):
|
|
111
|
-
self.
|
|
112
|
-
|
|
113
|
-
if self.l1_warmup_steps != 0:
|
|
114
|
-
self.current_l1_coefficient = 0.0
|
|
115
|
-
else:
|
|
116
|
-
self.current_l1_coefficient = final_l1_coefficient
|
|
117
|
-
|
|
118
|
-
self.final_l1_coefficient = final_l1_coefficient
|
|
119
|
-
|
|
112
|
+
self.warm_up_steps = warm_up_steps
|
|
113
|
+
self.final_value = final_value
|
|
120
114
|
self.current_step = 0
|
|
121
|
-
|
|
122
|
-
if not isinstance(self.
|
|
115
|
+
|
|
116
|
+
if not isinstance(self.final_value, (float, int)):
|
|
123
117
|
raise TypeError(
|
|
124
|
-
f"
|
|
118
|
+
f"final_value must be float or int, got {type(self.final_value)}."
|
|
125
119
|
)
|
|
126
120
|
|
|
121
|
+
# Initialize current_value based on whether warm-up is used
|
|
122
|
+
if self.warm_up_steps > 0:
|
|
123
|
+
self.current_value = 0.0
|
|
124
|
+
else:
|
|
125
|
+
self.current_value = self.final_value
|
|
126
|
+
|
|
127
127
|
def __repr__(self) -> str:
|
|
128
128
|
return (
|
|
129
|
-
f"
|
|
130
|
-
f"
|
|
131
|
-
f"total_steps={self.total_steps})"
|
|
129
|
+
f"{self.__class__.__name__}(final_value={self.final_value}, "
|
|
130
|
+
f"warm_up_steps={self.warm_up_steps})"
|
|
132
131
|
)
|
|
133
132
|
|
|
134
|
-
def step(self):
|
|
133
|
+
def step(self) -> float:
|
|
135
134
|
"""
|
|
136
|
-
Updates the
|
|
135
|
+
Updates the scalar value based on the current step.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The current scalar value after the step.
|
|
137
139
|
"""
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
) # type: ignore
|
|
140
|
+
if self.current_step < self.warm_up_steps:
|
|
141
|
+
self.current_value = self.final_value * (
|
|
142
|
+
(self.current_step + 1) / self.warm_up_steps
|
|
143
|
+
)
|
|
143
144
|
else:
|
|
144
|
-
|
|
145
|
+
# Ensure the value stays at final_value after warm-up
|
|
146
|
+
self.current_value = self.final_value
|
|
145
147
|
|
|
146
148
|
self.current_step += 1
|
|
149
|
+
return self.current_value
|
|
147
150
|
|
|
148
|
-
|
|
149
|
-
|
|
151
|
+
@property
|
|
152
|
+
def value(self) -> float:
|
|
153
|
+
"""Returns the current scalar value."""
|
|
154
|
+
return self.current_value
|
|
155
|
+
|
|
156
|
+
def state_dict(self) -> dict[str, Any]:
|
|
157
|
+
"""State dict for serialization."""
|
|
150
158
|
return {
|
|
151
|
-
"
|
|
152
|
-
"
|
|
153
|
-
"current_l1_coefficient": self.current_l1_coefficient,
|
|
154
|
-
"final_l1_coefficient": self.final_l1_coefficient,
|
|
159
|
+
"warm_up_steps": self.warm_up_steps,
|
|
160
|
+
"final_value": self.final_value,
|
|
155
161
|
"current_step": self.current_step,
|
|
162
|
+
"current_value": self.current_value,
|
|
156
163
|
}
|
|
157
164
|
|
|
158
165
|
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
159
|
-
"""Loads
|
|
160
|
-
|
|
161
|
-
|
|
166
|
+
"""Loads the scheduler state."""
|
|
167
|
+
self.warm_up_steps = state_dict["warm_up_steps"]
|
|
168
|
+
self.final_value = state_dict["final_value"]
|
|
169
|
+
self.current_step = state_dict["current_step"]
|
|
170
|
+
# Maintain consistency: re-calculate current_value based on loaded step
|
|
171
|
+
# This handles resuming correctly if stopped mid-warmup.
|
|
172
|
+
if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
|
|
173
|
+
# Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
|
|
174
|
+
step_for_calc = max(0, self.current_step)
|
|
175
|
+
# Recalculate based on the step *before* the one about to be taken
|
|
176
|
+
# Or simply use the saved current_value if available and consistent
|
|
177
|
+
if "current_value" in state_dict:
|
|
178
|
+
self.current_value = state_dict["current_value"]
|
|
179
|
+
else: # Legacy state dicts might not have current_value
|
|
180
|
+
self.current_value = self.final_value * (
|
|
181
|
+
step_for_calc / self.warm_up_steps
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
else:
|
|
185
|
+
self.current_value = self.final_value
|
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any, Protocol, cast
|
|
3
|
+
from typing import Any, Generic, Protocol, cast
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import wandb
|
|
@@ -11,16 +11,16 @@ from transformer_lens.hook_points import HookedRootModule
|
|
|
11
11
|
from sae_lens import __version__
|
|
12
12
|
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
13
13
|
from sae_lens.evals import EvalConfig, run_evals
|
|
14
|
+
from sae_lens.saes.sae import (
|
|
15
|
+
T_TRAINING_SAE,
|
|
16
|
+
T_TRAINING_SAE_CONFIG,
|
|
17
|
+
TrainCoefficientConfig,
|
|
18
|
+
TrainingSAE,
|
|
19
|
+
TrainStepInput,
|
|
20
|
+
TrainStepOutput,
|
|
21
|
+
)
|
|
14
22
|
from sae_lens.training.activations_store import ActivationsStore
|
|
15
|
-
from sae_lens.training.optim import
|
|
16
|
-
from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput
|
|
17
|
-
|
|
18
|
-
# used to map between parameters which are updated during finetuning and the config str.
|
|
19
|
-
FINETUNING_PARAMETERS = {
|
|
20
|
-
"scale": ["scaling_factor"],
|
|
21
|
-
"decoder": ["scaling_factor", "W_dec", "b_dec"],
|
|
22
|
-
"unrotated_decoder": ["scaling_factor", "b_dec"],
|
|
23
|
-
}
|
|
23
|
+
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def _log_feature_sparsity(
|
|
@@ -29,7 +29,7 @@ def _log_feature_sparsity(
|
|
|
29
29
|
return torch.log10(feature_sparsity + eps).detach().cpu()
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
|
|
32
|
+
def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
|
|
33
33
|
"""
|
|
34
34
|
Make sure we record the version of SAELens used for the training run
|
|
35
35
|
"""
|
|
@@ -38,7 +38,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
|
|
|
38
38
|
|
|
39
39
|
@dataclass
|
|
40
40
|
class TrainSAEOutput:
|
|
41
|
-
sae: TrainingSAE
|
|
41
|
+
sae: TrainingSAE[Any]
|
|
42
42
|
checkpoint_path: str
|
|
43
43
|
log_feature_sparsities: torch.Tensor
|
|
44
44
|
|
|
@@ -46,13 +46,13 @@ class TrainSAEOutput:
|
|
|
46
46
|
class SaveCheckpointFn(Protocol):
|
|
47
47
|
def __call__(
|
|
48
48
|
self,
|
|
49
|
-
trainer: "SAETrainer",
|
|
49
|
+
trainer: "SAETrainer[Any, Any]",
|
|
50
50
|
checkpoint_name: str,
|
|
51
51
|
wandb_aliases: list[str] | None = None,
|
|
52
52
|
) -> None: ...
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
class SAETrainer:
|
|
55
|
+
class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
56
56
|
"""
|
|
57
57
|
Core SAE class used for inference. For training, see TrainingSAE.
|
|
58
58
|
"""
|
|
@@ -60,10 +60,10 @@ class SAETrainer:
|
|
|
60
60
|
def __init__(
|
|
61
61
|
self,
|
|
62
62
|
model: HookedRootModule,
|
|
63
|
-
sae:
|
|
63
|
+
sae: T_TRAINING_SAE,
|
|
64
64
|
activation_store: ActivationsStore,
|
|
65
65
|
save_checkpoint_fn: SaveCheckpointFn,
|
|
66
|
-
cfg: LanguageModelSAERunnerConfig,
|
|
66
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
|
|
67
67
|
) -> None:
|
|
68
68
|
self.model = model
|
|
69
69
|
self.sae = sae
|
|
@@ -88,11 +88,11 @@ class SAETrainer:
|
|
|
88
88
|
)[1:]
|
|
89
89
|
|
|
90
90
|
self.act_freq_scores = torch.zeros(
|
|
91
|
-
cast(int, cfg.d_sae),
|
|
91
|
+
cast(int, cfg.sae.d_sae),
|
|
92
92
|
device=cfg.device,
|
|
93
93
|
)
|
|
94
94
|
self.n_forward_passes_since_fired = torch.zeros(
|
|
95
|
-
cast(int, cfg.d_sae),
|
|
95
|
+
cast(int, cfg.sae.d_sae),
|
|
96
96
|
device=cfg.device,
|
|
97
97
|
)
|
|
98
98
|
self.n_frac_active_tokens = 0
|
|
@@ -121,11 +121,14 @@ class SAETrainer:
|
|
|
121
121
|
lr_end=cfg.lr_end,
|
|
122
122
|
num_cycles=cfg.n_restart_cycles,
|
|
123
123
|
)
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
124
|
+
self.coefficient_schedulers = {}
|
|
125
|
+
for name, coeff_cfg in self.sae.get_coefficients().items():
|
|
126
|
+
if not isinstance(coeff_cfg, TrainCoefficientConfig):
|
|
127
|
+
coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
|
|
128
|
+
self.coefficient_schedulers[name] = CoefficientScheduler(
|
|
129
|
+
warm_up_steps=coeff_cfg.warm_up_steps,
|
|
130
|
+
final_value=coeff_cfg.value,
|
|
131
|
+
)
|
|
129
132
|
|
|
130
133
|
# Setup autocast if using
|
|
131
134
|
self.scaler = torch.amp.GradScaler(
|
|
@@ -163,15 +166,11 @@ class SAETrainer:
|
|
|
163
166
|
def log_feature_sparsity(self) -> torch.Tensor:
|
|
164
167
|
return _log_feature_sparsity(self.feature_sparsity)
|
|
165
168
|
|
|
166
|
-
@property
|
|
167
|
-
def current_l1_coefficient(self) -> float:
|
|
168
|
-
return self.l1_scheduler.current_l1_coefficient
|
|
169
|
-
|
|
170
169
|
@property
|
|
171
170
|
def dead_neurons(self) -> torch.Tensor:
|
|
172
171
|
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
|
|
173
172
|
|
|
174
|
-
def fit(self) ->
|
|
173
|
+
def fit(self) -> T_TRAINING_SAE:
|
|
175
174
|
pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
|
|
176
175
|
|
|
177
176
|
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
@@ -186,7 +185,7 @@ class SAETrainer:
|
|
|
186
185
|
|
|
187
186
|
step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
|
|
188
187
|
|
|
189
|
-
if self.cfg.log_to_wandb:
|
|
188
|
+
if self.cfg.logger.log_to_wandb:
|
|
190
189
|
self._log_train_step(step_output)
|
|
191
190
|
self._run_and_log_evals()
|
|
192
191
|
|
|
@@ -194,9 +193,6 @@ class SAETrainer:
|
|
|
194
193
|
self.n_training_steps += 1
|
|
195
194
|
self._update_pbar(step_output, pbar)
|
|
196
195
|
|
|
197
|
-
### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
|
|
198
|
-
self._begin_finetuning_if_needed()
|
|
199
|
-
|
|
200
196
|
# fold the estimated norm scaling factor into the sae weights
|
|
201
197
|
if self.activations_store.estimated_norm_scaling_factor is not None:
|
|
202
198
|
self.sae.fold_activation_norm_scaling_factor(
|
|
@@ -216,17 +212,14 @@ class SAETrainer:
|
|
|
216
212
|
|
|
217
213
|
def _train_step(
|
|
218
214
|
self,
|
|
219
|
-
sae:
|
|
215
|
+
sae: T_TRAINING_SAE,
|
|
220
216
|
sae_in: torch.Tensor,
|
|
221
217
|
) -> TrainStepOutput:
|
|
222
218
|
sae.train()
|
|
223
|
-
# Make sure the W_dec is still zero-norm
|
|
224
|
-
if self.cfg.normalize_sae_decoder:
|
|
225
|
-
sae.set_decoder_norm_to_unit_norm()
|
|
226
219
|
|
|
227
220
|
# log and then reset the feature sparsity every feature_sampling_window steps
|
|
228
221
|
if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
|
|
229
|
-
if self.cfg.log_to_wandb:
|
|
222
|
+
if self.cfg.logger.log_to_wandb:
|
|
230
223
|
sparsity_log_dict = self._build_sparsity_log_dict()
|
|
231
224
|
wandb.log(sparsity_log_dict, step=self.n_training_steps)
|
|
232
225
|
self._reset_running_sparsity_stats()
|
|
@@ -235,9 +228,11 @@ class SAETrainer:
|
|
|
235
228
|
# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
|
|
236
229
|
with self.autocast_if_enabled:
|
|
237
230
|
train_step_output = self.sae.training_forward_pass(
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
231
|
+
step_input=TrainStepInput(
|
|
232
|
+
sae_in=sae_in,
|
|
233
|
+
dead_neuron_mask=self.dead_neurons,
|
|
234
|
+
coefficients=self.get_coefficients(),
|
|
235
|
+
),
|
|
241
236
|
)
|
|
242
237
|
|
|
243
238
|
with torch.no_grad():
|
|
@@ -259,18 +254,16 @@ class SAETrainer:
|
|
|
259
254
|
self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting
|
|
260
255
|
self.scaler.update()
|
|
261
256
|
|
|
262
|
-
if self.cfg.normalize_sae_decoder:
|
|
263
|
-
sae.remove_gradient_parallel_to_decoder_directions()
|
|
264
|
-
|
|
265
257
|
self.optimizer.zero_grad()
|
|
266
258
|
self.lr_scheduler.step()
|
|
267
|
-
self.
|
|
259
|
+
for scheduler in self.coefficient_schedulers.values():
|
|
260
|
+
scheduler.step()
|
|
268
261
|
|
|
269
262
|
return train_step_output
|
|
270
263
|
|
|
271
264
|
@torch.no_grad()
|
|
272
265
|
def _log_train_step(self, step_output: TrainStepOutput):
|
|
273
|
-
if (self.n_training_steps + 1) % self.cfg.wandb_log_frequency == 0:
|
|
266
|
+
if (self.n_training_steps + 1) % self.cfg.logger.wandb_log_frequency == 0:
|
|
274
267
|
wandb.log(
|
|
275
268
|
self._build_train_step_log_dict(
|
|
276
269
|
output=step_output,
|
|
@@ -279,6 +272,13 @@ class SAETrainer:
|
|
|
279
272
|
step=self.n_training_steps,
|
|
280
273
|
)
|
|
281
274
|
|
|
275
|
+
@torch.no_grad()
|
|
276
|
+
def get_coefficients(self) -> dict[str, float]:
|
|
277
|
+
return {
|
|
278
|
+
name: scheduler.value
|
|
279
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
280
|
+
}
|
|
281
|
+
|
|
282
282
|
@torch.no_grad()
|
|
283
283
|
def _build_train_step_log_dict(
|
|
284
284
|
self,
|
|
@@ -311,19 +311,15 @@ class SAETrainer:
|
|
|
311
311
|
"sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
|
|
312
312
|
"sparsity/dead_features": self.dead_neurons.sum().item(),
|
|
313
313
|
"details/current_learning_rate": current_learning_rate,
|
|
314
|
-
"details/current_l1_coefficient": self.current_l1_coefficient,
|
|
315
314
|
"details/n_training_tokens": n_training_tokens,
|
|
315
|
+
**{
|
|
316
|
+
f"details/{name}_coefficient": scheduler.value
|
|
317
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
318
|
+
},
|
|
316
319
|
}
|
|
317
320
|
for loss_name, loss_value in output.losses.items():
|
|
318
321
|
loss_item = _unwrap_item(loss_value)
|
|
319
|
-
|
|
320
|
-
if loss_name == "l1_loss":
|
|
321
|
-
log_dict[f"losses/{loss_name}"] = (
|
|
322
|
-
loss_item / self.current_l1_coefficient
|
|
323
|
-
)
|
|
324
|
-
log_dict[f"losses/raw_{loss_name}"] = loss_item
|
|
325
|
-
else:
|
|
326
|
-
log_dict[f"losses/{loss_name}"] = loss_item
|
|
322
|
+
log_dict[f"losses/{loss_name}"] = loss_item
|
|
327
323
|
|
|
328
324
|
return log_dict
|
|
329
325
|
|
|
@@ -331,7 +327,8 @@ class SAETrainer:
|
|
|
331
327
|
def _run_and_log_evals(self):
|
|
332
328
|
# record loss frequently, but not all the time.
|
|
333
329
|
if (self.n_training_steps + 1) % (
|
|
334
|
-
self.cfg.wandb_log_frequency
|
|
330
|
+
self.cfg.logger.wandb_log_frequency
|
|
331
|
+
* self.cfg.logger.eval_every_n_wandb_logs
|
|
335
332
|
) == 0:
|
|
336
333
|
self.sae.eval()
|
|
337
334
|
ignore_tokens = set()
|
|
@@ -358,17 +355,8 @@ class SAETrainer:
|
|
|
358
355
|
# Remove metrics that are not useful for wandb logging
|
|
359
356
|
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
360
357
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
if self.sae.cfg.architecture == "standard":
|
|
365
|
-
b_e_dist = self.sae.b_enc.detach().float().cpu().numpy()
|
|
366
|
-
eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore
|
|
367
|
-
elif self.sae.cfg.architecture == "gated":
|
|
368
|
-
b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy()
|
|
369
|
-
eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore
|
|
370
|
-
b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy()
|
|
371
|
-
eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore
|
|
358
|
+
for key, value in self.sae.log_histograms().items():
|
|
359
|
+
eval_metrics[key] = wandb.Histogram(value) # type: ignore
|
|
372
360
|
|
|
373
361
|
wandb.log(
|
|
374
362
|
eval_metrics,
|
|
@@ -390,7 +378,7 @@ class SAETrainer:
|
|
|
390
378
|
@torch.no_grad()
|
|
391
379
|
def _reset_running_sparsity_stats(self) -> None:
|
|
392
380
|
self.act_freq_scores = torch.zeros(
|
|
393
|
-
self.cfg.d_sae, # type: ignore
|
|
381
|
+
self.cfg.sae.d_sae, # type: ignore
|
|
394
382
|
device=self.cfg.device,
|
|
395
383
|
)
|
|
396
384
|
self.n_frac_active_tokens = 0
|
|
@@ -422,25 +410,6 @@ class SAETrainer:
|
|
|
422
410
|
pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
|
|
423
411
|
pbar.update(update_interval * self.cfg.train_batch_size_tokens)
|
|
424
412
|
|
|
425
|
-
def _begin_finetuning_if_needed(self):
|
|
426
|
-
if (not self.started_fine_tuning) and (
|
|
427
|
-
self.n_training_tokens > self.cfg.training_tokens
|
|
428
|
-
):
|
|
429
|
-
self.started_fine_tuning = True
|
|
430
|
-
|
|
431
|
-
# finetuning method should be set in the config
|
|
432
|
-
# if not, then we don't finetune
|
|
433
|
-
if not isinstance(self.cfg.finetuning_method, str):
|
|
434
|
-
return
|
|
435
|
-
|
|
436
|
-
for name, param in self.sae.named_parameters():
|
|
437
|
-
if name in FINETUNING_PARAMETERS[self.cfg.finetuning_method]:
|
|
438
|
-
param.requires_grad = True
|
|
439
|
-
else:
|
|
440
|
-
param.requires_grad = False
|
|
441
|
-
|
|
442
|
-
self.finetuning = True
|
|
443
|
-
|
|
444
413
|
|
|
445
414
|
def _unwrap_item(item: float | torch.Tensor) -> float:
|
|
446
415
|
return item.item() if isinstance(item, torch.Tensor) else item
|
|
@@ -2,23 +2,24 @@ import io
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from tempfile import TemporaryDirectory
|
|
4
4
|
from textwrap import dedent
|
|
5
|
-
from typing import Iterable
|
|
5
|
+
from typing import Any, Iterable
|
|
6
6
|
|
|
7
7
|
from huggingface_hub import HfApi, create_repo, get_hf_file_metadata, hf_hub_url
|
|
8
8
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
|
9
9
|
from tqdm.autonotebook import tqdm
|
|
10
10
|
|
|
11
11
|
from sae_lens import logger
|
|
12
|
-
from sae_lens.
|
|
12
|
+
from sae_lens.constants import (
|
|
13
|
+
RUNNER_CFG_FILENAME,
|
|
13
14
|
SAE_CFG_FILENAME,
|
|
14
15
|
SAE_WEIGHTS_FILENAME,
|
|
15
16
|
SPARSITY_FILENAME,
|
|
16
17
|
)
|
|
17
|
-
from sae_lens.sae import SAE
|
|
18
|
+
from sae_lens.saes.sae import SAE
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def upload_saes_to_huggingface(
|
|
21
|
-
saes_dict: dict[str, SAE | Path | str],
|
|
22
|
+
saes_dict: dict[str, SAE[Any] | Path | str],
|
|
22
23
|
hf_repo_id: str,
|
|
23
24
|
hf_revision: str = "main",
|
|
24
25
|
show_progress: bool = True,
|
|
@@ -119,11 +120,16 @@ def _upload_sae(api: HfApi, sae_path: Path, repo_id: str, sae_id: str, revision:
|
|
|
119
120
|
revision=revision,
|
|
120
121
|
repo_type="model",
|
|
121
122
|
commit_message=f"Upload SAE {sae_id}",
|
|
122
|
-
allow_patterns=[
|
|
123
|
+
allow_patterns=[
|
|
124
|
+
SAE_CFG_FILENAME,
|
|
125
|
+
SAE_WEIGHTS_FILENAME,
|
|
126
|
+
SPARSITY_FILENAME,
|
|
127
|
+
RUNNER_CFG_FILENAME,
|
|
128
|
+
],
|
|
123
129
|
)
|
|
124
130
|
|
|
125
131
|
|
|
126
|
-
def _build_sae_path(sae_ref: SAE | Path | str, tmp_dir: str) -> Path:
|
|
132
|
+
def _build_sae_path(sae_ref: SAE[Any] | Path | str, tmp_dir: str) -> Path:
|
|
127
133
|
if isinstance(sae_ref, SAE):
|
|
128
134
|
sae_ref.save_model(tmp_dir)
|
|
129
135
|
return Path(tmp_dir)
|
sae_lens/util.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import asdict, fields, is_dataclass
|
|
2
|
+
from typing import Sequence, TypeVar
|
|
3
|
+
|
|
4
|
+
K = TypeVar("K")
|
|
5
|
+
V = TypeVar("V")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def filter_valid_dataclass_fields(
|
|
9
|
+
source: dict[str, V] | object,
|
|
10
|
+
destination: object | type,
|
|
11
|
+
whitelist_fields: Sequence[str] | None = None,
|
|
12
|
+
) -> dict[str, V]:
|
|
13
|
+
"""Filter a source dict or dataclass instance to only include fields that are present in the destination dataclass."""
|
|
14
|
+
|
|
15
|
+
if not is_dataclass(destination):
|
|
16
|
+
raise ValueError(f"{destination} is not a dataclass")
|
|
17
|
+
|
|
18
|
+
if is_dataclass(source) and not isinstance(source, type):
|
|
19
|
+
source_dict = asdict(source)
|
|
20
|
+
elif isinstance(source, dict):
|
|
21
|
+
source_dict = source
|
|
22
|
+
else:
|
|
23
|
+
raise ValueError(f"{source} is not a dict or dataclass")
|
|
24
|
+
|
|
25
|
+
valid_field_names = {field.name for field in fields(destination)}
|
|
26
|
+
if whitelist_fields is not None:
|
|
27
|
+
valid_field_names = valid_field_names.union(whitelist_fields)
|
|
28
|
+
return {key: val for key, val in source_dict.items() if key in valid_field_names}
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
sae_lens/__init__.py,sha256=JZATcdlWGVOXYTHb41hn7dPp7pR2tWgpLAz2ztQOE-A,2747
|
|
2
|
+
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
|
+
sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=27jp2hFxZj4foWCRCJJd2VCwYJtMgkvPx6MuIhQBofc,12591
|
|
6
|
+
sae_lens/config.py,sha256=Ff6MRzRlVk8xtgkvHdJEmuPh9Owc10XIWBaUwdypzkU,26062
|
|
7
|
+
sae_lens/constants.py,sha256=HSiSp0j2Umak2buT30seFhkmj7KNuPmB3u4yLXrgfOg,462
|
|
8
|
+
sae_lens/evals.py,sha256=aR0pJMBWBUdZElXPcxUyNnNYWbM2LC5UeaESKAwdOMY,39098
|
|
9
|
+
sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
|
|
10
|
+
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=IgQ-XSJ5VTLCzmJavPmk1vExBVB-36wW7w-ZNo7tzPY,31214
|
|
12
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
13
|
+
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
14
|
+
sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
|
|
15
|
+
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
16
|
+
sae_lens/sae_training_runner.py,sha256=lI_d3ywS312dIz0wctm_Sgt3W9ffBOS7ahnDXBljX1s,8320
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
|
|
18
|
+
sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
|
|
19
|
+
sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
|
|
20
|
+
sae_lens/saes/sae.py,sha256=edJK3VFzOVBPXUX6QJ5fhhoY0wcfEisDmVXiqFRA7Xg,35089
|
|
21
|
+
sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
|
|
22
|
+
sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
|
|
23
|
+
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
24
|
+
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
+
sae_lens/training/activations_store.py,sha256=5V5dExeXWoE0dw-ePOZVnQIbBJwrepRMdsQrRam9Lg8,36790
|
|
26
|
+
sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
|
|
27
|
+
sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
|
|
28
|
+
sae_lens/training/sae_trainer.py,sha256=zYAk_9QJ8AJi2TjDZ1qW_lyoovSBqrJvBHzyYgb89ZY,15251
|
|
29
|
+
sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
|
|
30
|
+
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
31
|
+
sae_lens/util.py,sha256=4lqtl7HT9OiyRK8fe8nXtkcn2lOR1uX7ANrAClf6Bv8,1026
|
|
32
|
+
sae_lens-6.0.0rc2.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
33
|
+
sae_lens-6.0.0rc2.dist-info/METADATA,sha256=Z8Zwb6EknAPB5dOvfduYZewr4nldot-1dQoqz50Co3k,5326
|
|
34
|
+
sae_lens-6.0.0rc2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
35
|
+
sae_lens-6.0.0rc2.dist-info/RECORD,,
|