sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,710 +0,0 @@
1
- """Most of this is just copied over from Arthur's code and slightly simplified:
2
- https://github.com/ArthurConmy/sae/blob/main/sae/model.py
3
- """
4
-
5
- from dataclasses import dataclass, fields
6
- from typing import Any
7
-
8
- import einops
9
- import numpy as np
10
- import torch
11
- from jaxtyping import Float
12
- from torch import nn
13
- from typing_extensions import deprecated
14
-
15
- from sae_lens import logger
16
- from sae_lens.config import LanguageModelSAERunnerConfig
17
- from sae_lens.sae import SAE, SAEConfig
18
- from sae_lens.toolkit.pretrained_sae_loaders import (
19
- PretrainedSaeDiskLoader,
20
- handle_config_defaulting,
21
- sae_lens_disk_loader,
22
- )
23
-
24
- SPARSITY_PATH = "sparsity.safetensors"
25
- SAE_WEIGHTS_PATH = "sae_weights.safetensors"
26
- SAE_CFG_PATH = "cfg.json"
27
-
28
-
29
- def rectangle(x: torch.Tensor) -> torch.Tensor:
30
- return ((x > -0.5) & (x < 0.5)).to(x)
31
-
32
-
33
- class Step(torch.autograd.Function):
34
- @staticmethod
35
- def forward(
36
- x: torch.Tensor,
37
- threshold: torch.Tensor,
38
- bandwidth: float, # noqa: ARG004
39
- ) -> torch.Tensor:
40
- return (x > threshold).to(x)
41
-
42
- @staticmethod
43
- def setup_context(
44
- ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
45
- ) -> None:
46
- x, threshold, bandwidth = inputs
47
- del output
48
- ctx.save_for_backward(x, threshold)
49
- ctx.bandwidth = bandwidth
50
-
51
- @staticmethod
52
- def backward( # type: ignore[override]
53
- ctx: Any, grad_output: torch.Tensor
54
- ) -> tuple[None, torch.Tensor, None]:
55
- x, threshold = ctx.saved_tensors
56
- bandwidth = ctx.bandwidth
57
- threshold_grad = torch.sum(
58
- -(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
59
- dim=0,
60
- )
61
- return None, threshold_grad, None
62
-
63
-
64
- class JumpReLU(torch.autograd.Function):
65
- @staticmethod
66
- def forward(
67
- x: torch.Tensor,
68
- threshold: torch.Tensor,
69
- bandwidth: float, # noqa: ARG004
70
- ) -> torch.Tensor:
71
- return (x * (x > threshold)).to(x)
72
-
73
- @staticmethod
74
- def setup_context(
75
- ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
76
- ) -> None:
77
- x, threshold, bandwidth = inputs
78
- del output
79
- ctx.save_for_backward(x, threshold)
80
- ctx.bandwidth = bandwidth
81
-
82
- @staticmethod
83
- def backward( # type: ignore[override]
84
- ctx: Any, grad_output: torch.Tensor
85
- ) -> tuple[torch.Tensor, torch.Tensor, None]:
86
- x, threshold = ctx.saved_tensors
87
- bandwidth = ctx.bandwidth
88
- x_grad = (x > threshold) * grad_output # We don't apply STE to x input
89
- threshold_grad = torch.sum(
90
- -(threshold / bandwidth)
91
- * rectangle((x - threshold) / bandwidth)
92
- * grad_output,
93
- dim=0,
94
- )
95
- return x_grad, threshold_grad, None
96
-
97
-
98
- @dataclass
99
- class TrainStepOutput:
100
- sae_in: torch.Tensor
101
- sae_out: torch.Tensor
102
- feature_acts: torch.Tensor
103
- hidden_pre: torch.Tensor
104
- loss: torch.Tensor # we need to call backwards on this
105
- losses: dict[str, float | torch.Tensor]
106
-
107
-
108
- @dataclass(kw_only=True)
109
- class TrainingSAEConfig(SAEConfig):
110
- # Sparsity Loss Calculations
111
- l1_coefficient: float
112
- lp_norm: float
113
- use_ghost_grads: bool
114
- normalize_sae_decoder: bool
115
- noise_scale: float
116
- decoder_orthogonal_init: bool
117
- mse_loss_normalization: str | None
118
- jumprelu_init_threshold: float
119
- jumprelu_bandwidth: float
120
- decoder_heuristic_init: bool
121
- decoder_heuristic_init_norm: float
122
- init_encoder_as_decoder_transpose: bool
123
- scale_sparsity_penalty_by_decoder_norm: bool
124
-
125
- @classmethod
126
- def from_sae_runner_config(
127
- cls, cfg: LanguageModelSAERunnerConfig
128
- ) -> "TrainingSAEConfig":
129
- return cls(
130
- # base config
131
- architecture=cfg.architecture,
132
- d_in=cfg.d_in,
133
- d_sae=cfg.d_sae, # type: ignore
134
- dtype=cfg.dtype,
135
- device=cfg.device,
136
- model_name=cfg.model_name,
137
- hook_name=cfg.hook_name,
138
- hook_layer=cfg.hook_layer,
139
- hook_head_index=cfg.hook_head_index,
140
- activation_fn_str=cfg.activation_fn,
141
- activation_fn_kwargs=cfg.activation_fn_kwargs,
142
- apply_b_dec_to_input=cfg.apply_b_dec_to_input,
143
- finetuning_scaling_factor=cfg.finetuning_method is not None,
144
- sae_lens_training_version=cfg.sae_lens_training_version,
145
- context_size=cfg.context_size,
146
- dataset_path=cfg.dataset_path,
147
- prepend_bos=cfg.prepend_bos,
148
- seqpos_slice=cfg.seqpos_slice,
149
- # Training cfg
150
- l1_coefficient=cfg.l1_coefficient,
151
- lp_norm=cfg.lp_norm,
152
- use_ghost_grads=cfg.use_ghost_grads,
153
- normalize_sae_decoder=cfg.normalize_sae_decoder,
154
- noise_scale=cfg.noise_scale,
155
- decoder_orthogonal_init=cfg.decoder_orthogonal_init,
156
- mse_loss_normalization=cfg.mse_loss_normalization,
157
- decoder_heuristic_init=cfg.decoder_heuristic_init,
158
- decoder_heuristic_init_norm=cfg.decoder_heuristic_init_norm,
159
- init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
160
- scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
161
- normalize_activations=cfg.normalize_activations,
162
- dataset_trust_remote_code=cfg.dataset_trust_remote_code,
163
- model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
164
- jumprelu_init_threshold=cfg.jumprelu_init_threshold,
165
- jumprelu_bandwidth=cfg.jumprelu_bandwidth,
166
- )
167
-
168
- @classmethod
169
- def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
170
- # remove any keys that are not in the dataclass
171
- # since we sometimes enhance the config with the whole LM runner config
172
- valid_field_names = {field.name for field in fields(cls)}
173
- valid_config_dict = {
174
- key: val for key, val in config_dict.items() if key in valid_field_names
175
- }
176
-
177
- # ensure seqpos slice is tuple
178
- # ensure that seqpos slices is a tuple
179
- # Ensure seqpos_slice is a tuple
180
- if "seqpos_slice" in valid_config_dict:
181
- if isinstance(valid_config_dict["seqpos_slice"], list):
182
- valid_config_dict["seqpos_slice"] = tuple(
183
- valid_config_dict["seqpos_slice"]
184
- )
185
- elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
186
- valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
187
-
188
- return TrainingSAEConfig(**valid_config_dict)
189
-
190
- def to_dict(self) -> dict[str, Any]:
191
- return {
192
- **super().to_dict(),
193
- "l1_coefficient": self.l1_coefficient,
194
- "lp_norm": self.lp_norm,
195
- "use_ghost_grads": self.use_ghost_grads,
196
- "normalize_sae_decoder": self.normalize_sae_decoder,
197
- "noise_scale": self.noise_scale,
198
- "decoder_orthogonal_init": self.decoder_orthogonal_init,
199
- "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
200
- "mse_loss_normalization": self.mse_loss_normalization,
201
- "decoder_heuristic_init": self.decoder_heuristic_init,
202
- "decoder_heuristic_init_norm": self.decoder_heuristic_init_norm,
203
- "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
204
- "normalize_activations": self.normalize_activations,
205
- "jumprelu_init_threshold": self.jumprelu_init_threshold,
206
- "jumprelu_bandwidth": self.jumprelu_bandwidth,
207
- }
208
-
209
- # this needs to exist so we can initialize the parent sae cfg without the training specific
210
- # parameters. Maybe there's a cleaner way to do this
211
- def get_base_sae_cfg_dict(self) -> dict[str, Any]:
212
- return {
213
- "architecture": self.architecture,
214
- "d_in": self.d_in,
215
- "d_sae": self.d_sae,
216
- "activation_fn_str": self.activation_fn_str,
217
- "activation_fn_kwargs": self.activation_fn_kwargs,
218
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
219
- "dtype": self.dtype,
220
- "model_name": self.model_name,
221
- "hook_name": self.hook_name,
222
- "hook_layer": self.hook_layer,
223
- "hook_head_index": self.hook_head_index,
224
- "device": self.device,
225
- "context_size": self.context_size,
226
- "prepend_bos": self.prepend_bos,
227
- "finetuning_scaling_factor": self.finetuning_scaling_factor,
228
- "normalize_activations": self.normalize_activations,
229
- "dataset_path": self.dataset_path,
230
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
231
- "sae_lens_training_version": self.sae_lens_training_version,
232
- }
233
-
234
-
235
- class TrainingSAE(SAE):
236
- """
237
- A SAE used for training. This class provides a `training_forward_pass` method which calculates
238
- losses used for training.
239
- """
240
-
241
- cfg: TrainingSAEConfig
242
- use_error_term: bool
243
- dtype: torch.dtype
244
- device: torch.device
245
-
246
- def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
247
- base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict())
248
- super().__init__(base_sae_cfg)
249
- self.cfg = cfg # type: ignore
250
-
251
- if cfg.architecture == "standard" or cfg.architecture == "topk":
252
- self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre
253
- elif cfg.architecture == "gated":
254
- self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated
255
- elif cfg.architecture == "jumprelu":
256
- self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_jumprelu
257
- self.bandwidth = cfg.jumprelu_bandwidth
258
- self.log_threshold.data = torch.ones(
259
- self.cfg.d_sae, dtype=self.dtype, device=self.device
260
- ) * np.log(cfg.jumprelu_init_threshold)
261
-
262
- else:
263
- raise ValueError(f"Unknown architecture: {cfg.architecture}")
264
-
265
- self.check_cfg_compatibility()
266
-
267
- self.use_error_term = use_error_term
268
-
269
- self.initialize_weights_complex()
270
-
271
- # The training SAE will assume that the activation store handles
272
- # reshaping.
273
- self.turn_off_forward_pass_hook_z_reshaping()
274
-
275
- self.mse_loss_fn = self._get_mse_loss_fn()
276
-
277
- def initialize_weights_jumprelu(self):
278
- # same as the superclass, except we use a log_threshold parameter instead of threshold
279
- self.log_threshold = nn.Parameter(
280
- torch.empty(self.cfg.d_sae, dtype=self.dtype, device=self.device)
281
- )
282
- self.initialize_weights_basic()
283
-
284
- @property
285
- def threshold(self) -> torch.Tensor:
286
- if self.cfg.architecture != "jumprelu":
287
- raise ValueError("Threshold is only defined for Jumprelu SAEs")
288
- return torch.exp(self.log_threshold)
289
-
290
- @classmethod
291
- def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE":
292
- return cls(TrainingSAEConfig.from_dict(config_dict))
293
-
294
- def check_cfg_compatibility(self):
295
- if self.cfg.architecture != "standard" and self.cfg.use_ghost_grads:
296
- raise ValueError(f"{self.cfg.architecture} SAEs do not support ghost grads")
297
- if self.cfg.architecture == "gated" and self.use_error_term:
298
- raise ValueError("Gated SAEs do not support error terms")
299
-
300
- def encode_standard(
301
- self, x: Float[torch.Tensor, "... d_in"]
302
- ) -> Float[torch.Tensor, "... d_sae"]:
303
- """
304
- Calcuate SAE features from inputs
305
- """
306
- feature_acts, _ = self.encode_with_hidden_pre_fn(x)
307
- return feature_acts
308
-
309
- def encode_with_hidden_pre_jumprelu(
310
- self, x: Float[torch.Tensor, "... d_in"]
311
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
312
- sae_in = self.process_sae_in(x)
313
-
314
- hidden_pre = sae_in @ self.W_enc + self.b_enc
315
-
316
- if self.training:
317
- hidden_pre = (
318
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
319
- )
320
-
321
- threshold = torch.exp(self.log_threshold)
322
-
323
- feature_acts = JumpReLU.apply(hidden_pre, threshold, self.bandwidth)
324
-
325
- return feature_acts, hidden_pre # type: ignore
326
-
327
- def encode_with_hidden_pre(
328
- self, x: Float[torch.Tensor, "... d_in"]
329
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
330
- sae_in = self.process_sae_in(x)
331
-
332
- # "... d_in, d_in d_sae -> ... d_sae",
333
- hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
334
- hidden_pre_noised = hidden_pre + (
335
- torch.randn_like(hidden_pre) * self.cfg.noise_scale * self.training
336
- )
337
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
338
-
339
- return feature_acts, hidden_pre_noised
340
-
341
- def encode_with_hidden_pre_gated(
342
- self, x: Float[torch.Tensor, "... d_in"]
343
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
344
- sae_in = self.process_sae_in(x)
345
-
346
- # Gating path with Heaviside step function
347
- gating_pre_activation = sae_in @ self.W_enc + self.b_gate
348
- active_features = (gating_pre_activation > 0).to(self.dtype)
349
-
350
- # Magnitude path with weight sharing
351
- magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
352
- # magnitude_pre_activation_noised = magnitude_pre_activation + (
353
- # torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale * self.training
354
- # )
355
- feature_magnitudes = self.activation_fn(
356
- magnitude_pre_activation
357
- ) # magnitude_pre_activation_noised)
358
-
359
- # Return both the gated feature activations and the magnitude pre-activations
360
- return (
361
- active_features * feature_magnitudes,
362
- magnitude_pre_activation,
363
- ) # magnitude_pre_activation_noised
364
-
365
- def forward(
366
- self,
367
- x: Float[torch.Tensor, "... d_in"],
368
- ) -> Float[torch.Tensor, "... d_in"]:
369
- feature_acts, _ = self.encode_with_hidden_pre_fn(x)
370
- return self.decode(feature_acts)
371
-
372
- def training_forward_pass(
373
- self,
374
- sae_in: torch.Tensor,
375
- current_l1_coefficient: float,
376
- dead_neuron_mask: torch.Tensor | None = None,
377
- ) -> TrainStepOutput:
378
- # do a forward pass to get SAE out, but we also need the
379
- # hidden pre.
380
- feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
381
- sae_out = self.decode(feature_acts)
382
-
383
- # MSE LOSS
384
- per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in)
385
- mse_loss = per_item_mse_loss.sum(dim=-1).mean()
386
-
387
- losses: dict[str, float | torch.Tensor] = {}
388
-
389
- if self.cfg.architecture == "gated":
390
- # Gated SAE Loss Calculation
391
-
392
- # Shared variables
393
- sae_in_centered = (
394
- self.reshape_fn_in(sae_in) - self.b_dec * self.cfg.apply_b_dec_to_input
395
- )
396
- pi_gate = sae_in_centered @ self.W_enc + self.b_gate
397
- pi_gate_act = torch.relu(pi_gate)
398
-
399
- # SFN sparsity loss - summed over the feature dimension and averaged over the batch
400
- l1_loss = (
401
- current_l1_coefficient
402
- * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
403
- )
404
-
405
- # Auxiliary reconstruction loss - summed over the feature dimension and averaged over the batch
406
- via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
407
- aux_reconstruction_loss = torch.sum(
408
- (via_gate_reconstruction - sae_in) ** 2, dim=-1
409
- ).mean()
410
- loss = mse_loss + l1_loss + aux_reconstruction_loss
411
- losses["auxiliary_reconstruction_loss"] = aux_reconstruction_loss
412
- losses["l1_loss"] = l1_loss
413
- elif self.cfg.architecture == "jumprelu":
414
- threshold = torch.exp(self.log_threshold)
415
- l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore
416
- l0_loss = (current_l1_coefficient * l0).mean()
417
- loss = mse_loss + l0_loss
418
- losses["l0_loss"] = l0_loss
419
- elif self.cfg.architecture == "topk":
420
- topk_loss = self.calculate_topk_aux_loss(
421
- sae_in=sae_in,
422
- sae_out=sae_out,
423
- hidden_pre=hidden_pre,
424
- dead_neuron_mask=dead_neuron_mask,
425
- )
426
- losses["auxiliary_reconstruction_loss"] = topk_loss
427
- loss = mse_loss + topk_loss
428
- else:
429
- # default SAE sparsity loss
430
- weighted_feature_acts = feature_acts
431
- if self.cfg.scale_sparsity_penalty_by_decoder_norm:
432
- weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
433
- sparsity = weighted_feature_acts.norm(
434
- p=self.cfg.lp_norm, dim=-1
435
- ) # sum over the feature dimension
436
-
437
- l1_loss = (current_l1_coefficient * sparsity).mean()
438
- loss = mse_loss + l1_loss
439
- if (
440
- self.cfg.use_ghost_grads
441
- and self.training
442
- and dead_neuron_mask is not None
443
- ):
444
- ghost_grad_loss = self.calculate_ghost_grad_loss(
445
- x=sae_in,
446
- sae_out=sae_out,
447
- per_item_mse_loss=per_item_mse_loss,
448
- hidden_pre=hidden_pre,
449
- dead_neuron_mask=dead_neuron_mask,
450
- )
451
- losses["ghost_grad_loss"] = ghost_grad_loss
452
- loss = loss + ghost_grad_loss
453
- losses["l1_loss"] = l1_loss
454
-
455
- losses["mse_loss"] = mse_loss
456
-
457
- return TrainStepOutput(
458
- sae_in=sae_in,
459
- sae_out=sae_out,
460
- feature_acts=feature_acts,
461
- hidden_pre=hidden_pre,
462
- loss=loss,
463
- losses=losses,
464
- )
465
-
466
- def calculate_topk_aux_loss(
467
- self,
468
- sae_in: torch.Tensor,
469
- sae_out: torch.Tensor,
470
- hidden_pre: torch.Tensor,
471
- dead_neuron_mask: torch.Tensor | None,
472
- ) -> torch.Tensor:
473
- # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
474
- # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
475
- if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
476
- return sae_out.new_tensor(0.0)
477
- residual = (sae_in - sae_out).detach()
478
-
479
- # Heuristic from Appendix B.1 in the paper
480
- k_aux = sae_in.shape[-1] // 2
481
-
482
- # Reduce the scale of the loss if there are a small number of dead latents
483
- scale = min(num_dead / k_aux, 1.0)
484
- k_aux = min(k_aux, num_dead)
485
-
486
- auxk_acts = _calculate_topk_aux_acts(
487
- k_aux=k_aux,
488
- hidden_pre=hidden_pre,
489
- dead_neuron_mask=dead_neuron_mask,
490
- )
491
-
492
- # Encourage the top ~50% of dead latents to predict the residual of the
493
- # top k living latents
494
- recons = self.decode(auxk_acts)
495
- auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
496
- return scale * auxk_loss
497
-
498
- def calculate_ghost_grad_loss(
499
- self,
500
- x: torch.Tensor,
501
- sae_out: torch.Tensor,
502
- per_item_mse_loss: torch.Tensor,
503
- hidden_pre: torch.Tensor,
504
- dead_neuron_mask: torch.Tensor,
505
- ) -> torch.Tensor:
506
- # 1.
507
- residual = x - sae_out
508
- l2_norm_residual = torch.norm(residual, dim=-1)
509
-
510
- # 2.
511
- # ghost grads use an exponentional activation function, ignoring whatever
512
- # the activation function is in the SAE. The forward pass uses the dead neurons only.
513
- feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
514
- ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
515
- l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
516
- norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
517
- ghost_out = ghost_out * norm_scaling_factor[:, None].detach()
518
-
519
- # 3. There is some fairly complex rescaling here to make sure that the loss
520
- # is comparable to the original loss. This is because the ghost grads are
521
- # only calculated for the dead neurons, so we need to rescale the loss to
522
- # make sure that the loss is comparable to the original loss.
523
- # There have been methodological improvements that are not implemented here yet
524
- # see here: https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Improving_ghost_grads
525
- per_item_mse_loss_ghost_resid = self.mse_loss_fn(ghost_out, residual.detach())
526
- mse_rescaling_factor = (
527
- per_item_mse_loss / (per_item_mse_loss_ghost_resid + 1e-6)
528
- ).detach()
529
- per_item_mse_loss_ghost_resid = (
530
- mse_rescaling_factor * per_item_mse_loss_ghost_resid
531
- )
532
-
533
- return per_item_mse_loss_ghost_resid.mean()
534
-
535
- @torch.no_grad()
536
- def _get_mse_loss_fn(self) -> Any:
537
- def standard_mse_loss_fn(
538
- preds: torch.Tensor, target: torch.Tensor
539
- ) -> torch.Tensor:
540
- return torch.nn.functional.mse_loss(preds, target, reduction="none")
541
-
542
- def batch_norm_mse_loss_fn(
543
- preds: torch.Tensor, target: torch.Tensor
544
- ) -> torch.Tensor:
545
- target_centered = target - target.mean(dim=0, keepdim=True)
546
- normalization = target_centered.norm(dim=-1, keepdim=True)
547
- return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
548
- normalization + 1e-6
549
- )
550
-
551
- if self.cfg.mse_loss_normalization == "dense_batch":
552
- return batch_norm_mse_loss_fn
553
- return standard_mse_loss_fn
554
-
555
- def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
556
- if self.cfg.architecture == "jumprelu" and "log_threshold" in state_dict:
557
- threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
558
- del state_dict["log_threshold"]
559
- state_dict["threshold"] = threshold
560
-
561
- def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
562
- if self.cfg.architecture == "jumprelu" and "threshold" in state_dict:
563
- threshold = state_dict["threshold"]
564
- del state_dict["threshold"]
565
- state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
566
-
567
- @classmethod
568
- @deprecated("Use load_from_disk instead")
569
- def load_from_pretrained(
570
- cls, path: str, device: str = "cpu", dtype: str | None = None
571
- ) -> "TrainingSAE":
572
- return cls.load_from_disk(path, device, dtype)
573
-
574
- @classmethod
575
- def load_from_disk(
576
- cls,
577
- path: str,
578
- device: str = "cpu",
579
- dtype: str | None = None,
580
- converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
581
- ) -> "TrainingSAE":
582
- overrides = {"dtype": dtype} if dtype is not None else None
583
- cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
584
- cfg_dict = handle_config_defaulting(cfg_dict)
585
- sae_cfg = TrainingSAEConfig.from_dict(cfg_dict)
586
- sae = cls(sae_cfg)
587
- sae.process_state_dict_for_loading(state_dict)
588
- sae.load_state_dict(state_dict)
589
- return sae
590
-
591
- def initialize_weights_complex(self):
592
- """ """
593
-
594
- if self.cfg.decoder_orthogonal_init:
595
- self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
596
-
597
- elif self.cfg.decoder_heuristic_init:
598
- self.W_dec = nn.Parameter(
599
- torch.randn(
600
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
601
- )
602
- )
603
- self.initialize_decoder_norm_constant_norm(
604
- self.cfg.decoder_heuristic_init_norm
605
- )
606
-
607
- # Then we initialize the encoder weights (either as the transpose of decoder or not)
608
- if self.cfg.init_encoder_as_decoder_transpose:
609
- self.W_enc.data = self.W_dec.data.T.clone().contiguous()
610
- else:
611
- self.W_enc = nn.Parameter(
612
- torch.nn.init.kaiming_uniform_(
613
- torch.empty(
614
- self.cfg.d_in,
615
- self.cfg.d_sae,
616
- dtype=self.dtype,
617
- device=self.device,
618
- )
619
- )
620
- )
621
-
622
- if self.cfg.normalize_sae_decoder:
623
- with torch.no_grad():
624
- # Anthropic normalize this to have unit columns
625
- self.set_decoder_norm_to_unit_norm()
626
-
627
- @torch.no_grad()
628
- def fold_W_dec_norm(self):
629
- # need to deal with the jumprelu having a log_threshold in training
630
- if self.cfg.architecture == "jumprelu":
631
- cur_threshold = self.threshold.clone()
632
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
633
- super().fold_W_dec_norm()
634
- self.log_threshold.data = torch.log(cur_threshold * W_dec_norms.squeeze())
635
- else:
636
- super().fold_W_dec_norm()
637
-
638
- ## Initialization Methods
639
- @torch.no_grad()
640
- def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
641
- out = torch.tensor(origin, dtype=self.dtype, device=self.device)
642
- self.b_dec.data = out
643
-
644
- @torch.no_grad()
645
- def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
646
- previous_b_dec = self.b_dec.clone().cpu()
647
- out = all_activations.mean(dim=0)
648
-
649
- previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
650
- distances = torch.norm(all_activations - out, dim=-1)
651
-
652
- logger.info("Reinitializing b_dec with mean of activations")
653
- logger.debug(
654
- f"Previous distances: {previous_distances.median(0).values.mean().item()}"
655
- )
656
- logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
657
-
658
- self.b_dec.data = out.to(self.dtype).to(self.device)
659
-
660
- ## Training Utils
661
- @torch.no_grad()
662
- def set_decoder_norm_to_unit_norm(self):
663
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
664
-
665
- @torch.no_grad()
666
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
667
- """
668
- A heuristic proceedure inspired by:
669
- https://transformer-circuits.pub/2024/april-update/index.html#training-saes
670
- """
671
- # TODO: Parameterise this as a function of m and n
672
-
673
- # ensure W_dec norms at unit norm
674
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
675
- self.W_dec.data *= norm # will break tests but do this for now.
676
-
677
- @torch.no_grad()
678
- def remove_gradient_parallel_to_decoder_directions(self):
679
- """
680
- Update grads so that they remove the parallel component
681
- (d_sae, d_in) shape
682
- """
683
- assert self.W_dec.grad is not None # keep pyright happy
684
-
685
- parallel_component = einops.einsum(
686
- self.W_dec.grad,
687
- self.W_dec.data,
688
- "d_sae d_in, d_sae d_in -> d_sae",
689
- )
690
- self.W_dec.grad -= einops.einsum(
691
- parallel_component,
692
- self.W_dec.data,
693
- "d_sae, d_sae d_in -> d_sae d_in",
694
- )
695
-
696
-
697
- def _calculate_topk_aux_acts(
698
- k_aux: int,
699
- hidden_pre: torch.Tensor,
700
- dead_neuron_mask: torch.Tensor,
701
- ) -> torch.Tensor:
702
- # Don't include living latents in this loss
703
- auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
704
- # Top-k dead latents
705
- auxk_topk = auxk_latents.topk(k_aux, sorted=False)
706
- # Set the activations to zero for all but the top k_aux dead latents
707
- auxk_acts = torch.zeros_like(hidden_pre)
708
- auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
709
- # Set activations to zero for all but top k_aux dead latents
710
- return auxk_acts