sae-lens 6.12.2__py3-none-any.whl → 6.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.12.2"
2
+ __version__ = "6.13.0"
3
3
 
4
4
  import logging
5
5
 
sae_lens/evals.py CHANGED
@@ -466,6 +466,8 @@ def get_sparsity_and_variance_metrics(
466
466
  sae_out_scaled = sae.decode(sae_feature_activations).to(
467
467
  original_act_scaled.device
468
468
  )
469
+ if sae_feature_activations.is_sparse:
470
+ sae_feature_activations = sae_feature_activations.to_dense()
469
471
  del cache
470
472
 
471
473
  sae_out = activation_scaler.unscale(sae_out_scaled)
@@ -233,6 +233,12 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
233
233
  "reshape_activations",
234
234
  "hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
235
235
  )
236
+ if (
237
+ new_cfg.get("activation_fn") == "topk"
238
+ and new_cfg.get("activation_fn_kwargs", {}).get("k") is not None
239
+ ):
240
+ new_cfg["architecture"] = "topk"
241
+ new_cfg["k"] = new_cfg["activation_fn_kwargs"]["k"]
236
242
 
237
243
  if "normalize_activations" in new_cfg and isinstance(
238
244
  new_cfg["normalize_activations"], bool
@@ -1,9 +1,10 @@
1
1
  import io
2
2
  import json
3
3
  import sys
4
+ from collections.abc import Iterator
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
- from typing import Iterator, Literal, cast
7
+ from typing import Literal, cast
7
8
 
8
9
  import torch
9
10
  from datasets import Dataset, DatasetDict, load_dataset
sae_lens/saes/sae.py CHANGED
@@ -14,7 +14,6 @@ from typing import (
14
14
  Generic,
15
15
  Literal,
16
16
  NamedTuple,
17
- Type,
18
17
  TypeVar,
19
18
  )
20
19
 
@@ -534,7 +533,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
534
533
  @classmethod
535
534
  @deprecated("Use load_from_disk instead")
536
535
  def load_from_pretrained(
537
- cls: Type[T_SAE],
536
+ cls: type[T_SAE],
538
537
  path: str | Path,
539
538
  device: str = "cpu",
540
539
  dtype: str | None = None,
@@ -543,7 +542,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
543
542
 
544
543
  @classmethod
545
544
  def load_from_disk(
546
- cls: Type[T_SAE],
545
+ cls: type[T_SAE],
547
546
  path: str | Path,
548
547
  device: str = "cpu",
549
548
  dtype: str | None = None,
@@ -564,7 +563,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
564
563
 
565
564
  @classmethod
566
565
  def from_pretrained(
567
- cls: Type[T_SAE],
566
+ cls: type[T_SAE],
568
567
  release: str,
569
568
  sae_id: str,
570
569
  device: str = "cpu",
@@ -585,7 +584,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
585
584
 
586
585
  @classmethod
587
586
  def from_pretrained_with_cfg_and_sparsity(
588
- cls: Type[T_SAE],
587
+ cls: type[T_SAE],
589
588
  release: str,
590
589
  sae_id: str,
591
590
  device: str = "cpu",
@@ -684,7 +683,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
684
683
  return sae, cfg_dict, log_sparsities
685
684
 
686
685
  @classmethod
687
- def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
686
+ def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
688
687
  """Create an SAE from a config dictionary."""
689
688
  sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
690
689
  sae_config_cls = cls.get_sae_config_class_for_architecture(
@@ -694,8 +693,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
694
693
 
695
694
  @classmethod
696
695
  def get_sae_class_for_architecture(
697
- cls: Type[T_SAE], architecture: str
698
- ) -> Type[T_SAE]:
696
+ cls: type[T_SAE], architecture: str
697
+ ) -> type[T_SAE]:
699
698
  """Get the SAE class for a given architecture."""
700
699
  sae_cls, _ = get_sae_class(architecture)
701
700
  if not issubclass(sae_cls, cls):
@@ -1000,8 +999,8 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1000
999
 
1001
1000
  @classmethod
1002
1001
  def get_sae_class_for_architecture(
1003
- cls: Type[T_TRAINING_SAE], architecture: str
1004
- ) -> Type[T_TRAINING_SAE]:
1002
+ cls: type[T_TRAINING_SAE], architecture: str
1003
+ ) -> type[T_TRAINING_SAE]:
1005
1004
  """Get the SAE class for a given architecture."""
1006
1005
  sae_cls, _ = get_sae_training_class(architecture)
1007
1006
  if not issubclass(sae_cls, cls):
sae_lens/saes/topk_sae.py CHANGED
@@ -6,6 +6,7 @@ from typing import Callable
6
6
  import torch
7
7
  from jaxtyping import Float
8
8
  from torch import nn
9
+ from transformer_lens.hook_points import HookPoint
9
10
  from typing_extensions import override
10
11
 
11
12
  from sae_lens.saes.sae import (
@@ -15,34 +16,102 @@ from sae_lens.saes.sae import (
15
16
  TrainingSAE,
16
17
  TrainingSAEConfig,
17
18
  TrainStepInput,
19
+ _disable_hooks,
18
20
  )
19
21
 
20
22
 
23
+ class SparseHookPoint(HookPoint):
24
+ """
25
+ A HookPoint that takes in a sparse tensor.
26
+ Overrides TransformerLens's HookPoint.
27
+ """
28
+
29
+ def __init__(self, d_sae: int):
30
+ super().__init__()
31
+ self.d_sae = d_sae
32
+
33
+ @override
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ using_hooks = (
36
+ self._forward_hooks is not None and len(self._forward_hooks) > 0
37
+ ) or (self._backward_hooks is not None and len(self._backward_hooks) > 0)
38
+ if using_hooks and x.is_sparse:
39
+ return x.to_dense()
40
+ return x # if no hooks are being used, use passthrough
41
+
42
+
21
43
  class TopK(nn.Module):
22
44
  """
23
45
  A simple TopK activation that zeroes out all but the top K elements along the last dimension,
24
46
  and applies ReLU to the top K elements.
25
47
  """
26
48
 
27
- b_enc: nn.Parameter
49
+ use_sparse_activations: bool
28
50
 
29
51
  def __init__(
30
52
  self,
31
53
  k: int,
54
+ use_sparse_activations: bool = False,
32
55
  ):
33
56
  super().__init__()
34
57
  self.k = k
58
+ self.use_sparse_activations = use_sparse_activations
35
59
 
36
- def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ def forward(
61
+ self,
62
+ x: torch.Tensor,
63
+ ) -> torch.Tensor:
37
64
  """
38
65
  1) Select top K elements along the last dimension.
39
66
  2) Apply ReLU.
40
67
  3) Zero out all other entries.
41
68
  """
42
- topk = torch.topk(x, k=self.k, dim=-1)
43
- values = topk.values.relu()
69
+ topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False)
70
+ values = topk_values.relu()
71
+ if self.use_sparse_activations:
72
+ # Produce a COO sparse tensor (use sparse matrix multiply in decode)
73
+ original_shape = x.shape
74
+
75
+ # Create indices for all dimensions
76
+ # For each element in topk_indices, we need to map it back to the original tensor coordinates
77
+ batch_dims = original_shape[:-1] # All dimensions except the last one
78
+ num_batch_elements = torch.prod(torch.tensor(batch_dims)).item()
79
+
80
+ # Create batch indices - each batch element repeated k times
81
+ batch_indices_flat = torch.arange(
82
+ num_batch_elements, device=x.device
83
+ ).repeat_interleave(self.k)
84
+
85
+ # Convert flat batch indices back to multi-dimensional indices
86
+ if len(batch_dims) == 1:
87
+ # 2D case: [batch, features]
88
+ sparse_indices = torch.stack(
89
+ [
90
+ batch_indices_flat,
91
+ topk_indices.flatten(),
92
+ ]
93
+ )
94
+ else:
95
+ # 3D+ case: need to unravel the batch indices
96
+ batch_indices_multi = []
97
+ remaining = batch_indices_flat
98
+ for dim_size in reversed(batch_dims):
99
+ batch_indices_multi.append(remaining % dim_size)
100
+ remaining = remaining // dim_size
101
+ batch_indices_multi.reverse()
102
+
103
+ sparse_indices = torch.stack(
104
+ [
105
+ *batch_indices_multi,
106
+ topk_indices.flatten(),
107
+ ]
108
+ )
109
+
110
+ return torch.sparse_coo_tensor(
111
+ sparse_indices, values.flatten(), original_shape
112
+ )
44
113
  result = torch.zeros_like(x)
45
- result.scatter_(-1, topk.indices, values)
114
+ result.scatter_(-1, topk_indices, values)
46
115
  return result
47
116
 
48
117
 
@@ -60,6 +129,63 @@ class TopKSAEConfig(SAEConfig):
60
129
  return "topk"
61
130
 
62
131
 
132
+ def _sparse_matmul_nd(
133
+ sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor
134
+ ) -> torch.Tensor:
135
+ """
136
+ Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out]
137
+ to get a result of shape [..., d_out].
138
+
139
+ This function handles sparse tensors with arbitrary batch dimensions by flattening
140
+ the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back.
141
+ """
142
+ original_shape = sparse_tensor.shape
143
+ batch_dims = original_shape[:-1]
144
+ d_sae = original_shape[-1]
145
+ d_out = dense_matrix.shape[-1]
146
+
147
+ if sparse_tensor.ndim == 2:
148
+ # Simple 2D case - use torch.sparse.mm directly
149
+ # sparse.mm errors with bfloat16 :(
150
+ with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
151
+ return torch.sparse.mm(sparse_tensor, dense_matrix)
152
+
153
+ # For 3D+ case, reshape to 2D, multiply, then reshape back
154
+ batch_size = int(torch.prod(torch.tensor(batch_dims)).item())
155
+
156
+ # Ensure tensor is coalesced for efficient access to indices/values
157
+ if not sparse_tensor.is_coalesced():
158
+ sparse_tensor = sparse_tensor.coalesce()
159
+
160
+ # Get indices and values
161
+ indices = sparse_tensor.indices() # [ndim, nnz]
162
+ values = sparse_tensor.values() # [nnz]
163
+
164
+ # Convert multi-dimensional batch indices to flat indices
165
+ flat_batch_indices = torch.zeros_like(indices[0])
166
+ multiplier = 1
167
+ for i in reversed(range(len(batch_dims))):
168
+ flat_batch_indices += indices[i] * multiplier
169
+ multiplier *= batch_dims[i]
170
+
171
+ # Create 2D sparse tensor indices [batch_flat, feature]
172
+ sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]])
173
+
174
+ # Create 2D sparse tensor
175
+ sparse_2d = torch.sparse_coo_tensor(
176
+ sparse_2d_indices, values, (batch_size, d_sae)
177
+ ).coalesce()
178
+
179
+ # sparse.mm errors with bfloat16 :(
180
+ with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
181
+ # Do the matrix multiplication
182
+ result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out]
183
+
184
+ # Reshape back to original batch dimensions
185
+ result_shape = tuple(batch_dims) + (d_out,)
186
+ return result_2d.view(result_shape)
187
+
188
+
63
189
  class TopKSAE(SAE[TopKSAEConfig]):
