mi-crow 1.0.0__py3-none-any.whl → 1.0.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,7 @@ import torch
10
10
  from torch import nn
11
11
 
12
12
  from mi_crow.mechanistic.sae.concepts.concept_models import NeuronText
13
+ from mi_crow.mechanistic.sae.concepts.text_heap import TextHeap
13
14
  from mi_crow.mechanistic.sae.autoencoder_context import AutoencoderContext
14
15
  from mi_crow.utils import get_logger
15
16
 
@@ -28,12 +29,11 @@ class AutoencoderConcepts:
28
29
  self._n_size = context.n_latents
29
30
  self.dictionary: ConceptDictionary | None = None
30
31
 
31
- # Concept manipulation parameters
32
32
  self.multiplication = nn.Parameter(torch.ones(self._n_size))
33
33
  self.bias = nn.Parameter(torch.ones(self._n_size))
34
34
 
35
- # Top texts tracking
36
- self._top_texts_heaps: list[list[tuple[float, tuple[float, str, int]]]] | None = None
35
+ self._text_heaps_positive: list[TextHeap] | None = None
36
+ self._text_heaps_negative: list[TextHeap] | None = None
37
37
  self._text_tracking_k: int = 5
38
38
  self._text_tracking_negative: bool = False
39
39
 
@@ -81,7 +81,7 @@ class AutoencoderConcepts:
81
81
 
82
82
  def generate_concepts_with_llm(self, llm_provider: str | None = None):
83
83
  """Generate concepts using LLM based on current top texts"""
84
- if self._top_texts_heaps is None:
84
+ if self._text_heaps_positive is None:
85
85
  raise ValueError("No top texts available. Enable text tracking and run inference first.")
86
86
 
87
87
  from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
@@ -96,8 +96,10 @@ class AutoencoderConcepts:
96
96
 
97
97
  def _ensure_heaps(self, n_neurons: int) -> None:
98
98
  """Ensure heaps are initialized for the given number of neurons."""
99
- if self._top_texts_heaps is None:
100
- self._top_texts_heaps = [[] for _ in range(n_neurons)]
99
+ if self._text_heaps_positive is None:
100
+ self._text_heaps_positive = [TextHeap(self._text_tracking_k) for _ in range(n_neurons)]
101
+ if self._text_tracking_negative and self._text_heaps_negative is None:
102
+ self._text_heaps_negative = [TextHeap(self._text_tracking_k) for _ in range(n_neurons)]
101
103
 
102
104
  def _decode_token(self, text: str, token_idx: int) -> str:
103
105
  """
@@ -148,6 +150,11 @@ class AutoencoderConcepts:
148
150
  """
149
151
  Update top texts heaps from latents and texts.
150
152
 
153
+ Optimized version that:
154
+ - Only processes active neurons (non-zero activations)
155
+ - Vectorizes argmax/argmin operations
156
+ - Eliminates per-neuron tensor slicing
157
+
151
158
  Args:
152
159
  latents: Latent activations tensor, shape [B*T, n_latents] or [B, n_latents] (already flattened)
153
160
  texts: List of texts corresponding to the batch
@@ -173,100 +180,120 @@ class AutoencoderConcepts:
173
180
  # Use the actual number of texts as batch size
174
181
  B = original_B
175
182
  T = BT // B if B > 0 else 1
176
- # Create token indices: [0, 1, 2, ..., T-1, 0, 1, 2, ..., T-1, ...]
177
- token_indices = torch.arange(T, device='cpu').unsqueeze(0).expand(B, T).contiguous().view(B * T)
178
183
  else:
179
184
  # Original was [B, D], latents are [B, n_latents]
180
- # All tokens are at index 0
185
+ B = original_B
181
186
  T = 1
