sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/training/optim.py
CHANGED
|
@@ -101,61 +101,85 @@ def _get_main_lr_scheduler(
|
|
|
101
101
|
raise ValueError(f"Unsupported scheduler: {scheduler_name}")
|
|
102
102
|
|
|
103
103
|
|
|
104
|
-
class
|
|
104
|
+
class CoefficientScheduler:
|
|
105
|
+
"""Linearly warms up a scalar value from 0.0 to a final value."""
|
|
106
|
+
|
|
105
107
|
def __init__(
|
|
106
108
|
self,
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
final_l1_coefficient: float,
|
|
109
|
+
warm_up_steps: float,
|
|
110
|
+
final_value: float,
|
|
110
111
|
):
|
|
111
|
-
self.
|
|
112
|
-
|
|
113
|
-
if self.l1_warmup_steps != 0:
|
|
114
|
-
self.current_l1_coefficient = 0.0
|
|
115
|
-
else:
|
|
116
|
-
self.current_l1_coefficient = final_l1_coefficient
|
|
117
|
-
|
|
118
|
-
self.final_l1_coefficient = final_l1_coefficient
|
|
119
|
-
|
|
112
|
+
self.warm_up_steps = warm_up_steps
|
|
113
|
+
self.final_value = final_value
|
|
120
114
|
self.current_step = 0
|
|
121
|
-
|
|
122
|
-
if not isinstance(self.
|
|
115
|
+
|
|
116
|
+
if not isinstance(self.final_value, (float, int)):
|
|
123
117
|
raise TypeError(
|
|
124
|
-
f"
|
|
118
|
+
f"final_value must be float or int, got {type(self.final_value)}."
|
|
125
119
|
)
|
|
126
120
|
|
|
121
|
+
# Initialize current_value based on whether warm-up is used
|
|
122
|
+
if self.warm_up_steps > 0:
|
|
123
|
+
self.current_value = 0.0
|
|
124
|
+
else:
|
|
125
|
+
self.current_value = self.final_value
|
|
126
|
+
|
|
127
127
|
def __repr__(self) -> str:
|
|
128
128
|
return (
|
|
129
|
-
f"
|
|
130
|
-
f"
|
|
131
|
-
f"total_steps={self.total_steps})"
|
|
129
|
+
f"{self.__class__.__name__}(final_value={self.final_value}, "
|
|
130
|
+
f"warm_up_steps={self.warm_up_steps})"
|
|
132
131
|
)
|
|
133
132
|
|
|
134
|
-
def step(self):
|
|
133
|
+
def step(self) -> float:
|
|
135
134
|
"""
|
|
136
|
-
Updates the
|
|
135
|
+
Updates the scalar value based on the current step.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The current scalar value after the step.
|
|
137
139
|
"""
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
) # type: ignore
|
|
140
|
+
if self.current_step < self.warm_up_steps:
|
|
141
|
+
self.current_value = self.final_value * (
|
|
142
|
+
(self.current_step + 1) / self.warm_up_steps
|
|
143
|
+
)
|
|
143
144
|
else:
|
|
144
|
-
|
|
145
|
+
# Ensure the value stays at final_value after warm-up
|
|
146
|
+
self.current_value = self.final_value
|
|
145
147
|
|
|
146
148
|
self.current_step += 1
|
|
149
|
+
return self.current_value
|
|
147
150
|
|
|
148
|
-
|
|
149
|
-
|
|
151
|
+
@property
|
|
152
|
+
def value(self) -> float:
|
|
153
|
+
"""Returns the current scalar value."""
|
|
154
|
+
return self.current_value
|
|
155
|
+
|
|
156
|
+
def state_dict(self) -> dict[str, Any]:
|
|
157
|
+
"""State dict for serialization."""
|
|
150
158
|
return {
|
|
151
|
-
"
|
|
152
|
-
"
|
|
153
|
-
"current_l1_coefficient": self.current_l1_coefficient,
|
|
154
|
-
"final_l1_coefficient": self.final_l1_coefficient,
|
|
159
|
+
"warm_up_steps": self.warm_up_steps,
|
|
160
|
+
"final_value": self.final_value,
|
|
155
161
|
"current_step": self.current_step,
|
|
162
|
+
"current_value": self.current_value,
|
|
156
163
|
}
|
|
157
164
|
|
|
158
165
|
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
159
|
-
"""Loads
|
|
160
|
-
|
|
161
|
-
|
|
166
|
+
"""Loads the scheduler state."""
|
|
167
|
+
self.warm_up_steps = state_dict["warm_up_steps"]
|
|
168
|
+
self.final_value = state_dict["final_value"]
|
|
169
|
+
self.current_step = state_dict["current_step"]
|
|
170
|
+
# Maintain consistency: re-calculate current_value based on loaded step
|
|
171
|
+
# This handles resuming correctly if stopped mid-warmup.
|
|
172
|
+
if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
|
|
173
|
+
# Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
|
|
174
|
+
step_for_calc = max(0, self.current_step)
|
|
175
|
+
# Recalculate based on the step *before* the one about to be taken
|
|
176
|
+
# Or simply use the saved current_value if available and consistent
|
|
177
|
+
if "current_value" in state_dict:
|
|
178
|
+
self.current_value = state_dict["current_value"]
|
|
179
|
+
else: # Legacy state dicts might not have current_value
|
|
180
|
+
self.current_value = self.final_value * (
|
|
181
|
+
step_for_calc / self.warm_up_steps
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
else:
|
|
185
|
+
self.current_value = self.final_value
|
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -1,26 +1,28 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Generic, Protocol
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
import wandb
|
|
8
|
+
from safetensors.torch import save_file
|
|
7
9
|
from torch.optim import Adam
|
|
8
|
-
from tqdm import tqdm
|
|
9
|
-
from transformer_lens.hook_points import HookedRootModule
|
|
10
|
+
from tqdm.auto import tqdm
|
|
10
11
|
|
|
11
12
|
from sae_lens import __version__
|
|
12
|
-
from sae_lens.config import
|
|
13
|
-
from sae_lens.
|
|
14
|
-
from sae_lens.
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
13
|
+
from sae_lens.config import SAETrainerConfig
|
|
14
|
+
from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
|
|
15
|
+
from sae_lens.saes.sae import (
|
|
16
|
+
T_TRAINING_SAE,
|
|
17
|
+
T_TRAINING_SAE_CONFIG,
|
|
18
|
+
TrainCoefficientConfig,
|
|
19
|
+
TrainingSAE,
|
|
20
|
+
TrainStepInput,
|
|
21
|
+
TrainStepOutput,
|
|
22
|
+
)
|
|
23
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
24
|
+
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
|
|
25
|
+
from sae_lens.training.types import DataProvider
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def _log_feature_sparsity(
|
|
@@ -29,7 +31,7 @@ def _log_feature_sparsity(
|
|
|
29
31
|
return torch.log10(feature_sparsity + eps).detach().cpu()
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
|
|
34
|
+
def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
|
|
33
35
|
"""
|
|
34
36
|
Make sure we record the version of SAELens used for the training run
|
|
35
37
|
"""
|
|
@@ -38,7 +40,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
|
|
|
38
40
|
|
|
39
41
|
@dataclass
|
|
40
42
|
class TrainSAEOutput:
|
|
41
|
-
sae: TrainingSAE
|
|
43
|
+
sae: TrainingSAE[Any]
|
|
42
44
|
checkpoint_path: str
|
|
43
45
|
log_feature_sparsities: torch.Tensor
|
|
44
46
|
|
|
@@ -46,33 +48,39 @@ class TrainSAEOutput:
|
|
|
46
48
|
class SaveCheckpointFn(Protocol):
|
|
47
49
|
def __call__(
|
|
48
50
|
self,
|
|
49
|
-
|
|
50
|
-
checkpoint_name: str,
|
|
51
|
-
wandb_aliases: list[str] | None = None,
|
|
51
|
+
checkpoint_path: Path,
|
|
52
52
|
) -> None: ...
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
|
|
55
|
+
Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str, Any]]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
56
59
|
"""
|
|
57
60
|
Core SAE class used for inference. For training, see TrainingSAE.
|
|
58
61
|
"""
|
|
59
62
|
|
|
63
|
+
data_provider: DataProvider
|
|
64
|
+
activation_scaler: ActivationScaler
|
|
65
|
+
evaluator: Evaluator[T_TRAINING_SAE] | None
|
|
66
|
+
|
|
60
67
|
def __init__(
|
|
61
68
|
self,
|
|
62
|
-
|
|
63
|
-
sae:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
69
|
+
cfg: SAETrainerConfig,
|
|
70
|
+
sae: T_TRAINING_SAE,
|
|
71
|
+
data_provider: DataProvider,
|
|
72
|
+
evaluator: Evaluator[T_TRAINING_SAE] | None = None,
|
|
73
|
+
save_checkpoint_fn: SaveCheckpointFn | None = None,
|
|
67
74
|
) -> None:
|
|
68
|
-
self.model = model
|
|
69
75
|
self.sae = sae
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
76
|
+
self.data_provider = data_provider
|
|
77
|
+
self.evaluator = evaluator
|
|
78
|
+
self.activation_scaler = ActivationScaler()
|
|
79
|
+
self.save_checkpoint_fn = save_checkpoint_fn
|
|
72
80
|
self.cfg = cfg
|
|
73
81
|
|
|
74
82
|
self.n_training_steps: int = 0
|
|
75
|
-
self.
|
|
83
|
+
self.n_training_samples: int = 0
|
|
76
84
|
self.started_fine_tuning: bool = False
|
|
77
85
|
|
|
78
86
|
_update_sae_lens_training_version(self.sae)
|
|
@@ -82,20 +90,16 @@ class SAETrainer:
|
|
|
82
90
|
self.checkpoint_thresholds = list(
|
|
83
91
|
range(
|
|
84
92
|
0,
|
|
85
|
-
cfg.
|
|
86
|
-
cfg.
|
|
93
|
+
cfg.total_training_samples,
|
|
94
|
+
cfg.total_training_samples // self.cfg.n_checkpoints,
|
|
87
95
|
)
|
|
88
96
|
)[1:]
|
|
89
97
|
|
|
90
|
-
self.act_freq_scores = torch.zeros(
|
|
91
|
-
cast(int, cfg.d_sae),
|
|
92
|
-
device=cfg.device,
|
|
93
|
-
)
|
|
98
|
+
self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
|
|
94
99
|
self.n_forward_passes_since_fired = torch.zeros(
|
|
95
|
-
|
|
96
|
-
device=cfg.device,
|
|
100
|
+
sae.cfg.d_sae, device=cfg.device
|
|
97
101
|
)
|
|
98
|
-
self.
|
|
102
|
+
self.n_frac_active_samples = 0
|
|
99
103
|
# we don't train the scaling factor (initially)
|
|
100
104
|
# set requires grad to false for the scaling factor
|
|
101
105
|
for name, param in self.sae.named_parameters():
|
|
@@ -121,14 +125,17 @@ class SAETrainer:
|
|
|
121
125
|
lr_end=cfg.lr_end,
|
|
122
126
|
num_cycles=cfg.n_restart_cycles,
|
|
123
127
|
)
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
128
|
+
self.coefficient_schedulers = {}
|
|
129
|
+
for name, coeff_cfg in self.sae.get_coefficients().items():
|
|
130
|
+
if not isinstance(coeff_cfg, TrainCoefficientConfig):
|
|
131
|
+
coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
|
|
132
|
+
self.coefficient_schedulers[name] = CoefficientScheduler(
|
|
133
|
+
warm_up_steps=coeff_cfg.warm_up_steps,
|
|
134
|
+
final_value=coeff_cfg.value,
|
|
135
|
+
)
|
|
129
136
|
|
|
130
137
|
# Setup autocast if using
|
|
131
|
-
self.
|
|
138
|
+
self.grad_scaler = torch.amp.GradScaler(
|
|
132
139
|
device=self.cfg.device, enabled=self.cfg.autocast
|
|
133
140
|
)
|
|
134
141
|
|
|
@@ -141,52 +148,39 @@ class SAETrainer:
|
|
|
141
148
|
else:
|
|
142
149
|
self.autocast_if_enabled = contextlib.nullcontext()
|
|
143
150
|
|
|
144
|
-
# Set up eval config
|
|
145
|
-
|
|
146
|
-
self.trainer_eval_config = EvalConfig(
|
|
147
|
-
batch_size_prompts=self.cfg.eval_batch_size_prompts,
|
|
148
|
-
n_eval_reconstruction_batches=self.cfg.n_eval_batches,
|
|
149
|
-
n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
|
|
150
|
-
compute_ce_loss=True,
|
|
151
|
-
compute_l2_norms=True,
|
|
152
|
-
compute_sparsity_metrics=True,
|
|
153
|
-
compute_variance_metrics=True,
|
|
154
|
-
compute_kl=False,
|
|
155
|
-
compute_featurewise_weight_based_metrics=False,
|
|
156
|
-
)
|
|
157
|
-
|
|
158
151
|
@property
|
|
159
152
|
def feature_sparsity(self) -> torch.Tensor:
|
|
160
|
-
return self.act_freq_scores / self.
|
|
153
|
+
return self.act_freq_scores / self.n_frac_active_samples
|
|
161
154
|
|
|
162
155
|
@property
|
|
163
156
|
def log_feature_sparsity(self) -> torch.Tensor:
|
|
164
157
|
return _log_feature_sparsity(self.feature_sparsity)
|
|
165
158
|
|
|
166
|
-
@property
|
|
167
|
-
def current_l1_coefficient(self) -> float:
|
|
168
|
-
return self.l1_scheduler.current_l1_coefficient
|
|
169
|
-
|
|
170
159
|
@property
|
|
171
160
|
def dead_neurons(self) -> torch.Tensor:
|
|
172
161
|
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
|
|
173
162
|
|
|
174
|
-
def fit(self) ->
|
|
175
|
-
|
|
163
|
+
def fit(self) -> T_TRAINING_SAE:
|
|
164
|
+
self.sae.to(self.cfg.device)
|
|
165
|
+
pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
|
|
176
166
|
|
|
177
|
-
self.
|
|
167
|
+
if self.sae.cfg.normalize_activations == "expected_average_only_in":
|
|
168
|
+
self.activation_scaler.estimate_scaling_factor(
|
|
169
|
+
d_in=self.sae.cfg.d_in,
|
|
170
|
+
data_provider=self.data_provider,
|
|
171
|
+
n_batches_for_norm_estimate=int(1e3),
|
|
172
|
+
)
|
|
178
173
|
|
|
179
174
|
# Train loop
|
|
180
|
-
while self.
|
|
175
|
+
while self.n_training_samples < self.cfg.total_training_samples:
|
|
181
176
|
# Do a training step.
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
)
|
|
185
|
-
self.n_training_tokens += self.cfg.train_batch_size_tokens
|
|
177
|
+
batch = next(self.data_provider).to(self.sae.device)
|
|
178
|
+
self.n_training_samples += batch.shape[0]
|
|
179
|
+
scaled_batch = self.activation_scaler(batch)
|
|
186
180
|
|
|
187
|
-
step_output = self._train_step(sae=self.sae, sae_in=
|
|
181
|
+
step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)
|
|
188
182
|
|
|
189
|
-
if self.cfg.log_to_wandb:
|
|
183
|
+
if self.cfg.logger.log_to_wandb:
|
|
190
184
|
self._log_train_step(step_output)
|
|
191
185
|
self._run_and_log_evals()
|
|
192
186
|
|
|
@@ -194,39 +188,67 @@ class SAETrainer:
|
|
|
194
188
|
self.n_training_steps += 1
|
|
195
189
|
self._update_pbar(step_output, pbar)
|
|
196
190
|
|
|
197
|
-
### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
|
|
198
|
-
self._begin_finetuning_if_needed()
|
|
199
|
-
|
|
200
191
|
# fold the estimated norm scaling factor into the sae weights
|
|
201
|
-
if self.
|
|
192
|
+
if self.activation_scaler.scaling_factor is not None:
|
|
202
193
|
self.sae.fold_activation_norm_scaling_factor(
|
|
203
|
-
self.
|
|
194
|
+
self.activation_scaler.scaling_factor
|
|
204
195
|
)
|
|
205
|
-
self.
|
|
196
|
+
self.activation_scaler.scaling_factor = None
|
|
206
197
|
|
|
207
|
-
# save final sae group to checkpoints folder
|
|
198
|
+
# save final inference sae group to checkpoints folder
|
|
208
199
|
self.save_checkpoint(
|
|
209
|
-
|
|
210
|
-
checkpoint_name=f"final_{self.n_training_tokens}",
|
|
200
|
+
checkpoint_name=f"final_{self.n_training_samples}",
|
|
211
201
|
wandb_aliases=["final_model"],
|
|
202
|
+
save_inference_model=True,
|
|
212
203
|
)
|
|
213
204
|
|
|
214
205
|
pbar.close()
|
|
215
206
|
return self.sae
|
|
216
207
|
|
|
208
|
+
def save_checkpoint(
|
|
209
|
+
self,
|
|
210
|
+
checkpoint_name: str,
|
|
211
|
+
wandb_aliases: list[str] | None = None,
|
|
212
|
+
save_inference_model: bool = False,
|
|
213
|
+
) -> None:
|
|
214
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
|
|
215
|
+
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
216
|
+
|
|
217
|
+
save_fn = (
|
|
218
|
+
self.sae.save_inference_model
|
|
219
|
+
if save_inference_model
|
|
220
|
+
else self.sae.save_model
|
|
221
|
+
)
|
|
222
|
+
weights_path, cfg_path = save_fn(str(checkpoint_path))
|
|
223
|
+
|
|
224
|
+
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
225
|
+
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
226
|
+
|
|
227
|
+
activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
|
|
228
|
+
self.activation_scaler.save(str(activation_scaler_path))
|
|
229
|
+
|
|
230
|
+
if self.cfg.logger.log_to_wandb:
|
|
231
|
+
self.cfg.logger.log(
|
|
232
|
+
self,
|
|
233
|
+
weights_path,
|
|
234
|
+
cfg_path,
|
|
235
|
+
sparsity_path=sparsity_path,
|
|
236
|
+
wandb_aliases=wandb_aliases,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if self.save_checkpoint_fn is not None:
|
|
240
|
+
self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
|
|
241
|
+
|
|
217
242
|
def _train_step(
|
|
218
243
|
self,
|
|
219
|
-
sae:
|
|
244
|
+
sae: T_TRAINING_SAE,
|
|
220
245
|
sae_in: torch.Tensor,
|
|
221
246
|
) -> TrainStepOutput:
|
|
222
247
|
sae.train()
|
|
223
|
-
# Make sure the W_dec is still zero-norm
|
|
224
|
-
if self.cfg.normalize_sae_decoder:
|
|
225
|
-
sae.set_decoder_norm_to_unit_norm()
|
|
226
248
|
|
|
227
249
|
# log and then reset the feature sparsity every feature_sampling_window steps
|
|
228
250
|
if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
|
|
229
|
-
if self.cfg.log_to_wandb:
|
|
251
|
+
if self.cfg.logger.log_to_wandb:
|
|
230
252
|
sparsity_log_dict = self._build_sparsity_log_dict()
|
|
231
253
|
wandb.log(sparsity_log_dict, step=self.n_training_steps)
|
|
232
254
|
self._reset_running_sparsity_stats()
|
|
@@ -235,9 +257,11 @@ class SAETrainer:
|
|
|
235
257
|
# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
|
|
236
258
|
with self.autocast_if_enabled:
|
|
237
259
|
train_step_output = self.sae.training_forward_pass(
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
260
|
+
step_input=TrainStepInput(
|
|
261
|
+
sae_in=sae_in,
|
|
262
|
+
dead_neuron_mask=self.dead_neurons,
|
|
263
|
+
coefficients=self.get_coefficients(),
|
|
264
|
+
),
|
|
241
265
|
)
|
|
242
266
|
|
|
243
267
|
with torch.no_grad():
|
|
@@ -247,43 +271,50 @@ class SAETrainer:
|
|
|
247
271
|
self.act_freq_scores += (
|
|
248
272
|
(train_step_output.feature_acts.abs() > 0).float().sum(0)
|
|
249
273
|
)
|
|
250
|
-
self.
|
|
274
|
+
self.n_frac_active_samples += self.cfg.train_batch_size_samples
|
|
251
275
|
|
|
252
|
-
#
|
|
253
|
-
self.
|
|
276
|
+
# Grad scaler will rescale gradients if autocast is enabled
|
|
277
|
+
self.grad_scaler.scale(
|
|
254
278
|
train_step_output.loss
|
|
255
279
|
).backward() # loss.backward() if not autocasting
|
|
256
|
-
self.
|
|
280
|
+
self.grad_scaler.unscale_(self.optimizer) # needed to clip correctly
|
|
257
281
|
# TODO: Work out if grad norm clipping should be in config / how to test it.
|
|
258
282
|
torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
|
|
259
|
-
self.
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
sae.remove_gradient_parallel_to_decoder_directions()
|
|
283
|
+
self.grad_scaler.step(
|
|
284
|
+
self.optimizer
|
|
285
|
+
) # just ctx.optimizer.step() if not autocasting
|
|
286
|
+
self.grad_scaler.update()
|
|
264
287
|
|
|
265
288
|
self.optimizer.zero_grad()
|
|
266
289
|
self.lr_scheduler.step()
|
|
267
|
-
self.
|
|
290
|
+
for scheduler in self.coefficient_schedulers.values():
|
|
291
|
+
scheduler.step()
|
|
268
292
|
|
|
269
293
|
return train_step_output
|
|
270
294
|
|
|
271
295
|
@torch.no_grad()
|
|
272
296
|
def _log_train_step(self, step_output: TrainStepOutput):
|
|
273
|
-
if (self.n_training_steps + 1) % self.cfg.wandb_log_frequency == 0:
|
|
297
|
+
if (self.n_training_steps + 1) % self.cfg.logger.wandb_log_frequency == 0:
|
|
274
298
|
wandb.log(
|
|
275
299
|
self._build_train_step_log_dict(
|
|
276
300
|
output=step_output,
|
|
277
|
-
|
|
301
|
+
n_training_samples=self.n_training_samples,
|
|
278
302
|
),
|
|
279
303
|
step=self.n_training_steps,
|
|
280
304
|
)
|
|
281
305
|
|
|
306
|
+
@torch.no_grad()
|
|
307
|
+
def get_coefficients(self) -> dict[str, float]:
|
|
308
|
+
return {
|
|
309
|
+
name: scheduler.value
|
|
310
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
311
|
+
}
|
|
312
|
+
|
|
282
313
|
@torch.no_grad()
|
|
283
314
|
def _build_train_step_log_dict(
|
|
284
315
|
self,
|
|
285
316
|
output: TrainStepOutput,
|
|
286
|
-
|
|
317
|
+
n_training_samples: int,
|
|
287
318
|
) -> dict[str, Any]:
|
|
288
319
|
sae_in = output.sae_in
|
|
289
320
|
sae_out = output.sae_out
|
|
@@ -311,19 +342,15 @@ class SAETrainer:
|
|
|
311
342
|
"sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
|
|
312
343
|
"sparsity/dead_features": self.dead_neurons.sum().item(),
|
|
313
344
|
"details/current_learning_rate": current_learning_rate,
|
|
314
|
-
"details/
|
|
315
|
-
|
|
345
|
+
"details/n_training_samples": n_training_samples,
|
|
346
|
+
**{
|
|
347
|
+
f"details/{name}_coefficient": scheduler.value
|
|
348
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
349
|
+
},
|
|
316
350
|
}
|
|
317
351
|
for loss_name, loss_value in output.losses.items():
|
|
318
352
|
loss_item = _unwrap_item(loss_value)
|
|
319
|
-
|
|
320
|
-
if loss_name == "l1_loss":
|
|
321
|
-
log_dict[f"losses/{loss_name}"] = (
|
|
322
|
-
loss_item / self.current_l1_coefficient
|
|
323
|
-
)
|
|
324
|
-
log_dict[f"losses/raw_{loss_name}"] = loss_item
|
|
325
|
-
else:
|
|
326
|
-
log_dict[f"losses/{loss_name}"] = loss_item
|
|
353
|
+
log_dict[f"losses/{loss_name}"] = loss_item
|
|
327
354
|
|
|
328
355
|
return log_dict
|
|
329
356
|
|
|
@@ -331,44 +358,17 @@ class SAETrainer:
|
|
|
331
358
|
def _run_and_log_evals(self):
|
|
332
359
|
# record loss frequently, but not all the time.
|
|
333
360
|
if (self.n_training_steps + 1) % (
|
|
334
|
-
self.cfg.wandb_log_frequency
|
|
361
|
+
self.cfg.logger.wandb_log_frequency
|
|
362
|
+
* self.cfg.logger.eval_every_n_wandb_logs
|
|
335
363
|
) == 0:
|
|
336
364
|
self.sae.eval()
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
activation_store=self.activations_store,
|
|
345
|
-
model=self.model,
|
|
346
|
-
eval_config=self.trainer_eval_config,
|
|
347
|
-
ignore_tokens=ignore_tokens,
|
|
348
|
-
model_kwargs=self.cfg.model_kwargs,
|
|
349
|
-
) # not calculating featurwise metrics here.
|
|
350
|
-
|
|
351
|
-
# Remove eval metrics that are already logged during training
|
|
352
|
-
eval_metrics.pop("metrics/explained_variance", None)
|
|
353
|
-
eval_metrics.pop("metrics/explained_variance_std", None)
|
|
354
|
-
eval_metrics.pop("metrics/l0", None)
|
|
355
|
-
eval_metrics.pop("metrics/l1", None)
|
|
356
|
-
eval_metrics.pop("metrics/mse", None)
|
|
357
|
-
|
|
358
|
-
# Remove metrics that are not useful for wandb logging
|
|
359
|
-
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
360
|
-
|
|
361
|
-
W_dec_norm_dist = self.sae.W_dec.detach().float().norm(dim=1).cpu().numpy()
|
|
362
|
-
eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore
|
|
363
|
-
|
|
364
|
-
if self.sae.cfg.architecture == "standard":
|
|
365
|
-
b_e_dist = self.sae.b_enc.detach().float().cpu().numpy()
|
|
366
|
-
eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore
|
|
367
|
-
elif self.sae.cfg.architecture == "gated":
|
|
368
|
-
b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy()
|
|
369
|
-
eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore
|
|
370
|
-
b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy()
|
|
371
|
-
eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore
|
|
365
|
+
eval_metrics = (
|
|
366
|
+
self.evaluator(self.sae, self.data_provider, self.activation_scaler)
|
|
367
|
+
if self.evaluator is not None
|
|
368
|
+
else {}
|
|
369
|
+
)
|
|
370
|
+
for key, value in self.sae.log_histograms().items():
|
|
371
|
+
eval_metrics[key] = wandb.Histogram(value) # type: ignore
|
|
372
372
|
|
|
373
373
|
wandb.log(
|
|
374
374
|
eval_metrics,
|
|
@@ -390,21 +390,18 @@ class SAETrainer:
|
|
|
390
390
|
@torch.no_grad()
|
|
391
391
|
def _reset_running_sparsity_stats(self) -> None:
|
|
392
392
|
self.act_freq_scores = torch.zeros(
|
|
393
|
-
self.cfg.d_sae, # type: ignore
|
|
393
|
+
self.sae.cfg.d_sae, # type: ignore
|
|
394
394
|
device=self.cfg.device,
|
|
395
395
|
)
|
|
396
|
-
self.
|
|
396
|
+
self.n_frac_active_samples = 0
|
|
397
397
|
|
|
398
398
|
@torch.no_grad()
|
|
399
399
|
def _checkpoint_if_needed(self):
|
|
400
400
|
if (
|
|
401
401
|
self.checkpoint_thresholds
|
|
402
|
-
and self.
|
|
402
|
+
and self.n_training_samples > self.checkpoint_thresholds[0]
|
|
403
403
|
):
|
|
404
|
-
self.save_checkpoint(
|
|
405
|
-
trainer=self,
|
|
406
|
-
checkpoint_name=str(self.n_training_tokens),
|
|
407
|
-
)
|
|
404
|
+
self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
|
|
408
405
|
self.checkpoint_thresholds.pop(0)
|
|
409
406
|
|
|
410
407
|
@torch.no_grad()
|
|
@@ -420,26 +417,7 @@ class SAETrainer:
|
|
|
420
417
|
for loss_name, loss_value in step_output.losses.items()
|
|
421
418
|
)
|
|
422
419
|
pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
|
|
423
|
-
pbar.update(update_interval * self.cfg.
|
|
424
|
-
|
|
425
|
-
def _begin_finetuning_if_needed(self):
|
|
426
|
-
if (not self.started_fine_tuning) and (
|
|
427
|
-
self.n_training_tokens > self.cfg.training_tokens
|
|
428
|
-
):
|
|
429
|
-
self.started_fine_tuning = True
|
|
430
|
-
|
|
431
|
-
# finetuning method should be set in the config
|
|
432
|
-
# if not, then we don't finetune
|
|
433
|
-
if not isinstance(self.cfg.finetuning_method, str):
|
|
434
|
-
return
|
|
435
|
-
|
|
436
|
-
for name, param in self.sae.named_parameters():
|
|
437
|
-
if name in FINETUNING_PARAMETERS[self.cfg.finetuning_method]:
|
|
438
|
-
param.requires_grad = True
|
|
439
|
-
else:
|
|
440
|
-
param.requires_grad = False
|
|
441
|
-
|
|
442
|
-
self.finetuning = True
|
|
420
|
+
pbar.update(update_interval * self.cfg.train_batch_size_samples)
|
|
443
421
|
|
|
444
422
|
|
|
445
423
|
def _unwrap_item(item: float | torch.Tensor) -> float:
|