sae-lens 5.9.1__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/sae.py DELETED
@@ -1,747 +0,0 @@
1
- """Most of this is just copied over from Arthur's code and slightly simplified:
2
- https://github.com/ArthurConmy/sae/blob/main/sae/model.py
3
- """
4
-
5
- import json
6
- import warnings
7
- from contextlib import contextmanager
8
- from dataclasses import dataclass, field
9
- from pathlib import Path
10
- from typing import Any, Callable, Literal, TypeVar, overload
11
-
12
- import einops
13
- import torch
14
- from jaxtyping import Float
15
- from safetensors.torch import save_file
16
- from torch import nn
17
- from transformer_lens.hook_points import HookedRootModule, HookPoint
18
- from typing_extensions import deprecated
19
-
20
- from sae_lens.config import (
21
- DTYPE_MAP,
22
- SAE_CFG_FILENAME,
23
- SAE_WEIGHTS_FILENAME,
24
- SPARSITY_FILENAME,
25
- )
26
- from sae_lens.toolkit.pretrained_sae_loaders import (
27
- NAMED_PRETRAINED_SAE_LOADERS,
28
- PretrainedSaeDiskLoader,
29
- PretrainedSaeHuggingfaceLoader,
30
- get_conversion_loader_name,
31
- handle_config_defaulting,
32
- sae_lens_disk_loader,
33
- )
34
- from sae_lens.toolkit.pretrained_saes_directory import (
35
- get_config_overrides,
36
- get_norm_scaling_factor,
37
- get_pretrained_saes_directory,
38
- get_repo_id_and_folder_name,
39
- )
40
-
41
- T = TypeVar("T", bound="SAE")
42
-
43
-
44
- @dataclass
45
- class SAEConfig:
46
- # architecture details
47
- architecture: Literal["standard", "gated", "jumprelu", "topk"]
48
-
49
- # forward pass details.
50
- d_in: int
51
- d_sae: int
52
- activation_fn_str: str
53
- apply_b_dec_to_input: bool
54
- finetuning_scaling_factor: bool
55
-
56
- # dataset it was trained on details.
57
- context_size: int
58
- model_name: str
59
- hook_name: str
60
- hook_layer: int
61
- hook_head_index: int | None
62
- prepend_bos: bool
63
- dataset_path: str
64
- dataset_trust_remote_code: bool
65
- normalize_activations: str
66
-
67
- # misc
68
- dtype: str
69
- device: str
70
- sae_lens_training_version: str | None
71
- activation_fn_kwargs: dict[str, Any] = field(default_factory=dict)
72
- neuronpedia_id: str | None = None
73
- model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
74
- seqpos_slice: tuple[int | None, ...] = (None,)
75
-
76
- @classmethod
77
- def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
78
- # rename dict:
79
- rename_dict = { # old : new
80
- "hook_point": "hook_name",
81
- "hook_point_head_index": "hook_head_index",
82
- "hook_point_layer": "hook_layer",
83
- "activation_fn": "activation_fn_str",
84
- }
85
- config_dict = {rename_dict.get(k, k): v for k, v in config_dict.items()}
86
-
87
- # use only config terms that are in the dataclass
88
- config_dict = {
89
- k: v
90
- for k, v in config_dict.items()
91
- if k in cls.__dataclass_fields__ # pylint: disable=no-member
92
- }
93
-
94
- if "seqpos_slice" in config_dict:
95
- config_dict["seqpos_slice"] = tuple(config_dict["seqpos_slice"])
96
-
97
- return cls(**config_dict)
98
-
99
- # def __post_init__(self):
100
-
101
- def to_dict(self) -> dict[str, Any]:
102
- return {
103
- "architecture": self.architecture,
104
- "d_in": self.d_in,
105
- "d_sae": self.d_sae,
106
- "dtype": self.dtype,
107
- "device": self.device,
108
- "model_name": self.model_name,
109
- "hook_name": self.hook_name,
110
- "hook_layer": self.hook_layer,
111
- "hook_head_index": self.hook_head_index,
112
- "activation_fn_str": self.activation_fn_str, # use string for serialization
113
- "activation_fn_kwargs": self.activation_fn_kwargs or {},
114
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
115
- "finetuning_scaling_factor": self.finetuning_scaling_factor,
116
- "sae_lens_training_version": self.sae_lens_training_version,
117
- "prepend_bos": self.prepend_bos,
118
- "dataset_path": self.dataset_path,
119
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
120
- "context_size": self.context_size,
121
- "normalize_activations": self.normalize_activations,
122
- "neuronpedia_id": self.neuronpedia_id,
123
- "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
124
- "seqpos_slice": self.seqpos_slice,
125
- }
126
-
127
-
128
- class SAE(HookedRootModule):
129
- """
130
- Core Sparse Autoencoder (SAE) class used for inference. For training, see `TrainingSAE`.
131
- """
132
-
133
- cfg: SAEConfig
134
- dtype: torch.dtype
135
- device: torch.device
136
- x_norm_coeff: torch.Tensor
137
-
138
- # analysis
139
- use_error_term: bool
140
-
141
- def __init__(
142
- self,
143
- cfg: SAEConfig,
144
- use_error_term: bool = False,
145
- ):
146
- super().__init__()
147
-
148
- self.cfg = cfg
149
-
150
- if cfg.model_from_pretrained_kwargs:
151
- warnings.warn(
152
- "\nThis SAE has non-empty model_from_pretrained_kwargs. "
153
- "\nFor optimal performance, load the model like so:\n"
154
- "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
155
- category=UserWarning,
156
- stacklevel=1,
157
- )
158
-
159
- self.activation_fn = get_activation_fn(
160
- cfg.activation_fn_str, **cfg.activation_fn_kwargs or {}
161
- )
162
- self.dtype = DTYPE_MAP[cfg.dtype]
163
- self.device = torch.device(cfg.device)
164
- self.use_error_term = use_error_term
165
-
166
- if self.cfg.architecture == "standard" or self.cfg.architecture == "topk":
167
- self.initialize_weights_basic()
168
- self.encode = self.encode_standard
169
- elif self.cfg.architecture == "gated":
170
- self.initialize_weights_gated()
171
- self.encode = self.encode_gated
172
- elif self.cfg.architecture == "jumprelu":
173
- self.initialize_weights_jumprelu()
174
- self.encode = self.encode_jumprelu
175
- else:
176
- raise ValueError(f"Invalid architecture: {self.cfg.architecture}")
177
-
178
- # handle presence / absence of scaling factor.
179
- if self.cfg.finetuning_scaling_factor:
180
- self.apply_finetuning_scaling_factor = (
181
- lambda x: x * self.finetuning_scaling_factor
182
- )
183
- else:
184
- self.apply_finetuning_scaling_factor = lambda x: x
185
-
186
- # set up hooks
187
- self.hook_sae_input = HookPoint()
188
- self.hook_sae_acts_pre = HookPoint()
189
- self.hook_sae_acts_post = HookPoint()
190
- self.hook_sae_output = HookPoint()
191
- self.hook_sae_recons = HookPoint()
192
- self.hook_sae_error = HookPoint()
193
-
194
- # handle hook_z reshaping if needed.
195
- # this is very cursed and should be refactored. it exists so that we can reshape out
196
- # the z activations for hook_z SAEs. but don't know d_head if we split up the forward pass
197
- # into a separate encode and decode function.
198
- # this will cause errors if we call decode before encode.
199
- if self.cfg.hook_name.endswith("_z"):
200
- self.turn_on_forward_pass_hook_z_reshaping()
201
- else:
202
- # need to default the reshape fns
203
- self.turn_off_forward_pass_hook_z_reshaping()
204
-
205
- # handle run time activation normalization if needed:
206
- if self.cfg.normalize_activations == "constant_norm_rescale":
207
- # we need to scale the norm of the input and store the scaling factor
208
- def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
209
- self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
210
- return x * self.x_norm_coeff
211
-
212
- def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor: #
213
- x = x / self.x_norm_coeff
214
- del self.x_norm_coeff # prevents reusing
215
- return x
216
-
217
- self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
218
- self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
219
-
220
- elif self.cfg.normalize_activations == "layer_norm":
221
- # we need to scale the norm of the input and store the scaling factor
222
- def run_time_activation_ln_in(
223
- x: torch.Tensor, eps: float = 1e-5
224
- ) -> torch.Tensor:
225
- mu = x.mean(dim=-1, keepdim=True)
226
- x = x - mu
227
- std = x.std(dim=-1, keepdim=True)
228
- x = x / (std + eps)
229
- self.ln_mu = mu
230
- self.ln_std = std
231
- return x
232
-
233
- def run_time_activation_ln_out(
234
- x: torch.Tensor,
235
- eps: float = 1e-5, # noqa: ARG001
236
- ) -> torch.Tensor:
237
- return x * self.ln_std + self.ln_mu # type: ignore
238
-
239
- self.run_time_activation_norm_fn_in = run_time_activation_ln_in
240
- self.run_time_activation_norm_fn_out = run_time_activation_ln_out
241
- else:
242
- self.run_time_activation_norm_fn_in = lambda x: x
243
- self.run_time_activation_norm_fn_out = lambda x: x
244
-
245
- self.setup() # Required for `HookedRootModule`s
246
-
247
- def initialize_weights_basic(self):
248
- # no config changes encoder bias init for now.
249
- self.b_enc = nn.Parameter(
250
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
251
- )
252
-
253
- # Start with the default init strategy:
254
- self.W_dec = nn.Parameter(
255
- torch.nn.init.kaiming_uniform_(
256
- torch.empty(
257
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
258
- )
259
- )
260
- )
261
-
262
- self.W_enc = nn.Parameter(
263
- torch.nn.init.kaiming_uniform_(
264
- torch.empty(
265
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
266
- )
267
- )
268
- )
269
-
270
- # methdods which change b_dec as a function of the dataset are implemented after init.
271
- self.b_dec = nn.Parameter(
272
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
273
- )
274
-
275
- # scaling factor for fine-tuning (not to be used in initial training)
276
- # TODO: Make this optional and not included with all SAEs by default (but maintain backwards compatibility)
277
- if self.cfg.finetuning_scaling_factor:
278
- self.finetuning_scaling_factor = nn.Parameter(
279
- torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
280
- )
281
-
282
- def initialize_weights_gated(self):
283
- # Initialize the weights and biases for the gated encoder
284
- self.W_enc = nn.Parameter(
285
- torch.nn.init.kaiming_uniform_(
286
- torch.empty(
287
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
288
- )
289
- )
290
- )
291
-
292
- self.b_gate = nn.Parameter(
293
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
294
- )
295
-
296
- self.r_mag = nn.Parameter(
297
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
298
- )
299
-
300
- self.b_mag = nn.Parameter(
301
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
302
- )
303
-
304
- self.W_dec = nn.Parameter(
305
- torch.nn.init.kaiming_uniform_(
306
- torch.empty(
307
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
308
- )
309
- )
310
- )
311
-
312
- self.b_dec = nn.Parameter(
313
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
314
- )
315
-
316
- def initialize_weights_jumprelu(self):
317
- # The params are identical to the standard SAE
318
- # except we use a threshold parameter too
319
- self.threshold = nn.Parameter(
320
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
321
- )
322
- self.initialize_weights_basic()
323
-
324
- @overload
325
- def to(
326
- self: T,
327
- device: torch.device | str | None = ...,
328
- dtype: torch.dtype | None = ...,
329
- non_blocking: bool = ...,
330
- ) -> T: ...
331
-
332
- @overload
333
- def to(self: T, dtype: torch.dtype, non_blocking: bool = ...) -> T: ...
334
-
335
- @overload
336
- def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: ...
337
-
338
- def to(self, *args: Any, **kwargs: Any) -> "SAE": # type: ignore
339
- device_arg = None
340
- dtype_arg = None
341
-
342
- # Check args
343
- for arg in args:
344
- if isinstance(arg, (torch.device, str)):
345
- device_arg = arg
346
- elif isinstance(arg, torch.dtype):
347
- dtype_arg = arg
348
- elif isinstance(arg, torch.Tensor):
349
- device_arg = arg.device
350
- dtype_arg = arg.dtype
351
-
352
- # Check kwargs
353
- device_arg = kwargs.get("device", device_arg)
354
- dtype_arg = kwargs.get("dtype", dtype_arg)
355
-
356
- if device_arg is not None:
357
- # Convert device to torch.device if it's a string
358
- device = (
359
- torch.device(device_arg) if isinstance(device_arg, str) else device_arg
360
- )
361
-
362
- # Update the cfg.device
363
- self.cfg.device = str(device)
364
-
365
- # Update the .device property
366
- self.device = device
367
-
368
- if dtype_arg is not None:
369
- # Update the cfg.dtype
370
- self.cfg.dtype = str(dtype_arg)
371
-
372
- # Update the .dtype property
373
- self.dtype = dtype_arg
374
-
375
- # Call the parent class's to() method to handle all cases (device, dtype, tensor)
376
- return super().to(*args, **kwargs)
377
-
378
- # Basic Forward Pass Functionality.
379
- def forward(
380
- self,
381
- x: torch.Tensor,
382
- ) -> torch.Tensor:
383
- feature_acts = self.encode(x)
384
- sae_out = self.decode(feature_acts)
385
-
386
- # TEMP
387
- if self.use_error_term:
388
- with torch.no_grad():
389
- # Recompute everything without hooks to get true error term
390
- # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
391
- # This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
392
- # NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.
393
- with _disable_hooks(self):
394
- feature_acts_clean = self.encode(x)
395
- x_reconstruct_clean = self.decode(feature_acts_clean)
396
- sae_error = self.hook_sae_error(x - x_reconstruct_clean)
397
- sae_out = sae_out + sae_error
398
- return self.hook_sae_output(sae_out)
399
-
400
- def encode_gated(
401
- self, x: Float[torch.Tensor, "... d_in"]
402
- ) -> Float[torch.Tensor, "... d_sae"]:
403
- sae_in = self.process_sae_in(x)
404
-
405
- # Gating path
406
- gating_pre_activation = sae_in @ self.W_enc + self.b_gate
407
- active_features = (gating_pre_activation > 0).to(self.dtype)
408
-
409
- # Magnitude path with weight sharing
410
- magnitude_pre_activation = self.hook_sae_acts_pre(
411
- sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
412
- )
413
- feature_magnitudes = self.activation_fn(magnitude_pre_activation)
414
-
415
- return self.hook_sae_acts_post(active_features * feature_magnitudes)
416
-
417
- def encode_jumprelu(
418
- self, x: Float[torch.Tensor, "... d_in"]
419
- ) -> Float[torch.Tensor, "... d_sae"]:
420
- """
421
- Calculate SAE features from inputs
422
- """
423
- sae_in = self.process_sae_in(x)
424
-
425
- # "... d_in, d_in d_sae -> ... d_sae",
426
- hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
427
-
428
- return self.hook_sae_acts_post(
429
- self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
430
- )
431
-
432
- def encode_standard(
433
- self, x: Float[torch.Tensor, "... d_in"]
434
- ) -> Float[torch.Tensor, "... d_sae"]:
435
- """
436
- Calculate SAE features from inputs
437
- """
438
- sae_in = self.process_sae_in(x)
439
-
440
- # "... d_in, d_in d_sae -> ... d_sae",
441
- hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
442
- return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
443
-
444
- def process_sae_in(
445
- self, sae_in: Float[torch.Tensor, "... d_in"]
446
- ) -> Float[torch.Tensor, "... d_sae"]:
447
- sae_in = sae_in.to(self.dtype)
448
- sae_in = self.reshape_fn_in(sae_in)
449
- sae_in = self.hook_sae_input(sae_in)
450
- sae_in = self.run_time_activation_norm_fn_in(sae_in)
451
- return sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)
452
-
453
- def decode(
454
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
455
- ) -> Float[torch.Tensor, "... d_in"]:
456
- """Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
457
- # "... d_sae, d_sae d_in -> ... d_in",
458
- sae_out = self.hook_sae_recons(
459
- self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
460
- )
461
-
462
- # handle run time activation normalization if needed
463
- # will fail if you call this twice without calling encode in between.
464
- sae_out = self.run_time_activation_norm_fn_out(sae_out)
465
-
466
- # handle hook z reshaping if needed.
467
- return self.reshape_fn_out(sae_out, self.d_head) # type: ignore
468
-
469
- @torch.no_grad()
470
- def fold_W_dec_norm(self):
471
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
472
- self.W_dec.data = self.W_dec.data / W_dec_norms
473
- self.W_enc.data = self.W_enc.data * W_dec_norms.T
474
- if self.cfg.architecture == "gated":
475
- self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
476
- self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
477
- self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
478
- elif self.cfg.architecture == "jumprelu":
479
- self.threshold.data = self.threshold.data * W_dec_norms.squeeze()
480
- self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
481
- else:
482
- self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
483
-
484
- @torch.no_grad()
485
- def fold_activation_norm_scaling_factor(
486
- self, activation_norm_scaling_factor: float
487
- ):
488
- self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor
489
- # previously weren't doing this.
490
- self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor
491
- self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor
492
-
493
- # once we normalize, we shouldn't need to scale activations.
494
- self.cfg.normalize_activations = "none"
495
-
496
- @overload
497
- def save_model(self, path: str | Path) -> tuple[Path, Path]: ...
498
-
499
- @overload
500
- def save_model(
501
- self, path: str | Path, sparsity: torch.Tensor
502
- ) -> tuple[Path, Path, Path]: ...
503
-
504
- def save_model(self, path: str | Path, sparsity: torch.Tensor | None = None):
505
- path = Path(path)
506
-
507
- if not path.exists():
508
- path.mkdir(parents=True)
509
-
510
- # generate the weights
511
- state_dict = self.state_dict()
512
- self.process_state_dict_for_saving(state_dict)
513
- model_weights_path = path / SAE_WEIGHTS_FILENAME
514
- save_file(state_dict, model_weights_path)
515
-
516
- # save the config
517
- config = self.cfg.to_dict()
518
-
519
- cfg_path = path / SAE_CFG_FILENAME
520
- with open(cfg_path, "w") as f:
521
- json.dump(config, f)
522
-
523
- if sparsity is not None:
524
- sparsity_in_dict = {"sparsity": sparsity}
525
- sparsity_path = path / SPARSITY_FILENAME
526
- save_file(sparsity_in_dict, sparsity_path)
527
- return model_weights_path, cfg_path, sparsity_path
528
-
529
- return model_weights_path, cfg_path
530
-
531
- # overwrite this in subclasses to modify the state_dict in-place before saving
532
- def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
533
- pass
534
-
535
- # overwrite this in subclasses to modify the state_dict in-place after loading
536
- def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
537
- pass
538
-
539
- @classmethod
540
- @deprecated("Use load_from_disk instead")
541
- def load_from_pretrained(
542
- cls, path: str, device: str = "cpu", dtype: str | None = None
543
- ) -> "SAE":
544
- sae = cls.load_from_disk(path, device)
545
- if dtype is not None:
546
- sae.cfg.dtype = dtype
547
- sae = sae.to(dtype)
548
- return sae
549
-
550
- @classmethod
551
- def load_from_disk(
552
- cls,
553
- path: str,
554
- device: str = "cpu",
555
- dtype: str | None = None,
556
- converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
557
- ) -> "SAE":
558
- overrides = {"dtype": dtype} if dtype is not None else None
559
- cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
560
- cfg_dict = handle_config_defaulting(cfg_dict)
561
- sae_cfg = SAEConfig.from_dict(cfg_dict)
562
- sae = cls(sae_cfg)
563
- sae.process_state_dict_for_loading(state_dict)
564
- sae.load_state_dict(state_dict)
565
- return sae
566
-
567
- @classmethod
568
- def from_pretrained(
569
- cls,
570
- release: str,
571
- sae_id: str,
572
- device: str = "cpu",
573
- force_download: bool = False,
574
- converter: PretrainedSaeHuggingfaceLoader | None = None,
575
- ) -> tuple["SAE", dict[str, Any], torch.Tensor | None]:
576
- """
577
- Load a pretrained SAE from the Hugging Face model hub.
578
-
579
- Args:
580
- release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
581
- id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
582
- device: The device to load the SAE on.
583
- return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
584
- """
585
-
586
- # get sae directory
587
- sae_directory = get_pretrained_saes_directory()
588
-
589
- # get the repo id and path to the SAE
590
- if release not in sae_directory:
591
- if "/" not in release:
592
- raise ValueError(
593
- f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
594
- )
595
- elif sae_id not in sae_directory[release].saes_map:
596
- # If using Gemma Scope and not the canonical release, give a hint to use it
597
- if (
598
- "gemma-scope" in release
599
- and "canonical" not in release
600
- and f"{release}-canonical" in sae_directory
601
- ):
602
- canonical_ids = list(
603
- sae_directory[release + "-canonical"].saes_map.keys()
604
- )
605
- # Shorten the lengthy string of valid IDs
606
- if len(canonical_ids) > 5:
607
- str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
608
- else:
609
- str_canonical_ids = str(canonical_ids)
610
- value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
611
- else:
612
- value_suffix = ""
613
-
614
- valid_ids = list(sae_directory[release].saes_map.keys())
615
- # Shorten the lengthy string of valid IDs
616
- if len(valid_ids) > 5:
617
- str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
618
- else:
619
- str_valid_ids = str(valid_ids)
620
-
621
- raise ValueError(
622
- f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
623
- + value_suffix
624
- )
625
-
626
- conversion_loader = (
627
- converter
628
- or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
629
- )
630
- repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
631
- config_overrides = get_config_overrides(release, sae_id)
632
- config_overrides["device"] = device
633
-
634
- cfg_dict, state_dict, log_sparsities = conversion_loader(
635
- repo_id=repo_id,
636
- folder_name=folder_name,
637
- device=device,
638
- force_download=force_download,
639
- cfg_overrides=config_overrides,
640
- )
641
- cfg_dict = handle_config_defaulting(cfg_dict)
642
-
643
- sae = cls(SAEConfig.from_dict(cfg_dict))
644
- sae.process_state_dict_for_loading(state_dict)
645
- sae.load_state_dict(state_dict)
646
-
647
- # Check if normalization is 'expected_average_only_in'
648
- if cfg_dict.get("normalize_activations") == "expected_average_only_in":
649
- norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
650
- if norm_scaling_factor is not None:
651
- sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
652
- cfg_dict["normalize_activations"] = "none"
653
- else:
654
- warnings.warn(
655
- f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
656
- )
657
-
658
- return sae, cfg_dict, log_sparsities
659
-
660
- def get_name(self):
661
- return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
662
-
663
- @classmethod
664
- def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
665
- return cls(SAEConfig.from_dict(config_dict))
666
-
667
- def turn_on_forward_pass_hook_z_reshaping(self):
668
- if not self.cfg.hook_name.endswith("_z"):
669
- raise ValueError("This method should only be called for hook_z SAEs.")
670
-
671
- def reshape_fn_in(x: torch.Tensor):
672
- self.d_head = x.shape[-1] # type: ignore
673
- self.reshape_fn_in = lambda x: einops.rearrange(
674
- x, "... n_heads d_head -> ... (n_heads d_head)"
675
- )
676
- return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
677
-
678
- self.reshape_fn_in = reshape_fn_in
679
-
680
- self.reshape_fn_out = lambda x, d_head: einops.rearrange(
681
- x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
682
- )
683
- self.hook_z_reshaping_mode = True
684
-
685
- def turn_off_forward_pass_hook_z_reshaping(self):
686
- self.reshape_fn_in = lambda x: x
687
- self.reshape_fn_out = lambda x, d_head: x # noqa: ARG005
688
- self.d_head = None
689
- self.hook_z_reshaping_mode = False
690
-
691
-
692
- class TopK(nn.Module):
693
- def __init__(
694
- self, k: int, postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU()
695
- ):
696
- super().__init__()
697
- self.k = k
698
- self.postact_fn = postact_fn
699
-
700
- # TODO: Use a fused kernel to speed up topk decoding like https://github.com/EleutherAI/sae/blob/main/sae/kernels.py
701
- def forward(self, x: torch.Tensor) -> torch.Tensor:
702
- topk = torch.topk(x, k=self.k, dim=-1)
703
- values = self.postact_fn(topk.values)
704
- result = torch.zeros_like(x)
705
- result.scatter_(-1, topk.indices, values)
706
- return result
707
-
708
-
709
- def get_activation_fn(
710
- activation_fn: str, **kwargs: Any
711
- ) -> Callable[[torch.Tensor], torch.Tensor]:
712
- if activation_fn == "relu":
713
- return torch.nn.ReLU()
714
- if activation_fn == "tanh-relu":
715
-
716
- def tanh_relu(input: torch.Tensor) -> torch.Tensor:
717
- input = torch.relu(input)
718
- return torch.tanh(input)
719
-
720
- return tanh_relu
721
- if activation_fn == "topk":
722
- if "k" not in kwargs:
723
- raise ValueError("TopK activation function requires a k value.")
724
- k = kwargs.get("k", 1) # Default k to 1 if not provided
725
- postact_fn = kwargs.get(
726
- "postact_fn", nn.ReLU()
727
- ) # Default post-activation to ReLU if not provided
728
-
729
- return TopK(k, postact_fn)
730
- raise ValueError(f"Unknown activation function: {activation_fn}")
731
-
732
-
733
- _blank_hook = nn.Identity()
734
-
735
-
736
- @contextmanager
737
- def _disable_hooks(sae: SAE):
738
- """
739
- Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
740
- """
741
- try:
742
- for hook_name in sae.hook_dict:
743
- setattr(sae, hook_name, _blank_hook)
744
- yield
745
- finally:
746
- for hook_name, hook in sae.hook_dict.items():
747
- setattr(sae, hook_name, hook)