sae-lens 6.31.0__tar.gz → 6.34.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {sae_lens-6.31.0 → sae_lens-6.34.0}/PKG-INFO +2 -2
  2. {sae_lens-6.31.0 → sae_lens-6.34.0}/pyproject.toml +3 -2
  3. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/__init__.py +14 -1
  4. sae_lens-6.34.0/sae_lens/analysis/__init__.py +15 -0
  5. sae_lens-6.34.0/sae_lens/analysis/compat.py +16 -0
  6. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/analysis/hooked_sae_transformer.py +175 -76
  7. sae_lens-6.34.0/sae_lens/analysis/sae_transformer_bridge.py +379 -0
  8. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/loading/pretrained_sae_loaders.py +2 -1
  9. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/loading/pretrained_saes_directory.py +0 -22
  10. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/sae.py +20 -33
  11. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/__init__.py +13 -0
  12. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/correlation.py +12 -14
  13. sae_lens-6.34.0/sae_lens/synthetic/stats.py +205 -0
  14. sae_lens-6.31.0/sae_lens/training/__init__.py +0 -0
  15. {sae_lens-6.31.0 → sae_lens-6.34.0}/LICENSE +0 -0
  16. {sae_lens-6.31.0 → sae_lens-6.34.0}/README.md +0 -0
  17. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  18. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/cache_activations_runner.py +0 -0
  19. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/config.py +0 -0
  20. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/constants.py +0 -0
  21. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/evals.py +0 -0
  22. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/llm_sae_training_runner.py +0 -0
  23. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/load_model.py +0 -0
  24. {sae_lens-6.31.0/sae_lens/analysis → sae_lens-6.34.0/sae_lens/loading}/__init__.py +0 -0
  25. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/pretokenize_runner.py +0 -0
  26. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/pretrained_saes.yaml +0 -0
  27. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/registry.py +0 -0
  28. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/__init__.py +0 -0
  29. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  30. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/gated_sae.py +0 -0
  31. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  32. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
  33. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  34. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/standard_sae.py +0 -0
  35. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/temporal_sae.py +0 -0
  36. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/topk_sae.py +0 -0
  37. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/transcoder.py +0 -0
  38. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/activation_generator.py +0 -0
  39. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/evals.py +0 -0
  40. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/feature_dictionary.py +0 -0
  41. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/firing_probabilities.py +0 -0
  42. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/hierarchy.py +0 -0
  43. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/initialization.py +0 -0
  44. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/plotting.py +0 -0
  45. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/training.py +0 -0
  46. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/tokenization_and_batching.py +0 -0
  47. {sae_lens-6.31.0/sae_lens/loading → sae_lens-6.34.0/sae_lens/training}/__init__.py +0 -0
  48. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/activation_scaler.py +0 -0
  49. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/activations_store.py +0 -0
  50. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/mixing_buffer.py +0 -0
  51. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/optim.py +0 -0
  52. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/sae_trainer.py +0 -0
  53. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/types.py +0 -0
  54. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  55. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/tutorial/tsea.py +0 -0
  56. {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.31.0
3
+ Version: 6.34.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -27,7 +27,7 @@ Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
27
27
  Requires-Dist: safetensors (>=0.4.2,<1.0.0)
28
28
  Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
29
29
  Requires-Dist: tenacity (>=9.0.0)
30
- Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
30
+ Requires-Dist: transformer-lens (>=2.16.1)
31
31
  Requires-Dist: transformers (>=4.38.1,<5.0.0)
32
32
  Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
33
33
  Project-URL: Homepage, https://decoderesearch.github.io/SAELens
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.31.0"
3
+ version = "6.34.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -19,7 +19,7 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
19
19
 
20
20
  [tool.poetry.dependencies]
21
21
  python = "^3.10"
22
- transformer-lens = "^2.16.1"
22
+ transformer-lens = ">=2.16.1"
23
23
  transformers = "^4.38.1"
24
24
  plotly = ">=5.19.0"
25
25
  plotly-express = ">=0.4.1"
@@ -59,6 +59,7 @@ mike = "^2.0.0"
59
59
  trio = "^0.30.0"
60
60
  dictionary-learning = "^0.1.0"
61
61
  kaleido = "^1.2.0"
62
+ transformer-lens = { version = "3.0.0b1", allow-prereleases = true }
62
63
 
63
64
  [tool.poetry.extras]
64
65
  mamba = ["mamba-lens"]
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.31.0"
2
+ __version__ = "6.34.0"
3
3
 
4
4
  import logging
5
5
 
@@ -125,6 +125,19 @@ __all__ = [
125
125
  "MatchingPursuitTrainingSAEConfig",
126
126
  ]
127
127
 
128
+ # Conditional export for SAETransformerBridge (requires transformer-lens v3+)
129
+ try:
130
+ from sae_lens.analysis.compat import has_transformer_bridge
131
+
132
+ if has_transformer_bridge():
133
+ from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
134
+ SAETransformerBridge,
135
+ )
136
+
137
+ __all__.append("SAETransformerBridge")
138
+ except ImportError:
139
+ pass
140
+
128
141
 