182
- token_indices = torch.zeros(BT, dtype=torch.long, device='cpu')
183
-
184
- # For each neuron, find the maximum activation per text
185
- # This ensures we only track the best activation for each text, not every token position
186
- for j in range(n_neurons):
187
- heap = self._top_texts_heaps[j]
188
-
189
- # For each text in the batch, find the max activation and its token position
190
- texts_processed = 0
191
- texts_added = 0
192
- texts_updated = 0
193
- texts_skipped_duplicate = 0
187
+
188
+ # OPTIMIZATION 1: Find active neurons (have any non-zero activation across batch)
189
+ # Shape: [n_neurons] - boolean mask
190
+ active_neurons_mask = (latents.abs().sum(dim=0) > 0)
191
+ active_neuron_indices = torch.nonzero(active_neurons_mask, as_tuple=False).flatten().tolist()
192
+
193
+ if not active_neuron_indices:
194
+ return # No active neurons, skip
195
+
196
+ # OPTIMIZATION 2: Vectorize argmax/argmin for all neurons at once
197
+ if original_shape is not None and len(original_shape) == 3:
198
+ # Reshape to [B, T, n_neurons]
199
+ latents_3d = latents.view(B, T, n_neurons)
200
+ # For each text, find max/min across tokens for each neuron
201
+ # Shape: [B, n_neurons] - max activation per text per neuron
202
+ max_activations, max_token_indices_3d = latents_3d.max(dim=1) # [B, n_neurons]
203
+ min_activations, min_token_indices_3d = latents_3d.min(dim=1) # [B, n_neurons]
204
+ # max_token_indices_3d is already the token index (0 to T-1)
205
+ max_token_indices = max_token_indices_3d
206
+ min_token_indices = min_token_indices_3d
207
+ else:
208
+ # Shape: [B, n_neurons]
209
+ latents_2d = latents.view(B, n_neurons)
210
+ max_activations = latents_2d # [B, n_neurons]
211
+ max_token_indices = torch.zeros(B, n_neurons, dtype=torch.long, device=latents.device)
212
+ min_activations = latents_2d
213
+ min_token_indices = torch.zeros(B, n_neurons, dtype=torch.long, device=latents.device)
214
+
215
+ # Convert to numpy for faster CPU access (already on CPU from l1_sae.py)
216
+ max_activations_np = max_activations.cpu().numpy()
217
+ min_activations_np = min_activations.cpu().numpy()
218
+ max_token_indices_np = max_token_indices.cpu().numpy()
219
+ min_token_indices_np = min_token_indices.cpu().numpy()
220
+
221
+ # OPTIMIZATION 3: Only process active neurons
222
+ for j in active_neuron_indices:
223
+ heap_positive = self._text_heaps_positive[j]
224
+ heap_negative = self._text_heaps_negative[j] if self._text_tracking_negative else None
225
+
226
+ # OPTIMIZATION 4: Batch process all texts for this neuron
194
227
  for batch_idx in range(original_B):
195
228
  if batch_idx >= len(texts):
196
229
  continue
197
230
 
198
231
  text = texts[batch_idx]