64
190
  """
65
191
  An inference-only sparse autoencoder using a "topk" activation function.
@@ -96,21 +222,26 @@ class TopKSAE(SAE[TopKSAEConfig]):
96
222
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
97
223
 
98
224
  def decode(
99
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
225
+ self,
226
+ feature_acts: Float[torch.Tensor, "... d_sae"],
100
227
  ) -> Float[torch.Tensor, "... d_in"]:
101
228
  """
102
229
  Reconstructs the input from topk feature activations.
103
230
  Applies optional finetuning scaling, hooking to recons, out normalization,
104
231
  and optional head reshaping.
105
232
  """
106
- sae_out_pre = feature_acts @ self.W_dec + self.b_dec
233
+ # Handle sparse tensors using efficient sparse matrix multiplication
234
+ if feature_acts.is_sparse:
235
+ sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
236
+ else:
237
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
107
238
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
108
239
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
109
240
  return self.reshape_fn_out(sae_out_pre, self.d_head)
110
241
 
111
242
  @override
112
243
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
113
- return TopK(self.cfg.k)
244
+ return TopK(self.cfg.k, use_sparse_activations=False)
114
245
 
115
246
  @override
116
247
  @torch.no_grad()
@@ -124,9 +255,43 @@ class TopKSAE(SAE[TopKSAEConfig]):
124
255
  class TopKTrainingSAEConfig(TrainingSAEConfig):
125
256
  """
