sae-lens 6.33.0__py3-none-any.whl → 6.34.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.33.0"
2
+ __version__ = "6.34.1"
3
3
 
4
4
  import logging
5
5
 
@@ -1,13 +1,14 @@
1
- import logging
2
1
  from contextlib import contextmanager
3
2
  from typing import Any, Callable
4
3
 
5
4
  import torch
5
+ from torch import nn
6
6
  from transformer_lens.ActivationCache import ActivationCache
7
7
  from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
8
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
9
9
  from transformer_lens.HookedTransformer import HookedTransformer
10
10
 
11
+ from sae_lens import logger
11
12
  from sae_lens.saes.sae import SAE
12
13
 
13
14
  SingleLoss = torch.Tensor # Type alias for a single element tensor
@@ -15,6 +16,72 @@ LossPerToken = torch.Tensor
15
16
  Loss = SingleLoss | LossPerToken
16
17
 
17
18
 
19
+ class _SAEWrapper(nn.Module):
20
+ """Wrapper for SAE/Transcoder that handles error term and hook coordination.
21
+
22
+ For SAEs (input_hook == output_hook), _captured_input stays None and we use
23
+ the forward argument directly. For transcoders, _captured_input is set at
24
+ the input hook via capture_input().
25
+
26
+ Implementation Note:
27
+ The SAE is stored in __dict__ directly rather than as a registered submodule.
28
+ This is intentional: PyTorch's module registration would add a ".sae." prefix
29
+ to all hook names in the cache (e.g., "blocks.0.hook_mlp_out.sae.hook_sae_input"
30
+ instead of "blocks.0.hook_mlp_out.hook_sae_input"). By storing in __dict__ and
31
+ copying hooks directly to the wrapper, we preserve the expected cache paths
32
+ for backwards compatibility.
33
+ """
34
+
35
+ def __init__(self, sae: SAE[Any], use_error_term: bool = False):
36
+ super().__init__()
37
+ # Store SAE in __dict__ to avoid registering as submodule. This keeps cache
38
+ # paths clean by avoiding a ".sae." prefix on hook names. See class docstring.
39
+ self.__dict__["_sae"] = sae
40
+ # Copy SAE's hooks directly to wrapper so they appear at the right path
41
+ for name, hook in sae.hook_dict.items():
42
+ setattr(self, name, hook)
43
+ self.use_error_term = use_error_term
44
+ self._captured_input: torch.Tensor | None = None
45
+
46
+ @property
47
+ def sae(self) -> SAE[Any]:
48
+ return self.__dict__["_sae"]
49
+
50
+ def capture_input(self, x: torch.Tensor) -> None:
51
+ """Capture input at input hook (for transcoders).
52
+
53
+ Note: We don't clone the tensor here - the input should not be modified
54
+ in-place between capture and use, and avoiding clone preserves memory.
55
+ """
56
+ self._captured_input = x
57
+
58
+ def forward(self, original_output: torch.Tensor) -> torch.Tensor:
59
+ """Run SAE/transcoder at output hook location."""
60
+ # For SAE: use original_output as input (same hook for input/output)
61
+ # For transcoder: use captured input from earlier hook
62
+ sae_input = (
63
+ self._captured_input
64
+ if self._captured_input is not None
65
+ else original_output
66
+ )
67
+
68
+ # Temporarily disable SAE's internal use_error_term - we handle it here
69
+ # Use _use_error_term directly to avoid triggering deprecation warning
70
+ sae_use_error_term = self.sae._use_error_term
71
+ self.sae._use_error_term = False
72
+ try:
73
+ sae_out = self.sae(sae_input)
74
+
75
+ if self.use_error_term:
76
+ error = original_output - sae_out.detach()
77
+ sae_out = sae_out + error
78
+
79
+ return sae_out
80
+ finally:
81
+ self.sae._use_error_term = sae_use_error_term
82
+ self._captured_input = None
83
+
84
+
18
85
  def get_deep_attr(obj: Any, path: str):