199
- texts_processed += 1
200
-
201
- # Get activations for this text (all token positions)
202
- if original_shape is not None and len(original_shape) == 3:
203
- # 3D case: [B, T, D] -> get slice for this batch
204
- start_idx = batch_idx * T
205
- end_idx = start_idx + T
206
- text_activations = latents[start_idx:end_idx, j] # [T]
207
- text_token_indices = token_indices[start_idx:end_idx] # [T]
208
- else:
209
- # 2D case: [B, D] -> single token
210
- text_activations = latents[batch_idx:batch_idx + 1, j] # [1]
211
- text_token_indices = token_indices[batch_idx:batch_idx + 1] # [1]
212
-
213
- # Find the maximum activation (or minimum if tracking negative)
214
- if self._text_tracking_negative:
215
- # For negative tracking, find the most negative (minimum) value
216
- max_idx = torch.argmin(text_activations)
217
- max_score = float(text_activations[max_idx].item())
218
- adj = -max_score # Negate for heap ordering
219
- else:
220
- # For positive tracking, find the maximum value
221
- max_idx = torch.argmax(text_activations)
222
- max_score = float(text_activations[max_idx].item())
223
- adj = max_score
224
-
225
- # Skip if score is zero (no activation)
226
- if max_score == 0.0:
227
- continue
228
-
229
- token_idx = int(text_token_indices[max_idx].item())
230
-
231
- # Check if we already have this text in the heap
232
- # If so, only update if this activation is better
233
- existing_entry = None
234
- heap_texts = []
235
- for heap_idx, (heap_adj, (heap_score, heap_text, heap_token_idx)) in enumerate(heap):
236
- heap_texts.append(heap_text[:50] if len(heap_text) > 50 else heap_text)
237
- if heap_text == text:
238
- existing_entry = (heap_idx, heap_adj, heap_score, heap_token_idx)
239
- break
240
-
241
- if existing_entry is not None:
242
- # Update existing entry if this activation is better
243
- heap_idx, heap_adj, heap_score, heap_token_idx = existing_entry
244
- if adj > heap_adj:
245
- # Replace with better activation
246
- heap[heap_idx] = (adj, (max_score, text, token_idx))
247
- heapq.heapify(heap) # Re-heapify after modification
248
- texts_updated += 1
249
- else:
250
- texts_skipped_duplicate += 1
251
- else:
252
- # New text, add to heap
253
- if len(heap) < self._text_tracking_k:
254
- heapq.heappush(heap, (adj, (max_score, text, token_idx)))
255
- texts_added += 1
256
- else:
257
- # Compare with smallest adjusted score; replace if better
258
- if adj > heap[0][0]:
259
- heapq.heapreplace(heap, (adj, (max_score, text, token_idx)))
260
- texts_added += 1
232
+
233
+ # Use pre-computed max/min (no tensor slicing needed!)
234
+ max_score_positive = float(max_activations_np[batch_idx, j])
235
+ token_idx_positive = int(max_token_indices_np[batch_idx, j])
236
+
237
+ if max_score_positive > 0.0:
238
+ heap_positive.update(text, max_score_positive, token_idx_positive)
239
+
240
+ if self._text_tracking_negative and heap_negative is not None:
241
+ min_score_negative = float(min_activations_np[batch_idx, j])
242
+ if min_score_negative != 0.0:
243
+ token_idx_negative = int(min_token_indices_np[batch_idx, j])
244
+ heap_negative.update(text, min_score_negative, token_idx_negative, adjusted_score=-min_score_negative)
245
+
246
+ def _extract_activations(
247
+ self,
248
+ latents: torch.Tensor,
249
+ token_indices: torch.Tensor,
250
+ batch_idx: int,
251
+ neuron_idx: int,
252
+ original_shape: tuple[int, ...] | None,
253
+ T: int
254
+ ) -> tuple[torch.Tensor, torch.Tensor]:
255
+ """
256
+ Extract activations for a specific batch item and neuron.
257
+
258
+ Returns:
259
+ Tuple of (text_activations, text_token_indices)
260
+ """
261
+ if original_shape is not None and len(original_shape) == 3:
262
+ start_idx = batch_idx * T
263
+ end_idx = start_idx + T
264
+ text_activations = latents[start_idx:end_idx, neuron_idx]
265
+ text_token_indices = token_indices[start_idx:end_idx]
266
+ else:
267
+ text_activations = latents[batch_idx:batch_idx + 1, neuron_idx]
268
+ text_token_indices = token_indices[batch_idx:batch_idx + 1]
269
+
270
+ return text_activations, text_token_indices
261
271
 
262
272
  def get_top_texts_for_neuron(self, neuron_idx: int, top_m: int | None = None) -> list[NeuronText]:
263
- """Get top texts for a specific neuron."""
264
- if self._top_texts_heaps is None or neuron_idx < 0 or neuron_idx >= len(self._top_texts_heaps):
273
+ """Get top texts for a specific neuron (positive activations)."""
274
+ if self._text_heaps_positive is None or neuron_idx < 0 or neuron_idx >= len(self._text_heaps_positive):
275
+ return []
276
+ heap = self._text_heaps_positive[neuron_idx]
277
+ items = heap.get_items()
278
+ items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=True)
279
+ if top_m is not None:
280
+ items_sorted = items_sorted[: top_m]
281
+
282
+ neuron_texts = []
283
+ for score, text, token_idx in items_sorted:
284
+ token_str = self._decode_token(text, token_idx)
285
+ neuron_texts.append(NeuronText(score=score, text=text, token_idx=token_idx, token_str=token_str))
286
+ return neuron_texts
287
+
288
+ def get_bottom_texts_for_neuron(self, neuron_idx: int, top_m: int | None = None) -> list[NeuronText]:
289
+ """Get bottom texts for a specific neuron (negative activations)."""
290
+ if not self._text_tracking_negative:
291
+ return []
292
+ if self._text_heaps_negative is None or neuron_idx < 0 or neuron_idx >= len(self._text_heaps_negative):
265
293
  return []
