sae-lens 6.32.1__py3-none-any.whl → 6.34.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.32.1"
2
+ __version__ = "6.34.1"
3
3
 
4
4
  import logging
5
5
 
@@ -1,13 +1,14 @@
1
- import logging
2
1
  from contextlib import contextmanager
3
2
  from typing import Any, Callable
4
3
 
5
4
  import torch
5
+ from torch import nn
6
6
  from transformer_lens.ActivationCache import ActivationCache
7
7
  from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
8
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
9
9
  from transformer_lens.HookedTransformer import HookedTransformer
10
10
 
11
+ from sae_lens import logger
11
12
  from sae_lens.saes.sae import SAE
12
13
 
13
14
  SingleLoss = torch.Tensor # Type alias for a single element tensor
@@ -15,6 +16,72 @@ LossPerToken = torch.Tensor
15
16
  Loss = SingleLoss | LossPerToken
16
17
 
17
18
 
19
+ class _SAEWrapper(nn.Module):
20
+ """Wrapper for SAE/Transcoder that handles error term and hook coordination.
21
+
22
+ For SAEs (input_hook == output_hook), _captured_input stays None and we use
23
+ the forward argument directly. For transcoders, _captured_input is set at
24
+ the input hook via capture_input().
25
+
26
+ Implementation Note:
27
+ The SAE is stored in __dict__ directly rather than as a registered submodule.
28
+ This is intentional: PyTorch's module registration would add a ".sae." prefix
29
+ to all hook names in the cache (e.g., "blocks.0.hook_mlp_out.sae.hook_sae_input"
30
+ instead of "blocks.0.hook_mlp_out.hook_sae_input"). By storing in __dict__ and
31
+ copying hooks directly to the wrapper, we preserve the expected cache paths
32
+ for backwards compatibility.
33
+ """
34
+
35
+ def __init__(self, sae: SAE[Any], use_error_term: bool = False):
36
+ super().__init__()
37
+ # Store SAE in __dict__ to avoid registering as submodule. This keeps cache
38
+ # paths clean by avoiding a ".sae." prefix on hook names. See class docstring.
39
+ self.__dict__["_sae"] = sae
40
+ # Copy SAE's hooks directly to wrapper so they appear at the right path
41
+ for name, hook in sae.hook_dict.items():
42
+ setattr(self, name, hook)
43
+ self.use_error_term = use_error_term
44
+ self._captured_input: torch.Tensor | None = None
45
+
46
+ @property
47
+ def sae(self) -> SAE[Any]:
48
+ return self.__dict__["_sae"]
49
+
50
+ def capture_input(self, x: torch.Tensor) -> None:
51
+ """Capture input at input hook (for transcoders).
52
+
53
+ Note: We don't clone the tensor here - the input should not be modified
54
+ in-place between capture and use, and avoiding clone preserves memory.
55
+ """
56
+ self._captured_input = x
57
+
58
+ def forward(self, original_output: torch.Tensor) -> torch.Tensor:
59
+ """Run SAE/transcoder at output hook location."""
60
+ # For SAE: use original_output as input (same hook for input/output)
61
+ # For transcoder: use captured input from earlier hook
62
+ sae_input = (
63
+ self._captured_input
64
+ if self._captured_input is not None
65
+ else original_output
66
+ )
67
+
68
+ # Temporarily disable SAE's internal use_error_term - we handle it here
69
+ # Use _use_error_term directly to avoid triggering deprecation warning
70
+ sae_use_error_term = self.sae._use_error_term
71
+ self.sae._use_error_term = False
72
+ try:
73
+ sae_out = self.sae(sae_input)
74
+
75
+ if self.use_error_term:
76
+ error = original_output - sae_out.detach()
77
+ sae_out = sae_out + error
78
+
79
+ return sae_out
80
+ finally:
81
+ self.sae._use_error_term = sae_use_error_term
82
+ self._captured_input = None
83
+
84
+
18
85
  def get_deep_attr(obj: Any, path: str):
19
86
  """Helper function to get a nested attribute from a object.
20
87
  In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z)
@@ -78,88 +145,130 @@ class HookedSAETransformer(HookedTransformer):
78
145
  add_hook_in_to_mlp(block.mlp) # type: ignore
79
146
  self.setup()
80
147
 
81
- self.acts_to_saes: dict[str, SAE] = {} # type: ignore
148
+ self._acts_to_saes: dict[str, _SAEWrapper] = {}
149
+ # Track output hooks used by transcoders for cleanup
150
+ self._transcoder_output_hooks: dict[str, str] = {}
151
+
152
+ @property
153
+ def acts_to_saes(self) -> dict[str, SAE[Any]]:
154
+ """Returns a dict mapping hook names to attached SAEs."""
155
+ return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
82
156
 
83
157
  def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
84
- """Attaches an SAE to the model
158
+ """Attaches an SAE or Transcoder to the model.
85
159
 
86
- WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
160
+ WARNING: This SAE will be permanently attached until you remove it with
161
+ reset_saes. This function will also overwrite any existing SAE attached
162
+ to the same hook point.
87
163
 
88
164
  Args:
89
- sae: SparseAutoencoderBase. The SAE to attach to the model
90
- use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
165
+ sae: The SAE or Transcoder to attach to the model.
166
+ use_error_term: If True, computes error term so output matches what the
167
+ model would have produced without the SAE. This works for both SAEs
168
+ (where input==output hook) and transcoders (where they differ).
169
+ Defaults to None (uses SAE's existing setting).
91
170
  """