129
142
  register_sae_class("standard", StandardSAE, StandardSAEConfig)
130
143
  register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
@@ -0,0 +1,15 @@
1
+ from sae_lens.analysis.hooked_sae_transformer import HookedSAETransformer
2
+
3
+ __all__ = ["HookedSAETransformer"]
4
+
5
+ try:
6
+ from sae_lens.analysis.compat import has_transformer_bridge
7
+
8
+ if has_transformer_bridge():
9
+ from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
10
+ SAETransformerBridge,
11
+ )
12
+
13
+ __all__.append("SAETransformerBridge")
14
+ except ImportError:
15
+ pass
@@ -0,0 +1,16 @@
1
+ import importlib.metadata
2
+
3
+ from packaging.version import parse as parse_version
4
+
5
+
6
+ def get_transformer_lens_version() -> tuple[int, int, int]:
7
+ """Get transformer-lens version as (major, minor, patch)."""
8
+ version_str = importlib.metadata.version("transformer-lens")
9
+ version = parse_version(version_str)
10
+ return (version.major, version.minor, version.micro)
11
+
12
+
13
+ def has_transformer_bridge() -> bool:
14
+ """Check if TransformerBridge is available (v3+)."""
15
+ major, _, _ = get_transformer_lens_version()
16
+ return major >= 3
@@ -1,13 +1,14 @@
1
- import logging
2
1
  from contextlib import contextmanager
3
2
  from typing import Any, Callable
4
3
 
5
4
  import torch
5
+ from torch import nn
6
6
  from transformer_lens.ActivationCache import ActivationCache
7
7
  from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
8
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
9
9
  from transformer_lens.HookedTransformer import HookedTransformer
10
10
 
11
+ from sae_lens import logger
11
12
  from sae_lens.saes.sae import SAE
12
13
 
13
14
  SingleLoss = torch.Tensor # Type alias for a single element tensor
@@ -15,6 +16,71 @@ LossPerToken = torch.Tensor
15
16
  Loss = SingleLoss | LossPerToken
16
17
 
17
18
 