126
257
  Configuration class for training a TopKTrainingSAE.
258
+
259
+ Args:
260
+ k (int): Number of top features to keep active. Only the top k features
261
+ with the highest pre-activations will be non-zero. Defaults to 100.
262
+ use_sparse_activations (bool): Whether to use sparse tensor representations
263
+ for activations during training. This can reduce memory usage and improve
264
+ performance when k is small relative to d_sae, but is only worthwhile if
265
+ using float32 and not using autocast. Defaults to False.
266
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
267
+ dead neurons to learn useful features. This loss helps prevent neuron death
268
+ in TopK SAEs by having dead neurons reconstruct the residual error from
269
+ live neurons. Defaults to 1.0.
270
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
271
+ 0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
272
+ Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
273
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
274
+ Inherited from SAEConfig.
275
+ d_sae (int): SAE latent dimension (number of features in the SAE).
276
+ Inherited from SAEConfig.
277
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
278
+ Defaults to "float32".
279
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
280
+ Defaults to "cpu".
281
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
282
+ before encoding. Inherited from SAEConfig. Defaults to True.
283
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
284
+ Normalization strategy for input activations. Inherited from SAEConfig.
285
+ Defaults to "none".
286
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
287
+ (useful for attention head outputs). Inherited from SAEConfig.
288
+ Defaults to "none".
289
+ metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
290
+ Inherited from SAEConfig.
127
291
  """
128
292
 
129
293
  k: int = 100
294
+ use_sparse_activations: bool = False
130
295
  aux_loss_coefficient: float = 1.0
131
296
 
132
297
  @override
@@ -144,6 +309,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
144
309
 
145
310
  def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
146
311
  super().__init__(cfg, use_error_term)
312
+ self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae)
313
+ self.setup()
147
314
 
148
315
  @override
149
316
  def initialize_weights(self) -> None:
@@ -163,6 +330,41 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
163
330
  feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
164
331
  return feature_acts, hidden_pre
165
332
 
333
+ @override
334
+ def decode(
335
+ self,
336
+ feature_acts: Float[torch.Tensor, "... d_sae"],
337
+ ) -> Float[torch.Tensor, "... d_in"]:
338
+ """
339
+ Decodes feature activations back into input space,
340
+ applying optional finetuning scale, hooking, out normalization, etc.
341
+ """
342
+ # Handle sparse tensors using efficient sparse matrix multiplication
343
+ if feature_acts.is_sparse:
344
+ sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
345
+ else:
346
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
347
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
348
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
349
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
350
+
351
+ @override
352
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
353
+ """Forward pass through the SAE."""
354
+ feature_acts = self.encode(x)
355
+ sae_out = self.decode(feature_acts)
356
+
357
+ if self.use_error_term:
358
+ with torch.no_grad():
359
+ # Recompute without hooks for true error term
360
+ with _disable_hooks(self):
361
+ feature_acts_clean = self.encode(x)
362
+ x_reconstruct_clean = self.decode(feature_acts_clean)
363
+ sae_error = self.hook_sae_error(x - x_reconstruct_clean)
364
+ sae_out = sae_out + sae_error
365
+
366
+ return self.hook_sae_output(sae_out)
367
+
166
368
  @override
167
369
  def calculate_aux_loss(
168
370
  self,
@@ -189,7 +391,7 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
189
391
 
190
392
  @override
191
393
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
192
- return TopK(self.cfg.k)
394
+ return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations)
193
395
 
194
396
  @override
195
397
  def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
@@ -1,4 +1,4 @@
1
- from typing import Generator, Iterator
1
+ from collections.abc import Generator, Iterator
2
2
 
3
3
  import torch
4
4
 
@@ -253,12 +253,14 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
253
253
  )
254
254
 
255
255
  with torch.no_grad():
256
- did_fire = (train_step_output.feature_acts > 0).float().sum(-2) > 0
256
+ # calling .bool() should be equivalent to .abs() > 0, and work with coo tensors
257
+ firing_feats = train_step_output.feature_acts.bool().float()
258
+ did_fire = firing_feats.sum(-2).bool()
259
+ if did_fire.is_sparse:
260
+ did_fire = did_fire.to_dense()
257
261
  self.n_forward_passes_since_fired += 1
258
262
  self.n_forward_passes_since_fired[did_fire] = 0
259
- self.act_freq_scores += (
260
- (train_step_output.feature_acts.abs() > 0).float().sum(0)
261
- )
263
+ self.act_freq_scores += firing_feats.sum(0)
262
264
  self.n_frac_active_samples += self.cfg.train_batch_size_samples
263
265
 
264
266
  # Grad scaler will rescale gradients if autocast is enabled
@@ -310,7 +312,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
310
312
  loss = output.loss.item()
311
313
 
312
314
  # metrics for currents acts
313
- l0 = (feature_acts > 0).float().sum(-1).mean()
315
+ l0 = feature_acts.bool().float().sum(-1).to_dense().mean()
314
316
  current_learning_rate = self.optimizer.param_groups[0]["lr"]
315
317
 
316
318
  per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze()
@@ -1,4 +1,4 @@
1
- from typing import Iterator
1
+ from collections.abc import Iterator
2
2
 
3
3
  import torch
4
4
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.12.2
3
+ Version: 6.13.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,39 +1,39 @@
1
- sae_lens/__init__.py,sha256=TptOdqP3B6E_TTQ4n6DXAIDA9c1_9LUUsDkoqyrSSBg,3589
1
+ sae_lens/__init__.py,sha256=6cL-2l4CIzZJfgyRP5I90zu2Tty196wOgFg1JGlQd1c,3589
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
5
  sae_lens/cache_activations_runner.py,sha256=cNeAtp2JQ_vKbeddZVM-tcPLYyyfTWL8NDna5KQpkLI,12583
6
6
  sae_lens/config.py,sha256=IdRXSKPfYY3hwUovj-u83eep8z52gkJHII0mY0KseYY,28739
7
7
  sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
- sae_lens/evals.py,sha256=4hanbyG8qZLItWqft94F4ZjUoytPVB7fw5s0P4Oi0VE,39504
8
+ sae_lens/evals.py,sha256=p4AOueeemhJXyfLx2TxOva8LXxXj63JSKe9Lnib3mHs,39623
9
9
  sae_lens/llm_sae_training_runner.py,sha256=sJTcDX1bUJJ_jZLUT88-8KUYIAPeUGoXktX68PsBqw0,15137
10
10
  sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
11
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=CVzHntSUKR1X3_gAqn8K_Ajq8D85qBrmrgEgU93IV4A,49609
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=SM4aT8NM6ezYix5c2u7p72Fz2RfvTtf7gw5RdOSKXhc,49846
13
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
- sae_lens/pretokenize_runner.py,sha256=w0f6SfZLAxbp5eAAKnet8RqUB_DKofZ9RGsoJwFnYbA,7058
14
+ sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r4,7085
15
15
  sae_lens/pretrained_saes.yaml,sha256=6ca3geEB6NyhULUrmdtPDK8ea0YdpLp8_au78vIFC5w,602553
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=jVwazK8Q6dW5J6_zFXPoNAuBvSxgziQ8eMOjGM3t-X8,1475
18
18
  sae_lens/saes/batchtopk_sae.py,sha256=GX_J0vH4vzeLqYxl0mkfsZQpFEoCEHMR4dIG8fz8N8w,3449
19
19
  sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
20
20
  sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
21
- sae_lens/saes/sae.py,sha256=McpF4pTh70r6SQUbHFm0YQ9X2c2qPULBUSd_YmnEk4Y,38284
21
+ sae_lens/saes/sae.py,sha256=nuII6ZmaVtJWhPjyhasHQyiv_Wj-zdAtRQqJRYbVBQs,38274
22
22
  sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
23
- sae_lens/saes/topk_sae.py,sha256=CXMBI6CFvI5829bOhoQ350VXR9d8uFHUDlULTIWHXoU,8686
23
+ sae_lens/saes/topk_sae.py,sha256=pM26I9uDeh_ZWx0HXUyPVFfEV2pfuRJmAPNWR5pmRhY,17615
24
24
  sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,12156
25
- sae_lens/tokenization_and_batching.py,sha256=now7caLbU3p-iGokNwmqZDyIvxYoXgnG1uklhgiLZN4,4656
25
+ sae_lens/tokenization_and_batching.py,sha256=jV7Rx5wHHcYMmexFhvbSk2q5R0gYBjtKoJKpowAgMEo,4665
26
26
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
28
28
  sae_lens/training/activations_store.py,sha256=2EUY2abqpT5El3T95sypM_JRDgiKL3VeT73U9SQIFGY,32903
29
29
  sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
30
30
  sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
31
- sae_lens/training/sae_trainer.py,sha256=Jh5AyBGtfZjnprv9H3k0p_luWWnM7YFjlmHuO1W_J6U,15465
32
- sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
31
+ sae_lens/training/sae_trainer.py,sha256=il4Evf-c4F3Uf2n_v-AOItCasX-uPxYTzn_sZLvLkl0,15633
32
+ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
33
33
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
34
34
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
35
35
  sae_lens/util.py,sha256=lW7fBn_b8quvRYlen9PUmB7km60YhKyjmuelB1f6KzQ,2253
36
- sae_lens-6.12.2.dist-info/METADATA,sha256=m8hF8tj-b70b5iAvN21ZDxOXzRHRxDmwpJclZzHqPw4,5318
37
- sae_lens-6.12.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
38
- sae_lens-6.12.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
39
- sae_lens-6.12.2.dist-info/RECORD,,
36
+ sae_lens-6.13.0.dist-info/METADATA,sha256=rqSlR_xjf3fqZga4OHpNtrhKzaA4tIrobj-e6yq8sbA,5318
37
+ sae_lens-6.13.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
38
+ sae_lens-6.13.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
39
+ sae_lens-6.13.0.dist-info/RECORD,,