266
- heap = self._top_texts_heaps[neuron_idx]
267
- items = [val for (_, val) in heap]
268
- reverse = not self._text_tracking_negative
269
- items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=reverse)
294
+ heap = self._text_heaps_negative[neuron_idx]
295
+ items = heap.get_items()
296
+ items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=False)
270
297
  if top_m is not None:
271
298
  items_sorted = items_sorted[: top_m]
272
299
 
@@ -277,17 +304,25 @@ class AutoencoderConcepts:
277
304
  return neuron_texts
278
305
 
279
306
  def get_all_top_texts(self) -> list[list[NeuronText]]:
280
- """Get top texts for all neurons."""
281
- if self._top_texts_heaps is None:
307
+ """Get top texts for all neurons (positive activations)."""
308
+ if self._text_heaps_positive is None:
282
309
  return []
283
- return [self.get_top_texts_for_neuron(i) for i in range(len(self._top_texts_heaps))]
310
+ return [self.get_top_texts_for_neuron(i) for i in range(len(self._text_heaps_positive))]
311
+
312
+ def get_all_bottom_texts(self) -> list[list[NeuronText]]:
313
+ """Get bottom texts for all neurons (negative activations)."""
314
+ if not self._text_tracking_negative or self._text_heaps_negative is None:
315
+ return []
316
+ return [self.get_bottom_texts_for_neuron(i) for i in range(len(self._text_heaps_negative))]
284
317
 
285
318
  def reset_top_texts(self) -> None:
286
319
  """Reset all tracked top texts."""
287
- self._top_texts_heaps = None
320
+ self._text_heaps_positive = None
321
+ self._text_heaps_negative = None
288
322
 
289
323
  def export_top_texts_to_json(self, filepath: Path | str) -> Path:
290
- if self._top_texts_heaps is None:
324
+ """Export top texts (positive activations) to JSON file."""
325
+ if self._text_heaps_positive is None:
291
326
  raise ValueError("No top texts available. Enable text tracking and run inference first.")
292
327
 
293
328
  filepath = Path(filepath)
@@ -312,8 +347,35 @@ class AutoencoderConcepts:
312
347
 
313
348
  return filepath
314
349
 
350
+ def export_bottom_texts_to_json(self, filepath: Path | str) -> Path:
351
+ """Export bottom texts (negative activations) to JSON file."""
352
+ if not self._text_tracking_negative or self._text_heaps_negative is None:
353
+ raise ValueError("No bottom texts available. Enable negative text tracking and run inference first.")
354
+
355
+ filepath = Path(filepath)
356
+ filepath.parent.mkdir(parents=True, exist_ok=True)
357
+
358
+ all_texts = self.get_all_bottom_texts()
359
+ export_data = {}
360
+
361
+ for neuron_idx, neuron_texts in enumerate(all_texts):
362
+ export_data[neuron_idx] = [
363
+ {
364
+ "text": nt.text,
365
+ "score": nt.score,
366
+ "token_str": nt.token_str,
367
+ "token_idx": nt.token_idx
368
+ }
369
+ for nt in neuron_texts
370
+ ]
371
+
372
+ with filepath.open("w", encoding="utf-8") as f:
373
+ json.dump(export_data, f, ensure_ascii=False, indent=2)
374
+
375
+ return filepath
376
+
315
377
  def export_top_texts_to_csv(self, filepath: Path | str) -> Path:
316
- if self._top_texts_heaps is None:
378
+ if self._text_heaps_positive is None:
317
379
  raise ValueError("No top texts available. Enable text tracking and run inference first.")
318
380
 
319
381
  filepath = Path(filepath)
@@ -135,8 +135,18 @@ class ConceptDictionary:
135
135
  with json_path.open("r", encoding="utf-8") as f:
136
136
  data = json.load(f)
137
137
 
