sae-lens 6.28.1__py3-none-any.whl → 6.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,18 +9,35 @@ from typing import Callable
9
9
 
10
10
  import torch
11
11
  from torch import nn
12
- from tqdm import tqdm
12
+ from tqdm.auto import tqdm
13
13
 
14
14
  FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
15
15
 
16
16
 
17
17
  def orthogonalize_embeddings(
18
18
  embeddings: torch.Tensor,
19
- target_cos_sim: float = 0,
20
19
  num_steps: int = 200,
21
20
  lr: float = 0.01,
22
21
  show_progress: bool = False,
22
+ chunk_size: int = 1024,
23
23
  ) -> torch.Tensor:
24
+ """
25
+ Orthogonalize embeddings using gradient descent with chunked computation.
26
+
27
+ Uses chunked computation to avoid O(n²) memory usage when computing pairwise
28
+ dot products. Memory usage is O(chunk_size × n) instead of O(n²).
29
+
30
+ Args:
31
+ embeddings: Tensor of shape [num_vectors, hidden_dim]
32
+ num_steps: Number of optimization steps
33
+ lr: Learning rate for Adam optimizer
34
+ show_progress: Whether to show progress bar
35
+ chunk_size: Number of vectors to process at once. Smaller values use less
36
+ memory but may be slower.
37
+
38
+ Returns:
39
+ Orthogonalized embeddings of the same shape, normalized to unit length.
40
+ """
24
41
  num_vectors = embeddings.shape[0]
25
42
  # Create a detached copy and normalize, then enable gradients
26
43
  embeddings = embeddings.detach().clone()
@@ -29,24 +46,37 @@ def orthogonalize_embeddings(
29
46
 
30
47
  optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
31
48
 
32
- # Create a mask to zero out diagonal elements (avoid in-place operations)
33
- off_diagonal_mask = ~torch.eye(
34
- num_vectors, dtype=torch.bool, device=embeddings.device
35
- )
36
-
37
49
  pbar = tqdm(
38
50
  range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
39
51
  )
40
52
  for _ in pbar:
41
53
  optimizer.zero_grad()
42
54
 
43
- dot_products = embeddings @ embeddings.T
44
- diff = dot_products - target_cos_sim
45
- # Use masking instead of in-place fill_diagonal_
46
- off_diagonal_diff = diff * off_diagonal_mask.float()
47
- loss = off_diagonal_diff.pow(2).sum()
48
- loss = loss + num_vectors * (dot_products.diag() - 1).pow(2).sum()
55
+ off_diag_loss = torch.tensor(0.0, device=embeddings.device)
56
+ diag_loss = torch.tensor(0.0, device=embeddings.device)
57
+
58
+ for i in range(0, num_vectors, chunk_size):
59
+ end_i = min(i + chunk_size, num_vectors)
60
+ chunk = embeddings[i:end_i]
61
+ chunk_dots = chunk @ embeddings.T # [chunk_size, num_vectors]
49
62
 
63
+ # Create mask to zero out diagonal elements for this chunk
64
+ # Diagonal of full matrix: position (i+k, i+k) → in chunk_dots: (k, i+k)
65
+ chunk_len = end_i - i
66
+ row_indices = torch.arange(chunk_len, device=embeddings.device)
67
+ col_indices = i + row_indices # column indices in full matrix
68
+
69
+ # Boolean mask: True for off-diagonal elements we want to include
70
+ off_diag_mask = torch.ones_like(chunk_dots, dtype=torch.bool)
71
+ off_diag_mask[row_indices, col_indices] = False
72
+
73
+ off_diag_loss = off_diag_loss + chunk_dots[off_diag_mask].pow(2).sum()
74
+
75
+ # Diagonal loss: keep self-dot-products at 1
76
+ diag_vals = chunk_dots[row_indices, col_indices]
77
+ diag_loss = diag_loss + (diag_vals - 1).pow(2).sum()
78
+
79
+ loss = off_diag_loss + num_vectors * diag_loss
50
80
  loss.backward()
51
81
  optimizer.step()
52
82
  pbar.set_description(f"loss: {loss.item():.3f}")
@@ -59,7 +89,10 @@ def orthogonalize_embeddings(
59
89
 
60
90
 
61
91
  def orthogonal_initializer(
62
- num_steps: int = 200, lr: float = 0.01, show_progress: bool = False
92
+ num_steps: int = 200,
93
+ lr: float = 0.01,
94
+ show_progress: bool = False,
95
+ chunk_size: int = 1024,
63
96
  ) -> FeatureDictionaryInitializer:
64
97
  def initializer(feature_dict: "FeatureDictionary") -> None:
65
98
  feature_dict.feature_vectors.data = orthogonalize_embeddings(
@@ -67,6 +100,7 @@ def orthogonal_initializer(
67
100
  num_steps=num_steps,
68
101
  lr=lr,
69
102
  show_progress=show_progress,
103
+ chunk_size=chunk_size,
70
104
  )
71
105
 
72
106
  return initializer
@@ -97,6 +131,7 @@ class FeatureDictionary(nn.Module):
97
131
  hidden_dim: int,
98
132
  bias: bool = False,
99
133
  initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
134
+ device: str | torch.device = "cpu",
100
135
  ):
101
136
  """
102
137
  Create a new FeatureDictionary.
@@ -106,20 +141,23 @@ class FeatureDictionary(nn.Module):
106
141
  hidden_dim: Dimensionality of the hidden space
107
142
  bias: Whether to include a bias term in the embedding
108
143
  initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
144
+ device: Device to use for the feature dictionary.
109
145
  """
110
146
  super().__init__()
111
147
  self.num_features = num_features
112
148
  self.hidden_dim = hidden_dim
113
149
 
114
150
  # Initialize feature vectors as unit vectors
115
- embeddings = torch.randn(num_features, hidden_dim)
151
+ embeddings = torch.randn(num_features, hidden_dim, device=device)
116
152
  embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
117
153
  min=1e-8
118
154
  )
119
155
  self.feature_vectors = nn.Parameter(embeddings)
120
156
 
121
157
  # Initialize bias (zeros if not using bias, but still a parameter for consistent API)
122
- self.bias = nn.Parameter(torch.zeros(hidden_dim), requires_grad=bias)
158
+ self.bias = nn.Parameter(
159
+ torch.zeros(hidden_dim, device=device), requires_grad=bias
160
+ )
123
161
 
124
162
  if initializer is not None:
125
163
  initializer(self)
@@ -130,9 +168,18 @@ class FeatureDictionary(nn.Module):
130
168
 
131
169
  Args:
132
170
  feature_activations: Tensor of shape [batch, num_features] containing
133
- sparse feature activation values
171
+ sparse feature activation values. Can be dense or sparse COO.
134
172
 
135
173
  Returns:
136
174
  Tensor of shape [batch, hidden_dim] containing dense hidden activations
137
175
  """
176
+ if feature_activations.is_sparse:
177
+ # autocast is disabled here because sparse matmul is not supported with bfloat16
178
+ with torch.autocast(
179
+ device_type=feature_activations.device.type, enabled=False
180
+ ):
181
+ return (
182
+ torch.sparse.mm(feature_activations, self.feature_vectors)
183
+ + self.bias
184
+ )
138
185
  return feature_activations @ self.feature_vectors + self.bias