mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from typing import Dict, List, Callable, Sequence, Any, TYPE_CHECKING, Union
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
from transformers import AutoTokenizer
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from amber.language_model.context import LanguageModelContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LanguageModelTokenizer:
|
|
11
|
+
"""Handles tokenization for LanguageModel."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
context: "LanguageModelContext"
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Initialize LanguageModelTokenizer.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
context: LanguageModelContext instance
|
|
22
|
+
"""
|
|
23
|
+
self.context = context
|
|
24
|
+
|
|
25
|
+
def _setup_pad_token(self, tokenizer: Any, model: Any) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Setup pad token for tokenizer if not already set.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
tokenizer: Tokenizer instance
|
|
31
|
+
model: Model instance
|
|
32
|
+
"""
|
|
33
|
+
eos_token = getattr(tokenizer, "eos_token", None)
|
|
34
|
+
if eos_token is not None:
|
|
35
|
+
tokenizer.pad_token = eos_token
|
|
36
|
+
if hasattr(model, "config"):
|
|
37
|
+
model.config.pad_token_id = getattr(tokenizer, "eos_token_id", None)
|
|
38
|
+
else:
|
|
39
|
+
if hasattr(tokenizer, "add_special_tokens"):
|
|
40
|
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
41
|
+
if hasattr(model, "resize_token_embeddings"):
|
|
42
|
+
model.resize_token_embeddings(len(tokenizer))
|
|
43
|
+
if hasattr(model, "config"):
|
|
44
|
+
model.config.pad_token_id = getattr(tokenizer, "pad_token_id", None)
|
|
45
|
+
|
|
46
|
+
def split_to_tokens(
|
|
47
|
+
self,
|
|
48
|
+
text: Union[str, Sequence[str]],
|
|
49
|
+
add_special_tokens: bool = False
|
|
50
|
+
) -> Union[List[str], List[List[str]]]:
|
|
51
|
+
"""
|
|
52
|
+
Split text into token strings.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
text: Single string or sequence of strings to tokenize
|
|
56
|
+
add_special_tokens: Whether to add special tokens (e.g., BOS, EOS)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
For a single string: list of token strings
|
|
60
|
+
For a sequence of strings: list of lists of token strings
|
|
61
|
+
"""
|
|
62
|
+
if isinstance(text, str):
|
|
63
|
+
return self._split_single_text_to_tokens(text, add_special_tokens)
|
|
64
|
+
|
|
65
|
+
return [self._split_single_text_to_tokens(t, add_special_tokens) for t in text]
|
|
66
|
+
|
|
67
|
+
def _try_tokenize_with_method(
|
|
68
|
+
self,
|
|
69
|
+
tokenizer: Any,
|
|
70
|
+
text: str,
|
|
71
|
+
add_special_tokens: bool,
|
|
72
|
+
method_name: str,
|
|
73
|
+
fallback_method: str | None = None
|
|
74
|
+
) -> List[str] | None:
|
|
75
|
+
"""
|
|
76
|
+
Try tokenizing using a specific tokenizer method.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
tokenizer: Tokenizer instance
|
|
80
|
+
text: Text to tokenize
|
|
81
|
+
add_special_tokens: Whether to add special tokens
|
|
82
|
+
method_name: Primary method to try (e.g., "tokenize", "encode")
|
|
83
|
+
fallback_method: Optional fallback method (e.g., "convert_ids_to_tokens")
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of token strings or None if method fails
|
|
87
|
+
"""
|
|
88
|
+
if not hasattr(tokenizer, method_name):
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
if method_name == "tokenize":
|
|
93
|
+
return tokenizer.tokenize(text, add_special_tokens=add_special_tokens)
|
|
94
|
+
elif method_name == "encode":
|
|
95
|
+
if fallback_method and hasattr(tokenizer, fallback_method):
|
|
96
|
+
token_ids = tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
|
97
|
+
return tokenizer.convert_ids_to_tokens(token_ids)
|
|
98
|
+
elif method_name == "encode_plus":
|
|
99
|
+
if fallback_method and hasattr(tokenizer, fallback_method):
|
|
100
|
+
encoded = tokenizer.encode_plus(text, add_special_tokens=add_special_tokens)
|
|
101
|
+
if isinstance(encoded, dict) and "input_ids" in encoded:
|
|
102
|
+
token_ids = encoded["input_ids"]
|
|
103
|
+
return tokenizer.convert_ids_to_tokens(token_ids)
|
|
104
|
+
except (TypeError, ValueError, AttributeError):
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
def _split_single_text_to_tokens(self, text: str, add_special_tokens: bool) -> List[str]:
|
|
110
|
+
"""
|
|
111
|
+
Split a single text into token strings.
|
|
112
|
+
|
|
113
|
+
Uses the tokenizer from LanguageModelContext to split text into tokens.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
text: Text string to tokenize
|
|
117
|
+
add_special_tokens: Whether to add special tokens
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
List of token strings
|
|
121
|
+
"""
|
|
122
|
+
tokenizer = self.context.tokenizer
|
|
123
|
+
|
|
124
|
+
if tokenizer is None:
|
|
125
|
+
return text.split()
|
|
126
|
+
|
|
127
|
+
if not isinstance(text, str):
|
|
128
|
+
raise TypeError(f"Expected str, got {type(text)}")
|
|
129
|
+
|
|
130
|
+
# Try different tokenization methods in order
|
|
131
|
+
tokens = self._try_tokenize_with_method(tokenizer, text, add_special_tokens, "tokenize")
|
|
132
|
+
if tokens is not None:
|
|
133
|
+
return tokens
|
|
134
|
+
|
|
135
|
+
tokens = self._try_tokenize_with_method(tokenizer, text, add_special_tokens, "encode", "convert_ids_to_tokens")
|
|
136
|
+
if tokens is not None:
|
|
137
|
+
return tokens
|
|
138
|
+
|
|
139
|
+
tokens = self._try_tokenize_with_method(tokenizer, text, add_special_tokens, "encode_plus", "convert_ids_to_tokens")
|
|
140
|
+
if tokens is not None:
|
|
141
|
+
return tokens
|
|
142
|
+
|
|
143
|
+
return text.split()
|
|
144
|
+
|
|
145
|
+
def tokenize(
|
|
146
|
+
self,
|
|
147
|
+
texts: Sequence[str],
|
|
148
|
+
padding: bool = False,
|
|
149
|
+
pad_token: str = "[PAD]",
|
|
150
|
+
**kwargs: Any
|
|
151
|
+
) -> Any:
|
|
152
|
+
"""
|
|
153
|
+
Robust batch tokenization that works across tokenizer variants.
|
|
154
|
+
|
|
155
|
+
Tries methods in order:
|
|
156
|
+
- callable tokenizer (most HF tokenizers)
|
|
157
|
+
- batch_encode_plus
|
|
158
|
+
- encode_plus per item + tokenizer.pad to collate
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
texts: Sequence of text strings to tokenize
|
|
162
|
+
padding: Whether to pad sequences
|
|
163
|
+
pad_token: Pad token string
|
|
164
|
+
**kwargs: Additional tokenizer arguments
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Tokenized encodings
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
ValueError: If tokenizer is not initialized
|
|
171
|
+
TypeError: If tokenizer is not usable for batch tokenization
|
|
172
|
+
"""
|
|
173
|
+
tokenizer = self.context.tokenizer
|
|
174
|
+
if tokenizer is None:
|
|
175
|
+
raise ValueError("Tokenizer must be initialized before tokenization")
|
|
176
|
+
|
|
177
|
+
model = self.context.model
|
|
178
|
+
|
|
179
|
+
if padding and pad_token and getattr(tokenizer, "pad_token", None) is None:
|
|
180
|
+
self._setup_pad_token(tokenizer, model)
|
|
181
|
+
|
|
182
|
+
kwargs["padding"] = padding
|
|
183
|
+
|
|
184
|
+
# Try callable tokenizer first (most common case)
|
|
185
|
+
if callable(tokenizer):
|
|
186
|
+
try:
|
|
187
|
+
return tokenizer(texts, **kwargs)
|
|
188
|
+
except TypeError:
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
# Try batch_encode_plus
|
|
192
|
+
if hasattr(tokenizer, "batch_encode_plus"):
|
|
193
|
+
return tokenizer.batch_encode_plus(texts, **kwargs)
|
|
194
|
+
|
|
195
|
+
# Fallback to encode_plus per item
|
|
196
|
+
if hasattr(tokenizer, "encode_plus"):
|
|
197
|
+
encoded = [tokenizer.encode_plus(t, **kwargs) for t in texts]
|
|
198
|
+
if hasattr(tokenizer, "pad"):
|
|
199
|
+
rt = kwargs.get("return_tensors") or "pt"
|
|
200
|
+
return tokenizer.pad(encoded, return_tensors=rt)
|
|
201
|
+
return encoded
|
|
202
|
+
|
|
203
|
+
raise TypeError("Tokenizer object on LanguageModel is not usable for batch tokenization")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Utility functions for language model operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from transformers import PreTrainedTokenizerBase
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def extract_model_id(model: nn.Module, provided_model_id: str | None = None) -> str:
|
|
15
|
+
"""
|
|
16
|
+
Extract model ID from model or use provided one.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model: PyTorch model module
|
|
20
|
+
provided_model_id: Optional model ID provided by user
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Model ID string
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If model_id cannot be determined
|
|
27
|
+
"""
|
|
28
|
+
if provided_model_id is not None:
|
|
29
|
+
return provided_model_id
|
|
30
|
+
|
|
31
|
+
if hasattr(model, 'config') and hasattr(model.config, 'name_or_path'):
|
|
32
|
+
return model.config.name_or_path.replace("/", "_")
|
|
33
|
+
|
|
34
|
+
return model.__class__.__name__
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_device_from_model(model: nn.Module) -> torch.device:
|
|
38
|
+
"""
|
|
39
|
+
Get the device from model parameters.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model: PyTorch model module
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Device where model parameters are located, or CPU if no parameters
|
|
46
|
+
"""
|
|
47
|
+
first_param = next(model.parameters(), None)
|
|
48
|
+
return first_param.device if first_param is not None else torch.device("cpu")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def move_tensors_to_device(
|
|
52
|
+
tensors: dict[str, torch.Tensor],
|
|
53
|
+
device: torch.device
|
|
54
|
+
) -> dict[str, torch.Tensor]:
|
|
55
|
+
"""
|
|
56
|
+
Move dictionary of tensors to specified device.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
tensors: Dictionary of tensor name to tensor
|
|
60
|
+
device: Target device
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Dictionary with tensors moved to device
|
|
64
|
+
"""
|
|
65
|
+
device_type = str(device.type)
|
|
66
|
+
if device_type == "cuda":
|
|
67
|
+
return {k: v.to(device, non_blocking=True) for k, v in tensors.items()}
|
|
68
|
+
return {k: v.to(device) for k, v in tensors.items()}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def extract_logits_from_output(output: any) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Extract logits tensor from model output.
|
|
74
|
+
|
|
75
|
+
Handles various output formats:
|
|
76
|
+
- Objects with 'logits' attribute (e.g., HuggingFace model outputs)
|
|
77
|
+
- Tuples (takes first element)
|
|
78
|
+
- Direct tensors
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
output: Model output (various formats)
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Logits tensor
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If logits cannot be extracted from output
|
|
88
|
+
"""
|
|
89
|
+
if hasattr(output, 'logits'):
|
|
90
|
+
return output.logits
|
|
91
|
+
elif isinstance(output, tuple) and len(output) > 0:
|
|
92
|
+
return output[0]
|
|
93
|
+
elif isinstance(output, torch.Tensor):
|
|
94
|
+
return output
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Unable to extract logits from output type: {type(output)}")
|
|
97
|
+
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional, TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from amber.store.store import Store
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class AutoencoderContext:
|
|
12
|
+
"""Shared context for Autoencoder and its nested components."""
|
|
13
|
+
|
|
14
|
+
autoencoder: "Sae"
|
|
15
|
+
|
|
16
|
+
# Core SAE parameters
|
|
17
|
+
n_latents: int
|
|
18
|
+
n_inputs: int
|
|
19
|
+
|
|
20
|
+
# Language model parameters (shared across hierarchy)
|
|
21
|
+
lm: Optional["LanguageModel"] = None
|
|
22
|
+
lm_layer_signature: Optional[int | str] = None
|
|
23
|
+
model_id: Optional[str] = None
|
|
24
|
+
|
|
25
|
+
# Training/experiment metadata
|
|
26
|
+
device: str = 'cpu'
|
|
27
|
+
experiment_name: Optional[str] = None
|
|
28
|
+
run_id: Optional[str] = None
|
|
29
|
+
|
|
30
|
+
# Text tracking parameters
|
|
31
|
+
text_tracking_enabled: bool = False
|
|
32
|
+
text_tracking_k: int = 5
|
|
33
|
+
text_tracking_negative: bool = False
|
|
34
|
+
|
|
35
|
+
store: Optional[Store] = None
|
|
36
|
+
|
|
37
|
+
# Training parameters
|
|
38
|
+
tied: bool = False
|
|
39
|
+
bias_init: float = 0.0
|
|
40
|
+
init_method: str = "kaiming"
|
|
File without changes
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Sequence
|
|
5
|
+
import json
|
|
6
|
+
import csv
|
|
7
|
+
import heapq
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from amber.mechanistic.sae.concepts.concept_models import NeuronText
|
|
13
|
+
from amber.mechanistic.sae.autoencoder_context import AutoencoderContext
|
|
14
|
+
from amber.utils import get_logger
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
18
|
+
|
|
19
|
+
logger = get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AutoencoderConcepts:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
context: AutoencoderContext
|
|
26
|
+
):
|
|
27
|
+
self.context = context
|
|
28
|
+
self._n_size = context.n_latents
|
|
29
|
+
self.dictionary: ConceptDictionary | None = None
|
|
30
|
+
|
|
31
|
+
# Concept manipulation parameters
|
|
32
|
+
self.multiplication = nn.Parameter(torch.ones(self._n_size))
|
|
33
|
+
self.bias = nn.Parameter(torch.ones(self._n_size))
|
|
34
|
+
|
|
35
|
+
# Top texts tracking
|
|
36
|
+
self._top_texts_heaps: list[list[tuple[float, tuple[float, str, int]]]] | None = None
|
|
37
|
+
self._text_tracking_k: int = 5
|
|
38
|
+
self._text_tracking_negative: bool = False
|
|
39
|
+
|
|
40
|
+
def enable_text_tracking(self):
|
|
41
|
+
"""Enable text tracking using context parameters."""
|
|
42
|
+
if self.context.lm is None:
|
|
43
|
+
raise ValueError("LanguageModel must be set in context to enable tracking")
|
|
44
|
+
|
|
45
|
+
# Store tracking parameters
|
|
46
|
+
self._text_tracking_k = self.context.text_tracking_k
|
|
47
|
+
self._text_tracking_negative = self.context.text_tracking_negative
|
|
48
|
+
|
|
49
|
+
# Ensure InputTracker singleton exists on LanguageModel and enable it
|
|
50
|
+
input_tracker = self.context.lm._ensure_input_tracker()
|
|
51
|
+
input_tracker.enable()
|
|
52
|
+
|
|
53
|
+
# Enable text tracking on the SAE instance
|
|
54
|
+
if hasattr(self.context.autoencoder, '_text_tracking_enabled'):
|
|
55
|
+
self.context.autoencoder._text_tracking_enabled = True
|
|
56
|
+
|
|
57
|
+
def disable_text_tracking(self):
|
|
58
|
+
self.context.autoencoder._text_tracking_enabled = False
|
|
59
|
+
|
|
60
|
+
def _ensure_dictionary(self):
|
|
61
|
+
if self.dictionary is None:
|
|
62
|
+
from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
63
|
+
self.dictionary = ConceptDictionary(self._n_size)
|
|
64
|
+
return self.dictionary
|
|
65
|
+
|
|
66
|
+
def load_concepts_from_csv(self, csv_filepath: str | Path):
|
|
67
|
+
from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
68
|
+
self.dictionary = ConceptDictionary.from_csv(
|
|
69
|
+
csv_filepath=csv_filepath,
|
|
70
|
+
n_size=self._n_size,
|
|
71
|
+
store=self.dictionary.store if self.dictionary else None
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def load_concepts_from_json(self, json_filepath: str | Path):
|
|
75
|
+
from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
76
|
+
self.dictionary = ConceptDictionary.from_json(
|
|
77
|
+
json_filepath=json_filepath,
|
|
78
|
+
n_size=self._n_size,
|
|
79
|
+
store=self.dictionary.store if self.dictionary else None
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def generate_concepts_with_llm(self, llm_provider: str | None = None):
|
|
83
|
+
"""Generate concepts using LLM based on current top texts"""
|
|
84
|
+
if self._top_texts_heaps is None:
|
|
85
|
+
raise ValueError("No top texts available. Enable text tracking and run inference first.")
|
|
86
|
+
|
|
87
|
+
from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
88
|
+
neuron_texts = self.get_all_top_texts()
|
|
89
|
+
|
|
90
|
+
self.dictionary = ConceptDictionary.from_llm(
|
|
91
|
+
neuron_texts=neuron_texts,
|
|
92
|
+
n_size=self._n_size,
|
|
93
|
+
store=self.dictionary.store if self.dictionary else None,
|
|
94
|
+
llm_provider=llm_provider
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def _ensure_heaps(self, n_neurons: int) -> None:
|
|
98
|
+
"""Ensure heaps are initialized for the given number of neurons."""
|
|
99
|
+
if self._top_texts_heaps is None:
|
|
100
|
+
self._top_texts_heaps = [[] for _ in range(n_neurons)]
|
|
101
|
+
|
|
102
|
+
def _decode_token(self, text: str, token_idx: int) -> str:
|
|
103
|
+
"""
|
|
104
|
+
Decode a specific token from the text using the language model's tokenizer.
|
|
105
|
+
|
|
106
|
+
The token_idx is relative to the sequence length T that the model saw during inference.
|
|
107
|
+
However, there's a mismatch: during inference, texts are tokenized with
|
|
108
|
+
add_special_tokens=True (which adds BOS/EOS), but the token_idx appears to be
|
|
109
|
+
calculated relative to the sequence without special tokens.
|
|
110
|
+
|
|
111
|
+
We tokenize the text the same way as _decode_token originally did (without special tokens)
|
|
112
|
+
to match the token_idx calculation, but we also account for truncation that may have
|
|
113
|
+
occurred during inference (max_length).
|
|
114
|
+
"""
|
|
115
|
+
if self.context.lm is None:
|
|
116
|
+
return f"<token_{token_idx}>"
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
if self.context.lm.tokenizer is None:
|
|
120
|
+
return f"<token_{token_idx}>"
|
|
121
|
+
|
|
122
|
+
# Use the raw tokenizer (not the wrapper) to encode and decode
|
|
123
|
+
tokenizer = self.context.lm.tokenizer
|
|
124
|
+
|
|
125
|
+
# Tokenize without special tokens (matching original behavior)
|
|
126
|
+
# This matches how token_idx was calculated in update_top_texts_from_latents
|
|
127
|
+
tokens = tokenizer.encode(text, add_special_tokens=False)
|
|
128
|
+
|
|
129
|
+
# Check if token_idx is valid
|
|
130
|
+
if 0 <= token_idx < len(tokens):
|
|
131
|
+
token_id = tokens[token_idx]
|
|
132
|
+
# Decode the specific token
|
|
133
|
+
token_str = tokenizer.decode([token_id])
|
|
134
|
+
return token_str
|
|
135
|
+
else:
|
|
136
|
+
return f"<token_{token_idx}_out_of_range>"
|
|
137
|
+
except Exception as e:
|
|
138
|
+
# If tokenization fails, return a placeholder
|
|
139
|
+
logger.debug(f"Token decode error for token_idx={token_idx} in text (len={len(text)}): {e}")
|
|
140
|
+
return f"<token_{token_idx}_decode_error>"
|
|
141
|
+
|
|
142
|
+
def update_top_texts_from_latents(
|
|
143
|
+
self,
|
|
144
|
+
latents: torch.Tensor,
|
|
145
|
+
texts: Sequence[str],
|
|
146
|
+
original_shape: tuple[int, ...] | None = None
|
|
147
|
+
) -> None:
|
|
148
|
+
"""
|
|
149
|
+
Update top texts heaps from latents and texts.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
latents: Latent activations tensor, shape [B*T, n_latents] or [B, n_latents] (already flattened)
|
|
153
|
+
texts: List of texts corresponding to the batch
|
|
154
|
+
original_shape: Original shape before flattening, e.g., (B, T, D) or (B, D)
|
|
155
|
+
"""
|
|
156
|
+
if not texts:
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
n_neurons = latents.shape[-1]
|
|
160
|
+
self._ensure_heaps(n_neurons)
|
|
161
|
+
|
|
162
|
+
# Calculate batch and token dimensions
|
|
163
|
+
original_B = len(texts)
|
|
164
|
+
BT = latents.shape[0] # Total positions (B*T if 3D original, or B if 2D original)
|
|
165
|
+
|
|
166
|
+
# Determine if original was 3D or 2D
|
|
167
|
+
if original_shape is not None and len(original_shape) == 3:
|
|
168
|
+
# Original was [B, T, D], latents are [B*T, n_latents]
|
|
169
|
+
B, T, _ = original_shape
|
|
170
|
+
# Verify batch size matches
|
|
171
|
+
if B != original_B:
|
|
172
|
+
logger.warning(f"Batch size mismatch: original_shape has B={B}, but {original_B} texts provided")
|
|
173
|
+
# Use the actual number of texts as batch size
|
|
174
|
+
B = original_B
|
|
175
|
+
T = BT // B if B > 0 else 1
|
|
176
|
+
# Create token indices: [0, 1, 2, ..., T-1, 0, 1, 2, ..., T-1, ...]
|
|
177
|
+
token_indices = torch.arange(T, device='cpu').unsqueeze(0).expand(B, T).contiguous().view(B * T)
|
|
178
|
+
else:
|
|
179
|
+
# Original was [B, D], latents are [B, n_latents]
|
|
180
|
+
# All tokens are at index 0
|
|
181
|
+
T = 1
|
|
182
|
+
token_indices = torch.zeros(BT, dtype=torch.long, device='cpu')
|
|
183
|
+
|
|
184
|
+
# For each neuron, find the maximum activation per text
|
|
185
|
+
# This ensures we only track the best activation for each text, not every token position
|
|
186
|
+
for j in range(n_neurons):
|
|
187
|
+
heap = self._top_texts_heaps[j]
|
|
188
|
+
|
|
189
|
+
# For each text in the batch, find the max activation and its token position
|
|
190
|
+
texts_processed = 0
|
|
191
|
+
texts_added = 0
|
|
192
|
+
texts_updated = 0
|
|
193
|
+
texts_skipped_duplicate = 0
|
|
194
|
+
for batch_idx in range(original_B):
|
|
195
|
+
if batch_idx >= len(texts):
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
text = texts[batch_idx]
|
|
199
|
+
texts_processed += 1
|
|
200
|
+
|
|
201
|
+
# Get activations for this text (all token positions)
|
|
202
|
+
if original_shape is not None and len(original_shape) == 3:
|
|
203
|
+
# 3D case: [B, T, D] -> get slice for this batch
|
|
204
|
+
start_idx = batch_idx * T
|
|
205
|
+
end_idx = start_idx + T
|
|
206
|
+
text_activations = latents[start_idx:end_idx, j] # [T]
|
|
207
|
+
text_token_indices = token_indices[start_idx:end_idx] # [T]
|
|
208
|
+
else:
|
|
209
|
+
# 2D case: [B, D] -> single token
|
|
210
|
+
text_activations = latents[batch_idx:batch_idx + 1, j] # [1]
|
|
211
|
+
text_token_indices = token_indices[batch_idx:batch_idx + 1] # [1]
|
|
212
|
+
|
|
213
|
+
# Find the maximum activation (or minimum if tracking negative)
|
|
214
|
+
if self._text_tracking_negative:
|
|
215
|
+
# For negative tracking, find the most negative (minimum) value
|
|
216
|
+
max_idx = torch.argmin(text_activations)
|
|
217
|
+
max_score = float(text_activations[max_idx].item())
|
|
218
|
+
adj = -max_score # Negate for heap ordering
|
|
219
|
+
else:
|
|
220
|
+
# For positive tracking, find the maximum value
|
|
221
|
+
max_idx = torch.argmax(text_activations)
|
|
222
|
+
max_score = float(text_activations[max_idx].item())
|
|
223
|
+
adj = max_score
|
|
224
|
+
|
|
225
|
+
# Skip if score is zero (no activation)
|
|
226
|
+
if max_score == 0.0:
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
token_idx = int(text_token_indices[max_idx].item())
|
|
230
|
+
|
|
231
|
+
# Check if we already have this text in the heap
|
|
232
|
+
# If so, only update if this activation is better
|
|
233
|
+
existing_entry = None
|
|
234
|
+
heap_texts = []
|
|
235
|
+
for heap_idx, (heap_adj, (heap_score, heap_text, heap_token_idx)) in enumerate(heap):
|
|
236
|
+
heap_texts.append(heap_text[:50] if len(heap_text) > 50 else heap_text)
|
|
237
|
+
if heap_text == text:
|
|
238
|
+
existing_entry = (heap_idx, heap_adj, heap_score, heap_token_idx)
|
|
239
|
+
break
|
|
240
|
+
|
|
241
|
+
if existing_entry is not None:
|
|
242
|
+
# Update existing entry if this activation is better
|
|
243
|
+
heap_idx, heap_adj, heap_score, heap_token_idx = existing_entry
|
|
244
|
+
if adj > heap_adj:
|
|
245
|
+
# Replace with better activation
|
|
246
|
+
heap[heap_idx] = (adj, (max_score, text, token_idx))
|
|
247
|
+
heapq.heapify(heap) # Re-heapify after modification
|
|
248
|
+
texts_updated += 1
|
|
249
|
+
else:
|
|
250
|
+
texts_skipped_duplicate += 1
|
|
251
|
+
else:
|
|
252
|
+
# New text, add to heap
|
|
253
|
+
if len(heap) < self._text_tracking_k:
|
|
254
|
+
heapq.heappush(heap, (adj, (max_score, text, token_idx)))
|
|
255
|
+
texts_added += 1
|
|
256
|
+
else:
|
|
257
|
+
# Compare with smallest adjusted score; replace if better
|
|
258
|
+
if adj > heap[0][0]:
|
|
259
|
+
heapq.heapreplace(heap, (adj, (max_score, text, token_idx)))
|
|
260
|
+
texts_added += 1
|
|
261
|
+
|
|
262
|
+
def get_top_texts_for_neuron(self, neuron_idx: int, top_m: int | None = None) -> list[NeuronText]:
|
|
263
|
+
"""Get top texts for a specific neuron."""
|
|
264
|
+
if self._top_texts_heaps is None or neuron_idx < 0 or neuron_idx >= len(self._top_texts_heaps):
|
|
265
|
+
return []
|
|
266
|
+
heap = self._top_texts_heaps[neuron_idx]
|
|
267
|
+
items = [val for (_, val) in heap]
|
|
268
|
+
reverse = not self._text_tracking_negative
|
|
269
|
+
items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=reverse)
|
|
270
|
+
if top_m is not None:
|
|
271
|
+
items_sorted = items_sorted[: top_m]
|
|
272
|
+
|
|
273
|
+
neuron_texts = []
|
|
274
|
+
for score, text, token_idx in items_sorted:
|
|
275
|
+
token_str = self._decode_token(text, token_idx)
|
|
276
|
+
neuron_texts.append(NeuronText(score=score, text=text, token_idx=token_idx, token_str=token_str))
|
|
277
|
+
return neuron_texts
|
|
278
|
+
|
|
279
|
+
def get_all_top_texts(self) -> list[list[NeuronText]]:
|
|
280
|
+
"""Get top texts for all neurons."""
|
|
281
|
+
if self._top_texts_heaps is None:
|
|
282
|
+
return []
|
|
283
|
+
return [self.get_top_texts_for_neuron(i) for i in range(len(self._top_texts_heaps))]
|
|
284
|
+
|
|
285
|
+
def reset_top_texts(self) -> None:
|
|
286
|
+
"""Reset all tracked top texts."""
|
|
287
|
+
self._top_texts_heaps = None
|
|
288
|
+
|
|
289
|
+
def export_top_texts_to_json(self, filepath: Path | str) -> Path:
|
|
290
|
+
if self._top_texts_heaps is None:
|
|
291
|
+
raise ValueError("No top texts available. Enable text tracking and run inference first.")
|
|
292
|
+
|
|
293
|
+
filepath = Path(filepath)
|
|
294
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
295
|
+
|
|
296
|
+
all_texts = self.get_all_top_texts()
|
|
297
|
+
export_data = {}
|
|
298
|
+
|
|
299
|
+
for neuron_idx, neuron_texts in enumerate(all_texts):
|
|
300
|
+
export_data[neuron_idx] = [
|
|
301
|
+
{
|
|
302
|
+
"text": nt.text,
|
|
303
|
+
"score": nt.score,
|
|
304
|
+
"token_str": nt.token_str,
|
|
305
|
+
"token_idx": nt.token_idx
|
|
306
|
+
}
|
|
307
|
+
for nt in neuron_texts
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
with filepath.open("w", encoding="utf-8") as f:
|
|
311
|
+
json.dump(export_data, f, ensure_ascii=False, indent=2)
|
|
312
|
+
|
|
313
|
+
return filepath
|
|
314
|
+
|
|
315
|
+
def export_top_texts_to_csv(self, filepath: Path | str) -> Path:
|
|
316
|
+
if self._top_texts_heaps is None:
|
|
317
|
+
raise ValueError("No top texts available. Enable text tracking and run inference first.")
|
|
318
|
+
|
|
319
|
+
filepath = Path(filepath)
|
|
320
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
321
|
+
|
|
322
|
+
all_texts = self.get_all_top_texts()
|
|
323
|
+
|
|
324
|
+
with filepath.open("w", newline="", encoding="utf-8") as f:
|
|
325
|
+
writer = csv.writer(f)
|
|
326
|
+
writer.writerow(["neuron_idx", "text", "score", "token_str", "token_idx"])
|
|
327
|
+
|
|
328
|
+
for neuron_idx, neuron_texts in enumerate(all_texts):
|
|
329
|
+
for nt in neuron_texts:
|
|
330
|
+
writer.writerow([neuron_idx, nt.text, nt.score, nt.token_str, nt.token_idx])
|
|
331
|
+
|
|
332
|
+
return filepath
|