19
+ class _SAEWrapper(nn.Module):
20
+ """Wrapper for SAE/Transcoder that handles error term and hook coordination.
21
+
22
+ For SAEs (input_hook == output_hook), _captured_input stays None and we use
23
+ the forward argument directly. For transcoders, _captured_input is set at
24
+ the input hook via capture_input().
25
+
26
+ Implementation Note:
27
+ The SAE is stored in __dict__ directly rather than as a registered submodule.
28
+ This is intentional: PyTorch's module registration would add a ".sae." prefix
29
+ to all hook names in the cache (e.g., "blocks.0.hook_mlp_out.sae.hook_sae_input"
30
+ instead of "blocks.0.hook_mlp_out.hook_sae_input"). By storing in __dict__ and
31
+ copying hooks directly to the wrapper, we preserve the expected cache paths
32
+ for backwards compatibility.
33
+ """
34
+
35
+ def __init__(self, sae: SAE[Any], use_error_term: bool = False):
36
+ super().__init__()
37
+ # Store SAE in __dict__ to avoid registering as submodule. This keeps cache
38
+ # paths clean by avoiding a ".sae." prefix on hook names. See class docstring.
39
+ self.__dict__["_sae"] = sae
40
+ # Copy SAE's hooks directly to wrapper so they appear at the right path
41
+ for name, hook in sae.hook_dict.items():
42
+ setattr(self, name, hook)
43
+ self.use_error_term = use_error_term
44
+ self._captured_input: torch.Tensor | None = None
45
+
46
+ @property
47
+ def sae(self) -> SAE[Any]:
48
+ return self.__dict__["_sae"]
49
+
50
+ def capture_input(self, x: torch.Tensor) -> None:
51
+ """Capture input at input hook (for transcoders).
52
+
53
+ Note: We don't clone the tensor here - the input should not be modified
54
+ in-place between capture and use, and avoiding clone preserves memory.
55
+ """
56
+ self._captured_input = x
57
+
58
+ def forward(self, original_output: torch.Tensor) -> torch.Tensor:
59
+ """Run SAE/transcoder at output hook location."""
60
+ # For SAE: use original_output as input (same hook for input/output)
61
+ # For transcoder: use captured input from earlier hook
62
+ sae_input = (
63
+ self._captured_input
64
+ if self._captured_input is not None
65
+ else original_output
66
+ )
67
+
68
+ # Temporarily disable SAE's internal use_error_term - we handle it here
69
+ sae_use_error_term = self.sae.use_error_term
70
+ self.sae.use_error_term = False
71
+ try:
72
+ sae_out = self.sae(sae_input)
73
+
74
+ if self.use_error_term:
75
+ error = original_output - sae_out.detach()
76
+ sae_out = sae_out + error
77
+
78
+ return sae_out
79
+ finally:
80
+ self.sae.use_error_term = sae_use_error_term
81
+ self._captured_input = None
82
+
83
+
18
84
  def get_deep_attr(obj: Any, path: str):
19
85
  """Helper function to get a nested attribute from a object.
20
86
  In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z)
@@ -78,88 +144,130 @@ class HookedSAETransformer(HookedTransformer):
78
144
  add_hook_in_to_mlp(block.mlp) # type: ignore
79
145
  self.setup()
80
146
 
81
- self.acts_to_saes: dict[str, SAE] = {} # type: ignore
147
+ self._acts_to_saes: dict[str, _SAEWrapper] = {}
148
+ # Track output hooks used by transcoders for cleanup
149
+ self._transcoder_output_hooks: dict[str, str] = {}
150
+
151
+ @property
152
+ def acts_to_saes(self) -> dict[str, SAE[Any]]:
153
+ """Returns a dict mapping hook names to attached SAEs."""
154
+ return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
82
155
 
83
156
  def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
84
- """Attaches an SAE to the model
157
+ """Attaches an SAE or Transcoder to the model.
85
158
 
86
- WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
159
+ WARNING: This SAE will be permanently attached until you remove it with
160
+ reset_saes. This function will also overwrite any existing SAE attached
161
+ to the same hook point.
87
162
 
88
163
  Args:
89
- sae: SparseAutoencoderBase. The SAE to attach to the model
90
- use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
164
+ sae: The SAE or Transcoder to attach to the model.
165
+ use_error_term: If True, computes error term so output matches what the
166
+ model would have produced without the SAE. This works for both SAEs
167
+ (where input==output hook) and transcoders (where they differ).
168
+ Defaults to None (uses SAE's existing setting).
91
169
  """