19
86
  """Helper function to get a nested attribute from a object.
20
87
  In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z)
@@ -78,88 +145,130 @@ class HookedSAETransformer(HookedTransformer):
78
145
  add_hook_in_to_mlp(block.mlp) # type: ignore
79
146
  self.setup()
80
147
 
81
- self.acts_to_saes: dict[str, SAE] = {} # type: ignore
148
+ self._acts_to_saes: dict[str, _SAEWrapper] = {}
149
+ # Track output hooks used by transcoders for cleanup
150
+ self._transcoder_output_hooks: dict[str, str] = {}
151
+
152
+ @property
153
+ def acts_to_saes(self) -> dict[str, SAE[Any]]:
154
+ """Returns a dict mapping hook names to attached SAEs."""
155
+ return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
82
156
 
83
157
  def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
84
- """Attaches an SAE to the model
158
+ """Attaches an SAE or Transcoder to the model.
85
159
 
86
- WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
160
+ WARNING: This SAE will be permanently attached until you remove it with
161
+ reset_saes. This function will also overwrite any existing SAE attached
162
+ to the same hook point.
87
163
 
88
164
  Args:
89
- sae: SparseAutoencoderBase. The SAE to attach to the model
90
- use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
165
+ sae: The SAE or Transcoder to attach to the model.
166
+ use_error_term: If True, computes error term so output matches what the
167
+ model would have produced without the SAE. This works for both SAEs
168
+ (where input==output hook) and transcoders (where they differ).
169
+ Defaults to None (uses SAE's existing setting).
91
170
  """
92
- act_name = sae.cfg.metadata.hook_name
93
- if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
94
- logging.warning(
95
- f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
171
+ input_hook = sae.cfg.metadata.hook_name
172
+ output_hook = sae.cfg.metadata.hook_name_out or input_hook
173
+
174
+ if (input_hook not in self._acts_to_saes) and (
175
+ input_hook not in self.hook_dict
176
+ ):
177
+ logger.warning(
178
+ f"No hook found for {input_hook}. Skipping. Check model.hook_dict for available hooks."
96
179
  )
97
180
  return
98
181
 
99
- if use_error_term is not None:
100
- if not hasattr(sae, "_original_use_error_term"):
101
- sae._original_use_error_term = sae.use_error_term # type: ignore
102
- sae.use_error_term = use_error_term
103
- self.acts_to_saes[act_name] = sae
104
- set_deep_attr(self, act_name, sae)
182
+ # Check if output hook exists (either as hook_dict entry or already has SAE attached)
183
+ output_hook_exists = (
184
+ output_hook in self.hook_dict
185
+ or output_hook in self._acts_to_saes
186
+ or any(v == output_hook for v in self._transcoder_output_hooks.values())
187
+ )
188
+ if not output_hook_exists:
189
+ logger.warning(f"No hook found for output {output_hook}. Skipping.")
190
+ return
191
+
192
+ # Always use wrapper - it handles both SAEs and transcoders uniformly
193
+ # If use_error_term not specified, respect SAE's existing setting
194
+ effective_use_error_term = (
195
+ use_error_term if use_error_term is not None else sae.use_error_term
196
+ )
197
+ wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
198
+
199
+ # For transcoders (input != output), capture input at input hook
200
+ if input_hook != output_hook:
201
+ input_hook_point = get_deep_attr(self, input_hook)
202
+ if isinstance(input_hook_point, HookPoint):
203
+ input_hook_point.add_hook(
204
+ lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
205
+ dir="fwd",
206
+ is_permanent=True,
207
+ )
208
+ self._transcoder_output_hooks[input_hook] = output_hook
209
+
210
+ # Store wrapper in _acts_to_saes and at output hook
211
+ self._acts_to_saes[input_hook] = wrapper
212
+ set_deep_attr(self, output_hook, wrapper)
105
213
  self.setup()
106
214
 
107
- def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None):
108
- """Resets an SAE that was attached to the model
215
+ def _reset_sae(
216
+ self, act_name: str, prev_wrapper: _SAEWrapper | None = None
217
+ ) -> None:
218
+ """Resets an SAE that was attached to the model.
109
219
 