92
- act_name = sae.cfg.metadata.hook_name
93
- if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
94
- logging.warning(
95
- f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
171
+ input_hook = sae.cfg.metadata.hook_name
172
+ output_hook = sae.cfg.metadata.hook_name_out or input_hook
173
+
174
+ if (input_hook not in self._acts_to_saes) and (
175
+ input_hook not in self.hook_dict
176
+ ):
177
+ logger.warning(
178
+ f"No hook found for {input_hook}. Skipping. Check model.hook_dict for available hooks."
96
179
  )
97
180
  return
98
181
 
99
- if use_error_term is not None:
100
- if not hasattr(sae, "_original_use_error_term"):
101
- sae._original_use_error_term = sae.use_error_term # type: ignore
102
- sae.use_error_term = use_error_term
103
- self.acts_to_saes[act_name] = sae
104
- set_deep_attr(self, act_name, sae)
182
+ # Check if output hook exists (either as hook_dict entry or already has SAE attached)
183
+ output_hook_exists = (
184
+ output_hook in self.hook_dict
185
+ or output_hook in self._acts_to_saes
186
+ or any(v == output_hook for v in self._transcoder_output_hooks.values())
187
+ )
188
+ if not output_hook_exists:
189
+ logger.warning(f"No hook found for output {output_hook}. Skipping.")
190
+ return
191
+
192
+ # Always use wrapper - it handles both SAEs and transcoders uniformly
193
+ # If use_error_term not specified, respect SAE's existing setting
194
+ effective_use_error_term = (
195
+ use_error_term if use_error_term is not None else sae.use_error_term
196
+ )
197
+ wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
198
+
199
+ # For transcoders (input != output), capture input at input hook
200
+ if input_hook != output_hook:
201
+ input_hook_point = get_deep_attr(self, input_hook)
202
+ if isinstance(input_hook_point, HookPoint):
203
+ input_hook_point.add_hook(
204
+ lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
205
+ dir="fwd",
206
+ is_permanent=True,
207
+ )
208
+ self._transcoder_output_hooks[input_hook] = output_hook
209
+
210
+ # Store wrapper in _acts_to_saes and at output hook
211
+ self._acts_to_saes[input_hook] = wrapper
212
+ set_deep_attr(self, output_hook, wrapper)
105
213
  self.setup()
106
214
 
107
- def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None):
108
- """Resets an SAE that was attached to the model
215
+ def _reset_sae(
216
+ self, act_name: str, prev_wrapper: _SAEWrapper | None = None
217
+ ) -> None:
218
+ """Resets an SAE that was attached to the model.
109
219
 
110
220
  By default will remove the SAE from that hook_point.
111
- If prev_sae is provided, will replace the current SAE with the provided one.
112
- This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)
221
+ If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
113
222
 
114
223
  Args:
115
- act_name: str. The hook_name of the SAE to reset
116
- prev_sae: SAE | None. The SAE to replace the current one with. If None, will just remove the SAE from this hook point. Defaults to None
224
+ act_name: The hook_name of the SAE to reset.
225
+ prev_wrapper: The previous wrapper to restore. If None, will just
226
+ remove the SAE from this hook point. Defaults to None.
117
227
  """
118
- if act_name not in self.acts_to_saes:
119
- logging.warning(
228
+ if act_name not in self._acts_to_saes:
229
+ logger.warning(
120
230
  f"No SAE is attached to {act_name}. There's nothing to reset."
121
231
  )
122
232
  return
123
233
 
124
- current_sae = self.acts_to_saes[act_name]
125
- if hasattr(current_sae, "_original_use_error_term"):
126
- current_sae.use_error_term = current_sae._original_use_error_term # type: ignore
127
- delattr(current_sae, "_original_use_error_term")
234
+ # Determine output hook location (different from input for transcoders)
235
+ output_hook = self._transcoder_output_hooks.pop(act_name, act_name)
236
+
237
+ # For transcoders, clear permanent hooks from input hook point
238
+ if output_hook != act_name:
239
+ input_hook_point = get_deep_attr(self, act_name)
240
+ if isinstance(input_hook_point, HookPoint):
241
+ input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
242
+
243
+ # Reset output hook location
244
+ set_deep_attr(self, output_hook, HookPoint())
245
+ del self._acts_to_saes[act_name]
128
246
 
129
- if prev_sae is not None:
130
- set_deep_attr(self, act_name, prev_sae)
131
- self.acts_to_saes[act_name] = prev_sae
132
- else:
133
- set_deep_attr(self, act_name, HookPoint())
134
- del self.acts_to_saes[act_name]
247
+ if prev_wrapper is not None:
248
+ # Rebuild hook_dict before adding new SAE
249
+ self.setup()
250
+ self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
135
251
 
136
252
  def reset_saes(
137
253
  self,
138
254
  act_names: str | list[str] | None = None,
139
- prev_saes: list[SAE[Any] | None] | None = None,
140
- ):
141
- """Reset the SAEs attached to the model
255
+ ) -> None:
256
+ """Reset the SAEs attached to the model.
142
257
 
143
- If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
144
- Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).
258
+ If act_names are provided will just reset SAEs attached to those hooks.
259
+ Otherwise will reset all SAEs attached to the model.
145
260
 
146
261
  Args:
147
- act_names (str | list[str] | None): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
148
- prev_saes (list[SAE | None] | None): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
262
+ act_names: The act_names of the SAEs to reset. If None, will reset
263
+ all SAEs attached to the model. Defaults to None.
149
264
  """
150
265
  if isinstance(act_names, str):
151
266
  act_names = [act_names]
152
267
  elif act_names is None:
153
- act_names = list(self.acts_to_saes.keys())
154
-
155
- if prev_saes:
156
- if len(act_names) != len(prev_saes):
157
- raise ValueError("act_names and prev_saes must have the same length")
158
- else:
159
- prev_saes = [None] * len(act_names) # type: ignore
268
+ act_names = list(self._acts_to_saes.keys())
160
269
 
161
- for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore
162
- self._reset_sae(act_name, prev_sae)
270
+ for act_name in act_names:
271
+ self._reset_sae(act_name)
163
272
 
164
273
  self.setup()
165
274
 
@@ -269,41 +378,32 @@ class HookedSAETransformer(HookedTransformer):
269
378
  reset_saes_end: bool = True,
270
379
  use_error_term: bool | None = None,
271
380
  ):
272
- """
273
- A context manager for adding temporary SAEs to the model.
274
- See HookedTransformer.hooks for a similar context manager for hooks.
275
- By default will keep track of previously attached SAEs, and restore them when the context manager exits.
276
-
277
- Example:
278
-
279
- .. code-block:: python
280
-
281
- from transformer_lens import HookedSAETransformer
282
- from sae_lens.saes.sae import SAE
283
-
284
- model = HookedSAETransformer.from_pretrained('gpt2-small')
285
- sae_cfg = SAEConfig(...)
286
- sae = SAE(sae_cfg)
287
- with model.saes(saes=[sae]):
288
- spliced_logits = model(text)
381
+ """A context manager for adding temporary SAEs to the model.
289
382
 
