sae-lens 6.32.1__py3-none-any.whl → 6.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/loading/pretrained_saes_directory.py +0 -22
- sae_lens/saes/sae.py +0 -31
- sae_lens/synthetic/__init__.py +13 -0
- sae_lens/synthetic/correlation.py +12 -14
- sae_lens/synthetic/stats.py +205 -0
- {sae_lens-6.32.1.dist-info → sae_lens-6.33.0.dist-info}/METADATA +1 -1
- {sae_lens-6.32.1.dist-info → sae_lens-6.33.0.dist-info}/RECORD +10 -9
- {sae_lens-6.32.1.dist-info → sae_lens-6.33.0.dist-info}/WHEEL +0 -0
- {sae_lens-6.32.1.dist-info → sae_lens-6.33.0.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -57,28 +57,6 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
|
|
|
57
57
|
return directory
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
|
|
61
|
-
"""
|
|
62
|
-
Retrieve the norm_scaling_factor for a specific SAE if it exists.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
release (str): The release name of the SAE.
|
|
66
|
-
sae_id (str): The ID of the specific SAE.
|
|
67
|
-
|
|
68
|
-
Returns:
|
|
69
|
-
float | None: The norm_scaling_factor if it exists, None otherwise.
|
|
70
|
-
"""
|
|
71
|
-
package = "sae_lens"
|
|
72
|
-
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
73
|
-
with yaml_file.open("r") as file:
|
|
74
|
-
data = yaml.safe_load(file)
|
|
75
|
-
if release in data:
|
|
76
|
-
for sae_info in data[release]["saes"]:
|
|
77
|
-
if sae_info["id"] == sae_id:
|
|
78
|
-
return sae_info.get("norm_scaling_factor")
|
|
79
|
-
return None
|
|
80
|
-
|
|
81
|
-
|
|
82
60
|
def get_repo_id_and_folder_name(release: str, sae_id: str) -> tuple[str, str]:
|
|
83
61
|
saes_directory = get_pretrained_saes_directory()
|
|
84
62
|
sae_info = saes_directory.get(release, None)
|
sae_lens/saes/sae.py
CHANGED
|
@@ -45,7 +45,6 @@ from sae_lens.loading.pretrained_sae_loaders import (
|
|
|
45
45
|
)
|
|
46
46
|
from sae_lens.loading.pretrained_saes_directory import (
|
|
47
47
|
get_config_overrides,
|
|
48
|
-
get_norm_scaling_factor,
|
|
49
48
|
get_pretrained_saes_directory,
|
|
50
49
|
get_releases_for_repo_id,
|
|
51
50
|
get_repo_id_and_folder_name,
|
|
@@ -638,24 +637,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
638
637
|
stacklevel=2,
|
|
639
638
|
)
|
|
640
639
|
elif sae_id not in sae_directory[release].saes_map:
|
|
641
|
-
# Handle special cases like Gemma Scope
|
|
642
|
-
if (
|
|
643
|
-
"gemma-scope" in release
|
|
644
|
-
and "canonical" not in release
|
|
645
|
-
and f"{release}-canonical" in sae_directory
|
|
646
|
-
):
|
|
647
|
-
canonical_ids = list(
|
|
648
|
-
sae_directory[release + "-canonical"].saes_map.keys()
|
|
649
|
-
)
|
|
650
|
-
# Shorten the lengthy string of valid IDs
|
|
651
|
-
if len(canonical_ids) > 5:
|
|
652
|
-
str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
|
|
653
|
-
else:
|
|
654
|
-
str_canonical_ids = str(canonical_ids)
|
|
655
|
-
value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
|
|
656
|
-
else:
|
|
657
|
-
value_suffix = ""
|
|
658
|
-
|
|
659
640
|
valid_ids = list(sae_directory[release].saes_map.keys())
|
|
660
641
|
# Shorten the lengthy string of valid IDs
|
|
661
642
|
if len(valid_ids) > 5:
|
|
@@ -665,7 +646,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
665
646
|
|
|
666
647
|
raise ValueError(
|
|
667
648
|
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
|
|
668
|
-
+ value_suffix
|
|
669
649
|
)
|
|
670
650
|
|
|
671
651
|
conversion_loader = (
|
|
@@ -702,17 +682,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
702
682
|
sae.process_state_dict_for_loading(state_dict)
|
|
703
683
|
sae.load_state_dict(state_dict, assign=True)
|
|
704
684
|
|
|
705
|
-
# Apply normalization if needed
|
|
706
|
-
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
707
|
-
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
|
|
708
|
-
if norm_scaling_factor is not None:
|
|
709
|
-
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
|
|
710
|
-
cfg_dict["normalize_activations"] = "none"
|
|
711
|
-
else:
|
|
712
|
-
warnings.warn(
|
|
713
|
-
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
714
|
-
)
|
|
715
|
-
|
|
716
685
|
# the loaders should already handle the dtype / device conversion
|
|
717
686
|
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
718
687
|
return (
|
sae_lens/synthetic/__init__.py
CHANGED
|
@@ -50,6 +50,13 @@ from sae_lens.synthetic.plotting import (
|
|
|
50
50
|
find_best_feature_ordering_from_sae,
|
|
51
51
|
plot_sae_feature_similarity,
|
|
52
52
|
)
|
|
53
|
+
from sae_lens.synthetic.stats import (
|
|
54
|
+
CorrelationMatrixStats,
|
|
55
|
+
SuperpositionStats,
|
|
56
|
+
compute_correlation_matrix_stats,
|
|
57
|
+
compute_low_rank_correlation_matrix_stats,
|
|
58
|
+
compute_superposition_stats,
|
|
59
|
+
)
|
|
53
60
|
from sae_lens.synthetic.training import (
|
|
54
61
|
SyntheticActivationIterator,
|
|
55
62
|
train_toy_sae,
|
|
@@ -80,6 +87,12 @@ __all__ = [
|
|
|
80
87
|
"orthogonal_initializer",
|
|
81
88
|
"FeatureDictionaryInitializer",
|
|
82
89
|
"cosine_similarities",
|
|
90
|
+
# Statistics
|
|
91
|
+
"compute_correlation_matrix_stats",
|
|
92
|
+
"compute_low_rank_correlation_matrix_stats",
|
|
93
|
+
"compute_superposition_stats",
|
|
94
|
+
"CorrelationMatrixStats",
|
|
95
|
+
"SuperpositionStats",
|
|
83
96
|
# Training utilities
|
|
84
97
|
"SyntheticActivationIterator",
|
|
85
98
|
"SyntheticDataEvalResult",
|
|
@@ -3,6 +3,7 @@ from typing import NamedTuple
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
+
from sae_lens import logger
|
|
6
7
|
from sae_lens.util import str_to_dtype
|
|
7
8
|
|
|
8
9
|
|
|
@@ -268,7 +269,7 @@ def generate_random_correlation_matrix(
|
|
|
268
269
|
def generate_random_low_rank_correlation_matrix(
|
|
269
270
|
num_features: int,
|
|
270
271
|
rank: int,
|
|
271
|
-
correlation_scale: float = 0.
|
|
272
|
+
correlation_scale: float = 0.075,
|
|
272
273
|
seed: int | None = None,
|
|
273
274
|
device: torch.device | str = "cpu",
|
|
274
275
|
dtype: torch.dtype | str = torch.float32,
|
|
@@ -331,20 +332,17 @@ def generate_random_low_rank_correlation_matrix(
|
|
|
331
332
|
factor_sq_sum = (factor**2).sum(dim=1)
|
|
332
333
|
diag_term = 1 - factor_sq_sum
|
|
333
334
|
|
|
334
|
-
#
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
/
|
|
335
|
+
# alternatively, we can rescale each row independently to ensure the diagonal is 1
|
|
336
|
+
mask = diag_term < _MIN_DIAG
|
|
337
|
+
factor[mask, :] *= torch.sqrt((1 - _MIN_DIAG) / factor_sq_sum[mask].unsqueeze(1))
|
|
338
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
339
|
+
diag_term = 1 - factor_sq_sum
|
|
340
|
+
|
|
341
|
+
total_rescaled = mask.sum().item()
|
|
342
|
+
if total_rescaled > 0:
|
|
343
|
+
logger.warning(
|
|
344
|
+
f"{total_rescaled} / {num_features} rows were capped. Either reduce the rank or reduce the correlation_scale to avoid rescaling."
|
|
344
345
|
)
|
|
345
|
-
factor = factor * scale
|
|
346
|
-
factor_sq_sum = (factor**2).sum(dim=1)
|
|
347
|
-
diag_term = 1 - factor_sq_sum
|
|
348
346
|
|
|
349
347
|
return LowRankCorrelationMatrix(
|
|
350
348
|
correlation_factor=factor, correlation_diag=diag_term
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
|
|
6
|
+
from sae_lens.synthetic.feature_dictionary import FeatureDictionary
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class CorrelationMatrixStats:
|
|
11
|
+
"""Statistics computed from a correlation matrix."""
|
|
12
|
+
|
|
13
|
+
rms_correlation: float # Root mean square of off-diagonal correlations
|
|
14
|
+
mean_correlation: float # Mean of off-diagonal correlations (not absolute)
|
|
15
|
+
num_features: int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@torch.no_grad()
|
|
19
|
+
def compute_correlation_matrix_stats(
|
|
20
|
+
correlation_matrix: torch.Tensor,
|
|
21
|
+
) -> CorrelationMatrixStats:
|
|
22
|
+
"""Compute correlation statistics from a dense correlation matrix.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
correlation_matrix: Dense correlation matrix of shape (n, n)
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
CorrelationMatrixStats with correlation statistics
|
|
29
|
+
"""
|
|
30
|
+
num_features = correlation_matrix.shape[0]
|
|
31
|
+
|
|
32
|
+
# Extract off-diagonal elements
|
|
33
|
+
mask = ~torch.eye(num_features, dtype=torch.bool, device=correlation_matrix.device)
|
|
34
|
+
off_diag = correlation_matrix[mask]
|
|
35
|
+
|
|
36
|
+
rms_correlation = (off_diag**2).mean().sqrt().item()
|
|
37
|
+
mean_correlation = off_diag.mean().item()
|
|
38
|
+
|
|
39
|
+
return CorrelationMatrixStats(
|
|
40
|
+
rms_correlation=rms_correlation,
|
|
41
|
+
mean_correlation=mean_correlation,
|
|
42
|
+
num_features=num_features,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@torch.no_grad()
|
|
47
|
+
def compute_low_rank_correlation_matrix_stats(
|
|
48
|
+
correlation_matrix: LowRankCorrelationMatrix,
|
|
49
|
+
) -> CorrelationMatrixStats:
|
|
50
|
+
"""Compute correlation statistics from a LowRankCorrelationMatrix.
|
|
51
|
+
|
|
52
|
+
The correlation matrix is represented as:
|
|
53
|
+
correlation = factor @ factor.T + diag(diag_term)
|
|
54
|
+
|
|
55
|
+
The off-diagonal elements are simply factor @ factor.T (the diagonal term
|
|
56
|
+
only affects the diagonal).
|
|
57
|
+
|
|
58
|
+
All statistics are computed efficiently in O(n*r²) time and O(r²) memory
|
|
59
|
+
without materializing the full n×n correlation matrix.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
correlation_matrix: Low-rank correlation matrix
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
CorrelationMatrixStats with correlation statistics
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
factor = correlation_matrix.correlation_factor
|
|
69
|
+
num_features = factor.shape[0]
|
|
70
|
+
num_off_diag = num_features * (num_features - 1)
|
|
71
|
+
|
|
72
|
+
# RMS correlation: uses ||F @ F.T||_F² = ||F.T @ F||_F²
|
|
73
|
+
# This avoids computing the (num_features, num_features) matrix
|
|
74
|
+
G = factor.T @ factor # (rank, rank) - small!
|
|
75
|
+
frobenius_sq = (G**2).sum()
|
|
76
|
+
row_norms_sq = (factor**2).sum(dim=1) # ||F[i]||² for each row
|
|
77
|
+
diag_sq_sum = (row_norms_sq**2).sum() # Σᵢ ||F[i]||⁴
|
|
78
|
+
off_diag_sq_sum = frobenius_sq - diag_sq_sum
|
|
79
|
+
rms_correlation = (off_diag_sq_sum / num_off_diag).sqrt().item()
|
|
80
|
+
|
|
81
|
+
# Mean correlation (not absolute): sum(C) = ||col_sums(F)||², trace(C) = Σ||F[i]||²
|
|
82
|
+
col_sums = factor.sum(dim=0) # (rank,)
|
|
83
|
+
sum_all = (col_sums**2).sum() # 1ᵀ C 1
|
|
84
|
+
trace_C = row_norms_sq.sum()
|
|
85
|
+
mean_correlation = ((sum_all - trace_C) / num_off_diag).item()
|
|
86
|
+
|
|
87
|
+
return CorrelationMatrixStats(
|
|
88
|
+
rms_correlation=rms_correlation,
|
|
89
|
+
mean_correlation=mean_correlation,
|
|
90
|
+
num_features=num_features,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class SuperpositionStats:
|
|
96
|
+
"""Statistics measuring superposition in a feature dictionary."""
|
|
97
|
+
|
|
98
|
+
# Per-latent statistics: for each latent, max and percentile of |cos_sim| with others
|
|
99
|
+
max_abs_cos_sims: torch.Tensor # Shape: (num_features,)
|
|
100
|
+
percentile_abs_cos_sims: dict[int, torch.Tensor] # {percentile: (num_features,)}
|
|
101
|
+
|
|
102
|
+
# Summary statistics (means of the per-latent values)
|
|
103
|
+
mean_max_abs_cos_sim: float
|
|
104
|
+
mean_percentile_abs_cos_sim: dict[int, float]
|
|
105
|
+
mean_abs_cos_sim: float # Mean |cos_sim| across all pairs
|
|
106
|
+
|
|
107
|
+
# Metadata
|
|
108
|
+
num_features: int
|
|
109
|
+
hidden_dim: int
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@torch.no_grad()
|
|
113
|
+
def compute_superposition_stats(
|
|
114
|
+
feature_dictionary: FeatureDictionary,
|
|
115
|
+
batch_size: int = 1024,
|
|
116
|
+
device: str | torch.device | None = None,
|
|
117
|
+
percentiles: list[int] | None = None,
|
|
118
|
+
) -> SuperpositionStats:
|
|
119
|
+
"""Compute superposition statistics for a feature dictionary.
|
|
120
|
+
|
|
121
|
+
Computes pairwise cosine similarities in batches to handle large dictionaries.
|
|
122
|
+
|
|
123
|
+
For each latent i, computes:
|
|
124
|
+
|
|
125
|
+
- max |cos_sim(i, j)| over all j != i
|
|
126
|
+
- kth percentile of |cos_sim(i, j)| over all j != i (for each k in percentiles)
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
feature_dictionary: FeatureDictionary containing the feature vectors
|
|
130
|
+
batch_size: Number of features to process per batch
|
|
131
|
+
device: Device for computation (defaults to feature dictionary's device)
|
|
132
|
+
percentiles: List of percentiles to compute per latent (default: [95, 99])
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
SuperpositionStats with superposition metrics
|
|
136
|
+
"""
|
|
137
|
+
if percentiles is None:
|
|
138
|
+
percentiles = [95, 99]
|
|
139
|
+
|
|
140
|
+
feature_vectors = feature_dictionary.feature_vectors
|
|
141
|
+
num_features, hidden_dim = feature_vectors.shape
|
|
142
|
+
|
|
143
|
+
if num_features < 2:
|
|
144
|
+
raise ValueError("Need at least 2 features to compute superposition stats")
|
|
145
|
+
if device is None:
|
|
146
|
+
device = feature_vectors.device
|
|
147
|
+
|
|
148
|
+
# Normalize features to unit norm for cosine similarity
|
|
149
|
+
features_normalized = feature_vectors.to(device).float()
|
|
150
|
+
norms = torch.linalg.norm(features_normalized, dim=1, keepdim=True)
|
|
151
|
+
features_normalized = features_normalized / norms.clamp(min=1e-8)
|
|
152
|
+
|
|
153
|
+
# Track per-latent statistics
|
|
154
|
+
max_abs_cos_sims = torch.zeros(num_features, device=device)
|
|
155
|
+
percentile_abs_cos_sims = {
|
|
156
|
+
p: torch.zeros(num_features, device=device) for p in percentiles
|
|
157
|
+
}
|
|
158
|
+
sum_abs_cos_sim = 0.0
|
|
159
|
+
n_pairs = 0
|
|
160
|
+
|
|
161
|
+
# Process in batches: for each batch of features, compute similarities with all others
|
|
162
|
+
for i in range(0, num_features, batch_size):
|
|
163
|
+
batch_end = min(i + batch_size, num_features)
|
|
164
|
+
batch = features_normalized[i:batch_end] # (batch_size, hidden_dim)
|
|
165
|
+
|
|
166
|
+
# Compute cosine similarities with all features: (batch_size, num_features)
|
|
167
|
+
cos_sims = batch @ features_normalized.T
|
|
168
|
+
|
|
169
|
+
# Absolute cosine similarities
|
|
170
|
+
abs_cos_sims = cos_sims.abs()
|
|
171
|
+
|
|
172
|
+
# Process each latent in the batch
|
|
173
|
+
for j, idx in enumerate(range(i, batch_end)):
|
|
174
|
+
# Get similarities with all other features (exclude self)
|
|
175
|
+
row = abs_cos_sims[j].clone()
|
|
176
|
+
row[idx] = 0.0 # Exclude self for max
|
|
177
|
+
max_abs_cos_sims[idx] = row.max()
|
|
178
|
+
|
|
179
|
+
# For percentiles, exclude self and compute
|
|
180
|
+
other_sims = torch.cat([abs_cos_sims[j, :idx], abs_cos_sims[j, idx + 1 :]])
|
|
181
|
+
for p in percentiles:
|
|
182
|
+
percentile_abs_cos_sims[p][idx] = torch.quantile(other_sims, p / 100.0)
|
|
183
|
+
|
|
184
|
+
# Sum for mean computation (only count pairs once - with features after this one)
|
|
185
|
+
sum_abs_cos_sim += abs_cos_sims[j, idx + 1 :].sum().item()
|
|
186
|
+
n_pairs += num_features - idx - 1
|
|
187
|
+
|
|
188
|
+
# Compute summary statistics
|
|
189
|
+
mean_max_abs_cos_sim = max_abs_cos_sims.mean().item()
|
|
190
|
+
mean_percentile_abs_cos_sim = {
|
|
191
|
+
p: percentile_abs_cos_sims[p].mean().item() for p in percentiles
|
|
192
|
+
}
|
|
193
|
+
mean_abs_cos_sim = sum_abs_cos_sim / n_pairs if n_pairs > 0 else 0.0
|
|
194
|
+
|
|
195
|
+
return SuperpositionStats(
|
|
196
|
+
max_abs_cos_sims=max_abs_cos_sims.cpu(),
|
|
197
|
+
percentile_abs_cos_sims={
|
|
198
|
+
p: v.cpu() for p, v in percentile_abs_cos_sims.items()
|
|
199
|
+
},
|
|
200
|
+
mean_max_abs_cos_sim=mean_max_abs_cos_sim,
|
|
201
|
+
mean_percentile_abs_cos_sim=mean_percentile_abs_cos_sim,
|
|
202
|
+
mean_abs_cos_sim=mean_abs_cos_sim,
|
|
203
|
+
num_features=num_features,
|
|
204
|
+
hidden_dim=hidden_dim,
|
|
205
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=gHaxlySzLskrAUg2oUZ3aOpnI3U_AVIHce-agGJL9rI,5168
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=FZExlMviNwWR7OGUSGRbd0l-yUDGSp80gglI_ivILrY,412
|
|
3
3
|
sae_lens/analysis/compat.py,sha256=cgE3nhFcJTcuhppxbL71VanJS7YqVEOefuneB5eOaPw,538
|
|
4
4
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=LpnjxSAcItqqXA4SJyZuxY4Ki0UOuWV683wg9laYAsY,14050
|
|
@@ -12,7 +12,7 @@ sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHF
|
|
|
12
12
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
13
13
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
sae_lens/loading/pretrained_sae_loaders.py,sha256=kshvA0NivOc7B3sL19lHr_zrC_DDfW2T6YWb5j0hgAk,63930
|
|
15
|
-
sae_lens/loading/pretrained_saes_directory.py,sha256=
|
|
15
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=lSnHl77IO5dd7iO21ynCzZNMrzuJAT8Za4W5THNq0qw,3554
|
|
16
16
|
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
17
17
|
sae_lens/pretrained_saes.yaml,sha256=IVBLLR8_XNllJ1O-kVv9ED4u0u44Yn8UOL9R-f8Idp4,1511936
|
|
18
18
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
@@ -22,20 +22,21 @@ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,88
|
|
|
22
22
|
sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
|
|
23
23
|
sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
|
|
24
24
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
|
|
25
|
-
sae_lens/saes/sae.py,sha256=
|
|
25
|
+
sae_lens/saes/sae.py,sha256=wkwqzNragj-1189cV52S3_XeRtEgBd2ZNwvL2EsKkWw,39429
|
|
26
26
|
sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
|
|
27
27
|
sae_lens/saes/temporal_sae.py,sha256=S44sPddVj2xujA02CC8gT1tG0in7c_CSAhspu9FHbaA,13273
|
|
28
28
|
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
29
29
|
sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
|
|
30
|
-
sae_lens/synthetic/__init__.py,sha256=
|
|
30
|
+
sae_lens/synthetic/__init__.py,sha256=hRRA3xhEQUacGyFbJXkLVYg_8A1bbSYYWlVovb0g4KU,3503
|
|
31
31
|
sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
|
|
32
|
-
sae_lens/synthetic/correlation.py,sha256=
|
|
32
|
+
sae_lens/synthetic/correlation.py,sha256=tD8J9abWfuFtGZrEbbFn4P8FeTcNKF2V5JhBLwDUmkg,13146
|
|
33
33
|
sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
|
|
34
34
|
sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
|
|
35
35
|
sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
|
|
36
36
|
sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
|
|
37
37
|
sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
|
|
38
38
|
sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
|
|
39
|
+
sae_lens/synthetic/stats.py,sha256=BoDPKDx8pgFF5Ko_IaBRZTczm7-ANUIRjjF5W5Qh3Lk,7441
|
|
39
40
|
sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
|
|
40
41
|
sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
|
|
41
42
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -48,7 +49,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
48
49
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
49
50
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
50
51
|
sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
|
|
51
|
-
sae_lens-6.
|
|
52
|
-
sae_lens-6.
|
|
53
|
-
sae_lens-6.
|
|
54
|
-
sae_lens-6.
|
|
52
|
+
sae_lens-6.33.0.dist-info/METADATA,sha256=X6XqngWTNEsfdaPPWXxtF8Kvdp8fAk8i68sfRtDb2xo,6566
|
|
53
|
+
sae_lens-6.33.0.dist-info/WHEEL,sha256=3ny-bZhpXrU6vSQ1UPG34FoxZBp3lVcvK0LkgUz6VLk,88
|
|
54
|
+
sae_lens-6.33.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
55
|
+
sae_lens-6.33.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|