138
- for neuron_idx_str, concepts in data.items():
139
- neuron_idx = int(neuron_idx_str)
138
+ if isinstance(data, dict) and "concepts" in data:
139
+ concepts_data = data["concepts"]
140
+ if "n_size" in data:
141
+ concept_dict.n_size = int(data["n_size"])
142
+ else:
143
+ concepts_data = data
144
+
145
+ for neuron_idx_str, concepts in concepts_data.items():
146
+ try:
147
+ neuron_idx = int(neuron_idx_str)
148
+ except ValueError:
149
+ continue
140
150
 
141
151
  # Handle both old format (list) and new format (single dict)
142
152
  if isinstance(concepts, list):
@@ -0,0 +1,161 @@
1
+ from __future__ import annotations
2
+
3
+ import heapq
4
+
5
+
6
+ class TextHeap:
7
+ """
8
+ Efficient heap for tracking top texts with O(1) duplicate lookup.
9
+
10
+ Optimized with incremental index updates and correct heap operations.
11
+ Maintains a min-heap of size k and a dictionary for fast text lookup.
12
+ """
13
+
14
+ def __init__(self, max_size: int):
15
+ """
16
+ Initialize TextHeap.
17
+
18
+ Args:
19
+ max_size: Maximum number of items to keep in the heap
20
+ """
21
+ self._max_size = max_size
22
+ self._heap: list[tuple[float, tuple[float, str, int]]] = []
23
+ self._text_to_index: dict[str, int] = {}
24
+
25
+ def update(self, text: str, score: float, token_idx: int, adjusted_score: float | None = None) -> None:
26
+ """
27
+ Update heap with a new text entry.
28
+
29
+ Args:
30
+ text: Text string
31
+ score: Activation score (actual value to store)
32
+ token_idx: Token index within the text
33
+ adjusted_score: Optional adjusted score for heap ordering (defaults to score)
34
+ """
35
+ if adjusted_score is None:
36
+ adjusted_score = score
37
+ heap_idx = self._text_to_index.get(text)
38
+
39
+ if heap_idx is not None:
40
+ self._update_existing(heap_idx, text, adjusted_score, score, token_idx)
41
+ else:
42
+ self._add_new(text, adjusted_score, score, token_idx)
43
+
44
+ def _update_existing(
45
+ self,
46
+ heap_idx: int,
47
+ text: str,
48
+ adjusted_score: float,
49
+ score: float,
50
+ token_idx: int
51
+ ) -> None:
52
+ """Update an existing entry in the heap."""
53
+ current_adj = self._heap[heap_idx][0]
54
+ if adjusted_score > current_adj:
55
+ self._heap[heap_idx] = (adjusted_score, (score, text, token_idx))
56
+ self._text_to_index[text] = heap_idx
57
+ self._siftdown_with_tracking(heap_idx)
58
+
59
+ def _add_new(
60
+ self,
61
+ text: str,
62
+ adjusted_score: float,
63
+ score: float,
64
+ token_idx: int
65
+ ) -> None:
66
+ """Add a new entry to the heap."""
67
+ if len(self._heap) < self._max_size:
68
+ self._heap.append((adjusted_score, (score, text, token_idx)))
69
+ new_idx = len(self._heap) - 1
70
+ self._text_to_index[text] = new_idx
71
+ self._siftup_with_tracking(new_idx)
72
+ else:
73
+ if adjusted_score > self._heap[0][0]:
74
+ self._replace_minimum(text, adjusted_score, score, token_idx)
75
+
76
+ def _replace_minimum(
77
+ self,
78
+ text: str,
79
+ adjusted_score: float,
80
+ score: float,
81
+ token_idx: int
82
+ ) -> None:
83
+ """Replace the minimum element in the heap."""
84
+ old_text = self._heap[0][1][1]
85
+ if old_text in self._text_to_index:
86
+ del self._text_to_index[old_text]
87
+
88
+ self._heap[0] = (adjusted_score, (score, text, token_idx))
89
+ self._text_to_index[text] = 0
90
+ self._siftdown_with_tracking(0)
91
+
92
+ def _siftup_with_tracking(self, pos: int) -> None:
93
+ """
94
+ Sift element up in heap (toward root) and update text-to-index map incrementally.
95
+
96
+ Used when value decreases - compares with parent and moves up.
97
+ Only updates indices that actually change during the sift operation.
98
+ """
99
+ startpos = pos
100
+ newitem = self._heap[pos]
101
+ newitem_text = newitem[1][1]
102
+
103
+ while pos > 0:
104
+ parentpos = (pos - 1) >> 1
105
+ parent = self._heap[parentpos]
106
+ if newitem[0] >= parent[0]:
107
+ break
108
+ parent_text = parent[1][1]
109
+ self._heap[pos] = parent
110
+ self._text_to_index[parent_text] = pos
111
+ pos = parentpos
112
+
113
+ self._heap[pos] = newitem
114
+ if pos != startpos:
115
+ self._text_to_index[newitem_text] = pos
116
+
117
+ def _siftdown_with_tracking(self, pos: int) -> None:
118
+ """
119
+ Sift element down in heap and update text-to-index map incrementally.
120
+
121
+ Only updates indices that actually change during the sift operation.
122
+ """
123
+ endpos = len(self._heap)
124
+ startpos = pos
125
+ newitem = self._heap[pos]
126
+ newitem_text = newitem[1][1]
127
+
128
+ childpos = 2 * pos + 1
129
+ while childpos < endpos:
130
+ rightpos = childpos + 1
131
+ if rightpos < endpos and self._heap[rightpos][0] < self._heap[childpos][0]:
132
+ childpos = rightpos
133
+ if newitem[0] < self._heap[childpos][0]:
134
+ break
135
+ child_text = self._heap[childpos][1][1]
136
+ self._heap[pos] = self._heap[childpos]
137
+ self._text_to_index[child_text] = pos
138
+ pos = childpos
139
+ childpos = 2 * pos + 1
140
+
141
+ self._heap[pos] = newitem
142
+ if pos != startpos:
143
+ self._text_to_index[newitem_text] = pos
144
+
145
+ def get_items(self) -> list[tuple[float, str, int]]:
146
+ """
147
+ Get all items from the heap, sorted by score (descending).
148
+
149
+ Returns:
150
+ List of (score, text, token_idx) tuples
151
+ """
152
+ return [val for (_, val) in self._heap]
153
+
154
+ def clear(self) -> None:
155
+ """Clear the heap and text mapping."""
156
+ self._heap.clear()
157
+ self._text_to_index.clear()
158
+
159
+ def __len__(self) -> int:
160
+ """Return the number of items in the heap."""
161
+ return len(self._heap)
@@ -72,13 +72,14 @@ class TopKSae(Sae):
72
72
  A temporary default k=1 is used for engine initialization and will be