92
- act_name = sae.cfg.metadata.hook_name
93
- if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
94
- logging.warning(
95
- f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
170
+ input_hook = sae.cfg.metadata.hook_name
171
+ output_hook = sae.cfg.metadata.hook_name_out or input_hook
172
+
173
+ if (input_hook not in self._acts_to_saes) and (
174
+ input_hook not in self.hook_dict
175
+ ):
176
+ logger.warning(
177
+ f"No hook found for {input_hook}. Skipping. Check model.hook_dict for available hooks."
96
178
  )
97
179
  return
98
180
 
99
- if use_error_term is not None:
100
- if not hasattr(sae, "_original_use_error_term"):
101
- sae._original_use_error_term = sae.use_error_term # type: ignore
102
- sae.use_error_term = use_error_term
103
- self.acts_to_saes[act_name] = sae
104
- set_deep_attr(self, act_name, sae)
181
+ # Check if output hook exists (either as hook_dict entry or already has SAE attached)
182
+ output_hook_exists = (
183
+ output_hook in self.hook_dict
184
+ or output_hook in self._acts_to_saes
185
+ or any(v == output_hook for v in self._transcoder_output_hooks.values())
186
+ )
187
+ if not output_hook_exists:
188
+ logger.warning(f"No hook found for output {output_hook}. Skipping.")
189
+ return
190
+
191
+ # Always use wrapper - it handles both SAEs and transcoders uniformly
192
+ # If use_error_term not specified, respect SAE's existing setting
193
+ effective_use_error_term = (
194
+ use_error_term if use_error_term is not None else sae.use_error_term
195
+ )
196
+ wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
197
+
198
+ # For transcoders (input != output), capture input at input hook
199
+ if input_hook != output_hook:
200
+ input_hook_point = get_deep_attr(self, input_hook)
201
+ if isinstance(input_hook_point, HookPoint):
202
+ input_hook_point.add_hook(
203
+ lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
204
+ dir="fwd",
205
+ is_permanent=True,
206
+ )
207
+ self._transcoder_output_hooks[input_hook] = output_hook
208
+
209
+ # Store wrapper in _acts_to_saes and at output hook
210
+ self._acts_to_saes[input_hook] = wrapper
211
+ set_deep_attr(self, output_hook, wrapper)
105
212
  self.setup()
106
213
 
107
- def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None):
108
- """Resets an SAE that was attached to the model
214
+ def _reset_sae(
215
+ self, act_name: str, prev_wrapper: _SAEWrapper | None = None
216
+ ) -> None:
217
+ """Resets an SAE that was attached to the model.
109
218
 
110
219
  By default will remove the SAE from that hook_point.
111
- If prev_sae is provided, will replace the current SAE with the provided one.
112
- This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)
220
+ If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
113
221
 
114
222
  Args:
115
- act_name: str. The hook_name of the SAE to reset
116
- prev_sae: SAE | None. The SAE to replace the current one with. If None, will just remove the SAE from this hook point. Defaults to None
223
+ act_name: The hook_name of the SAE to reset.
224
+ prev_wrapper: The previous wrapper to restore. If None, will just
225
+ remove the SAE from this hook point. Defaults to None.
117
226
  """
118
- if act_name not in self.acts_to_saes:
119
- logging.warning(
227
+ if act_name not in self._acts_to_saes:
228
+ logger.warning(
120
229
  f"No SAE is attached to {act_name}. There's nothing to reset."
121
230
  )
122
231
  return
123
232
 
124
- current_sae = self.acts_to_saes[act_name]
125
- if hasattr(current_sae, "_original_use_error_term"):
126
- current_sae.use_error_term = current_sae._original_use_error_term # type: ignore
127
- delattr(current_sae, "_original_use_error_term")
233
+ # Determine output hook location (different from input for transcoders)
234
+ output_hook = self._transcoder_output_hooks.pop(act_name, act_name)
235
+
236
+ # For transcoders, clear permanent hooks from input hook point
237
+ if output_hook != act_name:
238
+ input_hook_point = get_deep_attr(self, act_name)
239
+ if isinstance(input_hook_point, HookPoint):
240
+ input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
241
+
242
+ # Reset output hook location
243
+ set_deep_attr(self, output_hook, HookPoint())
244
+ del self._acts_to_saes[act_name]
128
245
 
129
- if prev_sae:
130
- set_deep_attr(self, act_name, prev_sae)
131
- self.acts_to_saes[act_name] = prev_sae
132
- else:
133
- set_deep_attr(self, act_name, HookPoint())
134
- del self.acts_to_saes[act_name]
246
+ if prev_wrapper is not None:
247
+ # Rebuild hook_dict before adding new SAE
248
+ self.setup()
249
+ self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
135
250
 
136
251
  def reset_saes(
137
252
  self,
138
253
  act_names: str | list[str] | None = None,
139
- prev_saes: list[SAE[Any] | None] | None = None,
140
- ):
141
- """Reset the SAEs attached to the model
254
+ ) -> None:
255
+ """Reset the SAEs attached to the model.
142
256
 
143
- If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
144
- Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).
257
+ If act_names are provided will just reset SAEs attached to those hooks.
258
+ Otherwise will reset all SAEs attached to the model.
145
259
 
146
260
  Args:
147
- act_names (str | list[str] | None): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
148
- prev_saes (list[SAE | None] | None): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
261
+ act_names: The act_names of the SAEs to reset. If None, will reset
262
+ all SAEs attached to the model. Defaults to None.
149
263
  """