110
220
  By default will remove the SAE from that hook_point.
111
- If prev_sae is provided, will replace the current SAE with the provided one.
112
- This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)
221
+ If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
113
222
 
114
223
  Args:
115
- act_name: str. The hook_name of the SAE to reset
116
- prev_sae: SAE | None. The SAE to replace the current one with. If None, will just remove the SAE from this hook point. Defaults to None
224
+ act_name: The hook_name of the SAE to reset.
225
+ prev_wrapper: The previous wrapper to restore. If None, will just
226
+ remove the SAE from this hook point. Defaults to None.
117
227
  """
118
- if act_name not in self.acts_to_saes:
119
- logging.warning(
228
+ if act_name not in self._acts_to_saes:
229
+ logger.warning(
120
230
  f"No SAE is attached to {act_name}. There's nothing to reset."
121
231
  )
122
232
  return
123
233
 
124
- current_sae = self.acts_to_saes[act_name]
125
- if hasattr(current_sae, "_original_use_error_term"):
126
- current_sae.use_error_term = current_sae._original_use_error_term # type: ignore
127
- delattr(current_sae, "_original_use_error_term")
234
+ # Determine output hook location (different from input for transcoders)
235
+ output_hook = self._transcoder_output_hooks.pop(act_name, act_name)
236
+
237
+ # For transcoders, clear permanent hooks from input hook point
238
+ if output_hook != act_name:
239
+ input_hook_point = get_deep_attr(self, act_name)
240
+ if isinstance(input_hook_point, HookPoint):
241
+ input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
242
+
243
+ # Reset output hook location
244
+ set_deep_attr(self, output_hook, HookPoint())
245
+ del self._acts_to_saes[act_name]
128
246
 
129
- if prev_sae is not None:
130
- set_deep_attr(self, act_name, prev_sae)
131
- self.acts_to_saes[act_name] = prev_sae
132
- else:
133
- set_deep_attr(self, act_name, HookPoint())
134
- del self.acts_to_saes[act_name]
247
+ if prev_wrapper is not None:
248
+ # Rebuild hook_dict before adding new SAE
249
+ self.setup()
250
+ self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
135
251
 
136
252
  def reset_saes(
137
253
  self,
138
254
  act_names: str | list[str] | None = None,
139
- prev_saes: list[SAE[Any] | None] | None = None,
140
- ):
141
- """Reset the SAEs attached to the model
255
+ ) -> None:
256
+ """Reset the SAEs attached to the model.
142
257
 
143
- If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
144
- Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).
258
+ If act_names are provided will just reset SAEs attached to those hooks.
259
+ Otherwise will reset all SAEs attached to the model.
145
260
 
146
261
  Args:
147
- act_names (str | list[str] | None): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
148
- prev_saes (list[SAE | None] | None): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
262
+ act_names: The act_names of the SAEs to reset. If None, will reset
263
+ all SAEs attached to the model. Defaults to None.
149
264
  """
150
265
  if isinstance(act_names, str):
151
266
  act_names = [act_names]
152
267
  elif act_names is None:
153
- act_names = list(self.acts_to_saes.keys())
154
-
155
- if prev_saes:
156
- if len(act_names) != len(prev_saes):
157
- raise ValueError("act_names and prev_saes must have the same length")
158
- else:
159
- prev_saes = [None] * len(act_names) # type: ignore
268
+ act_names = list(self._acts_to_saes.keys())
160
269
 
161
- for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore
162
- self._reset_sae(act_name, prev_sae)
270
+ for act_name in act_names:
271
+ self._reset_sae(act_name)
163
272
 
164
273
  self.setup()
165
274
 
@@ -269,41 +378,32 @@ class HookedSAETransformer(HookedTransformer):
269
378
  reset_saes_end: bool = True,
270
379
  use_error_term: bool | None = None,
271
380
  ):