383
+ See HookedTransformer.hooks for a similar context manager for hooks.
384
+ By default will keep track of previously attached SAEs, and restore
385
+ them when the context manager exits.
290
386
 
291
387
  Args:
292
- saes (SAE | list[SAE]): SAEs to be attached.
293
- reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
294
- use_error_term (bool | None): If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
388
+ saes: SAEs to be attached.
389
+ reset_saes_end: If True, removes all SAEs added by this context
390
+ manager when the context manager exits, returning previously
391
+ attached SAEs to their original state.
392
+ use_error_term: If provided, will set the use_error_term attribute
393
+ of all SAEs attached during this run to this value.
295
394
  """
296
- act_names_to_reset = []
297
- prev_saes = []
395
+ saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
298
396
  if isinstance(saes, SAE):
299
397
  saes = [saes]
300
398
  try:
301
399
  for sae in saes:
302
- act_names_to_reset.append(sae.cfg.metadata.hook_name)
303
- prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
304
- prev_saes.append(prev_sae)
400
+ act_name = sae.cfg.metadata.hook_name
401
+ prev_wrapper = self._acts_to_saes.get(act_name, None)
402
+ saes_to_restore.append((act_name, prev_wrapper))
305
403
  self.add_sae(sae, use_error_term=use_error_term)
306
404
  yield self
307
405
  finally:
308
406
  if reset_saes_end:
309
- self.reset_saes(act_names_to_reset, prev_saes)
407
+ for act_name, prev_wrapper in saes_to_restore:
408
+ self._reset_sae(act_name, prev_wrapper)
409
+ self.setup()
@@ -8,7 +8,11 @@ from transformer_lens.hook_points import HookPoint
8
8
  from transformer_lens.model_bridge import TransformerBridge
9
9
 
10
10
  from sae_lens import logger
11
- from sae_lens.analysis.hooked_sae_transformer import set_deep_attr
11
+ from sae_lens.analysis.hooked_sae_transformer import (
12
+ _SAEWrapper,
13
+ get_deep_attr,
14
+ set_deep_attr,
15
+ )
12
16
  from sae_lens.saes.sae import SAE
13
17
 
14
18
  SingleLoss = torch.Tensor # Type alias for a single element tensor
@@ -30,11 +34,18 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
30
34
  useful for models not natively supported by HookedTransformer, such as Gemma 3.
31
35
  """
32
36
 
33
- acts_to_saes: dict[str, SAE[Any]]
37
+ _acts_to_saes: dict[str, _SAEWrapper]
34
38
 
35
39
  def __init__(self, *args: Any, **kwargs: Any):
36
40
  super().__init__(*args, **kwargs)
37
- self.acts_to_saes = {}
41
+ self._acts_to_saes = {}
42
+ # Track output hooks used by transcoders for cleanup
43
+ self._transcoder_output_hooks: dict[str, str] = {}
44
+
45
+ @property
46
+ def acts_to_saes(self) -> dict[str, SAE[Any]]:
47
+ """Returns a dict mapping hook names to attached SAEs."""
48
+ return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
38
49
 
39
50
  @classmethod
40
51
  def boot_transformers( # type: ignore[override]
@@ -56,7 +67,8 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
56
67
  # Convert to our class
57
68
  # NOTE: this is super hacky and scary, but I don't know how else to achieve this given TLens' internal code
58
69
  bridge.__class__ = cls
59
- bridge.acts_to_saes = {} # type: ignore[attr-defined]
70
+ bridge._acts_to_saes = {} # type: ignore[attr-defined]
71
+ bridge._transcoder_output_hooks = {} # type: ignore[attr-defined]
60
72
  return bridge # type: ignore[return-value]
61
73
 
62
74
  def _resolve_hook_name(self, hook_name: str) -> str:
@@ -75,110 +87,129 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
75
87
  return resolved if isinstance(resolved, str) else hook_name
76
88
 
77
89
  def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None) -> None:
78
- """Attaches an SAE to the model.
90
+ """Attaches an SAE or Transcoder to the model.
79
91
 
80
92
  WARNING: This SAE will be permanently attached until you remove it with
81
93
  reset_saes. This function will also overwrite any existing SAE attached
82
94
  to the same hook point.
83
95
 
84
96
  Args:
