mi-crow 0.1.2__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mi_crow/datasets/base_dataset.py +71 -1
- mi_crow/datasets/classification_dataset.py +136 -30
- mi_crow/datasets/text_dataset.py +165 -24
- mi_crow/hooks/controller.py +12 -7
- mi_crow/hooks/implementations/layer_activation_detector.py +30 -34
- mi_crow/hooks/implementations/model_input_detector.py +87 -87
- mi_crow/hooks/implementations/model_output_detector.py +43 -42
- mi_crow/hooks/utils.py +74 -0
- mi_crow/language_model/activations.py +174 -77
- mi_crow/language_model/device_manager.py +119 -0
- mi_crow/language_model/inference.py +18 -5
- mi_crow/language_model/initialization.py +10 -6
- mi_crow/language_model/language_model.py +67 -97
- mi_crow/language_model/layers.py +16 -13
- mi_crow/language_model/persistence.py +4 -2
- mi_crow/language_model/utils.py +5 -5
- mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py +157 -95
- mi_crow/mechanistic/sae/concepts/concept_dictionary.py +12 -2
- mi_crow/mechanistic/sae/concepts/text_heap.py +161 -0
- mi_crow/mechanistic/sae/modules/topk_sae.py +29 -22
- mi_crow/mechanistic/sae/sae.py +3 -1
- mi_crow/mechanistic/sae/sae_trainer.py +362 -29
- mi_crow/store/local_store.py +11 -5
- mi_crow/store/store.py +34 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/METADATA +2 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/RECORD +28 -26
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/WHEEL +1 -1
- {mi_crow-0.1.2.dist-info → mi_crow-1.0.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -10,6 +10,7 @@ import torch
|
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
12
12
|
from mi_crow.mechanistic.sae.concepts.concept_models import NeuronText
|
|
13
|
+
from mi_crow.mechanistic.sae.concepts.text_heap import TextHeap
|
|
13
14
|
from mi_crow.mechanistic.sae.autoencoder_context import AutoencoderContext
|
|
14
15
|
from mi_crow.utils import get_logger
|
|
15
16
|
|
|
@@ -28,12 +29,11 @@ class AutoencoderConcepts:
|
|
|
28
29
|
self._n_size = context.n_latents
|
|
29
30
|
self.dictionary: ConceptDictionary | None = None
|
|
30
31
|
|
|
31
|
-
# Concept manipulation parameters
|
|
32
32
|
self.multiplication = nn.Parameter(torch.ones(self._n_size))
|
|
33
33
|
self.bias = nn.Parameter(torch.ones(self._n_size))
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
self.
|
|
35
|
+
self._text_heaps_positive: list[TextHeap] | None = None
|
|
36
|
+
self._text_heaps_negative: list[TextHeap] | None = None
|
|
37
37
|
self._text_tracking_k: int = 5
|
|
38
38
|
self._text_tracking_negative: bool = False
|
|
39
39
|
|
|
@@ -81,7 +81,7 @@ class AutoencoderConcepts:
|
|
|
81
81
|
|
|
82
82
|
def generate_concepts_with_llm(self, llm_provider: str | None = None):
|
|
83
83
|
"""Generate concepts using LLM based on current top texts"""
|
|
84
|
-
if self.
|
|
84
|
+
if self._text_heaps_positive is None:
|
|
85
85
|
raise ValueError("No top texts available. Enable text tracking and run inference first.")
|
|
86
86
|
|
|
87
87
|
from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
@@ -96,8 +96,10 @@ class AutoencoderConcepts:
|
|
|
96
96
|
|
|
97
97
|
def _ensure_heaps(self, n_neurons: int) -> None:
|
|
98
98
|
"""Ensure heaps are initialized for the given number of neurons."""
|
|
99
|
-
if self.
|
|
100
|
-
self.
|
|
99
|
+
if self._text_heaps_positive is None:
|
|
100
|
+
self._text_heaps_positive = [TextHeap(self._text_tracking_k) for _ in range(n_neurons)]
|
|
101
|
+
if self._text_tracking_negative and self._text_heaps_negative is None:
|
|
102
|
+
self._text_heaps_negative = [TextHeap(self._text_tracking_k) for _ in range(n_neurons)]
|
|
101
103
|
|
|
102
104
|
def _decode_token(self, text: str, token_idx: int) -> str:
|
|
103
105
|
"""
|
|
@@ -148,6 +150,11 @@ class AutoencoderConcepts:
|
|
|
148
150
|
"""
|
|
149
151
|
Update top texts heaps from latents and texts.
|
|
150
152
|
|
|
153
|
+
Optimized version that:
|
|
154
|
+
- Only processes active neurons (non-zero activations)
|
|
155
|
+
- Vectorizes argmax/argmin operations
|
|
156
|
+
- Eliminates per-neuron tensor slicing
|
|
157
|
+
|
|
151
158
|
Args:
|
|
152
159
|
latents: Latent activations tensor, shape [B*T, n_latents] or [B, n_latents] (already flattened)
|
|
153
160
|
texts: List of texts corresponding to the batch
|
|
@@ -173,100 +180,120 @@ class AutoencoderConcepts:
|
|
|
173
180
|
# Use the actual number of texts as batch size
|
|
174
181
|
B = original_B
|
|
175
182
|
T = BT // B if B > 0 else 1
|
|
176
|
-
# Create token indices: [0, 1, 2, ..., T-1, 0, 1, 2, ..., T-1, ...]
|
|
177
|
-
token_indices = torch.arange(T, device='cpu').unsqueeze(0).expand(B, T).contiguous().view(B * T)
|
|
178
183
|
else:
|
|
179
184
|
# Original was [B, D], latents are [B, n_latents]
|
|
180
|
-
|
|
185
|
+
B = original_B
|
|
181
186
|
T = 1
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
#
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
#
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
187
|
+
|
|
188
|
+
# OPTIMIZATION 1: Find active neurons (have any non-zero activation across batch)
|
|
189
|
+
# Shape: [n_neurons] - boolean mask
|
|
190
|
+
active_neurons_mask = (latents.abs().sum(dim=0) > 0)
|
|
191
|
+
active_neuron_indices = torch.nonzero(active_neurons_mask, as_tuple=False).flatten().tolist()
|
|
192
|
+
|
|
193
|
+
if not active_neuron_indices:
|
|
194
|
+
return # No active neurons, skip
|
|
195
|
+
|
|
196
|
+
# OPTIMIZATION 2: Vectorize argmax/argmin for all neurons at once
|
|
197
|
+
if original_shape is not None and len(original_shape) == 3:
|
|
198
|
+
# Reshape to [B, T, n_neurons]
|
|
199
|
+
latents_3d = latents.view(B, T, n_neurons)
|
|
200
|
+
# For each text, find max/min across tokens for each neuron
|
|
201
|
+
# Shape: [B, n_neurons] - max activation per text per neuron
|
|
202
|
+
max_activations, max_token_indices_3d = latents_3d.max(dim=1) # [B, n_neurons]
|
|
203
|
+
min_activations, min_token_indices_3d = latents_3d.min(dim=1) # [B, n_neurons]
|
|
204
|
+
# max_token_indices_3d is already the token index (0 to T-1)
|
|
205
|
+
max_token_indices = max_token_indices_3d
|
|
206
|
+
min_token_indices = min_token_indices_3d
|
|
207
|
+
else:
|
|
208
|
+
# Shape: [B, n_neurons]
|
|
209
|
+
latents_2d = latents.view(B, n_neurons)
|
|
210
|
+
max_activations = latents_2d # [B, n_neurons]
|
|
211
|
+
max_token_indices = torch.zeros(B, n_neurons, dtype=torch.long, device=latents.device)
|
|
212
|
+
min_activations = latents_2d
|
|
213
|
+
min_token_indices = torch.zeros(B, n_neurons, dtype=torch.long, device=latents.device)
|
|
214
|
+
|
|
215
|
+
# Convert to numpy for faster CPU access (already on CPU from l1_sae.py)
|
|
216
|
+
max_activations_np = max_activations.cpu().numpy()
|
|
217
|
+
min_activations_np = min_activations.cpu().numpy()
|
|
218
|
+
max_token_indices_np = max_token_indices.cpu().numpy()
|
|
219
|
+
min_token_indices_np = min_token_indices.cpu().numpy()
|
|
220
|
+
|
|
221
|
+
# OPTIMIZATION 3: Only process active neurons
|
|
222
|
+
for j in active_neuron_indices:
|
|
223
|
+
heap_positive = self._text_heaps_positive[j]
|
|
224
|
+
heap_negative = self._text_heaps_negative[j] if self._text_tracking_negative else None
|
|
225
|
+
|
|
226
|
+
# OPTIMIZATION 4: Batch process all texts for this neuron
|
|
194
227
|
for batch_idx in range(original_B):
|
|
195
228
|
if batch_idx >= len(texts):
|
|
196
229
|
continue
|
|
197
230
|
|
|
198
231
|
text = texts[batch_idx]
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
existing_entry = (heap_idx, heap_adj, heap_score, heap_token_idx)
|
|
239
|
-
break
|
|
240
|
-
|
|
241
|
-
if existing_entry is not None:
|
|
242
|
-
# Update existing entry if this activation is better
|
|
243
|
-
heap_idx, heap_adj, heap_score, heap_token_idx = existing_entry
|
|
244
|
-
if adj > heap_adj:
|
|
245
|
-
# Replace with better activation
|
|
246
|
-
heap[heap_idx] = (adj, (max_score, text, token_idx))
|
|
247
|
-
heapq.heapify(heap) # Re-heapify after modification
|
|
248
|
-
texts_updated += 1
|
|
249
|
-
else:
|
|
250
|
-
texts_skipped_duplicate += 1
|
|
251
|
-
else:
|
|
252
|
-
# New text, add to heap
|
|
253
|
-
if len(heap) < self._text_tracking_k:
|
|
254
|
-
heapq.heappush(heap, (adj, (max_score, text, token_idx)))
|
|
255
|
-
texts_added += 1
|
|
256
|
-
else:
|
|
257
|
-
# Compare with smallest adjusted score; replace if better
|
|
258
|
-
if adj > heap[0][0]:
|
|
259
|
-
heapq.heapreplace(heap, (adj, (max_score, text, token_idx)))
|
|
260
|
-
texts_added += 1
|
|
232
|
+
|
|
233
|
+
# Use pre-computed max/min (no tensor slicing needed!)
|
|
234
|
+
max_score_positive = float(max_activations_np[batch_idx, j])
|
|
235
|
+
token_idx_positive = int(max_token_indices_np[batch_idx, j])
|
|
236
|
+
|
|
237
|
+
if max_score_positive > 0.0:
|
|
238
|
+
heap_positive.update(text, max_score_positive, token_idx_positive)
|
|
239
|
+
|
|
240
|
+
if self._text_tracking_negative and heap_negative is not None:
|
|
241
|
+
min_score_negative = float(min_activations_np[batch_idx, j])
|
|
242
|
+
if min_score_negative != 0.0:
|
|
243
|
+
token_idx_negative = int(min_token_indices_np[batch_idx, j])
|
|
244
|
+
heap_negative.update(text, min_score_negative, token_idx_negative, adjusted_score=-min_score_negative)
|
|
245
|
+
|
|
246
|
+
def _extract_activations(
|
|
247
|
+
self,
|
|
248
|
+
latents: torch.Tensor,
|
|
249
|
+
token_indices: torch.Tensor,
|
|
250
|
+
batch_idx: int,
|
|
251
|
+
neuron_idx: int,
|
|
252
|
+
original_shape: tuple[int, ...] | None,
|
|
253
|
+
T: int
|
|
254
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
255
|
+
"""
|
|
256
|
+
Extract activations for a specific batch item and neuron.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Tuple of (text_activations, text_token_indices)
|
|
260
|
+
"""
|
|
261
|
+
if original_shape is not None and len(original_shape) == 3:
|
|
262
|
+
start_idx = batch_idx * T
|
|
263
|
+
end_idx = start_idx + T
|
|
264
|
+
text_activations = latents[start_idx:end_idx, neuron_idx]
|
|
265
|
+
text_token_indices = token_indices[start_idx:end_idx]
|
|
266
|
+
else:
|
|
267
|
+
text_activations = latents[batch_idx:batch_idx + 1, neuron_idx]
|
|
268
|
+
text_token_indices = token_indices[batch_idx:batch_idx + 1]
|
|
269
|
+
|
|
270
|
+
return text_activations, text_token_indices
|
|
261
271
|
|
|
262
272
|
def get_top_texts_for_neuron(self, neuron_idx: int, top_m: int | None = None) -> list[NeuronText]:
|
|
263
|
-
"""Get top texts for a specific neuron."""
|
|
264
|
-
if self.
|
|
273
|
+
"""Get top texts for a specific neuron (positive activations)."""
|
|
274
|
+
if self._text_heaps_positive is None or neuron_idx < 0 or neuron_idx >= len(self._text_heaps_positive):
|
|
275
|
+
return []
|
|
276
|
+
heap = self._text_heaps_positive[neuron_idx]
|
|
277
|
+
items = heap.get_items()
|
|
278
|
+
items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=True)
|
|
279
|
+
if top_m is not None:
|
|
280
|
+
items_sorted = items_sorted[: top_m]
|
|
281
|
+
|
|
282
|
+
neuron_texts = []
|
|
283
|
+
for score, text, token_idx in items_sorted:
|
|
284
|
+
token_str = self._decode_token(text, token_idx)
|
|
285
|
+
neuron_texts.append(NeuronText(score=score, text=text, token_idx=token_idx, token_str=token_str))
|
|
286
|
+
return neuron_texts
|
|
287
|
+
|
|
288
|
+
def get_bottom_texts_for_neuron(self, neuron_idx: int, top_m: int | None = None) -> list[NeuronText]:
|
|
289
|
+
"""Get bottom texts for a specific neuron (negative activations)."""
|
|
290
|
+
if not self._text_tracking_negative:
|
|
291
|
+
return []
|
|
292
|
+
if self._text_heaps_negative is None or neuron_idx < 0 or neuron_idx >= len(self._text_heaps_negative):
|
|
265
293
|
return []
|
|
266
|
-
heap = self.
|
|
267
|
-
items =
|
|
268
|
-
|
|
269
|
-
items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=reverse)
|
|
294
|
+
heap = self._text_heaps_negative[neuron_idx]
|
|
295
|
+
items = heap.get_items()
|
|
296
|
+
items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=False)
|
|
270
297
|
if top_m is not None:
|
|
271
298
|
items_sorted = items_sorted[: top_m]
|
|
272
299
|
|
|
@@ -277,17 +304,25 @@ class AutoencoderConcepts:
|
|
|
277
304
|
return neuron_texts
|
|
278
305
|
|
|
279
306
|
def get_all_top_texts(self) -> list[list[NeuronText]]:
|
|
280
|
-
"""Get top texts for all neurons."""
|
|
281
|
-
if self.
|
|
307
|
+
"""Get top texts for all neurons (positive activations)."""
|
|
308
|
+
if self._text_heaps_positive is None:
|
|
282
309
|
return []
|
|
283
|
-
return [self.get_top_texts_for_neuron(i) for i in range(len(self.
|
|
310
|
+
return [self.get_top_texts_for_neuron(i) for i in range(len(self._text_heaps_positive))]
|
|
311
|
+
|
|
312
|
+
def get_all_bottom_texts(self) -> list[list[NeuronText]]:
|
|
313
|
+
"""Get bottom texts for all neurons (negative activations)."""
|
|
314
|
+
if not self._text_tracking_negative or self._text_heaps_negative is None:
|
|
315
|
+
return []
|
|
316
|
+
return [self.get_bottom_texts_for_neuron(i) for i in range(len(self._text_heaps_negative))]
|
|
284
317
|
|
|
285
318
|
def reset_top_texts(self) -> None:
|
|
286
319
|
"""Reset all tracked top texts."""
|
|
287
|
-
self.
|
|
320
|
+
self._text_heaps_positive = None
|
|
321
|
+
self._text_heaps_negative = None
|
|
288
322
|
|
|
289
323
|
def export_top_texts_to_json(self, filepath: Path | str) -> Path:
|
|
290
|
-
|
|
324
|
+
"""Export top texts (positive activations) to JSON file."""
|
|
325
|
+
if self._text_heaps_positive is None:
|
|
291
326
|
raise ValueError("No top texts available. Enable text tracking and run inference first.")
|
|
292
327
|
|
|
293
328
|
filepath = Path(filepath)
|
|
@@ -312,8 +347,35 @@ class AutoencoderConcepts:
|
|
|
312
347
|
|
|
313
348
|
return filepath
|
|
314
349
|
|
|
350
|
+
def export_bottom_texts_to_json(self, filepath: Path | str) -> Path:
|
|
351
|
+
"""Export bottom texts (negative activations) to JSON file."""
|
|
352
|
+
if not self._text_tracking_negative or self._text_heaps_negative is None:
|
|
353
|
+
raise ValueError("No bottom texts available. Enable negative text tracking and run inference first.")
|
|
354
|
+
|
|
355
|
+
filepath = Path(filepath)
|
|
356
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
357
|
+
|
|
358
|
+
all_texts = self.get_all_bottom_texts()
|
|
359
|
+
export_data = {}
|
|
360
|
+
|
|
361
|
+
for neuron_idx, neuron_texts in enumerate(all_texts):
|
|
362
|
+
export_data[neuron_idx] = [
|
|
363
|
+
{
|
|
364
|
+
"text": nt.text,
|
|
365
|
+
"score": nt.score,
|
|
366
|
+
"token_str": nt.token_str,
|
|
367
|
+
"token_idx": nt.token_idx
|
|
368
|
+
}
|
|
369
|
+
for nt in neuron_texts
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
with filepath.open("w", encoding="utf-8") as f:
|
|
373
|
+
json.dump(export_data, f, ensure_ascii=False, indent=2)
|
|
374
|
+
|
|
375
|
+
return filepath
|
|
376
|
+
|
|
315
377
|
def export_top_texts_to_csv(self, filepath: Path | str) -> Path:
|
|
316
|
-
if self.
|
|
378
|
+
if self._text_heaps_positive is None:
|
|
317
379
|
raise ValueError("No top texts available. Enable text tracking and run inference first.")
|
|
318
380
|
|
|
319
381
|
filepath = Path(filepath)
|
|
@@ -135,8 +135,18 @@ class ConceptDictionary:
|
|
|
135
135
|
with json_path.open("r", encoding="utf-8") as f:
|
|
136
136
|
data = json.load(f)
|
|
137
137
|
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
if isinstance(data, dict) and "concepts" in data:
|
|
139
|
+
concepts_data = data["concepts"]
|
|
140
|
+
if "n_size" in data:
|
|
141
|
+
concept_dict.n_size = int(data["n_size"])
|
|
142
|
+
else:
|
|
143
|
+
concepts_data = data
|
|
144
|
+
|
|
145
|
+
for neuron_idx_str, concepts in concepts_data.items():
|
|
146
|
+
try:
|
|
147
|
+
neuron_idx = int(neuron_idx_str)
|
|
148
|
+
except ValueError:
|
|
149
|
+
continue
|
|
140
150
|
|
|
141
151
|
# Handle both old format (list) and new format (single dict)
|
|
142
152
|
if isinstance(concepts, list):
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import heapq
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TextHeap:
|
|
7
|
+
"""
|
|
8
|
+
Efficient heap for tracking top texts with O(1) duplicate lookup.
|
|
9
|
+
|
|
10
|
+
Optimized with incremental index updates and correct heap operations.
|
|
11
|
+
Maintains a min-heap of size k and a dictionary for fast text lookup.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, max_size: int):
|
|
15
|
+
"""
|
|
16
|
+
Initialize TextHeap.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
max_size: Maximum number of items to keep in the heap
|
|
20
|
+
"""
|
|
21
|
+
self._max_size = max_size
|
|
22
|
+
self._heap: list[tuple[float, tuple[float, str, int]]] = []
|
|
23
|
+
self._text_to_index: dict[str, int] = {}
|
|
24
|
+
|
|
25
|
+
def update(self, text: str, score: float, token_idx: int, adjusted_score: float | None = None) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Update heap with a new text entry.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
text: Text string
|
|
31
|
+
score: Activation score (actual value to store)
|
|
32
|
+
token_idx: Token index within the text
|
|
33
|
+
adjusted_score: Optional adjusted score for heap ordering (defaults to score)
|
|
34
|
+
"""
|
|
35
|
+
if adjusted_score is None:
|
|
36
|
+
adjusted_score = score
|
|
37
|
+
heap_idx = self._text_to_index.get(text)
|
|
38
|
+
|
|
39
|
+
if heap_idx is not None:
|
|
40
|
+
self._update_existing(heap_idx, text, adjusted_score, score, token_idx)
|
|
41
|
+
else:
|
|
42
|
+
self._add_new(text, adjusted_score, score, token_idx)
|
|
43
|
+
|
|
44
|
+
def _update_existing(
|
|
45
|
+
self,
|
|
46
|
+
heap_idx: int,
|
|
47
|
+
text: str,
|
|
48
|
+
adjusted_score: float,
|
|
49
|
+
score: float,
|
|
50
|
+
token_idx: int
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Update an existing entry in the heap."""
|
|
53
|
+
current_adj = self._heap[heap_idx][0]
|
|
54
|
+
if adjusted_score > current_adj:
|
|
55
|
+
self._heap[heap_idx] = (adjusted_score, (score, text, token_idx))
|
|
56
|
+
self._text_to_index[text] = heap_idx
|
|
57
|
+
self._siftdown_with_tracking(heap_idx)
|
|
58
|
+
|
|
59
|
+
def _add_new(
|
|
60
|
+
self,
|
|
61
|
+
text: str,
|
|
62
|
+
adjusted_score: float,
|
|
63
|
+
score: float,
|
|
64
|
+
token_idx: int
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Add a new entry to the heap."""
|
|
67
|
+
if len(self._heap) < self._max_size:
|
|
68
|
+
self._heap.append((adjusted_score, (score, text, token_idx)))
|
|
69
|
+
new_idx = len(self._heap) - 1
|
|
70
|
+
self._text_to_index[text] = new_idx
|
|
71
|
+
self._siftup_with_tracking(new_idx)
|
|
72
|
+
else:
|
|
73
|
+
if adjusted_score > self._heap[0][0]:
|
|
74
|
+
self._replace_minimum(text, adjusted_score, score, token_idx)
|
|
75
|
+
|
|
76
|
+
def _replace_minimum(
|
|
77
|
+
self,
|
|
78
|
+
text: str,
|
|
79
|
+
adjusted_score: float,
|
|
80
|
+
score: float,
|
|
81
|
+
token_idx: int
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Replace the minimum element in the heap."""
|
|
84
|
+
old_text = self._heap[0][1][1]
|
|
85
|
+
if old_text in self._text_to_index:
|
|
86
|
+
del self._text_to_index[old_text]
|
|
87
|
+
|
|
88
|
+
self._heap[0] = (adjusted_score, (score, text, token_idx))
|
|
89
|
+
self._text_to_index[text] = 0
|
|
90
|
+
self._siftdown_with_tracking(0)
|
|
91
|
+
|
|
92
|
+
def _siftup_with_tracking(self, pos: int) -> None:
|
|
93
|
+
"""
|
|
94
|
+
Sift element up in heap (toward root) and update text-to-index map incrementally.
|
|
95
|
+
|
|
96
|
+
Used when value decreases - compares with parent and moves up.
|
|
97
|
+
Only updates indices that actually change during the sift operation.
|
|
98
|
+
"""
|
|
99
|
+
startpos = pos
|
|
100
|
+
newitem = self._heap[pos]
|
|
101
|
+
newitem_text = newitem[1][1]
|
|
102
|
+
|
|
103
|
+
while pos > 0:
|
|
104
|
+
parentpos = (pos - 1) >> 1
|
|
105
|
+
parent = self._heap[parentpos]
|
|
106
|
+
if newitem[0] >= parent[0]:
|
|
107
|
+
break
|
|
108
|
+
parent_text = parent[1][1]
|
|
109
|
+
self._heap[pos] = parent
|
|
110
|
+
self._text_to_index[parent_text] = pos
|
|
111
|
+
pos = parentpos
|
|
112
|
+
|
|
113
|
+
self._heap[pos] = newitem
|
|
114
|
+
if pos != startpos:
|
|
115
|
+
self._text_to_index[newitem_text] = pos
|
|
116
|
+
|
|
117
|
+
def _siftdown_with_tracking(self, pos: int) -> None:
|
|
118
|
+
"""
|
|
119
|
+
Sift element down in heap and update text-to-index map incrementally.
|
|
120
|
+
|
|
121
|
+
Only updates indices that actually change during the sift operation.
|
|
122
|
+
"""
|
|
123
|
+
endpos = len(self._heap)
|
|
124
|
+
startpos = pos
|
|
125
|
+
newitem = self._heap[pos]
|
|
126
|
+
newitem_text = newitem[1][1]
|
|
127
|
+
|
|
128
|
+
childpos = 2 * pos + 1
|
|
129
|
+
while childpos < endpos:
|
|
130
|
+
rightpos = childpos + 1
|
|
131
|
+
if rightpos < endpos and self._heap[rightpos][0] < self._heap[childpos][0]:
|
|
132
|
+
childpos = rightpos
|
|
133
|
+
if newitem[0] < self._heap[childpos][0]:
|
|
134
|
+
break
|
|
135
|
+
child_text = self._heap[childpos][1][1]
|
|
136
|
+
self._heap[pos] = self._heap[childpos]
|
|
137
|
+
self._text_to_index[child_text] = pos
|
|
138
|
+
pos = childpos
|
|
139
|
+
childpos = 2 * pos + 1
|
|
140
|
+
|
|
141
|
+
self._heap[pos] = newitem
|
|
142
|
+
if pos != startpos:
|
|
143
|
+
self._text_to_index[newitem_text] = pos
|
|
144
|
+
|
|
145
|
+
def get_items(self) -> list[tuple[float, str, int]]:
|
|
146
|
+
"""
|
|
147
|
+
Get all items from the heap, sorted by score (descending).
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
List of (score, text, token_idx) tuples
|
|
151
|
+
"""
|
|
152
|
+
return [val for (_, val) in self._heap]
|
|
153
|
+
|
|
154
|
+
def clear(self) -> None:
|
|
155
|
+
"""Clear the heap and text mapping."""
|
|
156
|
+
self._heap.clear()
|
|
157
|
+
self._text_to_index.clear()
|
|
158
|
+
|
|
159
|
+
def __len__(self) -> int:
|
|
160
|
+
"""Return the number of items in the heap."""
|
|
161
|
+
return len(self._heap)
|
|
@@ -72,13 +72,14 @@ class TopKSae(Sae):
|
|
|
72
72
|
A temporary default k=1 is used for engine initialization and will be
|
|
73
73
|
overridden with the actual k value from config during training.
|
|
74
74
|
"""
|
|
75
|
-
# Set temporary default k for engine initialization (base class calls _initialize_sae_engine)
|
|
76
|
-
# This will be overridden with the actual k from config during training
|
|
77
|
-
self.k: int = 1
|
|
78
75
|
super().__init__(n_latents, n_inputs, hook_id, device, store, *args, **kwargs)
|
|
79
76
|
|
|
80
|
-
def _initialize_sae_engine(self) -> OvercompleteSAE:
|
|
81
|
-
"""
|
|
77
|
+
def _initialize_sae_engine(self, k: int = 1) -> OvercompleteSAE:
|
|
78
|
+
"""
|
|
79
|
+
Initialize the SAE engine with the specified k value.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
k: Number of top activations to keep (default: 1 for initialization)
|
|
82
83
|
|
|
83
84
|
Note:
|
|
84
85
|
k should be set from TopKSaeTrainingConfig during training.
|
|
@@ -87,7 +88,7 @@ class TopKSae(Sae):
|
|
|
87
88
|
return OvercompleteTopkSAE(
|
|
88
89
|
input_shape=self.context.n_inputs,
|
|
89
90
|
nb_concepts=self.context.n_latents,
|
|
90
|
-
top_k=
|
|
91
|
+
top_k=k,
|
|
91
92
|
device=self.context.device
|
|
92
93
|
)
|
|
93
94
|
|
|
@@ -143,8 +144,7 @@ class TopKSae(Sae):
|
|
|
143
144
|
Train TopKSAE using activations from a Store.
|
|
144
145
|
|
|
145
146
|
This method delegates to the SaeTrainer composite class.
|
|
146
|
-
|
|
147
|
-
will be reinitialized with the config's k value.
|
|
147
|
+
The SAE engine will be reinitialized with the k value from config.
|
|
148
148
|
|
|
149
149
|
Args:
|
|
150
150
|
store: Store instance containing activations
|
|
@@ -170,15 +170,13 @@ class TopKSae(Sae):
|
|
|
170
170
|
"Example: TopKSaeTrainingConfig(k=10, epochs=100, ...)"
|
|
171
171
|
)
|
|
172
172
|
|
|
173
|
-
#
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
self.
|
|
180
|
-
# Initialize or reinitialize the SAE engine with k from config
|
|
181
|
-
self.sae_engine = self._initialize_sae_engine()
|
|
173
|
+
# Reinitialize engine with k from config
|
|
174
|
+
logger.info(f"Initializing SAE engine with k={config.k}")
|
|
175
|
+
self.sae_engine = self._initialize_sae_engine(k=config.k)
|
|
176
|
+
if hasattr(config, 'device') and config.device:
|
|
177
|
+
device = torch.device(config.device)
|
|
178
|
+
self.sae_engine.to(device)
|
|
179
|
+
self.context.device = str(device)
|
|
182
180
|
|
|
183
181
|
return self.trainer.train(store, run_id, layer_signature, config, training_run_id)
|
|
184
182
|
|
|
@@ -323,13 +321,14 @@ class TopKSae(Sae):
|
|
|
323
321
|
result[0] = reconstructed
|
|
324
322
|
return tuple(result)
|
|
325
323
|
|
|
326
|
-
def save(self, name: str, path: str | Path | None = None) -> None:
|
|
324
|
+
def save(self, name: str, path: str | Path | None = None, k: int | None = None) -> None:
|
|
327
325
|
"""
|
|
328
326
|
Save model using overcomplete's state dict + our metadata.
|
|
329
327
|
|
|
330
328
|
Args:
|
|
331
329
|
name: Model name
|
|
332
330
|
path: Directory path to save to (defaults to current directory)
|
|
331
|
+
k: Top-K value to save (if None, attempts to get from engine or raises error)
|
|
333
332
|
"""
|
|
334
333
|
if path is None:
|
|
335
334
|
path = Path.cwd()
|
|
@@ -340,6 +339,16 @@ class TopKSae(Sae):
|
|
|
340
339
|
# Save overcomplete model state dict
|
|
341
340
|
sae_state_dict = self.sae_engine.state_dict()
|
|
342
341
|
|
|
342
|
+
# Get k value - prefer parameter, then try to get from engine
|
|
343
|
+
if k is None:
|
|
344
|
+
if hasattr(self.sae_engine, 'top_k'):
|
|
345
|
+
k = self.sae_engine.top_k
|
|
346
|
+
else:
|
|
347
|
+
raise ValueError(
|
|
348
|
+
"k parameter must be provided to save() method. "
|
|
349
|
+
"The engine does not expose top_k attribute."
|
|
350
|
+
)
|
|
351
|
+
|
|
343
352
|
mi_crow_metadata = {
|
|
344
353
|
"concepts_state": {
|
|
345
354
|
'multiplication': self.concepts.multiplication.data,
|
|
@@ -347,7 +356,7 @@ class TopKSae(Sae):
|
|
|
347
356
|
},
|
|
348
357
|
"n_latents": self.context.n_latents,
|
|
349
358
|
"n_inputs": self.context.n_inputs,
|
|
350
|
-
"k":
|
|
359
|
+
"k": k,
|
|
351
360
|
"device": self.context.device,
|
|
352
361
|
"layer_signature": self.context.lm_layer_signature,
|
|
353
362
|
"model_id": self.context.model_id,
|
|
@@ -403,9 +412,7 @@ class TopKSae(Sae):
|
|
|
403
412
|
device=device
|
|
404
413
|
)
|
|
405
414
|
|
|
406
|
-
|
|
407
|
-
topk_sae.k = k
|
|
408
|
-
topk_sae.sae_engine = topk_sae._initialize_sae_engine()
|
|
415
|
+
topk_sae.sae_engine = topk_sae._initialize_sae_engine(k=k)
|
|
409
416
|
|
|
410
417
|
# Load overcomplete model state dict
|
|
411
418
|
if "sae_state_dict" in payload:
|
mi_crow/mechanistic/sae/sae.py
CHANGED
|
@@ -69,7 +69,7 @@ class Sae(Controller, Detector, abc.ABC):
|
|
|
69
69
|
"""Set the LanguageModelContext for this hook and sync to AutoencoderContext.
|
|
70
70
|
|
|
71
71
|
When the hook is registered, this method is called with the LanguageModelContext.
|
|
72
|
-
It automatically syncs relevant values to the AutoencoderContext.
|
|
72
|
+
It automatically syncs relevant values to the AutoencoderContext, including device.
|
|
73
73
|
|
|
74
74
|
Args:
|
|
75
75
|
context: The LanguageModelContext instance from the LanguageModel
|
|
@@ -84,6 +84,8 @@ class Sae(Controller, Detector, abc.ABC):
|
|
|
84
84
|
self._autoencoder_context.store = context.store
|
|
85
85
|
if self.layer_signature is not None:
|
|
86
86
|
self._autoencoder_context.lm_layer_signature = self.layer_signature
|
|
87
|
+
if context.device is not None:
|
|
88
|
+
self._autoencoder_context.device = context.device
|
|
87
89
|
|
|
88
90
|
@abc.abstractmethod
|
|
89
91
|
def _initialize_sae_engine(self) -> OvercompleteSAE:
|