73
73
  overridden with the actual k value from config during training.
74
74
  """
75
- # Set temporary default k for engine initialization (base class calls _initialize_sae_engine)
76
- # This will be overridden with the actual k from config during training
77
- self.k: int = 1
78
75
  super().__init__(n_latents, n_inputs, hook_id, device, store, *args, **kwargs)
79
76
 
80
- def _initialize_sae_engine(self) -> OvercompleteSAE:
81
- """Initialize the SAE engine with the current k value.
77
+ def _initialize_sae_engine(self, k: int = 1) -> OvercompleteSAE:
78
+ """
79
+ Initialize the SAE engine with the specified k value.
80
+
81
+ Args:
82
+ k: Number of top activations to keep (default: 1 for initialization)
82
83
 
83
84
  Note:
84
85
  k should be set from TopKSaeTrainingConfig during training.
@@ -87,7 +88,7 @@ class TopKSae(Sae):
87
88
  return OvercompleteTopkSAE(
88
89
  input_shape=self.context.n_inputs,
89
90
  nb_concepts=self.context.n_latents,
90
- top_k=self.k,
91
+ top_k=k,
91
92
  device=self.context.device
92
93
  )
93
94
 
@@ -143,8 +144,7 @@ class TopKSae(Sae):
143
144
  Train TopKSAE using activations from a Store.
144
145
 
145
146
  This method delegates to the SaeTrainer composite class.
146
- If k is provided in the config and differs from the current k, the SAE engine
147
- will be reinitialized with the config's k value.
147
+ The SAE engine will be reinitialized with the k value from config.
148
148
 
149
149
  Args:
150
150
  store: Store instance containing activations
@@ -170,15 +170,13 @@ class TopKSae(Sae):
170
170
  "Example: TopKSaeTrainingConfig(k=10, epochs=100, ...)"
171
171
  )
172
172
 
173
- # Set k from config and initialize/reinitialize the engine if needed
174
- if self.k != config.k:
175
- if self.k is None:
176
- logger.info(f"Initializing SAE engine with k={config.k}")
177
- else:
178
- logger.info(f"Reinitializing SAE engine with k={config.k} (was k={self.k})")
179
- self.k = config.k
180
- # Initialize or reinitialize the SAE engine with k from config
181
- self.sae_engine = self._initialize_sae_engine()
173
+ # Reinitialize engine with k from config
174
+ logger.info(f"Initializing SAE engine with k={config.k}")
175
+ self.sae_engine = self._initialize_sae_engine(k=config.k)
176
+ if hasattr(config, 'device') and config.device:
177
+ device = torch.device(config.device)
178
+ self.sae_engine.to(device)
179
+ self.context.device = str(device)
182
180
 
183
181
  return self.trainer.train(store, run_id, layer_signature, config, training_run_id)
184
182
 
@@ -323,13 +321,14 @@ class TopKSae(Sae):
323
321
  result[0] = reconstructed
324
322
  return tuple(result)
325
323
 
326
- def save(self, name: str, path: str | Path | None = None) -> None:
324
+ def save(self, name: str, path: str | Path | None = None, k: int | None = None) -> None:
327
325
  """