150
264
  if isinstance(act_names, str):
151
265
  act_names = [act_names]
152
266
  elif act_names is None:
153
- act_names = list(self.acts_to_saes.keys())
154
-
155
- if prev_saes:
156
- if len(act_names) != len(prev_saes):
157
- raise ValueError("act_names and prev_saes must have the same length")
158
- else:
159
- prev_saes = [None] * len(act_names) # type: ignore
267
+ act_names = list(self._acts_to_saes.keys())
160
268
 
161
- for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore
162
- self._reset_sae(act_name, prev_sae)
269
+ for act_name in act_names:
270
+ self._reset_sae(act_name)
163
271
 
164
272
  self.setup()
165
273
 
@@ -269,41 +377,32 @@ class HookedSAETransformer(HookedTransformer):
269
377
  reset_saes_end: bool = True,
270
378
  use_error_term: bool | None = None,
271
379
  ):
272
- """
273
- A context manager for adding temporary SAEs to the model.
274
- See HookedTransformer.hooks for a similar context manager for hooks.
275
- By default will keep track of previously attached SAEs, and restore them when the context manager exits.
276
-
277
- Example:
278
-
279
- .. code-block:: python
280
-
281
- from transformer_lens import HookedSAETransformer
282
- from sae_lens.saes.sae import SAE
283
-
284
- model = HookedSAETransformer.from_pretrained('gpt2-small')
285
- sae_cfg = SAEConfig(...)
286
- sae = SAE(sae_cfg)
287
- with model.saes(saes=[sae]):
288
- spliced_logits = model(text)
380
+ """A context manager for adding temporary SAEs to the model.
289
381
 
382
+ See HookedTransformer.hooks for a similar context manager for hooks.
383
+ By default will keep track of previously attached SAEs, and restore
384
+ them when the context manager exits.
290
385
 
291
386
  Args:
292
- saes (SAE | list[SAE]): SAEs to be attached.
293
- reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
294
- use_error_term (bool | None): If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
387
+ saes: SAEs to be attached.
388
+ reset_saes_end: If True, removes all SAEs added by this context
389
+ manager when the context manager exits, returning previously
390
+ attached SAEs to their original state.
391
+ use_error_term: If provided, will set the use_error_term attribute
392
+ of all SAEs attached during this run to this value.
295
393
  """
296
- act_names_to_reset = []
297
- prev_saes = []
394
+ saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
298
395
  if isinstance(saes, SAE):
299
396
  saes = [saes]
300
397
  try:
301
398
  for sae in saes:
302
- act_names_to_reset.append(sae.cfg.metadata.hook_name)
303
- prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
304
- prev_saes.append(prev_sae)
399
+ act_name = sae.cfg.metadata.hook_name
400
+ prev_wrapper = self._acts_to_saes.get(act_name, None)
401
+ saes_to_restore.append((act_name, prev_wrapper))
305
402
  self.add_sae(sae, use_error_term=use_error_term)
306
403
  yield self
307
404
  finally:
308
405
  if reset_saes_end:
309
- self.reset_saes(act_names_to_reset, prev_saes)
406
+ for act_name, prev_wrapper in saes_to_restore:
407
+ self._reset_sae(act_name, prev_wrapper)
408
+ self.setup()