sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +6 -3
- sae_lens/cache_activations_runner.py +7 -6
- sae_lens/config.py +47 -5
- sae_lens/constants.py +2 -0
- sae_lens/evals.py +19 -19
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +92 -60
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +0 -7
- sae_lens/saes/sae.py +0 -3
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +77 -172
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/sae_trainer.py +96 -95
- sae_lens/training/types.py +5 -0
- sae_lens/util.py +19 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/RECORD +19 -16
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +0 -0
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Generic, Protocol
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
import wandb
|
|
8
|
+
from safetensors.torch import save_file
|
|
7
9
|
from torch.optim import Adam
|
|
8
10
|
from tqdm import tqdm
|
|
9
|
-
from transformer_lens.hook_points import HookedRootModule
|
|
10
11
|
|
|
11
12
|
from sae_lens import __version__
|
|
12
|
-
from sae_lens.config import
|
|
13
|
-
from sae_lens.
|
|
13
|
+
from sae_lens.config import SAETrainerConfig
|
|
14
|
+
from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
|
|
14
15
|
from sae_lens.saes.sae import (
|
|
15
16
|
T_TRAINING_SAE,
|
|
16
17
|
T_TRAINING_SAE_CONFIG,
|
|
@@ -19,8 +20,9 @@ from sae_lens.saes.sae import (
|
|
|
19
20
|
TrainStepInput,
|
|
20
21
|
TrainStepOutput,
|
|
21
22
|
)
|
|
22
|
-
from sae_lens.training.
|
|
23
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
23
24
|
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
|
|
25
|
+
from sae_lens.training.types import DataProvider
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def _log_feature_sparsity(
|
|
@@ -46,33 +48,39 @@ class TrainSAEOutput:
|
|
|
46
48
|
class SaveCheckpointFn(Protocol):
|
|
47
49
|
def __call__(
|
|
48
50
|
self,
|
|
49
|
-
|
|
50
|
-
checkpoint_name: str,
|
|
51
|
-
wandb_aliases: list[str] | None = None,
|
|
51
|
+
checkpoint_path: Path,
|
|
52
52
|
) -> None: ...
|
|
53
53
|
|
|
54
54
|
|
|
55
|
+
Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str, Any]]
|
|
56
|
+
|
|
57
|
+
|
|
55
58
|
class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
56
59
|
"""
|
|
57
60
|
Core SAE class used for inference. For training, see TrainingSAE.
|
|
58
61
|
"""
|
|
59
62
|
|
|
63
|
+
data_provider: DataProvider
|
|
64
|
+
activation_scaler: ActivationScaler
|
|
65
|
+
evaluator: Evaluator[T_TRAINING_SAE] | None
|
|
66
|
+
|
|
60
67
|
def __init__(
|
|
61
68
|
self,
|
|
62
|
-
|
|
69
|
+
cfg: SAETrainerConfig,
|
|
63
70
|
sae: T_TRAINING_SAE,
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
71
|
+
data_provider: DataProvider,
|
|
72
|
+
evaluator: Evaluator[T_TRAINING_SAE] | None = None,
|
|
73
|
+
save_checkpoint_fn: SaveCheckpointFn | None = None,
|
|
67
74
|
) -> None:
|
|
68
|
-
self.model = model
|
|
69
75
|
self.sae = sae
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
76
|
+
self.data_provider = data_provider
|
|
77
|
+
self.evaluator = evaluator
|
|
78
|
+
self.activation_scaler = ActivationScaler()
|
|
79
|
+
self.save_checkpoint_fn = save_checkpoint_fn
|
|
72
80
|
self.cfg = cfg
|
|
73
81
|
|
|
74
82
|
self.n_training_steps: int = 0
|
|
75
|
-
self.
|
|
83
|
+
self.n_training_samples: int = 0
|
|
76
84
|
self.started_fine_tuning: bool = False
|
|
77
85
|
|
|
78
86
|
_update_sae_lens_training_version(self.sae)
|
|
@@ -82,20 +90,16 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
82
90
|
self.checkpoint_thresholds = list(
|
|
83
91
|
range(
|
|
84
92
|
0,
|
|
85
|
-
cfg.
|
|
86
|
-
cfg.
|
|
93
|
+
cfg.total_training_samples,
|
|
94
|
+
cfg.total_training_samples // self.cfg.n_checkpoints,
|
|
87
95
|
)
|
|
88
96
|
)[1:]
|
|
89
97
|
|
|
90
|
-
self.act_freq_scores = torch.zeros(
|
|
91
|
-
cast(int, cfg.sae.d_sae),
|
|
92
|
-
device=cfg.device,
|
|
93
|
-
)
|
|
98
|
+
self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
|
|
94
99
|
self.n_forward_passes_since_fired = torch.zeros(
|
|
95
|
-
|
|
96
|
-
device=cfg.device,
|
|
100
|
+
sae.cfg.d_sae, device=cfg.device
|
|
97
101
|
)
|
|
98
|
-
self.
|
|
102
|
+
self.n_frac_active_samples = 0
|
|
99
103
|
# we don't train the scaling factor (initially)
|
|
100
104
|
# set requires grad to false for the scaling factor
|
|
101
105
|
for name, param in self.sae.named_parameters():
|
|
@@ -131,7 +135,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
131
135
|
)
|
|
132
136
|
|
|
133
137
|
# Setup autocast if using
|
|
134
|
-
self.
|
|
138
|
+
self.grad_scaler = torch.amp.GradScaler(
|
|
135
139
|
device=self.cfg.device, enabled=self.cfg.autocast
|
|
136
140
|
)
|
|
137
141
|
|
|
@@ -144,23 +148,9 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
144
148
|
else:
|
|
145
149
|
self.autocast_if_enabled = contextlib.nullcontext()
|
|
146
150
|
|
|
147
|
-
# Set up eval config
|
|
148
|
-
|
|
149
|
-
self.trainer_eval_config = EvalConfig(
|
|
150
|
-
batch_size_prompts=self.cfg.eval_batch_size_prompts,
|
|
151
|
-
n_eval_reconstruction_batches=self.cfg.n_eval_batches,
|
|
152
|
-
n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
|
|
153
|
-
compute_ce_loss=True,
|
|
154
|
-
compute_l2_norms=True,
|
|
155
|
-
compute_sparsity_metrics=True,
|
|
156
|
-
compute_variance_metrics=True,
|
|
157
|
-
compute_kl=False,
|
|
158
|
-
compute_featurewise_weight_based_metrics=False,
|
|
159
|
-
)
|
|
160
|
-
|
|
161
151
|
@property
|
|
162
152
|
def feature_sparsity(self) -> torch.Tensor:
|
|
163
|
-
return self.act_freq_scores / self.
|
|
153
|
+
return self.act_freq_scores / self.n_frac_active_samples
|
|
164
154
|
|
|
165
155
|
@property
|
|
166
156
|
def log_feature_sparsity(self) -> torch.Tensor:
|
|
@@ -171,19 +161,23 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
171
161
|
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
|
|
172
162
|
|
|
173
163
|
def fit(self) -> T_TRAINING_SAE:
|
|
174
|
-
pbar = tqdm(total=self.cfg.
|
|
164
|
+
pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
|
|
175
165
|
|
|
176
|
-
self.
|
|
166
|
+
if self.sae.cfg.normalize_activations == "expected_average_only_in":
|
|
167
|
+
self.activation_scaler.estimate_scaling_factor(
|
|
168
|
+
d_in=self.sae.cfg.d_in,
|
|
169
|
+
data_provider=self.data_provider,
|
|
170
|
+
n_batches_for_norm_estimate=int(1e3),
|
|
171
|
+
)
|
|
177
172
|
|
|
178
173
|
# Train loop
|
|
179
|
-
while self.
|
|
174
|
+
while self.n_training_samples < self.cfg.total_training_samples:
|
|
180
175
|
# Do a training step.
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
)
|
|
184
|
-
self.n_training_tokens += self.cfg.train_batch_size_tokens
|
|
176
|
+
batch = next(self.data_provider).to(self.sae.device)
|
|
177
|
+
self.n_training_samples += batch.shape[0]
|
|
178
|
+
scaled_batch = self.activation_scaler(batch)
|
|
185
179
|
|
|
186
|
-
step_output = self._train_step(sae=self.sae, sae_in=
|
|
180
|
+
step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)
|
|
187
181
|
|
|
188
182
|
if self.cfg.logger.log_to_wandb:
|
|
189
183
|
self._log_train_step(step_output)
|
|
@@ -194,22 +188,49 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
194
188
|
self._update_pbar(step_output, pbar)
|
|
195
189
|
|
|
196
190
|
# fold the estimated norm scaling factor into the sae weights
|
|
197
|
-
if self.
|
|
191
|
+
if self.activation_scaler.scaling_factor is not None:
|
|
198
192
|
self.sae.fold_activation_norm_scaling_factor(
|
|
199
|
-
self.
|
|
193
|
+
self.activation_scaler.scaling_factor
|
|
200
194
|
)
|
|
201
|
-
self.
|
|
195
|
+
self.activation_scaler.scaling_factor = None
|
|
202
196
|
|
|
203
197
|
# save final sae group to checkpoints folder
|
|
204
198
|
self.save_checkpoint(
|
|
205
|
-
|
|
206
|
-
checkpoint_name=f"final_{self.n_training_tokens}",
|
|
199
|
+
checkpoint_name=f"final_{self.n_training_samples}",
|
|
207
200
|
wandb_aliases=["final_model"],
|
|
208
201
|
)
|
|
209
202
|
|
|
210
203
|
pbar.close()
|
|
211
204
|
return self.sae
|
|
212
205
|
|
|
206
|
+
def save_checkpoint(
|
|
207
|
+
self,
|
|
208
|
+
checkpoint_name: str,
|
|
209
|
+
wandb_aliases: list[str] | None = None,
|
|
210
|
+
) -> None:
|
|
211
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
|
|
212
|
+
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
213
|
+
|
|
214
|
+
weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))
|
|
215
|
+
|
|
216
|
+
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
217
|
+
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
218
|
+
|
|
219
|
+
activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
|
|
220
|
+
self.activation_scaler.save(str(activation_scaler_path))
|
|
221
|
+
|
|
222
|
+
if self.cfg.logger.log_to_wandb:
|
|
223
|
+
self.cfg.logger.log(
|
|
224
|
+
self,
|
|
225
|
+
weights_path,
|
|
226
|
+
cfg_path,
|
|
227
|
+
sparsity_path=sparsity_path,
|
|
228
|
+
wandb_aliases=wandb_aliases,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if self.save_checkpoint_fn is not None:
|
|
232
|
+
self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
|
|
233
|
+
|
|
213
234
|
def _train_step(
|
|
214
235
|
self,
|
|
215
236
|
sae: T_TRAINING_SAE,
|
|
@@ -242,17 +263,19 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
242
263
|
self.act_freq_scores += (
|
|
243
264
|
(train_step_output.feature_acts.abs() > 0).float().sum(0)
|
|
244
265
|
)
|
|
245
|
-
self.
|
|
266
|
+
self.n_frac_active_samples += self.cfg.train_batch_size_samples
|
|
246
267
|
|
|
247
|
-
#
|
|
248
|
-
self.
|
|
268
|
+
# Grad scaler will rescale gradients if autocast is enabled
|
|
269
|
+
self.grad_scaler.scale(
|
|
249
270
|
train_step_output.loss
|
|
250
271
|
).backward() # loss.backward() if not autocasting
|
|
251
|
-
self.
|
|
272
|
+
self.grad_scaler.unscale_(self.optimizer) # needed to clip correctly
|
|
252
273
|
# TODO: Work out if grad norm clipping should be in config / how to test it.
|
|
253
274
|
torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
|
|
254
|
-
self.
|
|
255
|
-
|
|
275
|
+
self.grad_scaler.step(
|
|
276
|
+
self.optimizer
|
|
277
|
+
) # just ctx.optimizer.step() if not autocasting
|
|
278
|
+
self.grad_scaler.update()
|
|
256
279
|
|
|
257
280
|
self.optimizer.zero_grad()
|
|
258
281
|
self.lr_scheduler.step()
|
|
@@ -267,7 +290,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
267
290
|
wandb.log(
|
|
268
291
|
self._build_train_step_log_dict(
|
|
269
292
|
output=step_output,
|
|
270
|
-
|
|
293
|
+
n_training_samples=self.n_training_samples,
|
|
271
294
|
),
|
|
272
295
|
step=self.n_training_steps,
|
|
273
296
|
)
|
|
@@ -283,7 +306,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
283
306
|
def _build_train_step_log_dict(
|
|
284
307
|
self,
|
|
285
308
|
output: TrainStepOutput,
|
|
286
|
-
|
|
309
|
+
n_training_samples: int,
|
|
287
310
|
) -> dict[str, Any]:
|
|
288
311
|
sae_in = output.sae_in
|
|
289
312
|
sae_out = output.sae_out
|
|
@@ -311,7 +334,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
311
334
|
"sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
|
|
312
335
|
"sparsity/dead_features": self.dead_neurons.sum().item(),
|
|
313
336
|
"details/current_learning_rate": current_learning_rate,
|
|
314
|
-
"details/
|
|
337
|
+
"details/n_training_samples": n_training_samples,
|
|
315
338
|
**{
|
|
316
339
|
f"details/{name}_coefficient": scheduler.value
|
|
317
340
|
for name, scheduler in self.coefficient_schedulers.items()
|
|
@@ -331,30 +354,11 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
331
354
|
* self.cfg.logger.eval_every_n_wandb_logs
|
|
332
355
|
) == 0:
|
|
333
356
|
self.sae.eval()
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
eval_metrics, _ = run_evals(
|
|
340
|
-
sae=self.sae,
|
|
341
|
-
activation_store=self.activations_store,
|
|
342
|
-
model=self.model,
|
|
343
|
-
eval_config=self.trainer_eval_config,
|
|
344
|
-
ignore_tokens=ignore_tokens,
|
|
345
|
-
model_kwargs=self.cfg.model_kwargs,
|
|
346
|
-
) # not calculating featurwise metrics here.
|
|
347
|
-
|
|
348
|
-
# Remove eval metrics that are already logged during training
|
|
349
|
-
eval_metrics.pop("metrics/explained_variance", None)
|
|
350
|
-
eval_metrics.pop("metrics/explained_variance_std", None)
|
|
351
|
-
eval_metrics.pop("metrics/l0", None)
|
|
352
|
-
eval_metrics.pop("metrics/l1", None)
|
|
353
|
-
eval_metrics.pop("metrics/mse", None)
|
|
354
|
-
|
|
355
|
-
# Remove metrics that are not useful for wandb logging
|
|
356
|
-
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
357
|
-
|
|
357
|
+
eval_metrics = (
|
|
358
|
+
self.evaluator(self.sae, self.data_provider, self.activation_scaler)
|
|
359
|
+
if self.evaluator is not None
|
|
360
|
+
else {}
|
|
361
|
+
)
|
|
358
362
|
for key, value in self.sae.log_histograms().items():
|
|
359
363
|
eval_metrics[key] = wandb.Histogram(value) # type: ignore
|
|
360
364
|
|
|
@@ -378,21 +382,18 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
378
382
|
@torch.no_grad()
|
|
379
383
|
def _reset_running_sparsity_stats(self) -> None:
|
|
380
384
|
self.act_freq_scores = torch.zeros(
|
|
381
|
-
self.cfg.
|
|
385
|
+
self.sae.cfg.d_sae, # type: ignore
|
|
382
386
|
device=self.cfg.device,
|
|
383
387
|
)
|
|
384
|
-
self.
|
|
388
|
+
self.n_frac_active_samples = 0
|
|
385
389
|
|
|
386
390
|
@torch.no_grad()
|
|
387
391
|
def _checkpoint_if_needed(self):
|
|
388
392
|
if (
|
|
389
393
|
self.checkpoint_thresholds
|
|
390
|
-
and self.
|
|
394
|
+
and self.n_training_samples > self.checkpoint_thresholds[0]
|
|
391
395
|
):
|
|
392
|
-
self.save_checkpoint(
|
|
393
|
-
trainer=self,
|
|
394
|
-
checkpoint_name=str(self.n_training_tokens),
|
|
395
|
-
)
|
|
396
|
+
self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
|
|
396
397
|
self.checkpoint_thresholds.pop(0)
|
|
397
398
|
|
|
398
399
|
@torch.no_grad()
|
|
@@ -408,7 +409,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
408
409
|
for loss_name, loss_value in step_output.losses.items()
|
|
409
410
|
)
|
|
410
411
|
pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
|
|
411
|
-
pbar.update(update_interval * self.cfg.
|
|
412
|
+
pbar.update(update_interval * self.cfg.train_batch_size_samples)
|
|
412
413
|
|
|
413
414
|
|
|
414
415
|
def _unwrap_item(item: float | torch.Tensor) -> float:
|
sae_lens/util.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from dataclasses import asdict, fields, is_dataclass
|
|
2
3
|
from typing import Sequence, TypeVar
|
|
3
4
|
|
|
@@ -26,3 +27,21 @@ def filter_valid_dataclass_fields(
|
|
|
26
27
|
if whitelist_fields is not None:
|
|
27
28
|
valid_field_names = valid_field_names.union(whitelist_fields)
|
|
28
29
|
return {key: val for key, val in source_dict.items() if key in valid_field_names}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def extract_stop_at_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
33
|
+
"""Extract the stop_at layer from a HookedTransformer hook name.
|
|
34
|
+
|
|
35
|
+
Returns None if the hook name is not a valid HookedTransformer hook name.
|
|
36
|
+
"""
|
|
37
|
+
layer = extract_layer_from_tlens_hook_name(hook_name)
|
|
38
|
+
return None if layer is None else layer + 1
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
42
|
+
"""Extract the layer from a HookedTransformer hook name.
|
|
43
|
+
|
|
44
|
+
Returns None if the hook name is not a valid HookedTransformer hook name.
|
|
45
|
+
"""
|
|
46
|
+
hook_match = re.search(r"\.(\d+)\.", hook_name)
|
|
47
|
+
return None if hook_match is None else int(hook_match.group(1))
|
|
@@ -1,35 +1,38 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=881mDkwEifeN32NsH78_CaeH11sKYK4YnqCW502qHE4,2861
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
|
|
5
|
-
sae_lens/cache_activations_runner.py,sha256=
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
7
|
-
sae_lens/constants.py,sha256=
|
|
8
|
-
sae_lens/evals.py,sha256=
|
|
9
|
-
sae_lens/
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
|
|
6
|
+
sae_lens/config.py,sha256=5Wgr8SsUvYWU2Xmet1JyJ0upAZArMDpYfr3jaK8TvRY,27234
|
|
7
|
+
sae_lens/constants.py,sha256=RJlzWx7wLNMNmrdI63naF7-M3enb55vYRN4x1hXx6vI,593
|
|
8
|
+
sae_lens/evals.py,sha256=WRdHlVeZxXCi33gef7rQE90PSUBF6pjrHnPP6av_Urg,38747
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=-FPXaHvDfSw5twSaDO8O80aGIzX6T0HywgdpEFFoO-8,9098
|
|
10
|
+
sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
|
|
10
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=FSAz9Je-8Xl7ccdEyp8-WRn-KFtaJ74zgKMefnfaj3A,30877
|
|
12
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
13
14
|
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
14
15
|
sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
|
|
15
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
16
|
-
sae_lens/sae_training_runner.py,sha256=lI_d3ywS312dIz0wctm_Sgt3W9ffBOS7ahnDXBljX1s,8320
|
|
17
17
|
sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
|
|
18
18
|
sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
|
|
19
19
|
sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
|
|
20
|
-
sae_lens/saes/sae.py,sha256=
|
|
20
|
+
sae_lens/saes/sae.py,sha256=u4kmsUVxa2rnFt8A5jLfj7T6h6qqBK6CkecHslebQgE,34938
|
|
21
21
|
sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
|
|
22
22
|
sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
|
|
23
23
|
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
24
24
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
-
sae_lens/training/
|
|
25
|
+
sae_lens/training/activation_scaler.py,sha256=1P-vva3wJhs2NH65YONli4Rw4auvgZkxe_KKwTNMCR0,1714
|
|
26
|
+
sae_lens/training/activations_store.py,sha256=Xvnz7l2aw3XWtOQsQDj4G4bt-XT6egbumGBwrAM1mtA,32722
|
|
26
27
|
sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
|
|
28
|
+
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
27
29
|
sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
|
|
28
|
-
sae_lens/training/sae_trainer.py,sha256=
|
|
30
|
+
sae_lens/training/sae_trainer.py,sha256=rFuMdnBDe82nd7YV_QKVE18V5jCWmohbzkIGL0Z2kIM,15153
|
|
31
|
+
sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
|
|
29
32
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
|
|
30
33
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
31
|
-
sae_lens/util.py,sha256=
|
|
32
|
-
sae_lens-6.0.
|
|
33
|
-
sae_lens-6.0.
|
|
34
|
-
sae_lens-6.0.
|
|
35
|
-
sae_lens-6.0.
|
|
34
|
+
sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
|
|
35
|
+
sae_lens-6.0.0rc3.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
36
|
+
sae_lens-6.0.0rc3.dist-info/METADATA,sha256=irWiVHtJUXiACNPxZ0fNIVwq1n7n0wxg87c0WSYUkMw,5326
|
|
37
|
+
sae_lens-6.0.0rc3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
38
|
+
sae_lens-6.0.0rc3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|