sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +6 -3
- sae_lens/analysis/neuronpedia_integration.py +3 -3
- sae_lens/cache_activations_runner.py +7 -6
- sae_lens/config.py +50 -6
- sae_lens/constants.py +2 -0
- sae_lens/evals.py +39 -28
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +24 -12
- sae_lens/saes/gated_sae.py +0 -4
- sae_lens/saes/jumprelu_sae.py +4 -10
- sae_lens/saes/sae.py +121 -51
- sae_lens/saes/standard_sae.py +4 -11
- sae_lens/saes/topk_sae.py +18 -12
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +77 -174
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/sae_trainer.py +107 -98
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- sae_lens/util.py +19 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc4.dist-info/RECORD +37 -0
- sae_lens/sae_training_runner.py +0 -237
- sae_lens/training/geometric_median.py +0 -101
- sae_lens-6.0.0rc2.dist-info/RECORD +0 -35
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/LICENSE +0 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/WHEEL +0 -0
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Generic, Protocol
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
import wandb
|
|
8
|
+
from safetensors.torch import save_file
|
|
7
9
|
from torch.optim import Adam
|
|
8
|
-
from tqdm import tqdm
|
|
9
|
-
from transformer_lens.hook_points import HookedRootModule
|
|
10
|
+
from tqdm.auto import tqdm
|
|
10
11
|
|
|
11
12
|
from sae_lens import __version__
|
|
12
|
-
from sae_lens.config import
|
|
13
|
-
from sae_lens.
|
|
13
|
+
from sae_lens.config import SAETrainerConfig
|
|
14
|
+
from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
|
|
14
15
|
from sae_lens.saes.sae import (
|
|
15
16
|
T_TRAINING_SAE,
|
|
16
17
|
T_TRAINING_SAE_CONFIG,
|
|
@@ -19,8 +20,9 @@ from sae_lens.saes.sae import (
|
|
|
19
20
|
TrainStepInput,
|
|
20
21
|
TrainStepOutput,
|
|
21
22
|
)
|
|
22
|
-
from sae_lens.training.
|
|
23
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
23
24
|
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
|
|
25
|
+
from sae_lens.training.types import DataProvider
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def _log_feature_sparsity(
|
|
@@ -46,33 +48,39 @@ class TrainSAEOutput:
|
|
|
46
48
|
class SaveCheckpointFn(Protocol):
|
|
47
49
|
def __call__(
|
|
48
50
|
self,
|
|
49
|
-
|
|
50
|
-
checkpoint_name: str,
|
|
51
|
-
wandb_aliases: list[str] | None = None,
|
|
51
|
+
checkpoint_path: Path,
|
|
52
52
|
) -> None: ...
|
|
53
53
|
|
|
54
54
|
|
|
55
|
+
Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str, Any]]
|
|
56
|
+
|
|
57
|
+
|
|
55
58
|
class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
56
59
|
"""
|
|
57
60
|
Core SAE class used for inference. For training, see TrainingSAE.
|
|
58
61
|
"""
|
|
59
62
|
|
|
63
|
+
data_provider: DataProvider
|
|
64
|
+
activation_scaler: ActivationScaler
|
|
65
|
+
evaluator: Evaluator[T_TRAINING_SAE] | None
|
|
66
|
+
|
|
60
67
|
def __init__(
|
|
61
68
|
self,
|
|
62
|
-
|
|
69
|
+
cfg: SAETrainerConfig,
|
|
63
70
|
sae: T_TRAINING_SAE,
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
71
|
+
data_provider: DataProvider,
|
|
72
|
+
evaluator: Evaluator[T_TRAINING_SAE] | None = None,
|
|
73
|
+
save_checkpoint_fn: SaveCheckpointFn | None = None,
|
|
67
74
|
) -> None:
|
|
68
|
-
self.model = model
|
|
69
75
|
self.sae = sae
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
76
|
+
self.data_provider = data_provider
|
|
77
|
+
self.evaluator = evaluator
|
|
78
|
+
self.activation_scaler = ActivationScaler()
|
|
79
|
+
self.save_checkpoint_fn = save_checkpoint_fn
|
|
72
80
|
self.cfg = cfg
|
|
73
81
|
|
|
74
82
|
self.n_training_steps: int = 0
|
|
75
|
-
self.
|
|
83
|
+
self.n_training_samples: int = 0
|
|
76
84
|
self.started_fine_tuning: bool = False
|
|
77
85
|
|
|
78
86
|
_update_sae_lens_training_version(self.sae)
|
|
@@ -82,20 +90,16 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
82
90
|
self.checkpoint_thresholds = list(
|
|
83
91
|
range(
|
|
84
92
|
0,
|
|
85
|
-
cfg.
|
|
86
|
-
cfg.
|
|
93
|
+
cfg.total_training_samples,
|
|
94
|
+
cfg.total_training_samples // self.cfg.n_checkpoints,
|
|
87
95
|
)
|
|
88
96
|
)[1:]
|
|
89
97
|
|
|
90
|
-
self.act_freq_scores = torch.zeros(
|
|
91
|
-
cast(int, cfg.sae.d_sae),
|
|
92
|
-
device=cfg.device,
|
|
93
|
-
)
|
|
98
|
+
self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
|
|
94
99
|
self.n_forward_passes_since_fired = torch.zeros(
|
|
95
|
-
|
|
96
|
-
device=cfg.device,
|
|
100
|
+
sae.cfg.d_sae, device=cfg.device
|
|
97
101
|
)
|
|
98
|
-
self.
|
|
102
|
+
self.n_frac_active_samples = 0
|
|
99
103
|
# we don't train the scaling factor (initially)
|
|
100
104
|
# set requires grad to false for the scaling factor
|
|
101
105
|
for name, param in self.sae.named_parameters():
|
|
@@ -131,7 +135,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
131
135
|
)
|
|
132
136
|
|
|
133
137
|
# Setup autocast if using
|
|
134
|
-
self.
|
|
138
|
+
self.grad_scaler = torch.amp.GradScaler(
|
|
135
139
|
device=self.cfg.device, enabled=self.cfg.autocast
|
|
136
140
|
)
|
|
137
141
|
|
|
@@ -144,23 +148,9 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
144
148
|
else:
|
|
145
149
|
self.autocast_if_enabled = contextlib.nullcontext()
|
|
146
150
|
|
|
147
|
-
# Set up eval config
|
|
148
|
-
|
|
149
|
-
self.trainer_eval_config = EvalConfig(
|
|
150
|
-
batch_size_prompts=self.cfg.eval_batch_size_prompts,
|
|
151
|
-
n_eval_reconstruction_batches=self.cfg.n_eval_batches,
|
|
152
|
-
n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
|
|
153
|
-
compute_ce_loss=True,
|
|
154
|
-
compute_l2_norms=True,
|
|
155
|
-
compute_sparsity_metrics=True,
|
|
156
|
-
compute_variance_metrics=True,
|
|
157
|
-
compute_kl=False,
|
|
158
|
-
compute_featurewise_weight_based_metrics=False,
|
|
159
|
-
)
|
|
160
|
-
|
|
161
151
|
@property
|
|
162
152
|
def feature_sparsity(self) -> torch.Tensor:
|
|
163
|
-
return self.act_freq_scores / self.
|
|
153
|
+
return self.act_freq_scores / self.n_frac_active_samples
|
|
164
154
|
|
|
165
155
|
@property
|
|
166
156
|
def log_feature_sparsity(self) -> torch.Tensor:
|
|
@@ -171,19 +161,24 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
171
161
|
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
|
|
172
162
|
|
|
173
163
|
def fit(self) -> T_TRAINING_SAE:
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
164
|
+
self.sae.to(self.cfg.device)
|
|
165
|
+
pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
|
|
166
|
+
|
|
167
|
+
if self.sae.cfg.normalize_activations == "expected_average_only_in":
|
|
168
|
+
self.activation_scaler.estimate_scaling_factor(
|
|
169
|
+
d_in=self.sae.cfg.d_in,
|
|
170
|
+
data_provider=self.data_provider,
|
|
171
|
+
n_batches_for_norm_estimate=int(1e3),
|
|
172
|
+
)
|
|
177
173
|
|
|
178
174
|
# Train loop
|
|
179
|
-
while self.
|
|
175
|
+
while self.n_training_samples < self.cfg.total_training_samples:
|
|
180
176
|
# Do a training step.
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
)
|
|
184
|
-
self.n_training_tokens += self.cfg.train_batch_size_tokens
|
|
177
|
+
batch = next(self.data_provider).to(self.sae.device)
|
|
178
|
+
self.n_training_samples += batch.shape[0]
|
|
179
|
+
scaled_batch = self.activation_scaler(batch)
|
|
185
180
|
|
|
186
|
-
step_output = self._train_step(sae=self.sae, sae_in=
|
|
181
|
+
step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)
|
|
187
182
|
|
|
188
183
|
if self.cfg.logger.log_to_wandb:
|
|
189
184
|
self._log_train_step(step_output)
|
|
@@ -194,22 +189,56 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
194
189
|
self._update_pbar(step_output, pbar)
|
|
195
190
|
|
|
196
191
|
# fold the estimated norm scaling factor into the sae weights
|
|
197
|
-
if self.
|
|
192
|
+
if self.activation_scaler.scaling_factor is not None:
|
|
198
193
|
self.sae.fold_activation_norm_scaling_factor(
|
|
199
|
-
self.
|
|
194
|
+
self.activation_scaler.scaling_factor
|
|
200
195
|
)
|
|
201
|
-
self.
|
|
196
|
+
self.activation_scaler.scaling_factor = None
|
|
202
197
|
|
|
203
|
-
# save final sae group to checkpoints folder
|
|
198
|
+
# save final inference sae group to checkpoints folder
|
|
204
199
|
self.save_checkpoint(
|
|
205
|
-
|
|
206
|
-
checkpoint_name=f"final_{self.n_training_tokens}",
|
|
200
|
+
checkpoint_name=f"final_{self.n_training_samples}",
|
|
207
201
|
wandb_aliases=["final_model"],
|
|
202
|
+
save_inference_model=True,
|
|
208
203
|
)
|
|
209
204
|
|
|
210
205
|
pbar.close()
|
|
211
206
|
return self.sae
|
|
212
207
|
|
|
208
|
+
def save_checkpoint(
|
|
209
|
+
self,
|
|
210
|
+
checkpoint_name: str,
|
|
211
|
+
wandb_aliases: list[str] | None = None,
|
|
212
|
+
save_inference_model: bool = False,
|
|
213
|
+
) -> None:
|
|
214
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
|
|
215
|
+
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
216
|
+
|
|
217
|
+
save_fn = (
|
|
218
|
+
self.sae.save_inference_model
|
|
219
|
+
if save_inference_model
|
|
220
|
+
else self.sae.save_model
|
|
221
|
+
)
|
|
222
|
+
weights_path, cfg_path = save_fn(str(checkpoint_path))
|
|
223
|
+
|
|
224
|
+
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
225
|
+
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
226
|
+
|
|
227
|
+
activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
|
|
228
|
+
self.activation_scaler.save(str(activation_scaler_path))
|
|
229
|
+
|
|
230
|
+
if self.cfg.logger.log_to_wandb:
|
|
231
|
+
self.cfg.logger.log(
|
|
232
|
+
self,
|
|
233
|
+
weights_path,
|
|
234
|
+
cfg_path,
|
|
235
|
+
sparsity_path=sparsity_path,
|
|
236
|
+
wandb_aliases=wandb_aliases,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if self.save_checkpoint_fn is not None:
|
|
240
|
+
self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
|
|
241
|
+
|
|
213
242
|
def _train_step(
|
|
214
243
|
self,
|
|
215
244
|
sae: T_TRAINING_SAE,
|
|
@@ -242,17 +271,19 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
242
271
|
self.act_freq_scores += (
|
|
243
272
|
(train_step_output.feature_acts.abs() > 0).float().sum(0)
|
|
244
273
|
)
|
|
245
|
-
self.
|
|
274
|
+
self.n_frac_active_samples += self.cfg.train_batch_size_samples
|
|
246
275
|
|
|
247
|
-
#
|
|
248
|
-
self.
|
|
276
|
+
# Grad scaler will rescale gradients if autocast is enabled
|
|
277
|
+
self.grad_scaler.scale(
|
|
249
278
|
train_step_output.loss
|
|
250
279
|
).backward() # loss.backward() if not autocasting
|
|
251
|
-
self.
|
|
280
|
+
self.grad_scaler.unscale_(self.optimizer) # needed to clip correctly
|
|
252
281
|
# TODO: Work out if grad norm clipping should be in config / how to test it.
|
|
253
282
|
torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
|
|
254
|
-
self.
|
|
255
|
-
|
|
283
|
+
self.grad_scaler.step(
|
|
284
|
+
self.optimizer
|
|
285
|
+
) # just ctx.optimizer.step() if not autocasting
|
|
286
|
+
self.grad_scaler.update()
|
|
256
287
|
|
|
257
288
|
self.optimizer.zero_grad()
|
|
258
289
|
self.lr_scheduler.step()
|
|
@@ -267,7 +298,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
267
298
|
wandb.log(
|
|
268
299
|
self._build_train_step_log_dict(
|
|
269
300
|
output=step_output,
|
|
270
|
-
|
|
301
|
+
n_training_samples=self.n_training_samples,
|
|
271
302
|
),
|
|
272
303
|
step=self.n_training_steps,
|
|
273
304
|
)
|
|
@@ -283,7 +314,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
283
314
|
def _build_train_step_log_dict(
|
|
284
315
|
self,
|
|
285
316
|
output: TrainStepOutput,
|
|
286
|
-
|
|
317
|
+
n_training_samples: int,
|
|
287
318
|
) -> dict[str, Any]:
|
|
288
319
|
sae_in = output.sae_in
|
|
289
320
|
sae_out = output.sae_out
|
|
@@ -311,7 +342,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
311
342
|
"sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
|
|
312
343
|
"sparsity/dead_features": self.dead_neurons.sum().item(),
|
|
313
344
|
"details/current_learning_rate": current_learning_rate,
|
|
314
|
-
"details/
|
|
345
|
+
"details/n_training_samples": n_training_samples,
|
|
315
346
|
**{
|
|
316
347
|
f"details/{name}_coefficient": scheduler.value
|
|
317
348
|
for name, scheduler in self.coefficient_schedulers.items()
|
|
@@ -331,30 +362,11 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
331
362
|
* self.cfg.logger.eval_every_n_wandb_logs
|
|
332
363
|
) == 0:
|
|
333
364
|
self.sae.eval()
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
eval_metrics, _ = run_evals(
|
|
340
|
-
sae=self.sae,
|
|
341
|
-
activation_store=self.activations_store,
|
|
342
|
-
model=self.model,
|
|
343
|
-
eval_config=self.trainer_eval_config,
|
|
344
|
-
ignore_tokens=ignore_tokens,
|
|
345
|
-
model_kwargs=self.cfg.model_kwargs,
|
|
346
|
-
) # not calculating featurwise metrics here.
|
|
347
|
-
|
|
348
|
-
# Remove eval metrics that are already logged during training
|
|
349
|
-
eval_metrics.pop("metrics/explained_variance", None)
|
|
350
|
-
eval_metrics.pop("metrics/explained_variance_std", None)
|
|
351
|
-
eval_metrics.pop("metrics/l0", None)
|
|
352
|
-
eval_metrics.pop("metrics/l1", None)
|
|
353
|
-
eval_metrics.pop("metrics/mse", None)
|
|
354
|
-
|
|
355
|
-
# Remove metrics that are not useful for wandb logging
|
|
356
|
-
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
357
|
-
|
|
365
|
+
eval_metrics = (
|
|
366
|
+
self.evaluator(self.sae, self.data_provider, self.activation_scaler)
|
|
367
|
+
if self.evaluator is not None
|
|
368
|
+
else {}
|
|
369
|
+
)
|
|
358
370
|
for key, value in self.sae.log_histograms().items():
|
|
359
371
|
eval_metrics[key] = wandb.Histogram(value) # type: ignore
|
|
360
372
|
|
|
@@ -378,21 +390,18 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
378
390
|
@torch.no_grad()
|
|
379
391
|
def _reset_running_sparsity_stats(self) -> None:
|
|
380
392
|
self.act_freq_scores = torch.zeros(
|
|
381
|
-
self.cfg.
|
|
393
|
+
self.sae.cfg.d_sae, # type: ignore
|
|
382
394
|
device=self.cfg.device,
|
|
383
395
|
)
|
|
384
|
-
self.
|
|
396
|
+
self.n_frac_active_samples = 0
|
|
385
397
|
|
|
386
398
|
@torch.no_grad()
|
|
387
399
|
def _checkpoint_if_needed(self):
|
|
388
400
|
if (
|
|
389
401
|
self.checkpoint_thresholds
|
|
390
|
-
and self.
|
|
402
|
+
and self.n_training_samples > self.checkpoint_thresholds[0]
|
|
391
403
|
):
|
|
392
|
-
self.save_checkpoint(
|
|
393
|
-
trainer=self,
|
|
394
|
-
checkpoint_name=str(self.n_training_tokens),
|
|
395
|
-
)
|
|
404
|
+
self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
|
|
396
405
|
self.checkpoint_thresholds.pop(0)
|
|
397
406
|
|
|
398
407
|
@torch.no_grad()
|
|
@@ -408,7 +417,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
408
417
|
for loss_name, loss_value in step_output.losses.items()
|
|
409
418
|
)
|
|
410
419
|
pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
|
|
411
|
-
pbar.update(update_interval * self.cfg.
|
|
420
|
+
pbar.update(update_interval * self.cfg.train_batch_size_samples)
|
|
412
421
|
|
|
413
422
|
|
|
414
423
|
def _unwrap_item(item: float | torch.Tensor) -> float:
|
|
@@ -88,7 +88,7 @@ def _create_default_readme(repo_id: str, sae_ids: Iterable[str]) -> str:
|
|
|
88
88
|
```python
|
|
89
89
|
from sae_lens import SAE
|
|
90
90
|
|
|
91
|
-
sae
|
|
91
|
+
sae = SAE.from_pretrained("{repo_id}", "<sae_id>")
|
|
92
92
|
```
|
|
93
93
|
"""
|
|
94
94
|
)
|
sae_lens/util.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from dataclasses import asdict, fields, is_dataclass
|
|
2
3
|
from typing import Sequence, TypeVar
|
|
3
4
|
|
|
@@ -26,3 +27,21 @@ def filter_valid_dataclass_fields(
|
|
|
26
27
|
if whitelist_fields is not None:
|
|
27
28
|
valid_field_names = valid_field_names.union(whitelist_fields)
|
|
28
29
|
return {key: val for key, val in source_dict.items() if key in valid_field_names}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def extract_stop_at_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
33
|
+
"""Extract the stop_at layer from a HookedTransformer hook name.
|
|
34
|
+
|
|
35
|
+
Returns None if the hook name is not a valid HookedTransformer hook name.
|
|
36
|
+
"""
|
|
37
|
+
layer = extract_layer_from_tlens_hook_name(hook_name)
|
|
38
|
+
return None if layer is None else layer + 1
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
42
|
+
"""Extract the layer from a HookedTransformer hook name.
|
|
43
|
+
|
|
44
|
+
Returns None if the hook name is not a valid HookedTransformer hook name.
|
|
45
|
+
"""
|
|
46
|
+
hook_match = re.search(r"\.(\d+)\.", hook_name)
|
|
47
|
+
return None if hook_match is None else int(hook_match.group(1))
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
sae_lens/__init__.py,sha256=dGZU3Y6iwiuW5oQVTfNvUmfnHO3bHWWbpU-nvXvw9M8,2861
|
|
2
|
+
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
|
+
sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
|
|
6
|
+
sae_lens/config.py,sha256=9Lg4HkQvj1t9QZJdmC071lyJMc_iqNQknosT7zOYfwM,27278
|
|
7
|
+
sae_lens/constants.py,sha256=RJlzWx7wLNMNmrdI63naF7-M3enb55vYRN4x1hXx6vI,593
|
|
8
|
+
sae_lens/evals.py,sha256=PIMGQobE9o2bHksDAtQe5bnTMYyHoZKB_elFhDOjrmo,38991
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
|
|
10
|
+
sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
|
|
11
|
+
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=kbirwfCg4Ks9Cg3rt78bYxIHMhz5h015n0UTRJQLJY0,31291
|
|
13
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
14
|
+
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
|
|
16
|
+
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
|
|
18
|
+
sae_lens/saes/gated_sae.py,sha256=0zd66bH04nsaGk3bxHk10hsZofa2GrFbMo15LOsuqgU,9233
|
|
19
|
+
sae_lens/saes/jumprelu_sae.py,sha256=iwmPQJ4XpIxzgosty680u8Zj7x1uVZhM75kPOT3obi0,12060
|
|
20
|
+
sae_lens/saes/sae.py,sha256=HAGkJAj_FIDzbSR1dsG8b2AyMq8UauUU_yx-LvdfjuE,37465
|
|
21
|
+
sae_lens/saes/standard_sae.py,sha256=PfkGLsw_6La3PXHOQL0u7qQsaZsXCJqYCeCcRDj5n64,6274
|
|
22
|
+
sae_lens/saes/topk_sae.py,sha256=kmry1FE1H06OvCfn84V-j2JfWGKcU5b2urwAq_Oq5j4,9893
|
|
23
|
+
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
24
|
+
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
+
sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
|
|
26
|
+
sae_lens/training/activations_store.py,sha256=s3Qvztv2siuuXSuXEUDZYSKq1QQCsqsGXY767kv6grc,32609
|
|
27
|
+
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
28
|
+
sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
|
|
29
|
+
sae_lens/training/sae_trainer.py,sha256=9K0VudwSTJp9OlCVzaU_ngZ0WlYNrN6-ozTCCAxR9_k,15421
|
|
30
|
+
sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
|
|
31
|
+
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
32
|
+
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
33
|
+
sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
|
|
34
|
+
sae_lens-6.0.0rc4.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
35
|
+
sae_lens-6.0.0rc4.dist-info/METADATA,sha256=wOQMSV4yNlpgpGxuE4DI0-q4KzTRYOg1m9ZxpdCsNjk,5326
|
|
36
|
+
sae_lens-6.0.0rc4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
37
|
+
sae_lens-6.0.0rc4.dist-info/RECORD,,
|