328
326
  Save model using overcomplete's state dict + our metadata.
329
327
 
330
328
  Args:
331
329
  name: Model name
332
330
  path: Directory path to save to (defaults to current directory)
331
+ k: Top-K value to save (if None, attempts to get from engine or raises error)
333
332
  """
334
333
  if path is None:
335
334
  path = Path.cwd()
@@ -340,6 +339,16 @@ class TopKSae(Sae):
340
339
  # Save overcomplete model state dict
341
340
  sae_state_dict = self.sae_engine.state_dict()
342
341
 
342
+ # Get k value - prefer parameter, then try to get from engine
343
+ if k is None:
344
+ if hasattr(self.sae_engine, 'top_k'):
345
+ k = self.sae_engine.top_k
346
+ else:
347
+ raise ValueError(
348
+ "k parameter must be provided to save() method. "
349
+ "The engine does not expose top_k attribute."
350
+ )
351
+
343
352
  mi_crow_metadata = {
344
353
  "concepts_state": {
345
354
  'multiplication': self.concepts.multiplication.data,
@@ -347,7 +356,7 @@ class TopKSae(Sae):
347
356
  },
348
357
  "n_latents": self.context.n_latents,
349
358
  "n_inputs": self.context.n_inputs,
350
- "k": self.k,
359
+ "k": k,
351
360
  "device": self.context.device,
352
361
  "layer_signature": self.context.lm_layer_signature,
353
362
  "model_id": self.context.model_id,
@@ -403,9 +412,7 @@ class TopKSae(Sae):
403
412
  device=device
404
413
  )
405
414
 
406
- # Set k from saved metadata and reinitialize engine with correct k
407
- topk_sae.k = k
408
- topk_sae.sae_engine = topk_sae._initialize_sae_engine()
415
+ topk_sae.sae_engine = topk_sae._initialize_sae_engine(k=k)
409
416
 
410
417
  # Load overcomplete model state dict
411
418
  if "sae_state_dict" in payload:
@@ -69,7 +69,7 @@ class Sae(Controller, Detector, abc.ABC):
69
69
  """Set the LanguageModelContext for this hook and sync to AutoencoderContext.
70
70
 
71
71
  When the hook is registered, this method is called with the LanguageModelContext.
72
- It automatically syncs relevant values to the AutoencoderContext.
72
+ It automatically syncs relevant values to the AutoencoderContext, including device.
73
73
 
74
74
  Args:
75
75
  context: The LanguageModelContext instance from the LanguageModel
@@ -84,6 +84,8 @@ class Sae(Controller, Detector, abc.ABC):
84
84
  self._autoencoder_context.store = context.store
85
85
  if self.layer_signature is not None:
86
86
  self._autoencoder_context.lm_layer_signature = self.layer_signature
87
+ if context.device is not None:
88
+ self._autoencoder_context.device = context.device
87
89
 
88
90
  @abc.abstractmethod
89
91
  def _initialize_sae_engine(self) -> OvercompleteSAE: