mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,203 @@
1
+ from typing import Dict, List, Callable, Sequence, Any, TYPE_CHECKING, Union
2
+
3
+ from torch import nn
4
+ from transformers import AutoTokenizer
5
+
6
+ if TYPE_CHECKING:
7
+ from amber.language_model.context import LanguageModelContext
8
+
9
+
10
+ class LanguageModelTokenizer:
11
+ """Handles tokenization for LanguageModel."""
12
+
13
+ def __init__(
14
+ self,
15
+ context: "LanguageModelContext"
16
+ ):
17
+ """
18
+ Initialize LanguageModelTokenizer.
19
+
20
+ Args:
21
+ context: LanguageModelContext instance
22
+ """
23
+ self.context = context
24
+
25
+ def _setup_pad_token(self, tokenizer: Any, model: Any) -> None:
26
+ """
27
+ Setup pad token for tokenizer if not already set.
28
+
29
+ Args:
30
+ tokenizer: Tokenizer instance
31
+ model: Model instance
32
+ """
33
+ eos_token = getattr(tokenizer, "eos_token", None)
34
+ if eos_token is not None:
35
+ tokenizer.pad_token = eos_token
36
+ if hasattr(model, "config"):
37
+ model.config.pad_token_id = getattr(tokenizer, "eos_token_id", None)
38
+ else:
39
+ if hasattr(tokenizer, "add_special_tokens"):
40
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
41
+ if hasattr(model, "resize_token_embeddings"):
42
+ model.resize_token_embeddings(len(tokenizer))
43
+ if hasattr(model, "config"):
44
+ model.config.pad_token_id = getattr(tokenizer, "pad_token_id", None)
45
+
46
+ def split_to_tokens(
47
+ self,
48
+ text: Union[str, Sequence[str]],
49
+ add_special_tokens: bool = False
50
+ ) -> Union[List[str], List[List[str]]]:
51
+ """
52
+ Split text into token strings.
53
+
54
+ Args:
55
+ text: Single string or sequence of strings to tokenize
56
+ add_special_tokens: Whether to add special tokens (e.g., BOS, EOS)
57
+
58
+ Returns:
59
+ For a single string: list of token strings
60
+ For a sequence of strings: list of lists of token strings
61
+ """
62
+ if isinstance(text, str):
63
+ return self._split_single_text_to_tokens(text, add_special_tokens)
64
+
65
+ return [self._split_single_text_to_tokens(t, add_special_tokens) for t in text]
66
+
67
+ def _try_tokenize_with_method(
68
+ self,
69
+ tokenizer: Any,
70
+ text: str,
71
+ add_special_tokens: bool,
72
+ method_name: str,
73
+ fallback_method: str | None = None
74
+ ) -> List[str] | None:
75
+ """
76
+ Try tokenizing using a specific tokenizer method.
77
+
78
+ Args:
79
+ tokenizer: Tokenizer instance
80
+ text: Text to tokenize
81
+ add_special_tokens: Whether to add special tokens
82
+ method_name: Primary method to try (e.g., "tokenize", "encode")
83
+ fallback_method: Optional fallback method (e.g., "convert_ids_to_tokens")
84
+
85
+ Returns:
86
+ List of token strings or None if method fails
87
+ """
88
+ if not hasattr(tokenizer, method_name):
89
+ return None
90
+
91
+ try:
92
+ if method_name == "tokenize":
93
+ return tokenizer.tokenize(text, add_special_tokens=add_special_tokens)
94
+ elif method_name == "encode":
95
+ if fallback_method and hasattr(tokenizer, fallback_method):
96
+ token_ids = tokenizer.encode(text, add_special_tokens=add_special_tokens)
97
+ return tokenizer.convert_ids_to_tokens(token_ids)
98
+ elif method_name == "encode_plus":
99
+ if fallback_method and hasattr(tokenizer, fallback_method):
100
+ encoded = tokenizer.encode_plus(text, add_special_tokens=add_special_tokens)
101
+ if isinstance(encoded, dict) and "input_ids" in encoded:
102
+ token_ids = encoded["input_ids"]
103
+ return tokenizer.convert_ids_to_tokens(token_ids)
104
+ except (TypeError, ValueError, AttributeError):
105
+ pass
106
+
107
+ return None
108
+
109
+ def _split_single_text_to_tokens(self, text: str, add_special_tokens: bool) -> List[str]:
110
+ """
111
+ Split a single text into token strings.
112
+
113
+ Uses the tokenizer from LanguageModelContext to split text into tokens.
114
+
115
+ Args:
116
+ text: Text string to tokenize
117
+ add_special_tokens: Whether to add special tokens
118
+
119
+ Returns:
120
+ List of token strings
121
+ """
122
+ tokenizer = self.context.tokenizer
123
+
124
+ if tokenizer is None:
125
+ return text.split()
126
+
127
+ if not isinstance(text, str):
128
+ raise TypeError(f"Expected str, got {type(text)}")
129
+
130
+ # Try different tokenization methods in order
131
+ tokens = self._try_tokenize_with_method(tokenizer, text, add_special_tokens, "tokenize")
132
+ if tokens is not None:
133
+ return tokens
134
+
135
+ tokens = self._try_tokenize_with_method(tokenizer, text, add_special_tokens, "encode", "convert_ids_to_tokens")
136
+ if tokens is not None:
137
+ return tokens
138
+
139
+ tokens = self._try_tokenize_with_method(tokenizer, text, add_special_tokens, "encode_plus", "convert_ids_to_tokens")
140
+ if tokens is not None:
141
+ return tokens
142
+
143
+ return text.split()
144
+
145
+ def tokenize(
146
+ self,
147
+ texts: Sequence[str],
148
+ padding: bool = False,
149
+ pad_token: str = "[PAD]",
150
+ **kwargs: Any
151
+ ) -> Any:
152
+ """
153
+ Robust batch tokenization that works across tokenizer variants.
154
+
155
+ Tries methods in order:
156
+ - callable tokenizer (most HF tokenizers)
157
+ - batch_encode_plus
158
+ - encode_plus per item + tokenizer.pad to collate
159
+
160
+ Args:
161
+ texts: Sequence of text strings to tokenize
162
+ padding: Whether to pad sequences
163
+ pad_token: Pad token string
164
+ **kwargs: Additional tokenizer arguments
165
+
166
+ Returns:
167
+ Tokenized encodings
168
+
169
+ Raises:
170
+ ValueError: If tokenizer is not initialized
171
+ TypeError: If tokenizer is not usable for batch tokenization
172
+ """
173
+ tokenizer = self.context.tokenizer
174
+ if tokenizer is None:
175
+ raise ValueError("Tokenizer must be initialized before tokenization")
176
+
177
+ model = self.context.model
178
+
179
+ if padding and pad_token and getattr(tokenizer, "pad_token", None) is None:
180
+ self._setup_pad_token(tokenizer, model)
181
+
182
+ kwargs["padding"] = padding
183
+
184
+ # Try callable tokenizer first (most common case)
185
+ if callable(tokenizer):
186
+ try:
187
+ return tokenizer(texts, **kwargs)
188
+ except TypeError:
189
+ pass
190
+
191
+ # Try batch_encode_plus
192
+ if hasattr(tokenizer, "batch_encode_plus"):
193
+ return tokenizer.batch_encode_plus(texts, **kwargs)
194
+
195
+ # Fallback to encode_plus per item
196
+ if hasattr(tokenizer, "encode_plus"):
197
+ encoded = [tokenizer.encode_plus(t, **kwargs) for t in texts]
198
+ if hasattr(tokenizer, "pad"):
199
+ rt = kwargs.get("return_tensors") or "pt"
200
+ return tokenizer.pad(encoded, return_tensors=rt)
201
+ return encoded
202
+
203
+ raise TypeError("Tokenizer object on LanguageModel is not usable for batch tokenization")
@@ -0,0 +1,97 @@
1
+ """Utility functions for language model operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ if TYPE_CHECKING:
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+
14
+ def extract_model_id(model: nn.Module, provided_model_id: str | None = None) -> str:
15
+ """
16
+ Extract model ID from model or use provided one.
17
+
18
+ Args:
19
+ model: PyTorch model module
20
+ provided_model_id: Optional model ID provided by user
21
+
22
+ Returns:
23
+ Model ID string
24
+
25
+ Raises:
26
+ ValueError: If model_id cannot be determined
27
+ """
28
+ if provided_model_id is not None:
29
+ return provided_model_id
30
+
31
+ if hasattr(model, 'config') and hasattr(model.config, 'name_or_path'):
32
+ return model.config.name_or_path.replace("/", "_")
33
+
34
+ return model.__class__.__name__
35
+
36
+
37
+ def get_device_from_model(model: nn.Module) -> torch.device:
38
+ """
39
+ Get the device from model parameters.
40
+
41
+ Args:
42
+ model: PyTorch model module
43
+
44
+ Returns:
45
+ Device where model parameters are located, or CPU if no parameters
46
+ """
47
+ first_param = next(model.parameters(), None)
48
+ return first_param.device if first_param is not None else torch.device("cpu")
49
+
50
+
51
+ def move_tensors_to_device(
52
+ tensors: dict[str, torch.Tensor],
53
+ device: torch.device
54
+ ) -> dict[str, torch.Tensor]:
55
+ """
56
+ Move dictionary of tensors to specified device.
57
+
58
+ Args:
59
+ tensors: Dictionary of tensor name to tensor
60
+ device: Target device
61
+
62
+ Returns:
63
+ Dictionary with tensors moved to device
64
+ """
65
+ device_type = str(device.type)
66
+ if device_type == "cuda":
67
+ return {k: v.to(device, non_blocking=True) for k, v in tensors.items()}
68
+ return {k: v.to(device) for k, v in tensors.items()}
69
+
70
+
71
+ def extract_logits_from_output(output: any) -> torch.Tensor:
72
+ """
73
+ Extract logits tensor from model output.
74
+
75
+ Handles various output formats:
76
+ - Objects with 'logits' attribute (e.g., HuggingFace model outputs)
77
+ - Tuples (takes first element)
78
+ - Direct tensors
79
+
80
+ Args:
81
+ output: Model output (various formats)
82
+
83
+ Returns:
84
+ Logits tensor
85
+
86
+ Raises:
87
+ ValueError: If logits cannot be extracted from output
88
+ """
89
+ if hasattr(output, 'logits'):
90
+ return output.logits
91
+ elif isinstance(output, tuple) and len(output) > 0:
92
+ return output[0]
93
+ elif isinstance(output, torch.Tensor):
94
+ return output
95
+ else:
96
+ raise ValueError(f"Unable to extract logits from output type: {type(output)}")
97
+
File without changes
File without changes
@@ -0,0 +1,40 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional, TYPE_CHECKING
3
+
4
+ from amber.store.store import Store
5
+
6
+ if TYPE_CHECKING:
7
+ pass
8
+
9
+
10
+ @dataclass
11
+ class AutoencoderContext:
12
+ """Shared context for Autoencoder and its nested components."""
13
+
14
+ autoencoder: "Sae"
15
+
16
+ # Core SAE parameters
17
+ n_latents: int
18
+ n_inputs: int
19
+
20
+ # Language model parameters (shared across hierarchy)
21
+ lm: Optional["LanguageModel"] = None
22
+ lm_layer_signature: Optional[int | str] = None
23
+ model_id: Optional[str] = None
24
+
25
+ # Training/experiment metadata
26
+ device: str = 'cpu'
27
+ experiment_name: Optional[str] = None
28
+ run_id: Optional[str] = None
29
+
30
+ # Text tracking parameters
31
+ text_tracking_enabled: bool = False
32
+ text_tracking_k: int = 5
33
+ text_tracking_negative: bool = False
34
+
35
+ store: Optional[Store] = None
36
+
37
+ # Training parameters
38
+ tied: bool = False
39
+ bias_init: float = 0.0
40
+ init_method: str = "kaiming"
File without changes
@@ -0,0 +1,332 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Sequence
5
+ import json
6
+ import csv
7
+ import heapq
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from amber.mechanistic.sae.concepts.concept_models import NeuronText
13
+ from amber.mechanistic.sae.autoencoder_context import AutoencoderContext
14
+ from amber.utils import get_logger
15
+
16
+ if TYPE_CHECKING:
17
+ from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ class AutoencoderConcepts:
23
+ def __init__(
24
+ self,
25
+ context: AutoencoderContext
26
+ ):
27
+ self.context = context
28
+ self._n_size = context.n_latents
29
+ self.dictionary: ConceptDictionary | None = None
30
+
31
+ # Concept manipulation parameters
32
+ self.multiplication = nn.Parameter(torch.ones(self._n_size))
33
+ self.bias = nn.Parameter(torch.ones(self._n_size))
34
+
35
+ # Top texts tracking
36
+ self._top_texts_heaps: list[list[tuple[float, tuple[float, str, int]]]] | None = None
37
+ self._text_tracking_k: int = 5
38
+ self._text_tracking_negative: bool = False
39
+
40
+ def enable_text_tracking(self):
41
+ """Enable text tracking using context parameters."""
42
+ if self.context.lm is None:
43
+ raise ValueError("LanguageModel must be set in context to enable tracking")
44
+
45
+ # Store tracking parameters
46
+ self._text_tracking_k = self.context.text_tracking_k
47
+ self._text_tracking_negative = self.context.text_tracking_negative
48
+
49
+ # Ensure InputTracker singleton exists on LanguageModel and enable it
50
+ input_tracker = self.context.lm._ensure_input_tracker()
51
+ input_tracker.enable()
52
+
53
+ # Enable text tracking on the SAE instance
54
+ if hasattr(self.context.autoencoder, '_text_tracking_enabled'):
55
+ self.context.autoencoder._text_tracking_enabled = True
56
+
57
+ def disable_text_tracking(self):
58
+ self.context.autoencoder._text_tracking_enabled = False
59
+
60
+ def _ensure_dictionary(self):
61
+ if self.dictionary is None:
62
+ from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
63
+ self.dictionary = ConceptDictionary(self._n_size)
64
+ return self.dictionary
65
+
66
+ def load_concepts_from_csv(self, csv_filepath: str | Path):
67
+ from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
68
+ self.dictionary = ConceptDictionary.from_csv(
69
+ csv_filepath=csv_filepath,
70
+ n_size=self._n_size,
71
+ store=self.dictionary.store if self.dictionary else None
72
+ )
73
+
74
+ def load_concepts_from_json(self, json_filepath: str | Path):
75
+ from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
76
+ self.dictionary = ConceptDictionary.from_json(
77
+ json_filepath=json_filepath,
78
+ n_size=self._n_size,
79
+ store=self.dictionary.store if self.dictionary else None
80
+ )
81
+
82
+ def generate_concepts_with_llm(self, llm_provider: str | None = None):
83
+ """Generate concepts using LLM based on current top texts"""
84
+ if self._top_texts_heaps is None:
85
+ raise ValueError("No top texts available. Enable text tracking and run inference first.")
86
+
87
+ from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
88
+ neuron_texts = self.get_all_top_texts()
89
+
90
+ self.dictionary = ConceptDictionary.from_llm(
91
+ neuron_texts=neuron_texts,
92
+ n_size=self._n_size,
93
+ store=self.dictionary.store if self.dictionary else None,
94
+ llm_provider=llm_provider
95
+ )
96
+
97
+ def _ensure_heaps(self, n_neurons: int) -> None:
98
+ """Ensure heaps are initialized for the given number of neurons."""
99
+ if self._top_texts_heaps is None:
100
+ self._top_texts_heaps = [[] for _ in range(n_neurons)]
101
+
102
+ def _decode_token(self, text: str, token_idx: int) -> str:
103
+ """
104
+ Decode a specific token from the text using the language model's tokenizer.
105
+
106
+ The token_idx is relative to the sequence length T that the model saw during inference.
107
+ However, there's a mismatch: during inference, texts are tokenized with
108
+ add_special_tokens=True (which adds BOS/EOS), but the token_idx appears to be
109
+ calculated relative to the sequence without special tokens.
110
+
111
+ We tokenize the text the same way as _decode_token originally did (without special tokens)
112
+ to match the token_idx calculation, but we also account for truncation that may have
113
+ occurred during inference (max_length).
114
+ """
115
+ if self.context.lm is None:
116
+ return f"<token_{token_idx}>"
117
+
118
+ try:
119
+ if self.context.lm.tokenizer is None:
120
+ return f"<token_{token_idx}>"
121
+
122
+ # Use the raw tokenizer (not the wrapper) to encode and decode
123
+ tokenizer = self.context.lm.tokenizer
124
+
125
+ # Tokenize without special tokens (matching original behavior)
126
+ # This matches how token_idx was calculated in update_top_texts_from_latents
127
+ tokens = tokenizer.encode(text, add_special_tokens=False)
128
+
129
+ # Check if token_idx is valid
130
+ if 0 <= token_idx < len(tokens):
131
+ token_id = tokens[token_idx]
132
+ # Decode the specific token
133
+ token_str = tokenizer.decode([token_id])
134
+ return token_str
135
+ else:
136
+ return f"<token_{token_idx}_out_of_range>"
137
+ except Exception as e:
138
+ # If tokenization fails, return a placeholder
139
+ logger.debug(f"Token decode error for token_idx={token_idx} in text (len={len(text)}): {e}")
140
+ return f"<token_{token_idx}_decode_error>"
141
+
142
+ def update_top_texts_from_latents(
143
+ self,
144
+ latents: torch.Tensor,
145
+ texts: Sequence[str],
146
+ original_shape: tuple[int, ...] | None = None
147
+ ) -> None:
148
+ """
149
+ Update top texts heaps from latents and texts.
150
+
151
+ Args:
152
+ latents: Latent activations tensor, shape [B*T, n_latents] or [B, n_latents] (already flattened)
153
+ texts: List of texts corresponding to the batch
154
+ original_shape: Original shape before flattening, e.g., (B, T, D) or (B, D)
155
+ """
156
+ if not texts:
157
+ return
158
+
159
+ n_neurons = latents.shape[-1]
160
+ self._ensure_heaps(n_neurons)
161
+
162
+ # Calculate batch and token dimensions
163
+ original_B = len(texts)
164
+ BT = latents.shape[0] # Total positions (B*T if 3D original, or B if 2D original)
165
+
166
+ # Determine if original was 3D or 2D
167
+ if original_shape is not None and len(original_shape) == 3:
168
+ # Original was [B, T, D], latents are [B*T, n_latents]
169
+ B, T, _ = original_shape
170
+ # Verify batch size matches
171
+ if B != original_B:
172
+ logger.warning(f"Batch size mismatch: original_shape has B={B}, but {original_B} texts provided")
173
+ # Use the actual number of texts as batch size
174
+ B = original_B
175
+ T = BT // B if B > 0 else 1
176
+ # Create token indices: [0, 1, 2, ..., T-1, 0, 1, 2, ..., T-1, ...]
177
+ token_indices = torch.arange(T, device='cpu').unsqueeze(0).expand(B, T).contiguous().view(B * T)
178
+ else:
179
+ # Original was [B, D], latents are [B, n_latents]
180
+ # All tokens are at index 0
181
+ T = 1
182
+ token_indices = torch.zeros(BT, dtype=torch.long, device='cpu')
183
+
184
+ # For each neuron, find the maximum activation per text
185
+ # This ensures we only track the best activation for each text, not every token position
186
+ for j in range(n_neurons):
187
+ heap = self._top_texts_heaps[j]
188
+
189
+ # For each text in the batch, find the max activation and its token position
190
+ texts_processed = 0
191
+ texts_added = 0
192
+ texts_updated = 0
193
+ texts_skipped_duplicate = 0
194
+ for batch_idx in range(original_B):
195
+ if batch_idx >= len(texts):
196
+ continue
197
+
198
+ text = texts[batch_idx]
199
+ texts_processed += 1
200
+
201
+ # Get activations for this text (all token positions)
202
+ if original_shape is not None and len(original_shape) == 3:
203
+ # 3D case: [B, T, D] -> get slice for this batch
204
+ start_idx = batch_idx * T
205
+ end_idx = start_idx + T
206
+ text_activations = latents[start_idx:end_idx, j] # [T]
207
+ text_token_indices = token_indices[start_idx:end_idx] # [T]
208
+ else:
209
+ # 2D case: [B, D] -> single token
210
+ text_activations = latents[batch_idx:batch_idx + 1, j] # [1]
211
+ text_token_indices = token_indices[batch_idx:batch_idx + 1] # [1]
212
+
213
+ # Find the maximum activation (or minimum if tracking negative)
214
+ if self._text_tracking_negative:
215
+ # For negative tracking, find the most negative (minimum) value
216
+ max_idx = torch.argmin(text_activations)
217
+ max_score = float(text_activations[max_idx].item())
218
+ adj = -max_score # Negate for heap ordering
219
+ else:
220
+ # For positive tracking, find the maximum value
221
+ max_idx = torch.argmax(text_activations)
222
+ max_score = float(text_activations[max_idx].item())
223
+ adj = max_score
224
+
225
+ # Skip if score is zero (no activation)
226
+ if max_score == 0.0:
227
+ continue
228
+
229
+ token_idx = int(text_token_indices[max_idx].item())
230
+
231
+ # Check if we already have this text in the heap
232
+ # If so, only update if this activation is better
233
+ existing_entry = None
234
+ heap_texts = []
235
+ for heap_idx, (heap_adj, (heap_score, heap_text, heap_token_idx)) in enumerate(heap):
236
+ heap_texts.append(heap_text[:50] if len(heap_text) > 50 else heap_text)
237
+ if heap_text == text:
238
+ existing_entry = (heap_idx, heap_adj, heap_score, heap_token_idx)
239
+ break
240
+
241
+ if existing_entry is not None:
242
+ # Update existing entry if this activation is better
243
+ heap_idx, heap_adj, heap_score, heap_token_idx = existing_entry
244
+ if adj > heap_adj:
245
+ # Replace with better activation
246
+ heap[heap_idx] = (adj, (max_score, text, token_idx))
247
+ heapq.heapify(heap) # Re-heapify after modification
248
+ texts_updated += 1
249
+ else:
250
+ texts_skipped_duplicate += 1
251
+ else:
252
+ # New text, add to heap
253
+ if len(heap) < self._text_tracking_k:
254
+ heapq.heappush(heap, (adj, (max_score, text, token_idx)))
255
+ texts_added += 1
256
+ else:
257
+ # Compare with smallest adjusted score; replace if better
258
+ if adj > heap[0][0]:
259
+ heapq.heapreplace(heap, (adj, (max_score, text, token_idx)))
260
+ texts_added += 1
261
+
262
+ def get_top_texts_for_neuron(self, neuron_idx: int, top_m: int | None = None) -> list[NeuronText]:
263
+ """Get top texts for a specific neuron."""
264
+ if self._top_texts_heaps is None or neuron_idx < 0 or neuron_idx >= len(self._top_texts_heaps):
265
+ return []
266
+ heap = self._top_texts_heaps[neuron_idx]
267
+ items = [val for (_, val) in heap]
268
+ reverse = not self._text_tracking_negative
269
+ items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=reverse)
270
+ if top_m is not None:
271
+ items_sorted = items_sorted[: top_m]
272
+
273
+ neuron_texts = []
274
+ for score, text, token_idx in items_sorted:
275
+ token_str = self._decode_token(text, token_idx)
276
+ neuron_texts.append(NeuronText(score=score, text=text, token_idx=token_idx, token_str=token_str))
277
+ return neuron_texts
278
+
279
+ def get_all_top_texts(self) -> list[list[NeuronText]]:
280
+ """Get top texts for all neurons."""
281
+ if self._top_texts_heaps is None:
282
+ return []
283
+ return [self.get_top_texts_for_neuron(i) for i in range(len(self._top_texts_heaps))]
284
+
285
+ def reset_top_texts(self) -> None:
286
+ """Reset all tracked top texts."""
287
+ self._top_texts_heaps = None
288
+
289
+ def export_top_texts_to_json(self, filepath: Path | str) -> Path:
290
+ if self._top_texts_heaps is None:
291
+ raise ValueError("No top texts available. Enable text tracking and run inference first.")
292
+
293
+ filepath = Path(filepath)
294
+ filepath.parent.mkdir(parents=True, exist_ok=True)
295
+
296
+ all_texts = self.get_all_top_texts()
297
+ export_data = {}
298
+
299
+ for neuron_idx, neuron_texts in enumerate(all_texts):
300
+ export_data[neuron_idx] = [
301
+ {
302
+ "text": nt.text,
303
+ "score": nt.score,
304
+ "token_str": nt.token_str,
305
+ "token_idx": nt.token_idx
306
+ }
307
+ for nt in neuron_texts
308
+ ]
309
+
310
+ with filepath.open("w", encoding="utf-8") as f:
311
+ json.dump(export_data, f, ensure_ascii=False, indent=2)
312
+
313
+ return filepath
314
+
315
+ def export_top_texts_to_csv(self, filepath: Path | str) -> Path:
316
+ if self._top_texts_heaps is None:
317
+ raise ValueError("No top texts available. Enable text tracking and run inference first.")
318
+
319
+ filepath = Path(filepath)
320
+ filepath.parent.mkdir(parents=True, exist_ok=True)
321
+
322
+ all_texts = self.get_all_top_texts()
323
+
324
+ with filepath.open("w", newline="", encoding="utf-8") as f:
325
+ writer = csv.writer(f)
326
+ writer.writerow(["neuron_idx", "text", "score", "token_str", "token_idx"])
327
+
328
+ for neuron_idx, neuron_texts in enumerate(all_texts):
329
+ for nt in neuron_texts:
330
+ writer.writerow([neuron_idx, nt.text, nt.score, nt.token_str, nt.token_idx])
331
+
332
+ return filepath