sae-lens 5.9.0__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,705 +0,0 @@
1
- """Most of this is just copied over from Arthur's code and slightly simplified:
2
- https://github.com/ArthurConmy/sae/blob/main/sae/model.py
3
- """
4
-
5
- from dataclasses import dataclass, fields
6
- from typing import Any
7
-
8
- import einops
9
- import numpy as np
10
- import torch
11
- from jaxtyping import Float
12
- from torch import nn
13
- from typing_extensions import deprecated
14
-
15
- from sae_lens import logger
16
- from sae_lens.config import LanguageModelSAERunnerConfig
17
- from sae_lens.sae import SAE, SAEConfig
18
- from sae_lens.toolkit.pretrained_sae_loaders import (
19
- PretrainedSaeDiskLoader,
20
- handle_config_defaulting,
21
- sae_lens_disk_loader,
22
- )
23
-
24
- SPARSITY_PATH = "sparsity.safetensors"
25
- SAE_WEIGHTS_PATH = "sae_weights.safetensors"
26
- SAE_CFG_PATH = "cfg.json"
27
-
28
-
29
- def rectangle(x: torch.Tensor) -> torch.Tensor:
30
- return ((x > -0.5) & (x < 0.5)).to(x)
31
-
32
-
33
- class Step(torch.autograd.Function):
34
- @staticmethod
35
- def forward(
36
- x: torch.Tensor,
37
- threshold: torch.Tensor,
38
- bandwidth: float, # noqa: ARG004
39
- ) -> torch.Tensor:
40
- return (x > threshold).to(x)
41
-
42
- @staticmethod
43
- def setup_context(
44
- ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
45
- ) -> None:
46
- x, threshold, bandwidth = inputs
47
- del output
48
- ctx.save_for_backward(x, threshold)
49
- ctx.bandwidth = bandwidth
50
-
51
- @staticmethod
52
- def backward( # type: ignore[override]
53
- ctx: Any, grad_output: torch.Tensor
54
- ) -> tuple[None, torch.Tensor, None]:
55
- x, threshold = ctx.saved_tensors
56
- bandwidth = ctx.bandwidth
57
- threshold_grad = torch.sum(
58
- -(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
59
- dim=0,
60
- )
61
- return None, threshold_grad, None
62
-
63
-
64
- class JumpReLU(torch.autograd.Function):
65
- @staticmethod
66
- def forward(
67
- x: torch.Tensor,
68
- threshold: torch.Tensor,
69
- bandwidth: float, # noqa: ARG004
70
- ) -> torch.Tensor:
71
- return (x * (x > threshold)).to(x)
72
-
73
- @staticmethod
74
- def setup_context(
75
- ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
76
- ) -> None:
77
- x, threshold, bandwidth = inputs
78
- del output
79
- ctx.save_for_backward(x, threshold)
80
- ctx.bandwidth = bandwidth
81
-
82
- @staticmethod
83
- def backward( # type: ignore[override]
84
- ctx: Any, grad_output: torch.Tensor
85
- ) -> tuple[torch.Tensor, torch.Tensor, None]:
86
- x, threshold = ctx.saved_tensors
87
- bandwidth = ctx.bandwidth
88
- x_grad = (x > threshold) * grad_output # We don't apply STE to x input
89
- threshold_grad = torch.sum(
90
- -(threshold / bandwidth)
91
- * rectangle((x - threshold) / bandwidth)
92
- * grad_output,
93
- dim=0,
94
- )
95
- return x_grad, threshold_grad, None
96
-
97
-
98
- @dataclass
99
- class TrainStepOutput:
100
- sae_in: torch.Tensor
101
- sae_out: torch.Tensor
102
- feature_acts: torch.Tensor
103
- hidden_pre: torch.Tensor
104
- loss: torch.Tensor # we need to call backwards on this
105
- losses: dict[str, float | torch.Tensor]
106
-
107
-
108
- @dataclass(kw_only=True)
109
- class TrainingSAEConfig(SAEConfig):
110
- # Sparsity Loss Calculations
111
- l1_coefficient: float
112
- lp_norm: float
113
- use_ghost_grads: bool
114
- normalize_sae_decoder: bool
115
- noise_scale: float
116
- decoder_orthogonal_init: bool
117
- mse_loss_normalization: str | None
118
- jumprelu_init_threshold: float
119
- jumprelu_bandwidth: float
120
- decoder_heuristic_init: bool
121
- init_encoder_as_decoder_transpose: bool
122
- scale_sparsity_penalty_by_decoder_norm: bool
123
-
124
- @classmethod
125
- def from_sae_runner_config(
126
- cls, cfg: LanguageModelSAERunnerConfig
127
- ) -> "TrainingSAEConfig":
128
- return cls(
129
- # base config
130
- architecture=cfg.architecture,
131
- d_in=cfg.d_in,
132
- d_sae=cfg.d_sae, # type: ignore
133
- dtype=cfg.dtype,
134
- device=cfg.device,
135
- model_name=cfg.model_name,
136
- hook_name=cfg.hook_name,
137
- hook_layer=cfg.hook_layer,
138
- hook_head_index=cfg.hook_head_index,
139
- activation_fn_str=cfg.activation_fn,
140
- activation_fn_kwargs=cfg.activation_fn_kwargs,
141
- apply_b_dec_to_input=cfg.apply_b_dec_to_input,
142
- finetuning_scaling_factor=cfg.finetuning_method is not None,
143
- sae_lens_training_version=cfg.sae_lens_training_version,
144
- context_size=cfg.context_size,
145
- dataset_path=cfg.dataset_path,
146
- prepend_bos=cfg.prepend_bos,
147
- seqpos_slice=cfg.seqpos_slice,
148
- # Training cfg
149
- l1_coefficient=cfg.l1_coefficient,
150
- lp_norm=cfg.lp_norm,
151
- use_ghost_grads=cfg.use_ghost_grads,
152
- normalize_sae_decoder=cfg.normalize_sae_decoder,
153
- noise_scale=cfg.noise_scale,
154
- decoder_orthogonal_init=cfg.decoder_orthogonal_init,
155
- mse_loss_normalization=cfg.mse_loss_normalization,
156
- decoder_heuristic_init=cfg.decoder_heuristic_init,
157
- init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
158
- scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
159
- normalize_activations=cfg.normalize_activations,
160
- dataset_trust_remote_code=cfg.dataset_trust_remote_code,
161
- model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
162
- jumprelu_init_threshold=cfg.jumprelu_init_threshold,
163
- jumprelu_bandwidth=cfg.jumprelu_bandwidth,
164
- )
165
-
166
- @classmethod
167
- def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
168
- # remove any keys that are not in the dataclass
169
- # since we sometimes enhance the config with the whole LM runner config
170
- valid_field_names = {field.name for field in fields(cls)}
171
- valid_config_dict = {
172
- key: val for key, val in config_dict.items() if key in valid_field_names
173
- }
174
-
175
- # ensure seqpos slice is tuple
176
- # ensure that seqpos slices is a tuple
177
- # Ensure seqpos_slice is a tuple
178
- if "seqpos_slice" in valid_config_dict:
179
- if isinstance(valid_config_dict["seqpos_slice"], list):
180
- valid_config_dict["seqpos_slice"] = tuple(
181
- valid_config_dict["seqpos_slice"]
182
- )
183
- elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
184
- valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
185
-
186
- return TrainingSAEConfig(**valid_config_dict)
187
-
188
- def to_dict(self) -> dict[str, Any]:
189
- return {
190
- **super().to_dict(),
191
- "l1_coefficient": self.l1_coefficient,
192
- "lp_norm": self.lp_norm,
193
- "use_ghost_grads": self.use_ghost_grads,
194
- "normalize_sae_decoder": self.normalize_sae_decoder,
195
- "noise_scale": self.noise_scale,
196
- "decoder_orthogonal_init": self.decoder_orthogonal_init,
197
- "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
198
- "mse_loss_normalization": self.mse_loss_normalization,
199
- "decoder_heuristic_init": self.decoder_heuristic_init,
200
- "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
201
- "normalize_activations": self.normalize_activations,
202
- "jumprelu_init_threshold": self.jumprelu_init_threshold,
203
- "jumprelu_bandwidth": self.jumprelu_bandwidth,
204
- }
205
-
206
- # this needs to exist so we can initialize the parent sae cfg without the training specific
207
- # parameters. Maybe there's a cleaner way to do this
208
- def get_base_sae_cfg_dict(self) -> dict[str, Any]:
209
- return {
210
- "architecture": self.architecture,
211
- "d_in": self.d_in,
212
- "d_sae": self.d_sae,
213
- "activation_fn_str": self.activation_fn_str,
214
- "activation_fn_kwargs": self.activation_fn_kwargs,
215
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
216
- "dtype": self.dtype,
217
- "model_name": self.model_name,
218
- "hook_name": self.hook_name,
219
- "hook_layer": self.hook_layer,
220
- "hook_head_index": self.hook_head_index,
221
- "device": self.device,
222
- "context_size": self.context_size,
223
- "prepend_bos": self.prepend_bos,
224
- "finetuning_scaling_factor": self.finetuning_scaling_factor,
225
- "normalize_activations": self.normalize_activations,
226
- "dataset_path": self.dataset_path,
227
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
228
- "sae_lens_training_version": self.sae_lens_training_version,
229
- }
230
-
231
-
232
- class TrainingSAE(SAE):
233
- """
234
- A SAE used for training. This class provides a `training_forward_pass` method which calculates
235
- losses used for training.
236
- """
237
-
238
- cfg: TrainingSAEConfig
239
- use_error_term: bool
240
- dtype: torch.dtype
241
- device: torch.device
242
-
243
- def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
244
- base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict())
245
- super().__init__(base_sae_cfg)
246
- self.cfg = cfg # type: ignore
247
-
248
- if cfg.architecture == "standard" or cfg.architecture == "topk":
249
- self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre
250
- elif cfg.architecture == "gated":
251
- self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated
252
- elif cfg.architecture == "jumprelu":
253
- self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_jumprelu
254
- self.bandwidth = cfg.jumprelu_bandwidth
255
- self.log_threshold.data = torch.ones(
256
- self.cfg.d_sae, dtype=self.dtype, device=self.device
257
- ) * np.log(cfg.jumprelu_init_threshold)
258
-
259
- else:
260
- raise ValueError(f"Unknown architecture: {cfg.architecture}")
261
-
262
- self.check_cfg_compatibility()
263
-
264
- self.use_error_term = use_error_term
265
-
266
- self.initialize_weights_complex()
267
-
268
- # The training SAE will assume that the activation store handles
269
- # reshaping.
270
- self.turn_off_forward_pass_hook_z_reshaping()
271
-
272
- self.mse_loss_fn = self._get_mse_loss_fn()
273
-
274
- def initialize_weights_jumprelu(self):
275
- # same as the superclass, except we use a log_threshold parameter instead of threshold
276
- self.log_threshold = nn.Parameter(
277
- torch.empty(self.cfg.d_sae, dtype=self.dtype, device=self.device)
278
- )
279
- self.initialize_weights_basic()
280
-
281
- @property
282
- def threshold(self) -> torch.Tensor:
283
- if self.cfg.architecture != "jumprelu":
284
- raise ValueError("Threshold is only defined for Jumprelu SAEs")
285
- return torch.exp(self.log_threshold)
286
-
287
- @classmethod
288
- def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE":
289
- return cls(TrainingSAEConfig.from_dict(config_dict))
290
-
291
- def check_cfg_compatibility(self):
292
- if self.cfg.architecture != "standard" and self.cfg.use_ghost_grads:
293
- raise ValueError(f"{self.cfg.architecture} SAEs do not support ghost grads")
294
- if self.cfg.architecture == "gated" and self.use_error_term:
295
- raise ValueError("Gated SAEs do not support error terms")
296
-
297
- def encode_standard(
298
- self, x: Float[torch.Tensor, "... d_in"]
299
- ) -> Float[torch.Tensor, "... d_sae"]:
300
- """
301
- Calcuate SAE features from inputs
302
- """
303
- feature_acts, _ = self.encode_with_hidden_pre_fn(x)
304
- return feature_acts
305
-
306
- def encode_with_hidden_pre_jumprelu(
307
- self, x: Float[torch.Tensor, "... d_in"]
308
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
309
- sae_in = self.process_sae_in(x)
310
-
311
- hidden_pre = sae_in @ self.W_enc + self.b_enc
312
-
313
- if self.training:
314
- hidden_pre = (
315
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
316
- )
317
-
318
- threshold = torch.exp(self.log_threshold)
319
-
320
- feature_acts = JumpReLU.apply(hidden_pre, threshold, self.bandwidth)
321
-
322
- return feature_acts, hidden_pre # type: ignore
323
-
324
- def encode_with_hidden_pre(
325
- self, x: Float[torch.Tensor, "... d_in"]
326
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
327
- sae_in = self.process_sae_in(x)
328
-
329
- # "... d_in, d_in d_sae -> ... d_sae",
330
- hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
331
- hidden_pre_noised = hidden_pre + (
332
- torch.randn_like(hidden_pre) * self.cfg.noise_scale * self.training
333
- )
334
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
335
-
336
- return feature_acts, hidden_pre_noised
337
-
338
- def encode_with_hidden_pre_gated(
339
- self, x: Float[torch.Tensor, "... d_in"]
340
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
341
- sae_in = self.process_sae_in(x)
342
-
343
- # Gating path with Heaviside step function
344
- gating_pre_activation = sae_in @ self.W_enc + self.b_gate
345
- active_features = (gating_pre_activation > 0).to(self.dtype)
346
-
347
- # Magnitude path with weight sharing
348
- magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
349
- # magnitude_pre_activation_noised = magnitude_pre_activation + (
350
- # torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale * self.training
351
- # )
352
- feature_magnitudes = self.activation_fn(
353
- magnitude_pre_activation
354
- ) # magnitude_pre_activation_noised)
355
-
356
- # Return both the gated feature activations and the magnitude pre-activations
357
- return (
358
- active_features * feature_magnitudes,
359
- magnitude_pre_activation,
360
- ) # magnitude_pre_activation_noised
361
-
362
- def forward(
363
- self,
364
- x: Float[torch.Tensor, "... d_in"],
365
- ) -> Float[torch.Tensor, "... d_in"]:
366
- feature_acts, _ = self.encode_with_hidden_pre_fn(x)
367
- return self.decode(feature_acts)
368
-
369
- def training_forward_pass(
370
- self,
371
- sae_in: torch.Tensor,
372
- current_l1_coefficient: float,
373
- dead_neuron_mask: torch.Tensor | None = None,
374
- ) -> TrainStepOutput:
375
- # do a forward pass to get SAE out, but we also need the
376
- # hidden pre.
377
- feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
378
- sae_out = self.decode(feature_acts)
379
-
380
- # MSE LOSS
381
- per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in)
382
- mse_loss = per_item_mse_loss.sum(dim=-1).mean()
383
-
384
- losses: dict[str, float | torch.Tensor] = {}
385
-
386
- if self.cfg.architecture == "gated":
387
- # Gated SAE Loss Calculation
388
-
389
- # Shared variables
390
- sae_in_centered = (
391
- self.reshape_fn_in(sae_in) - self.b_dec * self.cfg.apply_b_dec_to_input
392
- )
393
- pi_gate = sae_in_centered @ self.W_enc + self.b_gate
394
- pi_gate_act = torch.relu(pi_gate)
395
-
396
- # SFN sparsity loss - summed over the feature dimension and averaged over the batch
397
- l1_loss = (
398
- current_l1_coefficient
399
- * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
400
- )
401
-
402
- # Auxiliary reconstruction loss - summed over the feature dimension and averaged over the batch
403
- via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
404
- aux_reconstruction_loss = torch.sum(
405
- (via_gate_reconstruction - sae_in) ** 2, dim=-1
406
- ).mean()
407
- loss = mse_loss + l1_loss + aux_reconstruction_loss
408
- losses["auxiliary_reconstruction_loss"] = aux_reconstruction_loss
409
- losses["l1_loss"] = l1_loss
410
- elif self.cfg.architecture == "jumprelu":
411
- threshold = torch.exp(self.log_threshold)
412
- l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore
413
- l0_loss = (current_l1_coefficient * l0).mean()
414
- loss = mse_loss + l0_loss
415
- losses["l0_loss"] = l0_loss
416
- elif self.cfg.architecture == "topk":
417
- topk_loss = self.calculate_topk_aux_loss(
418
- sae_in=sae_in,
419
- sae_out=sae_out,
420
- hidden_pre=hidden_pre,
421
- dead_neuron_mask=dead_neuron_mask,
422
- )
423
- losses["auxiliary_reconstruction_loss"] = topk_loss
424
- loss = mse_loss + topk_loss
425
- else:
426
- # default SAE sparsity loss
427
- weighted_feature_acts = feature_acts
428
- if self.cfg.scale_sparsity_penalty_by_decoder_norm:
429
- weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
430
- sparsity = weighted_feature_acts.norm(
431
- p=self.cfg.lp_norm, dim=-1
432
- ) # sum over the feature dimension
433
-
434
- l1_loss = (current_l1_coefficient * sparsity).mean()
435
- loss = mse_loss + l1_loss
436
- if (
437
- self.cfg.use_ghost_grads
438
- and self.training
439
- and dead_neuron_mask is not None
440
- ):
441
- ghost_grad_loss = self.calculate_ghost_grad_loss(
442
- x=sae_in,
443
- sae_out=sae_out,
444
- per_item_mse_loss=per_item_mse_loss,
445
- hidden_pre=hidden_pre,
446
- dead_neuron_mask=dead_neuron_mask,
447
- )
448
- losses["ghost_grad_loss"] = ghost_grad_loss
449
- loss = loss + ghost_grad_loss
450
- losses["l1_loss"] = l1_loss
451
-
452
- losses["mse_loss"] = mse_loss
453
-
454
- return TrainStepOutput(
455
- sae_in=sae_in,
456
- sae_out=sae_out,
457
- feature_acts=feature_acts,
458
- hidden_pre=hidden_pre,
459
- loss=loss,
460
- losses=losses,
461
- )
462
-
463
- def calculate_topk_aux_loss(
464
- self,
465
- sae_in: torch.Tensor,
466
- sae_out: torch.Tensor,
467
- hidden_pre: torch.Tensor,
468
- dead_neuron_mask: torch.Tensor | None,
469
- ) -> torch.Tensor:
470
- # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
471
- # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
472
- if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
473
- return sae_out.new_tensor(0.0)
474
- residual = (sae_in - sae_out).detach()
475
-
476
- # Heuristic from Appendix B.1 in the paper
477
- k_aux = sae_in.shape[-1] // 2
478
-
479
- # Reduce the scale of the loss if there are a small number of dead latents
480
- scale = min(num_dead / k_aux, 1.0)
481
- k_aux = min(k_aux, num_dead)
482
-
483
- auxk_acts = _calculate_topk_aux_acts(
484
- k_aux=k_aux,
485
- hidden_pre=hidden_pre,
486
- dead_neuron_mask=dead_neuron_mask,
487
- )
488
-
489
- # Encourage the top ~50% of dead latents to predict the residual of the
490
- # top k living latents
491
- recons = self.decode(auxk_acts)
492
- auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
493
- return scale * auxk_loss
494
-
495
- def calculate_ghost_grad_loss(
496
- self,
497
- x: torch.Tensor,
498
- sae_out: torch.Tensor,
499
- per_item_mse_loss: torch.Tensor,
500
- hidden_pre: torch.Tensor,
501
- dead_neuron_mask: torch.Tensor,
502
- ) -> torch.Tensor:
503
- # 1.
504
- residual = x - sae_out
505
- l2_norm_residual = torch.norm(residual, dim=-1)
506
-
507
- # 2.
508
- # ghost grads use an exponentional activation function, ignoring whatever
509
- # the activation function is in the SAE. The forward pass uses the dead neurons only.
510
- feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
511
- ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
512
- l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
513
- norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
514
- ghost_out = ghost_out * norm_scaling_factor[:, None].detach()
515
-
516
- # 3. There is some fairly complex rescaling here to make sure that the loss
517
- # is comparable to the original loss. This is because the ghost grads are
518
- # only calculated for the dead neurons, so we need to rescale the loss to
519
- # make sure that the loss is comparable to the original loss.
520
- # There have been methodological improvements that are not implemented here yet
521
- # see here: https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Improving_ghost_grads
522
- per_item_mse_loss_ghost_resid = self.mse_loss_fn(ghost_out, residual.detach())
523
- mse_rescaling_factor = (
524
- per_item_mse_loss / (per_item_mse_loss_ghost_resid + 1e-6)
525
- ).detach()
526
- per_item_mse_loss_ghost_resid = (
527
- mse_rescaling_factor * per_item_mse_loss_ghost_resid
528
- )
529
-
530
- return per_item_mse_loss_ghost_resid.mean()
531
-
532
- @torch.no_grad()
533
- def _get_mse_loss_fn(self) -> Any:
534
- def standard_mse_loss_fn(
535
- preds: torch.Tensor, target: torch.Tensor
536
- ) -> torch.Tensor:
537
- return torch.nn.functional.mse_loss(preds, target, reduction="none")
538
-
539
- def batch_norm_mse_loss_fn(
540
- preds: torch.Tensor, target: torch.Tensor
541
- ) -> torch.Tensor:
542
- target_centered = target - target.mean(dim=0, keepdim=True)
543
- normalization = target_centered.norm(dim=-1, keepdim=True)
544
- return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
545
- normalization + 1e-6
546
- )
547
-
548
- if self.cfg.mse_loss_normalization == "dense_batch":
549
- return batch_norm_mse_loss_fn
550
- return standard_mse_loss_fn
551
-
552
- def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
553
- if self.cfg.architecture == "jumprelu" and "log_threshold" in state_dict:
554
- threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
555
- del state_dict["log_threshold"]
556
- state_dict["threshold"] = threshold
557
-
558
- def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
559
- if self.cfg.architecture == "jumprelu" and "threshold" in state_dict:
560
- threshold = state_dict["threshold"]
561
- del state_dict["threshold"]
562
- state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
563
-
564
- @classmethod
565
- @deprecated("Use load_from_disk instead")
566
- def load_from_pretrained(
567
- cls, path: str, device: str = "cpu", dtype: str | None = None
568
- ) -> "TrainingSAE":
569
- return cls.load_from_disk(path, device, dtype)
570
-
571
- @classmethod
572
- def load_from_disk(
573
- cls,
574
- path: str,
575
- device: str = "cpu",
576
- dtype: str | None = None,
577
- converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
578
- ) -> "TrainingSAE":
579
- overrides = {"dtype": dtype} if dtype is not None else None
580
- cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
581
- cfg_dict = handle_config_defaulting(cfg_dict)
582
- sae_cfg = TrainingSAEConfig.from_dict(cfg_dict)
583
- sae = cls(sae_cfg)
584
- sae.process_state_dict_for_loading(state_dict)
585
- sae.load_state_dict(state_dict)
586
- return sae
587
-
588
- def initialize_weights_complex(self):
589
- """ """
590
-
591
- if self.cfg.decoder_orthogonal_init:
592
- self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
593
-
594
- elif self.cfg.decoder_heuristic_init:
595
- self.W_dec = nn.Parameter(
596
- torch.rand(
597
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
598
- )
599
- )
600
- self.initialize_decoder_norm_constant_norm()
601
-
602
- # Then we initialize the encoder weights (either as the transpose of decoder or not)
603
- if self.cfg.init_encoder_as_decoder_transpose:
604
- self.W_enc.data = self.W_dec.data.T.clone().contiguous()
605
- else:
606
- self.W_enc = nn.Parameter(
607
- torch.nn.init.kaiming_uniform_(
608
- torch.empty(
609
- self.cfg.d_in,
610
- self.cfg.d_sae,
611
- dtype=self.dtype,
612
- device=self.device,
613
- )
614
- )
615
- )
616
-
617
- if self.cfg.normalize_sae_decoder:
618
- with torch.no_grad():
619
- # Anthropic normalize this to have unit columns
620
- self.set_decoder_norm_to_unit_norm()
621
-
622
- @torch.no_grad()
623
- def fold_W_dec_norm(self):
624
- # need to deal with the jumprelu having a log_threshold in training
625
- if self.cfg.architecture == "jumprelu":
626
- cur_threshold = self.threshold.clone()
627
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
628
- super().fold_W_dec_norm()
629
- self.log_threshold.data = torch.log(cur_threshold * W_dec_norms.squeeze())
630
- else:
631
- super().fold_W_dec_norm()
632
-
633
- ## Initialization Methods
634
- @torch.no_grad()
635
- def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
636
- out = torch.tensor(origin, dtype=self.dtype, device=self.device)
637
- self.b_dec.data = out
638
-
639
- @torch.no_grad()
640
- def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
641
- previous_b_dec = self.b_dec.clone().cpu()
642
- out = all_activations.mean(dim=0)
643
-
644
- previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
645
- distances = torch.norm(all_activations - out, dim=-1)
646
-
647
- logger.info("Reinitializing b_dec with mean of activations")
648
- logger.debug(
649
- f"Previous distances: {previous_distances.median(0).values.mean().item()}"
650
- )
651
- logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
652
-
653
- self.b_dec.data = out.to(self.dtype).to(self.device)
654
-
655
- ## Training Utils
656
- @torch.no_grad()
657
- def set_decoder_norm_to_unit_norm(self):
658
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
659
-
660
- @torch.no_grad()
661
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
662
- """
663
- A heuristic proceedure inspired by:
664
- https://transformer-circuits.pub/2024/april-update/index.html#training-saes
665
- """
666
- # TODO: Parameterise this as a function of m and n
667
-
668
- # ensure W_dec norms at unit norm
669
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
670
- self.W_dec.data *= norm # will break tests but do this for now.
671
-
672
- @torch.no_grad()
673
- def remove_gradient_parallel_to_decoder_directions(self):
674
- """
675
- Update grads so that they remove the parallel component
676
- (d_sae, d_in) shape
677
- """
678
- assert self.W_dec.grad is not None # keep pyright happy
679
-
680
- parallel_component = einops.einsum(
681
- self.W_dec.grad,
682
- self.W_dec.data,
683
- "d_sae d_in, d_sae d_in -> d_sae",
684
- )
685
- self.W_dec.grad -= einops.einsum(
686
- parallel_component,
687
- self.W_dec.data,
688
- "d_sae, d_sae d_in -> d_sae d_in",
689
- )
690
-
691
-
692
- def _calculate_topk_aux_acts(
693
- k_aux: int,
694
- hidden_pre: torch.Tensor,
695
- dead_neuron_mask: torch.Tensor,
696
- ) -> torch.Tensor:
697
- # Don't include living latents in this loss
698
- auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
699
- # Top-k dead latents
700
- auxk_topk = auxk_latents.topk(k_aux, sorted=False)
701
- # Set the activations to zero for all but the top k_aux dead latents
702
- auxk_acts = torch.zeros_like(hidden_pre)
703
- auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
704
- # Set activations to zero for all but top k_aux dead latents
705
- return auxk_acts