85
- sae: The SAE to attach to the model
86
- use_error_term: If provided, will set the use_error_term attribute of
87
- the SAE to this value. Determines whether the SAE returns input
88
- or reconstruction. Defaults to None.
97
+ sae: The SAE or Transcoder to attach to the model.
98
+ use_error_term: If True, computes error term so output matches what the
99
+ model would have produced without the SAE. This works for both SAEs
100
+ (where input==output hook) and transcoders (where they differ).
101
+ Defaults to None (uses SAE's existing setting).
89
102
  """
90
- alias_name = sae.cfg.metadata.hook_name
91
- actual_name = self._resolve_hook_name(alias_name)
92
-
93
- # Check if hook exists (either as alias or actual name)
94
- if (alias_name not in self.acts_to_saes) and (
95
- actual_name not in self._hook_registry
103
+ input_hook_alias = sae.cfg.metadata.hook_name
104
+ output_hook_alias = sae.cfg.metadata.hook_name_out or input_hook_alias
105
+ input_hook_actual = self._resolve_hook_name(input_hook_alias)
106
+ output_hook_actual = self._resolve_hook_name(output_hook_alias)
107
+
108
+ # Check if hooks exist
109
+ if (input_hook_alias not in self._acts_to_saes) and (
110
+ input_hook_actual not in self._hook_registry
96
111
  ):
97
112
  logger.warning(
98
- f"No hook found for {alias_name}. Skipping. "
113
+ f"No hook found for {input_hook_alias}. Skipping. "
99
114
  f"Check model._hook_registry for available hooks."
100
115
  )
101
116
  return
102
117
 
103
- if use_error_term is not None:
104
- if not hasattr(sae, "_original_use_error_term"):
105
- sae._original_use_error_term = sae.use_error_term # type: ignore[attr-defined]
106
- sae.use_error_term = use_error_term
107
-
108
- # Replace hook and update registry
109
- set_deep_attr(self, actual_name, sae)
110
- self._hook_registry[actual_name] = sae # type: ignore[assignment]
111
- self.acts_to_saes[alias_name] = sae
118
+ # Check if output hook exists (either as registry entry or already has SAE attached)
119
+ output_hook_exists = (
120
+ output_hook_actual in self._hook_registry
121
+ or input_hook_alias in self._acts_to_saes
122
+ or any(
123
+ v == output_hook_actual for v in self._transcoder_output_hooks.values()
124
+ )
125
+ )
126
+ if not output_hook_exists:
127
+ logger.warning(f"No hook found for output {output_hook_alias}. Skipping.")
128
+ return
112
129
 
113
- def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None) -> None:
130
+ # Always use wrapper - it handles both SAEs and transcoders uniformly
131
+ # If use_error_term not specified, respect SAE's existing setting
132
+ effective_use_error_term = (
133
+ use_error_term if use_error_term is not None else sae.use_error_term
134
+ )
135
+ wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
136
+
137
+ # For transcoders (input != output), capture input at input hook
138
+ if input_hook_alias != output_hook_alias:
139
+ input_hook_point = get_deep_attr(self, input_hook_actual)
140
+ if isinstance(input_hook_point, HookPoint):
141
+ input_hook_point.add_hook(
142
+ lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
143
+ dir="fwd",
144
+ is_permanent=True,
145
+ )
146
+ self._transcoder_output_hooks[input_hook_alias] = output_hook_actual
147
+
148
+ # Store wrapper in _acts_to_saes and at output hook
149
+ set_deep_attr(self, output_hook_actual, wrapper)
150
+ self._hook_registry[output_hook_actual] = wrapper # type: ignore[assignment]
151
+ self._acts_to_saes[input_hook_alias] = wrapper
152
+
153
+ def _reset_sae(
154
+ self, act_name: str, prev_wrapper: _SAEWrapper | None = None
155
+ ) -> None:
114
156
  """Resets an SAE that was attached to the model.
115
157
 
116
158
  By default will remove the SAE from that hook_point.
117
- If prev_sae is provided, will replace the current SAE with the provided one.
118
- This is mainly used to restore previously attached SAEs after temporarily
119
- running with different SAEs (e.g., with run_with_saes).
159
+ If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
120
160
 
121
161
  Args:
122
162
  act_name: The hook_name of the SAE to reset
123
- prev_sae: The SAE to replace the current one with. If None, will just
163
+ prev_wrapper: The previous wrapper to restore. If None, will just
124
164
  remove the SAE from this hook point. Defaults to None.
125
165
  """
126
- if act_name not in self.acts_to_saes:
166
+ if act_name not in self._acts_to_saes:
127
167
  logger.warning(
128
168
  f"No SAE is attached to {act_name}. There's nothing to reset."
129
169
  )
130
170
  return
131
171
 
132
172
  actual_name = self._resolve_hook_name(act_name)
133
- current_sae = self.acts_to_saes[act_name]
134
-
135
- if hasattr(current_sae, "_original_use_error_term"):
136
- current_sae.use_error_term = current_sae._original_use_error_term # type: ignore[attr-defined]
137
- delattr(current_sae, "_original_use_error_term")
138
-
139
- if prev_sae is not None:
140
- set_deep_attr(self, actual_name, prev_sae)
141
- self._hook_registry[actual_name] = prev_sae # type: ignore[assignment]
142
- self.acts_to_saes[act_name] = prev_sae
143
- else:
144
- new_hook = HookPoint()
145
- new_hook.name = actual_name
146
- set_deep_attr(self, actual_name, new_hook)
147
- self._hook_registry[actual_name] = new_hook
148
- del self.acts_to_saes[act_name]
173
+
174
+ # Determine output hook location (different from input for transcoders)
175
+ output_hook = self._transcoder_output_hooks.pop(act_name, actual_name)
176
+
177
+ # For transcoders, clear permanent hooks from input hook point
178
+ if output_hook != actual_name:
179
+ input_hook_point = get_deep_attr(self, actual_name)
180
+ if isinstance(input_hook_point, HookPoint):
181
+ input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
182
+
183
+ # Reset output hook location
184
+ new_hook = HookPoint()
185
+ new_hook.name = output_hook
186
+ set_deep_attr(self, output_hook, new_hook)
187
+ self._hook_registry[output_hook] = new_hook
188
+ del self._acts_to_saes[act_name]
189
+
190
+ if prev_wrapper is not None:
191
+ self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
149
192
 
150
193
  def reset_saes(
151
194
  self,
152
195
  act_names: str | list[str] | None = None,
153
- prev_saes: list[SAE[Any] | None] | None = None,
154
196
  ) -> None:
155
197
  """Reset the SAEs attached to the model.
156
198
 
157
199
  If act_names are provided will just reset SAEs attached to those hooks.
158
200
  Otherwise will reset all SAEs attached to the model.
159
- Optionally can provide a list of prev_saes to reset to. This is mainly
160
- used to restore previously attached SAEs after temporarily running with
161
- different SAEs (e.g., with run_with_saes).
162
201
 
163
202
  Args:
164
203
  act_names: The act_names of the SAEs to reset. If None, will reset all
165
204
  SAEs attached to the model. Defaults to None.
166
- prev_saes: List of SAEs to replace the current ones with. If None, will
167
- just remove the SAEs. Defaults to None.
168
205
  """
169
206
  if isinstance(act_names, str):
170
207
  act_names = [act_names]
171
208
  elif act_names is None:
172
- act_names = list(self.acts_to_saes.keys())
173
-
174
- if prev_saes:
175
- if len(act_names) != len(prev_saes):
176
- raise ValueError("act_names and prev_saes must have the same length")
177
- else:
178
- prev_saes = [None] * len(act_names) # type: ignore[assignment]
209
+ act_names = list(self._acts_to_saes.keys())
179
210
 
180
- for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore[arg-type]
181
- self._reset_sae(act_name, prev_sae)
211
+ for act_name in act_names:
212
+ self._reset_sae(act_name)
182
213
 
183
214
  def run_with_saes(
184
215
  self,
@@ -310,20 +341,20 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
310
341
  use_error_term: If provided, will set the use_error_term attribute of
311
342
  all SAEs attached during this run to this value. Defaults to None.
312
343
  """
313
- act_names_to_reset: list[str] = []
314
- prev_saes: list[SAE[Any] | None] = []
344
+ saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
315
345
  if isinstance(saes, SAE):
316
346
  saes = [saes]
317
347
  try:
318
348
  for sae in saes:
319
- act_names_to_reset.append(sae.cfg.metadata.hook_name)
320
- prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
321
- prev_saes.append(prev_sae)
349
+ act_name = sae.cfg.metadata.hook_name
350
+ prev_wrapper = self._acts_to_saes.get(act_name, None)
351
+ saes_to_restore.append((act_name, prev_wrapper))
322
352
  self.add_sae(sae, use_error_term=use_error_term)
323
353
  yield self
324
354
  finally:
325
355
  if reset_saes_end:
326
- self.reset_saes(act_names_to_reset, prev_saes)
356
+ for act_name, prev_wrapper in saes_to_restore:
357
+ self._reset_sae(act_name, prev_wrapper)
327
358
 
328
359
  @property
329
360
  def hook_dict(self) -> dict[str, HookPoint]:
@@ -337,9 +368,9 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
337
368
  hooks: dict[str, HookPoint] = {}
338
369
 
339
370
  for name, hook_or_sae in self._hook_registry.items():
340
- if isinstance(hook_or_sae, SAE):
371
+ if isinstance(hook_or_sae, _SAEWrapper):
341
372
  # Include SAE's internal hooks with full path names
342
- for sae_hook_name, sae_hook in hook_or_sae.hook_dict.items():
373
+ for sae_hook_name, sae_hook in hook_or_sae.sae.hook_dict.items():
343
374
  full_name = f"{name}.{sae_hook_name}"
344
375
  hooks[full_name] = sae_hook
345
376
  else:
@@ -57,28 +57,6 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
57
57
  return directory
58
58
 
59
59
 
60
- def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
61
- """
62
- Retrieve the norm_scaling_factor for a specific SAE if it exists.
63
-
64
- Args:
65
- release (str): The release name of the SAE.
66
- sae_id (str): The ID of the specific SAE.
67
-
68
- Returns:
69
- float | None: The norm_scaling_factor if it exists, None otherwise.
70
- """
71
- package = "sae_lens"
72
- yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
- with yaml_file.open("r") as file:
74
- data = yaml.safe_load(file)
75
- if release in data:
76
- for sae_info in data[release]["saes"]:
77
- if sae_info["id"] == sae_id:
78
- return sae_info.get("norm_scaling_factor")
79
- return None
80
-
81
-
82
60
  def get_repo_id_and_folder_name(release: str, sae_id: str) -> tuple[str, str]:
83
61
  saes_directory = get_pretrained_saes_directory()
84
62
  sae_info = saes_directory.get(release, None)
sae_lens/saes/sae.py CHANGED
@@ -45,7 +45,6 @@ from sae_lens.loading.pretrained_sae_loaders import (
45
45
  )
46
46
  from sae_lens.loading.pretrained_saes_directory import (
47
47
  get_config_overrides,
48
- get_norm_scaling_factor,
49
48
  get_pretrained_saes_directory,
50
49
  get_releases_for_repo_id,
51
50
  get_repo_id_and_folder_name,
@@ -230,7 +229,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
230
229
  cfg: T_SAE_CONFIG
231
230
  dtype: torch.dtype
232
231
  device: torch.device
233
- use_error_term: bool
232
+ _use_error_term: bool
234
233
 
235
234
  # For type checking only - don't provide default values
236
235
  # These will be initialized by subclasses
@@ -255,7 +254,9 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
255
254
 
256
255
  self.dtype = str_to_dtype(cfg.dtype)
257
256
  self.device = torch.device(cfg.device)
258
- self.use_error_term = use_error_term
257
+ self._use_error_term = False # Set directly to avoid warning during init
258
+ if use_error_term:
259
+ self.use_error_term = True # Use property setter to trigger warning
259
260
 
260
261
  # Set up activation function
261
262
  self.activation_fn = self.get_activation_fn()
@@ -282,6 +283,22 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
282
283
 
283
284
  self.setup() # Required for HookedRootModule
284
285
 
286
+ @property
287
+ def use_error_term(self) -> bool:
288
+ return self._use_error_term
289
+
290
+ @use_error_term.setter
291
+ def use_error_term(self, value: bool) -> None:
292
+ if value and not self._use_error_term:
293
+ warnings.warn(
294
+ "Setting use_error_term directly on SAE is deprecated. "
295
+ "Use HookedSAETransformer.add_sae(sae, use_error_term=True) instead. "
296
+ "This will be removed in a future version.",
297
+ DeprecationWarning,
298
+ stacklevel=2,
299
+ )
300
+ self._use_error_term = value
301
+
285
302
  @torch.no_grad()
286
303
  def fold_activation_norm_scaling_factor(self, scaling_factor: float):
287
304
  self.W_enc.data *= scaling_factor # type: ignore
@@ -638,24 +655,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
638
655
  stacklevel=2,
639
656
  )
640
657
  elif sae_id not in sae_directory[release].saes_map:
641
- # Handle special cases like Gemma Scope
642
- if (
643
- "gemma-scope" in release
644
- and "canonical" not in release
645
- and f"{release}-canonical" in sae_directory
646
- ):
647
- canonical_ids = list(
648
- sae_directory[release + "-canonical"].saes_map.keys()
649
- )
650
- # Shorten the lengthy string of valid IDs
651
- if len(canonical_ids) > 5:
652
- str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
653
- else:
654
- str_canonical_ids = str(canonical_ids)
655
- value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
656
- else:
657
- value_suffix = ""
658
-
659
658
  valid_ids = list(sae_directory[release].saes_map.keys())
660
659
  # Shorten the lengthy string of valid IDs
661
660
  if len(valid_ids) > 5:
@@ -665,7 +664,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
665
664
 
666
665
  raise ValueError(
667
666
  f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
668
- + value_suffix
669
667
  )
670
668
 
671
669
  conversion_loader = (
@@ -702,17 +700,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
702
700
  sae.process_state_dict_for_loading(state_dict)
703
701
  sae.load_state_dict(state_dict, assign=True)
704
702
 
705
- # Apply normalization if needed
706
- if cfg_dict.get("normalize_activations") == "expected_average_only_in":
707
- norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
708
- if norm_scaling_factor is not None:
709
- sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
710
- cfg_dict["normalize_activations"] = "none"
711
- else:
712
- warnings.warn(
713
- f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
714
- )
715
-
716
703
  # the loaders should already handle the dtype / device conversion
717
704
  # but this is a fallback to guarantee the SAE is on the correct device and dtype
718
705
  return (
@@ -50,6 +50,13 @@ from sae_lens.synthetic.plotting import (
50
50
  find_best_feature_ordering_from_sae,
51
51
  plot_sae_feature_similarity,
52
52
  )
53
+ from sae_lens.synthetic.stats import (
54
+ CorrelationMatrixStats,
55
+ SuperpositionStats,
56
+ compute_correlation_matrix_stats,
57
+ compute_low_rank_correlation_matrix_stats,
58
+ compute_superposition_stats,
59
+ )
53
60
  from sae_lens.synthetic.training import (
54
61
  SyntheticActivationIterator,
55
62
  train_toy_sae,
@@ -80,6 +87,12 @@ __all__ = [
80
87
  "orthogonal_initializer",
81
88
  "FeatureDictionaryInitializer",
82
89
  "cosine_similarities",
90
+ # Statistics
91
+ "compute_correlation_matrix_stats",
92
+ "compute_low_rank_correlation_matrix_stats",
93
+ "compute_superposition_stats",
94
+ "CorrelationMatrixStats",
95
+ "SuperpositionStats",
83
96
  # Training utilities
84
97
  "SyntheticActivationIterator",
85
98
  "SyntheticDataEvalResult",
@@ -3,6 +3,7 @@ from typing import NamedTuple
3
3
 
4
4
  import torch
5
5
 
6
+ from sae_lens import logger
6
7
  from sae_lens.util import str_to_dtype
7
8
 
8
9
 
@@ -268,7 +269,7 @@ def generate_random_correlation_matrix(
268
269
  def generate_random_low_rank_correlation_matrix(
269
270
  num_features: int,
270
271
  rank: int,
271
- correlation_scale: float = 0.1,
272
+ correlation_scale: float = 0.075,
272
273
  seed: int | None = None,
273
274
  device: torch.device | str = "cpu",
274
275
  dtype: torch.dtype | str = torch.float32,
@@ -331,20 +332,17 @@ def generate_random_low_rank_correlation_matrix(
331
332
  factor_sq_sum = (factor**2).sum(dim=1)
332
333
  diag_term = 1 - factor_sq_sum
333
334
 
334
- # Ensure diagonal terms are at least _MIN_DIAG for numerical stability
335
- # If any diagonal term is too small, scale down the factor matrix
336
- if torch.any(diag_term < _MIN_DIAG):
337
- # Scale factor so max row norm squared is at most (1 - _MIN_DIAG)
338
- # This ensures all diagonal terms are >= _MIN_DIAG
339
- max_factor_contribution = 1 - _MIN_DIAG
340
- max_sq_sum = factor_sq_sum.max()
341
- scale = torch.sqrt(
342
- torch.tensor(max_factor_contribution, device=device, dtype=dtype)
343
- / max_sq_sum
335
+ # alternatively, we can rescale each row independently to ensure the diagonal is 1
336
+ mask = diag_term < _MIN_DIAG
337
+ factor[mask, :] *= torch.sqrt((1 - _MIN_DIAG) / factor_sq_sum[mask].unsqueeze(1))
338
+ factor_sq_sum = (factor**2).sum(dim=1)
339
+ diag_term = 1 - factor_sq_sum
340
+
341
+ total_rescaled = mask.sum().item()
342
+ if total_rescaled > 0:
343
+ logger.warning(
344
+ f"{total_rescaled} / {num_features} rows were capped. Either reduce the rank or reduce the correlation_scale to avoid rescaling."
344
345
  )
345
- factor = factor * scale
346
- factor_sq_sum = (factor**2).sum(dim=1)
347
- diag_term = 1 - factor_sq_sum
348
346
 
349
347
  return LowRankCorrelationMatrix(
350
348
  correlation_factor=factor, correlation_diag=diag_term
@@ -0,0 +1,205 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
6
+ from sae_lens.synthetic.feature_dictionary import FeatureDictionary
7
+
8
+
9
+ @dataclass
10
+ class CorrelationMatrixStats:
11
+ """Statistics computed from a correlation matrix."""
12
+
13
+ rms_correlation: float # Root mean square of off-diagonal correlations
14
+ mean_correlation: float # Mean of off-diagonal correlations (not absolute)
15
+ num_features: int
16
+
17
+
18
+ @torch.no_grad()
19
+ def compute_correlation_matrix_stats(
20
+ correlation_matrix: torch.Tensor,
21
+ ) -> CorrelationMatrixStats:
22
+ """Compute correlation statistics from a dense correlation matrix.
23
+
24
+ Args:
25
+ correlation_matrix: Dense correlation matrix of shape (n, n)
26
+
27
+ Returns:
28
+ CorrelationMatrixStats with correlation statistics
29
+ """
30
+ num_features = correlation_matrix.shape[0]
31
+
32
+ # Extract off-diagonal elements
33
+ mask = ~torch.eye(num_features, dtype=torch.bool, device=correlation_matrix.device)
34
+ off_diag = correlation_matrix[mask]
35
+
36
+ rms_correlation = (off_diag**2).mean().sqrt().item()
37
+ mean_correlation = off_diag.mean().item()
38
+
39
+ return CorrelationMatrixStats(
40
+ rms_correlation=rms_correlation,
41
+ mean_correlation=mean_correlation,
42
+ num_features=num_features,
43
+ )
44
+
45
+
46
+ @torch.no_grad()
47
+ def compute_low_rank_correlation_matrix_stats(
48
+ correlation_matrix: LowRankCorrelationMatrix,
49
+ ) -> CorrelationMatrixStats:
50
+ """Compute correlation statistics from a LowRankCorrelationMatrix.
51
+
52
+ The correlation matrix is represented as:
53
+ correlation = factor @ factor.T + diag(diag_term)
54
+
55
+ The off-diagonal elements are simply factor @ factor.T (the diagonal term
56
+ only affects the diagonal).
57
+
58
+ All statistics are computed efficiently in O(n*r²) time and O(r²) memory
59
+ without materializing the full n×n correlation matrix.
60
+
61
+ Args:
62
+ correlation_matrix: Low-rank correlation matrix
63
+
64
+ Returns:
65
+ CorrelationMatrixStats with correlation statistics
66
+ """
67
+
68
+ factor = correlation_matrix.correlation_factor
69
+ num_features = factor.shape[0]
70
+ num_off_diag = num_features * (num_features - 1)
71
+
72
+ # RMS correlation: uses ||F @ F.T||_F² = ||F.T @ F||_F²
73
+ # This avoids computing the (num_features, num_features) matrix
74
+ G = factor.T @ factor # (rank, rank) - small!
75
+ frobenius_sq = (G**2).sum()
76
+ row_norms_sq = (factor**2).sum(dim=1) # ||F[i]||² for each row
77
+ diag_sq_sum = (row_norms_sq**2).sum() # Σᵢ ||F[i]||⁴
78
+ off_diag_sq_sum = frobenius_sq - diag_sq_sum
79
+ rms_correlation = (off_diag_sq_sum / num_off_diag).sqrt().item()
80
+
81
+ # Mean correlation (not absolute): sum(C) = ||col_sums(F)||², trace(C) = Σ||F[i]||²
82
+ col_sums = factor.sum(dim=0) # (rank,)
83
+ sum_all = (col_sums**2).sum() # 1ᵀ C 1
84
+ trace_C = row_norms_sq.sum()
85
+ mean_correlation = ((sum_all - trace_C) / num_off_diag).item()
86
+
87
+ return CorrelationMatrixStats(
88
+ rms_correlation=rms_correlation,
89
+ mean_correlation=mean_correlation,
90
+ num_features=num_features,
91
+ )
92
+
93
+
94
+ @dataclass
95
+ class SuperpositionStats:
96
+ """Statistics measuring superposition in a feature dictionary."""
97
+
98
+ # Per-latent statistics: for each latent, max and percentile of |cos_sim| with others
99
+ max_abs_cos_sims: torch.Tensor # Shape: (num_features,)
100
+ percentile_abs_cos_sims: dict[int, torch.Tensor] # {percentile: (num_features,)}
101
+
102
+ # Summary statistics (means of the per-latent values)
103
+ mean_max_abs_cos_sim: float
104
+ mean_percentile_abs_cos_sim: dict[int, float]
105
+ mean_abs_cos_sim: float # Mean |cos_sim| across all pairs
106
+
107
+ # Metadata
108
+ num_features: int
109
+ hidden_dim: int
110
+
111
+
112
+ @torch.no_grad()
113
+ def compute_superposition_stats(
114
+ feature_dictionary: FeatureDictionary,
115
+ batch_size: int = 1024,
116
+ device: str | torch.device | None = None,
117
+ percentiles: list[int] | None = None,
118
+ ) -> SuperpositionStats:
119
+ """Compute superposition statistics for a feature dictionary.
120
+
121
+ Computes pairwise cosine similarities in batches to handle large dictionaries.
122
+
123
+ For each latent i, computes:
124
+
125
+ - max |cos_sim(i, j)| over all j != i
126
+ - kth percentile of |cos_sim(i, j)| over all j != i (for each k in percentiles)
127
+
128
+ Args:
129
+ feature_dictionary: FeatureDictionary containing the feature vectors
130
+ batch_size: Number of features to process per batch
131
+ device: Device for computation (defaults to feature dictionary's device)
132
+ percentiles: List of percentiles to compute per latent (default: [95, 99])
133
+
134
+ Returns:
135
+ SuperpositionStats with superposition metrics
136
+ """
137
+ if percentiles is None:
138
+ percentiles = [95, 99]
139
+
140
+ feature_vectors = feature_dictionary.feature_vectors
141
+ num_features, hidden_dim = feature_vectors.shape
142
+
143
+ if num_features < 2:
144
+ raise ValueError("Need at least 2 features to compute superposition stats")
145
+ if device is None:
146
+ device = feature_vectors.device
147
+
148
+ # Normalize features to unit norm for cosine similarity
149
+ features_normalized = feature_vectors.to(device).float()
150
+ norms = torch.linalg.norm(features_normalized, dim=1, keepdim=True)
151
+ features_normalized = features_normalized / norms.clamp(min=1e-8)
152
+
153
+ # Track per-latent statistics
154
+ max_abs_cos_sims = torch.zeros(num_features, device=device)
155
+ percentile_abs_cos_sims = {
156
+ p: torch.zeros(num_features, device=device) for p in percentiles
157
+ }
158
+ sum_abs_cos_sim = 0.0
159
+ n_pairs = 0
160
+
161
+ # Process in batches: for each batch of features, compute similarities with all others
162
+ for i in range(0, num_features, batch_size):
163
+ batch_end = min(i + batch_size, num_features)
164
+ batch = features_normalized[i:batch_end] # (batch_size, hidden_dim)
165
+
166
+ # Compute cosine similarities with all features: (batch_size, num_features)
167
+ cos_sims = batch @ features_normalized.T
168
+
169
+ # Absolute cosine similarities
170
+ abs_cos_sims = cos_sims.abs()
171
+
172
+ # Process each latent in the batch
173
+ for j, idx in enumerate(range(i, batch_end)):
174
+ # Get similarities with all other features (exclude self)
175
+ row = abs_cos_sims[j].clone()
176
+ row[idx] = 0.0 # Exclude self for max
177
+ max_abs_cos_sims[idx] = row.max()
178
+
179
+ # For percentiles, exclude self and compute
180
+ other_sims = torch.cat([abs_cos_sims[j, :idx], abs_cos_sims[j, idx + 1 :]])
181
+ for p in percentiles:
182
+ percentile_abs_cos_sims[p][idx] = torch.quantile(other_sims, p / 100.0)
183
+
184
+ # Sum for mean computation (only count pairs once - with features after this one)
185
+ sum_abs_cos_sim += abs_cos_sims[j, idx + 1 :].sum().item()
186
+ n_pairs += num_features - idx - 1
187
+
188
+ # Compute summary statistics
189
+ mean_max_abs_cos_sim = max_abs_cos_sims.mean().item()
190
+ mean_percentile_abs_cos_sim = {
191
+ p: percentile_abs_cos_sims[p].mean().item() for p in percentiles
192
+ }
193
+ mean_abs_cos_sim = sum_abs_cos_sim / n_pairs if n_pairs > 0 else 0.0
194
+
195
+ return SuperpositionStats(
196
+ max_abs_cos_sims=max_abs_cos_sims.cpu(),
197
+ percentile_abs_cos_sims={
198
+ p: v.cpu() for p, v in percentile_abs_cos_sims.items()
199
+ },
200
+ mean_max_abs_cos_sim=mean_max_abs_cos_sim,
201
+ mean_percentile_abs_cos_sim=mean_percentile_abs_cos_sim,
202
+ mean_abs_cos_sim=mean_abs_cos_sim,
203
+ num_features=num_features,
204
+ hidden_dim=hidden_dim,
205
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.32.1
3
+ Version: 6.34.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,9 +1,9 @@
1
- sae_lens/__init__.py,sha256=Y_TVKGehpnTvQw8tvIn0fjo8uAw-XAYi7carZS_cRjQ,5168
1
+ sae_lens/__init__.py,sha256=U5I6dWw2_x6ZXwL8BF72vB0HGPEUt-_BzHq-btLq2pw,5168
2
2
  sae_lens/analysis/__init__.py,sha256=FZExlMviNwWR7OGUSGRbd0l-yUDGSp80gglI_ivILrY,412
3
3
  sae_lens/analysis/compat.py,sha256=cgE3nhFcJTcuhppxbL71VanJS7YqVEOefuneB5eOaPw,538
4
- sae_lens/analysis/hooked_sae_transformer.py,sha256=LpnjxSAcItqqXA4SJyZuxY4Ki0UOuWV683wg9laYAsY,14050
4
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=-LY9CKYEziSTt-H7MeLdTx6ErfvMqgNEF709wV7tEs4,17826
5
5
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
6
- sae_lens/analysis/sae_transformer_bridge.py,sha256=xpJRRcB0g47EOQcmNCwMyrJJsbqMsGxVViDrV6C3upU,14916
6
+ sae_lens/analysis/sae_transformer_bridge.py,sha256=y_ZdvaxuUM_-7ywSCeFl6f6cq1_FiqRstMAo8mxZtYE,16191
7
7
  sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
8
8
  sae_lens/config.py,sha256=V0BXV8rvpbm5YuVukow9FURPpdyE4HSflbdymAo0Ycg,31205
9
9
  sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
@@ -12,7 +12,7 @@ sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHF
12
12
  sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
13
13
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  sae_lens/loading/pretrained_sae_loaders.py,sha256=kshvA0NivOc7B3sL19lHr_zrC_DDfW2T6YWb5j0hgAk,63930
15
- sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
15
+ sae_lens/loading/pretrained_saes_directory.py,sha256=lSnHl77IO5dd7iO21ynCzZNMrzuJAT8Za4W5THNq0qw,3554
16
16
  sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
17
17
  sae_lens/pretrained_saes.yaml,sha256=IVBLLR8_XNllJ1O-kVv9ED4u0u44Yn8UOL9R-f8Idp4,1511936
18
18
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
@@ -22,20 +22,21 @@ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,88
22
22
  sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
23
23
  sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
24
24
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
25
- sae_lens/saes/sae.py,sha256=xRmgiLuaFlDCv8SyLbL-5TwdrWHpNLqSGe8mC1L6WcI,40942
25
+ sae_lens/saes/sae.py,sha256=2FHhLoTZYOJUGjkWH7eF2EAc1fPxQsF4KYr5TZ8IUIU,40155
26
26
  sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
27
27
  sae_lens/saes/temporal_sae.py,sha256=S44sPddVj2xujA02CC8gT1tG0in7c_CSAhspu9FHbaA,13273
28
28
  sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
29
29
  sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
30
- sae_lens/synthetic/__init__.py,sha256=MtTnGkTfHV2WjkIgs7zZyx10EK9U5fjOHXy69Aq3uKw,3095
30
+ sae_lens/synthetic/__init__.py,sha256=hRRA3xhEQUacGyFbJXkLVYg_8A1bbSYYWlVovb0g4KU,3503
31
31
  sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
32
- sae_lens/synthetic/correlation.py,sha256=tMTLo9fBfDpeXwqhyUgFqnTipj9x2W0t4oEtNxB7AG0,13256
32
+ sae_lens/synthetic/correlation.py,sha256=tD8J9abWfuFtGZrEbbFn4P8FeTcNKF2V5JhBLwDUmkg,13146
33
33
  sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
34
34
  sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
35
35
  sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
36
36
  sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
37
37
  sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
38
38
  sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
39
+ sae_lens/synthetic/stats.py,sha256=BoDPKDx8pgFF5Ko_IaBRZTczm7-ANUIRjjF5W5Qh3Lk,7441
39
40
  sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
40
41
  sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
41
42
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -48,7 +49,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
48
49
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
49
50
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
50
51
  sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
51
- sae_lens-6.32.1.dist-info/METADATA,sha256=TcO6hFEXKdbLp32UTiVluHcMXFetfYJDqTHNCsx9PRw,6566
52
- sae_lens-6.32.1.dist-info/WHEEL,sha256=3ny-bZhpXrU6vSQ1UPG34FoxZBp3lVcvK0LkgUz6VLk,88
53
- sae_lens-6.32.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
54
- sae_lens-6.32.1.dist-info/RECORD,,
52
+ sae_lens-6.34.1.dist-info/METADATA,sha256=vAlbWT90NggKoxjII6dnEj7wlC-PmuBn8zzqL6r8dRg,6566
53
+ sae_lens-6.34.1.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
54
+ sae_lens-6.34.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
55
+ sae_lens-6.34.1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.3.0
2
+ Generator: poetry-core 2.3.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any