272
- """
273
- A context manager for adding temporary SAEs to the model.
274
- See HookedTransformer.hooks for a similar context manager for hooks.
275
- By default will keep track of previously attached SAEs, and restore them when the context manager exits.
276
-
277
- Example:
278
-
279
- .. code-block:: python
280
-
281
- from transformer_lens import HookedSAETransformer
282
- from sae_lens.saes.sae import SAE
283
-
284
- model = HookedSAETransformer.from_pretrained('gpt2-small')
285
- sae_cfg = SAEConfig(...)
286
- sae = SAE(sae_cfg)
287
- with model.saes(saes=[sae]):
288
- spliced_logits = model(text)
381
+ """A context manager for adding temporary SAEs to the model.
289
382
 
383
+ See HookedTransformer.hooks for a similar context manager for hooks.
384
+ By default will keep track of previously attached SAEs, and restore
385
+ them when the context manager exits.
290
386
 
291
387
  Args:
292
- saes (SAE | list[SAE]): SAEs to be attached.
293
- reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
294
- use_error_term (bool | None): If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
388
+ saes: SAEs to be attached.
389
+ reset_saes_end: If True, removes all SAEs added by this context
390
+ manager when the context manager exits, returning previously
391
+ attached SAEs to their original state.
392
+ use_error_term: If provided, will set the use_error_term attribute
393
+ of all SAEs attached during this run to this value.
295
394
  """
296
- act_names_to_reset = []
297
- prev_saes = []
395
+ saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
298
396
  if isinstance(saes, SAE):
299
397
  saes = [saes]
300
398
  try:
301
399
  for sae in saes:
302
- act_names_to_reset.append(sae.cfg.metadata.hook_name)
303
- prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
304
- prev_saes.append(prev_sae)
400
+ act_name = sae.cfg.metadata.hook_name
401
+ prev_wrapper = self._acts_to_saes.get(act_name, None)
402
+ saes_to_restore.append((act_name, prev_wrapper))
305
403
  self.add_sae(sae, use_error_term=use_error_term)
306
404
  yield self
307
405
  finally:
308
406
  if reset_saes_end:
309
- self.reset_saes(act_names_to_reset, prev_saes)
407
+ for act_name, prev_wrapper in saes_to_restore:
408
+ self._reset_sae(act_name, prev_wrapper)
409
+ self.setup()
@@ -8,7 +8,11 @@ from transformer_lens.hook_points import HookPoint
8
8
  from transformer_lens.model_bridge import TransformerBridge
9
9
 
10
10
  from sae_lens import logger
11
- from sae_lens.analysis.hooked_sae_transformer import set_deep_attr
11
+ from sae_lens.analysis.hooked_sae_transformer import (
12
+ _SAEWrapper,
13
+ get_deep_attr,
14
+ set_deep_attr,
15
+ )
12
16
  from sae_lens.saes.sae import SAE
13
17
 
14
18
  SingleLoss = torch.Tensor # Type alias for a single element tensor
@@ -30,11 +34,18 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
30
34
  useful for models not natively supported by HookedTransformer, such as Gemma 3.
31
35
  """
32
36
 
33
- acts_to_saes: dict[str, SAE[Any]]
37
+ _acts_to_saes: dict[str, _SAEWrapper]
34
38
 
35
39
  def __init__(self, *args: Any, **kwargs: Any):
36
40
  super().__init__(*args, **kwargs)
37
- self.acts_to_saes = {}
41
+ self._acts_to_saes = {}
42
+ # Track output hooks used by transcoders for cleanup
43
+ self._transcoder_output_hooks: dict[str, str] = {}
44
+
45
+ @property
46
+ def acts_to_saes(self) -> dict[str, SAE[Any]]:
47
+ """Returns a dict mapping hook names to attached SAEs."""
48
+ return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
38
49
 
39
50
  @classmethod
40
51
  def boot_transformers( # type: ignore[override]
@@ -56,7 +67,8 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
56
67
  # Convert to our class
57
68
  # NOTE: this is super hacky and scary, but I don't know how else to achieve this given TLens' internal code
58
69
  bridge.__class__ = cls
59
- bridge.acts_to_saes = {} # type: ignore[attr-defined]
70
+ bridge._acts_to_saes = {} # type: ignore[attr-defined]
71
+ bridge._transcoder_output_hooks = {} # type: ignore[attr-defined]
60
72
  return bridge # type: ignore[return-value]
61
73
 
62
74
  def _resolve_hook_name(self, hook_name: str) -> str:
@@ -75,110 +87,129 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
75
87
  return resolved if isinstance(resolved, str) else hook_name
76
88
 
77
89
  def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None) -> None:
78
- """Attaches an SAE to the model.
90
+ """Attaches an SAE or Transcoder to the model.
79
91
 
80
92
  WARNING: This SAE will be permanently attached until you remove it with
81
93
  reset_saes. This function will also overwrite any existing SAE attached
82
94
  to the same hook point.
83
95
 
84
96
  Args:
85
- sae: The SAE to attach to the model
86
- use_error_term: If provided, will set the use_error_term attribute of
87
- the SAE to this value. Determines whether the SAE returns input
88
- or reconstruction. Defaults to None.
97
+ sae: The SAE or Transcoder to attach to the model.
98
+ use_error_term: If True, computes error term so output matches what the
99
+ model would have produced without the SAE. This works for both SAEs
100
+ (where input==output hook) and transcoders (where they differ).
101
+ Defaults to None (uses SAE's existing setting).
89
102
  """
90
- alias_name = sae.cfg.metadata.hook_name
91
- actual_name = self._resolve_hook_name(alias_name)
92
-
93
- # Check if hook exists (either as alias or actual name)
94
- if (alias_name not in self.acts_to_saes) and (
95
- actual_name not in self._hook_registry
103
+ input_hook_alias = sae.cfg.metadata.hook_name
104
+ output_hook_alias = sae.cfg.metadata.hook_name_out or input_hook_alias
105
+ input_hook_actual = self._resolve_hook_name(input_hook_alias)
106
+ output_hook_actual = self._resolve_hook_name(output_hook_alias)
107
+
108
+ # Check if hooks exist
109
+ if (input_hook_alias not in self._acts_to_saes) and (
110
+ input_hook_actual not in self._hook_registry
96
111
  ):
97
112
  logger.warning(
98
- f"No hook found for {alias_name}. Skipping. "
113
+ f"No hook found for {input_hook_alias}. Skipping. "
99
114
  f"Check model._hook_registry for available hooks."
100
115
  )
101
116
  return
102
117
 
103
- if use_error_term is not None:
104
- if not hasattr(sae, "_original_use_error_term"):
105
- sae._original_use_error_term = sae.use_error_term # type: ignore[attr-defined]
106
- sae.use_error_term = use_error_term
107
-
108
- # Replace hook and update registry
109
- set_deep_attr(self, actual_name, sae)
110
- self._hook_registry[actual_name] = sae # type: ignore[assignment]
111
- self.acts_to_saes[alias_name] = sae
118
+ # Check if output hook exists (either as registry entry or already has SAE attached)
119
+ output_hook_exists = (
120
+ output_hook_actual in self._hook_registry
121
+ or input_hook_alias in self._acts_to_saes
122
+ or any(
123
+ v == output_hook_actual for v in self._transcoder_output_hooks.values()
124
+ )
125
+ )
126
+ if not output_hook_exists:
127
+ logger.warning(f"No hook found for output {output_hook_alias}. Skipping.")
128
+ return
112
129
 
113
- def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None) -> None:
130
+ # Always use wrapper - it handles both SAEs and transcoders uniformly
131
+ # If use_error_term not specified, respect SAE's existing setting
132
+ effective_use_error_term = (
133
+ use_error_term if use_error_term is not None else sae.use_error_term
134
+ )
135
+ wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
136
+
137
+ # For transcoders (input != output), capture input at input hook
138
+ if input_hook_alias != output_hook_alias:
139
+ input_hook_point = get_deep_attr(self, input_hook_actual)
140
+ if isinstance(input_hook_point, HookPoint):
141
+ input_hook_point.add_hook(
142
+ lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
143
+ dir="fwd",
144
+ is_permanent=True,
145
+ )
146
+ self._transcoder_output_hooks[input_hook_alias] = output_hook_actual
147
+
148
+ # Store wrapper in _acts_to_saes and at output hook
149
+ set_deep_attr(self, output_hook_actual, wrapper)
150
+ self._hook_registry[output_hook_actual] = wrapper # type: ignore[assignment]
151
+ self._acts_to_saes[input_hook_alias] = wrapper
152
+
153
+ def _reset_sae(
154
+ self, act_name: str, prev_wrapper: _SAEWrapper | None = None
155
+ ) -> None:
114
156
  """Resets an SAE that was attached to the model.
115
157
 
116
158
  By default will remove the SAE from that hook_point.
117
- If prev_sae is provided, will replace the current SAE with the provided one.
118
- This is mainly used to restore previously attached SAEs after temporarily
119
- running with different SAEs (e.g., with run_with_saes).
159
+ If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
120
160
 
121
161
  Args:
122
162
  act_name: The hook_name of the SAE to reset
123
- prev_sae: The SAE to replace the current one with. If None, will just
163
+ prev_wrapper: The previous wrapper to restore. If None, will just
124
164
  remove the SAE from this hook point. Defaults to None.
125
165
  """
126
- if act_name not in self.acts_to_saes:
166
+ if act_name not in self._acts_to_saes:
127
167
  logger.warning(
128
168
  f"No SAE is attached to {act_name}. There's nothing to reset."
129
169
  )
130
170
  return
131
171
 
132
172
  actual_name = self._resolve_hook_name(act_name)
133
- current_sae = self.acts_to_saes[act_name]
134
-
135
- if hasattr(current_sae, "_original_use_error_term"):
136
- current_sae.use_error_term = current_sae._original_use_error_term # type: ignore[attr-defined]
137
- delattr(current_sae, "_original_use_error_term")
138
-
139
- if prev_sae is not None:
140
- set_deep_attr(self, actual_name, prev_sae)
141
- self._hook_registry[actual_name] = prev_sae # type: ignore[assignment]
142
- self.acts_to_saes[act_name] = prev_sae
143
- else:
144
- new_hook = HookPoint()
145
- new_hook.name = actual_name
146
- set_deep_attr(self, actual_name, new_hook)
147
- self._hook_registry[actual_name] = new_hook
148
- del self.acts_to_saes[act_name]
173
+
174
+ # Determine output hook location (different from input for transcoders)
175
+ output_hook = self._transcoder_output_hooks.pop(act_name, actual_name)
176
+
177
+ # For transcoders, clear permanent hooks from input hook point
178
+ if output_hook != actual_name:
179
+ input_hook_point = get_deep_attr(self, actual_name)
180
+ if isinstance(input_hook_point, HookPoint):
181
+ input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
182
+
183
+ # Reset output hook location
184
+ new_hook = HookPoint()
185
+ new_hook.name = output_hook
186
+ set_deep_attr(self, output_hook, new_hook)
187
+ self._hook_registry[output_hook] = new_hook
188
+ del self._acts_to_saes[act_name]
189
+
190
+ if prev_wrapper is not None:
191
+ self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
149
192
 
150
193
  def reset_saes(
151
194
  self,
152
195
  act_names: str | list[str] | None = None,
153
- prev_saes: list[SAE[Any] | None] | None = None,
154
196
  ) -> None:
155
197
  """Reset the SAEs attached to the model.
156
198
 
157
199
  If act_names are provided will just reset SAEs attached to those hooks.
158
200
  Otherwise will reset all SAEs attached to the model.
159
- Optionally can provide a list of prev_saes to reset to. This is mainly
160
- used to restore previously attached SAEs after temporarily running with
161
- different SAEs (e.g., with run_with_saes).
162
201
 
163
202
  Args:
164
203
  act_names: The act_names of the SAEs to reset. If None, will reset all
165
204
  SAEs attached to the model. Defaults to None.
166
- prev_saes: List of SAEs to replace the current ones with. If None, will
167
- just remove the SAEs. Defaults to None.
168
205
  """
169
206
  if isinstance(act_names, str):
170
207
  act_names = [act_names]
171
208
  elif act_names is None:
172
- act_names = list(self.acts_to_saes.keys())
173
-
174
- if prev_saes:
175
- if len(act_names) != len(prev_saes):
176
- raise ValueError("act_names and prev_saes must have the same length")
177
- else:
178
- prev_saes = [None] * len(act_names) # type: ignore[assignment]
209
+ act_names = list(self._acts_to_saes.keys())
179
210
 
180
- for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore[arg-type]
181
- self._reset_sae(act_name, prev_sae)
211
+ for act_name in act_names:
212
+ self._reset_sae(act_name)
182
213
 
183
214
  def run_with_saes(
184
215
  self,
@@ -310,20 +341,20 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
310
341
  use_error_term: If provided, will set the use_error_term attribute of
311
342
  all SAEs attached during this run to this value. Defaults to None.
312
343
  """
313
- act_names_to_reset: list[str] = []
314
- prev_saes: list[SAE[Any] | None] = []
344
+ saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
315
345
  if isinstance(saes, SAE):
316
346
  saes = [saes]
317
347
  try:
318
348
  for sae in saes:
319
- act_names_to_reset.append(sae.cfg.metadata.hook_name)
320
- prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
321
- prev_saes.append(prev_sae)
349
+ act_name = sae.cfg.metadata.hook_name
350
+ prev_wrapper = self._acts_to_saes.get(act_name, None)
351
+ saes_to_restore.append((act_name, prev_wrapper))
322
352
  self.add_sae(sae, use_error_term=use_error_term)
323
353
  yield self
324
354
  finally:
325
355
  if reset_saes_end:
326
- self.reset_saes(act_names_to_reset, prev_saes)
356
+ for act_name, prev_wrapper in saes_to_restore:
357
+ self._reset_sae(act_name, prev_wrapper)
327
358
 
328
359
  @property
329
360
  def hook_dict(self) -> dict[str, HookPoint]:
@@ -337,9 +368,9 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
337
368
  hooks: dict[str, HookPoint] = {}
338
369
 
339
370
  for name, hook_or_sae in self._hook_registry.items():
340
- if isinstance(hook_or_sae, SAE):
371
+ if isinstance(hook_or_sae, _SAEWrapper):
341
372
  # Include SAE's internal hooks with full path names
342
- for sae_hook_name, sae_hook in hook_or_sae.hook_dict.items():
373
+ for sae_hook_name, sae_hook in hook_or_sae.sae.hook_dict.items():
343
374
  full_name = f"{name}.{sae_hook_name}"
344
375
  hooks[full_name] = sae_hook
345
376
  else:
sae_lens/saes/sae.py CHANGED
@@ -229,7 +229,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
229
229
  cfg: T_SAE_CONFIG
230
230
  dtype: torch.dtype
231
231
  device: torch.device
232
- use_error_term: bool
232
+ _use_error_term: bool
233
233
 
234
234
  # For type checking only - don't provide default values
235
235
  # These will be initialized by subclasses
@@ -254,7 +254,9 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
254
254
 
255
255
  self.dtype = str_to_dtype(cfg.dtype)
256
256
  self.device = torch.device(cfg.device)
257
- self.use_error_term = use_error_term
257
+ self._use_error_term = False # Set directly to avoid warning during init
258
+ if use_error_term:
259
+ self.use_error_term = True # Use property setter to trigger warning
258
260
 
259
261
  # Set up activation function
260
262
  self.activation_fn = self.get_activation_fn()
@@ -281,6 +283,22 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
281
283
 
282
284
  self.setup() # Required for HookedRootModule
283
285
 
286
+ @property
287
+ def use_error_term(self) -> bool:
288
+ return self._use_error_term
289
+
290
+ @use_error_term.setter
291
+ def use_error_term(self, value: bool) -> None:
292
+ if value and not self._use_error_term:
293
+ warnings.warn(
294
+ "Setting use_error_term directly on SAE is deprecated. "
295
+ "Use HookedSAETransformer.add_sae(sae, use_error_term=True) instead. "
296
+ "This will be removed in a future version.",
297
+ DeprecationWarning,
298
+ stacklevel=2,
299
+ )
300
+ self._use_error_term = value
301
+
284
302
  @torch.no_grad()
285
303
  def fold_activation_norm_scaling_factor(self, scaling_factor: float):
286
304
  self.W_enc.data *= scaling_factor # type: ignore
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.33.0
3
+ Version: 6.34.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,9 +1,9 @@
1
- sae_lens/__init__.py,sha256=gHaxlySzLskrAUg2oUZ3aOpnI3U_AVIHce-agGJL9rI,5168
1
+ sae_lens/__init__.py,sha256=U5I6dWw2_x6ZXwL8BF72vB0HGPEUt-_BzHq-btLq2pw,5168
2
2
  sae_lens/analysis/__init__.py,sha256=FZExlMviNwWR7OGUSGRbd0l-yUDGSp80gglI_ivILrY,412
3
3
  sae_lens/analysis/compat.py,sha256=cgE3nhFcJTcuhppxbL71VanJS7YqVEOefuneB5eOaPw,538
4
- sae_lens/analysis/hooked_sae_transformer.py,sha256=LpnjxSAcItqqXA4SJyZuxY4Ki0UOuWV683wg9laYAsY,14050
4
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=-LY9CKYEziSTt-H7MeLdTx6ErfvMqgNEF709wV7tEs4,17826
5
5
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
6
- sae_lens/analysis/sae_transformer_bridge.py,sha256=xpJRRcB0g47EOQcmNCwMyrJJsbqMsGxVViDrV6C3upU,14916
6
+ sae_lens/analysis/sae_transformer_bridge.py,sha256=y_ZdvaxuUM_-7ywSCeFl6f6cq1_FiqRstMAo8mxZtYE,16191
7
7
  sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
8
8
  sae_lens/config.py,sha256=V0BXV8rvpbm5YuVukow9FURPpdyE4HSflbdymAo0Ycg,31205
9
9
  sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
@@ -22,7 +22,7 @@ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,88
22
22
  sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
23
23
  sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
24
24
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
25
- sae_lens/saes/sae.py,sha256=wkwqzNragj-1189cV52S3_XeRtEgBd2ZNwvL2EsKkWw,39429
25
+ sae_lens/saes/sae.py,sha256=2FHhLoTZYOJUGjkWH7eF2EAc1fPxQsF4KYr5TZ8IUIU,40155
26
26
  sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
27
27
  sae_lens/saes/temporal_sae.py,sha256=S44sPddVj2xujA02CC8gT1tG0in7c_CSAhspu9FHbaA,13273
28
28
  sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
@@ -49,7 +49,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
49
49
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
50
50
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
51
51
  sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
52
- sae_lens-6.33.0.dist-info/METADATA,sha256=X6XqngWTNEsfdaPPWXxtF8Kvdp8fAk8i68sfRtDb2xo,6566
53
- sae_lens-6.33.0.dist-info/WHEEL,sha256=3ny-bZhpXrU6vSQ1UPG34FoxZBp3lVcvK0LkgUz6VLk,88
54
- sae_lens-6.33.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
55
- sae_lens-6.33.0.dist-info/RECORD,,
52
+ sae_lens-6.34.1.dist-info/METADATA,sha256=vAlbWT90NggKoxjII6dnEj7wlC-PmuBn8zzqL6r8dRg,6566
53
+ sae_lens-6.34.1.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
54
+ sae_lens-6.34.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
55
+ sae_lens-6.34.1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.3.0
2
+ Generator: poetry-core 2.3.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any