mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
amber/store/store.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, Any, List, Iterator
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
TensorMetadata = Dict[str, Dict[str, torch.Tensor]]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Store(abc.ABC):
|
|
14
|
+
"""Abstract store optimized for tensor batches grouped by run_id.
|
|
15
|
+
|
|
16
|
+
This interface intentionally excludes generic bytes/JSON APIs.
|
|
17
|
+
Implementations should focus on efficient safetensors-backed IO.
|
|
18
|
+
|
|
19
|
+
The store organizes data hierarchically:
|
|
20
|
+
- Runs: Top-level grouping by run_id
|
|
21
|
+
- Batches: Within each run, data is organized by batch_index
|
|
22
|
+
- Layers: Within each batch, tensors are organized by layer_signature
|
|
23
|
+
- Keys: Within each layer, tensors are identified by key (e.g., "activations")
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
base_path: Path | str = "",
|
|
29
|
+
runs_prefix: str = "runs",
|
|
30
|
+
dataset_prefix: str = "datasets",
|
|
31
|
+
model_prefix: str = "models",
|
|
32
|
+
):
|
|
33
|
+
"""Initialize Store.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
base_path: Base directory path for the store
|
|
37
|
+
runs_prefix: Prefix for runs directory (default: "runs")
|
|
38
|
+
dataset_prefix: Prefix for datasets directory (default: "datasets")
|
|
39
|
+
model_prefix: Prefix for models directory (default: "models")
|
|
40
|
+
"""
|
|
41
|
+
self.runs_prefix = runs_prefix
|
|
42
|
+
self.dataset_prefix = dataset_prefix
|
|
43
|
+
self.model_prefix = model_prefix
|
|
44
|
+
self.base_path = Path(base_path)
|
|
45
|
+
|
|
46
|
+
def _run_key(self, run_id: str) -> Path:
|
|
47
|
+
"""Get path for a run directory.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
run_id: Run identifier
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Path to run directory
|
|
54
|
+
"""
|
|
55
|
+
return self.base_path / self.runs_prefix / run_id
|
|
56
|
+
|
|
57
|
+
def _run_batch_key(self, run_id: str, batch_index: int) -> Path:
|
|
58
|
+
"""Get path for a batch directory within a run.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
run_id: Run identifier
|
|
62
|
+
batch_index: Batch index
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Path to batch directory
|
|
66
|
+
"""
|
|
67
|
+
return self._run_key(run_id) / f"batch_{batch_index}"
|
|
68
|
+
|
|
69
|
+
def _run_metadata_key(self, run_id: str) -> Path:
|
|
70
|
+
"""Get path for run metadata file.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
run_id: Run identifier
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Path to metadata JSON file
|
|
77
|
+
"""
|
|
78
|
+
return self._run_key(run_id) / "meta.json"
|
|
79
|
+
|
|
80
|
+
@abc.abstractmethod
|
|
81
|
+
def put_run_batch(self, run_id: str, batch_index: int,
|
|
82
|
+
tensors: List[torch.Tensor] | Dict[str, torch.Tensor]) -> str:
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
@abc.abstractmethod
|
|
86
|
+
def get_run_batch(self, run_id: str, batch_index: int) -> List[torch.Tensor] | Dict[
|
|
87
|
+
str, torch.Tensor]:
|
|
88
|
+
raise NotImplementedError
|
|
89
|
+
|
|
90
|
+
@abc.abstractmethod
|
|
91
|
+
def list_run_batches(self, run_id: str) -> List[int]:
|
|
92
|
+
raise NotImplementedError
|
|
93
|
+
|
|
94
|
+
def iter_run_batches(self, run_id: str) -> Iterator[List[torch.Tensor] | Dict[str, torch.Tensor]]:
|
|
95
|
+
for idx in self.list_run_batches(run_id):
|
|
96
|
+
yield self.get_run_batch(run_id, idx)
|
|
97
|
+
|
|
98
|
+
def iter_run_batch_range(
|
|
99
|
+
self,
|
|
100
|
+
run_id: str,
|
|
101
|
+
*,
|
|
102
|
+
start: int = 0,
|
|
103
|
+
stop: int | None = None,
|
|
104
|
+
step: int = 1,
|
|
105
|
+
) -> Iterator[List[torch.Tensor] | Dict[str, torch.Tensor]]:
|
|
106
|
+
"""Iterate run batches for indices in range(start, stop, step).
|
|
107
|
+
|
|
108
|
+
If stop is None, it will be set to max(list_run_batches(run_id)) + 1 (or 0 if none).
|
|
109
|
+
Raises ValueError if step == 0 or start < 0.
|
|
110
|
+
"""
|
|
111
|
+
if step == 0:
|
|
112
|
+
raise ValueError("step must not be 0")
|
|
113
|
+
if start < 0:
|
|
114
|
+
raise ValueError("start must be >= 0")
|
|
115
|
+
indices = self.list_run_batches(run_id)
|
|
116
|
+
if not indices:
|
|
117
|
+
return
|
|
118
|
+
max_idx = max(indices)
|
|
119
|
+
if stop is None:
|
|
120
|
+
stop = max_idx + 1
|
|
121
|
+
for idx in range(start, stop, step):
|
|
122
|
+
try:
|
|
123
|
+
yield self.get_run_batch(run_id, idx)
|
|
124
|
+
except FileNotFoundError:
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
@abc.abstractmethod
|
|
128
|
+
def delete_run(self, run_id: str) -> None:
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
@abc.abstractmethod
|
|
132
|
+
def put_run_metadata(self, run_id: str, meta: Dict[str, Any]) -> str:
|
|
133
|
+
"""Persist metadata for a run (e.g., dataset/model identifiers).
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
run_id: Run identifier
|
|
137
|
+
meta: Metadata dictionary to save (must be JSON-serializable)
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
String path/key where metadata was stored (e.g., "runs/{run_id}/meta.json")
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
ValueError: If run_id is invalid or meta is not JSON-serializable
|
|
144
|
+
OSError: If file system operations fail
|
|
145
|
+
|
|
146
|
+
Note:
|
|
147
|
+
Implementations should store JSON at a stable location, e.g., runs/{run_id}/meta.json.
|
|
148
|
+
"""
|
|
149
|
+
raise NotImplementedError
|
|
150
|
+
|
|
151
|
+
@abc.abstractmethod
|
|
152
|
+
def get_run_metadata(self, run_id: str) -> Dict[str, Any]:
|
|
153
|
+
"""Load metadata for a run.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
run_id: Run identifier
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Metadata dictionary, or empty dict if not found
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
ValueError: If run_id is invalid
|
|
163
|
+
json.JSONDecodeError: If metadata file exists but contains invalid JSON
|
|
164
|
+
"""
|
|
165
|
+
raise NotImplementedError
|
|
166
|
+
|
|
167
|
+
@abc.abstractmethod
|
|
168
|
+
def put_detector_metadata(
|
|
169
|
+
self,
|
|
170
|
+
run_id: str,
|
|
171
|
+
batch_index: int,
|
|
172
|
+
metadata: Dict[str, Any],
|
|
173
|
+
tensor_metadata: TensorMetadata
|
|
174
|
+
) -> str:
|
|
175
|
+
"""Save detector metadata with separate JSON and tensor store.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
run_id: Run identifier
|
|
179
|
+
batch_index: Batch index (must be non-negative)
|
|
180
|
+
metadata: JSON-serializable metadata dictionary (aggregated from all detectors)
|
|
181
|
+
tensor_metadata: Dictionary mapping layer_signature to dict of tensor_key -> tensor
|
|
182
|
+
(from all detectors)
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Full path key used for store (e.g., "runs/{run_id}/batch_{batch_index}")
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If parameters are invalid or metadata is not JSON-serializable
|
|
189
|
+
OSError: If file system operations fail
|
|
190
|
+
"""
|
|
191
|
+
raise NotImplementedError
|
|
192
|
+
|
|
193
|
+
@abc.abstractmethod
|
|
194
|
+
def get_detector_metadata(
|
|
195
|
+
self,
|
|
196
|
+
run_id: str,
|
|
197
|
+
batch_index: int
|
|
198
|
+
) -> tuple[Dict[str, Any], TensorMetadata]:
|
|
199
|
+
"""Load detector metadata with separate JSON and tensor store.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
run_id: Run identifier
|
|
203
|
+
batch_index: Batch index
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Tuple of (metadata dict, tensor_metadata dict). Returns empty dicts if not found.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
ValueError: If parameters are invalid or metadata format is invalid
|
|
210
|
+
json.JSONDecodeError: If metadata file exists but contains invalid JSON
|
|
211
|
+
OSError: If tensor files exist but cannot be loaded
|
|
212
|
+
"""
|
|
213
|
+
raise NotImplementedError
|
|
214
|
+
|
|
215
|
+
@abc.abstractmethod
|
|
216
|
+
def get_detector_metadata_by_layer_by_key(
|
|
217
|
+
self,
|
|
218
|
+
run_id: str,
|
|
219
|
+
batch_index: int,
|
|
220
|
+
layer: str,
|
|
221
|
+
key: str
|
|
222
|
+
) -> torch.Tensor:
|
|
223
|
+
"""Get a specific tensor from detector metadata by layer and key.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
run_id: Run identifier
|
|
227
|
+
batch_index: Batch index
|
|
228
|
+
layer: Layer signature
|
|
229
|
+
key: Tensor key (e.g., "activations")
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
The requested tensor
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
ValueError: If parameters are invalid
|
|
236
|
+
FileNotFoundError: If the tensor doesn't exist
|
|
237
|
+
OSError: If tensor file exists but cannot be loaded
|
|
238
|
+
"""
|
|
239
|
+
raise NotImplementedError
|
|
240
|
+
|
|
241
|
+
# --- Unified detector metadata for whole runs ---
|
|
242
|
+
@abc.abstractmethod
|
|
243
|
+
def put_run_detector_metadata(
|
|
244
|
+
self,
|
|
245
|
+
run_id: str,
|
|
246
|
+
metadata: Dict[str, Any],
|
|
247
|
+
tensor_metadata: TensorMetadata,
|
|
248
|
+
) -> str:
|
|
249
|
+
"""
|
|
250
|
+
Save detector metadata for a whole run in a unified location.
|
|
251
|
+
|
|
252
|
+
This differs from ``put_detector_metadata`` which organises data
|
|
253
|
+
per-batch under ``runs/{run_id}/batch_{batch_index}``.
|
|
254
|
+
|
|
255
|
+
``put_run_detector_metadata`` instead stores everything under
|
|
256
|
+
``runs/{run_id}/detectors``. Implementations are expected to
|
|
257
|
+
support being called multiple times for the same ``run_id`` and
|
|
258
|
+
append / aggregate new metadata rather than overwrite it.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
run_id: Run identifier
|
|
262
|
+
metadata: JSON-serialisable metadata dictionary aggregated
|
|
263
|
+
from all detectors for the current chunk / batch.
|
|
264
|
+
tensor_metadata: Dictionary mapping layer_signature to dict
|
|
265
|
+
of tensor_key -> tensor (from all detectors).
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
String path/key where metadata was stored
|
|
269
|
+
(e.g. ``runs/{run_id}/detectors``).
|
|
270
|
+
|
|
271
|
+
Raises:
|
|
272
|
+
ValueError: If parameters are invalid or metadata is not
|
|
273
|
+
JSON‑serialisable.
|
|
274
|
+
OSError: If file system operations fail.
|
|
275
|
+
"""
|
|
276
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from typing import Optional, Iterator, TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from amber.utils import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from amber.store.store import Store
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class StoreDataloader:
|
|
14
|
+
"""
|
|
15
|
+
A reusable DataLoader-like class that can be iterated multiple times.
|
|
16
|
+
|
|
17
|
+
This is needed because overcomplete's train_sae iterates over the dataloader
|
|
18
|
+
once per epoch, so we need a dataloader that can be iterated multiple times.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
store: "Store",
|
|
24
|
+
run_id: str,
|
|
25
|
+
layer: str,
|
|
26
|
+
key: str = "activations",
|
|
27
|
+
batch_size: int = 32,
|
|
28
|
+
dtype: Optional[torch.dtype] = None,
|
|
29
|
+
device: Optional[torch.device] = None,
|
|
30
|
+
max_batches: Optional[int] = None,
|
|
31
|
+
logger_instance=None
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Initialize StoreDataloader.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
store: Store instance containing activations
|
|
38
|
+
run_id: Run ID to iterate over
|
|
39
|
+
layer: Layer signature to load activations from
|
|
40
|
+
key: Tensor key to load (default: "activations")
|
|
41
|
+
batch_size: Mini-batch size
|
|
42
|
+
dtype: Optional dtype to cast activations to
|
|
43
|
+
device: Optional device to move tensors to
|
|
44
|
+
max_batches: Optional limit on number of batches per epoch
|
|
45
|
+
logger_instance: Optional logger instance for debug messages
|
|
46
|
+
"""
|
|
47
|
+
self.store = store
|
|
48
|
+
self.run_id = run_id
|
|
49
|
+
self.layer = layer
|
|
50
|
+
self.key = key
|
|
51
|
+
self.batch_size = batch_size
|
|
52
|
+
self.dtype = dtype
|
|
53
|
+
self.device = device
|
|
54
|
+
self.max_batches = max_batches
|
|
55
|
+
self.logger = logger_instance or logger
|
|
56
|
+
|
|
57
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
|
58
|
+
"""
|
|
59
|
+
Create a new iterator for each epoch.
|
|
60
|
+
|
|
61
|
+
This allows the dataloader to be iterated multiple times,
|
|
62
|
+
which is required for multiple epochs.
|
|
63
|
+
"""
|
|
64
|
+
batches_yielded = 0
|
|
65
|
+
|
|
66
|
+
# Get list of batch indices
|
|
67
|
+
batch_indices = self.store.list_run_batches(self.run_id)
|
|
68
|
+
|
|
69
|
+
for batch_index in batch_indices:
|
|
70
|
+
if self.max_batches is not None and batches_yielded >= self.max_batches:
|
|
71
|
+
break
|
|
72
|
+
|
|
73
|
+
acts = None
|
|
74
|
+
try:
|
|
75
|
+
# Try to load from detector metadata first
|
|
76
|
+
acts = self.store.get_detector_metadata_by_layer_by_key(
|
|
77
|
+
self.run_id,
|
|
78
|
+
batch_index,
|
|
79
|
+
self.layer,
|
|
80
|
+
self.key
|
|
81
|
+
)
|
|
82
|
+
except FileNotFoundError:
|
|
83
|
+
# Fall back to traditional batch files
|
|
84
|
+
try:
|
|
85
|
+
batch = self.store.get_run_batch(self.run_id, batch_index)
|
|
86
|
+
if isinstance(batch, dict) and self.key in batch:
|
|
87
|
+
acts = batch[self.key]
|
|
88
|
+
elif isinstance(batch, dict) and "activations" in batch:
|
|
89
|
+
# For backward compatibility, use "activations" if key not found
|
|
90
|
+
acts = batch["activations"]
|
|
91
|
+
except Exception:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
if acts is None:
|
|
95
|
+
if self.logger.isEnabledFor(self.logger.level):
|
|
96
|
+
self.logger.debug(
|
|
97
|
+
f"Skipping batch {batch_index}: tensor not found "
|
|
98
|
+
f"(run_id={self.run_id}, layer={self.layer}, key={self.key})"
|
|
99
|
+
)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
# Ensure 2D [N, D]
|
|
103
|
+
if acts.dim() > 2:
|
|
104
|
+
d = acts.shape[-1]
|
|
105
|
+
acts = acts.view(-1, d)
|
|
106
|
+
elif acts.dim() == 1:
|
|
107
|
+
acts = acts.view(1, -1)
|
|
108
|
+
|
|
109
|
+
# dtype handling
|
|
110
|
+
if self.dtype is not None:
|
|
111
|
+
acts = acts.to(self.dtype)
|
|
112
|
+
|
|
113
|
+
# device handling
|
|
114
|
+
if self.device is not None:
|
|
115
|
+
acts = acts.to(self.device)
|
|
116
|
+
|
|
117
|
+
# Yield mini-batches
|
|
118
|
+
bs = max(1, int(self.batch_size))
|
|
119
|
+
n = acts.shape[0]
|
|
120
|
+
for start in range(0, n, bs):
|
|
121
|
+
if self.max_batches is not None and batches_yielded >= self.max_batches:
|
|
122
|
+
return
|
|
123
|
+
yield acts[start:start + bs]
|
|
124
|
+
batches_yielded += 1
|
amber/utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def set_seed(seed: int, deterministic: bool = True) -> None:
|
|
13
|
+
"""Set seeds for python, numpy, and torch.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
seed: Seed value.
|
|
17
|
+
deterministic: If True, tries to make torch deterministic where possible.
|
|
18
|
+
"""
|
|
19
|
+
random.seed(seed)
|
|
20
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
21
|
+
np.random.seed(seed)
|
|
22
|
+
torch.manual_seed(seed)
|
|
23
|
+
torch.cuda.manual_seed_all(seed)
|
|
24
|
+
if deterministic:
|
|
25
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
26
|
+
torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
|
27
|
+
torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_logger(name: str = "amber", level: int | str = logging.INFO) -> logging.Logger:
|
|
31
|
+
"""Get a configured logger with a simple format. Idempotent."""
|
|
32
|
+
logger = logging.getLogger(name)
|
|
33
|
+
if isinstance(level, str):
|
|
34
|
+
level = logging.getLevelName(level)
|
|
35
|
+
logger.setLevel(level)
|
|
36
|
+
if not logger.handlers:
|
|
37
|
+
handler = logging.StreamHandler()
|
|
38
|
+
fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
|
39
|
+
handler.setFormatter(fmt)
|
|
40
|
+
logger.addHandler(handler)
|
|
41
|
+
# Allow propagation so pytest's caplog can capture logs
|
|
42
|
+
logger.propagate = True
|
|
43
|
+
else:
|
|
44
|
+
# Even if a handler exists (e.g., configured by the app), ensure propagation is enabled
|
|
45
|
+
logger.propagate = True
|
|
46
|
+
return logger
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mi-crow
|
|
3
|
+
Version: 0.1.1.post12
|
|
4
|
+
Summary: Engineer Thesis: Explaining and modifying LLM responses using SAE and concepts.
|
|
5
|
+
Author-email: Hubert Kowalski <your.email@example.com>, Adam Kaniasty <adam.kaniasty@gmail.com>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: accelerate>=1.10.1
|
|
9
|
+
Requires-Dist: datasets>=4.0.0
|
|
10
|
+
Requires-Dist: notebook>=7.4.7
|
|
11
|
+
Requires-Dist: overcomplete
|
|
12
|
+
Requires-Dist: pre-commit>=4.3.0
|
|
13
|
+
Requires-Dist: setuptools-scm>=8
|
|
14
|
+
Requires-Dist: torch>=2.8.0
|
|
15
|
+
Requires-Dist: tqdm>=4.67.1
|
|
16
|
+
Requires-Dist: transformers>=4.56.1
|
|
17
|
+
Requires-Dist: safetensors>=0.4.5
|
|
18
|
+
Requires-Dist: wandb>=0.22.1
|
|
19
|
+
Requires-Dist: pytest>=8.4.2
|
|
20
|
+
Requires-Dist: pytest-xdist>=3.8.0
|
|
21
|
+
Requires-Dist: seaborn>=0.13.2
|
|
22
|
+
Provides-Extra: dev
|
|
23
|
+
Requires-Dist: pre-commit>=4.3.0; extra == "dev"
|
|
24
|
+
Requires-Dist: ruff>=0.13.2; extra == "dev"
|
|
25
|
+
Requires-Dist: setuptools-scm>=8; extra == "dev"
|
|
26
|
+
Requires-Dist: wheel>=0.45.1; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
|
28
|
+
Requires-Dist: pytest-cov>=5.0.0; extra == "dev"
|
|
29
|
+
Provides-Extra: docs
|
|
30
|
+
Requires-Dist: mkdocs>=1.6; extra == "docs"
|
|
31
|
+
Requires-Dist: mkdocs-material>=9.5; extra == "docs"
|
|
32
|
+
Requires-Dist: mkdocstrings[python]>=0.26; extra == "docs"
|
|
33
|
+
Requires-Dist: mkdocs-section-index>=0.3; extra == "docs"
|
|
34
|
+
Requires-Dist: mkdocs-redirects>=1.2; extra == "docs"
|
|
35
|
+
Requires-Dist: mkdocs-literate-nav>=0.6; extra == "docs"
|
|
36
|
+
Requires-Dist: mkdocs-gen-files>=0.5; extra == "docs"
|
|
37
|
+
Requires-Dist: mike>=2.1; extra == "docs"
|
|
38
|
+
|
|
39
|
+

|
|
40
|
+
[](https://adamkaniasty.github.io/Inzynierka/)
|
|
41
|
+
|
|
42
|
+
## Running Tests
|
|
43
|
+
|
|
44
|
+
The project uses pytest for testing. Tests are organized into unit tests and end-to-end tests.
|
|
45
|
+
|
|
46
|
+
### Running All Tests
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
pytest
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
### Running Specific Test Suites
|
|
53
|
+
|
|
54
|
+
Run only unit tests:
|
|
55
|
+
```bash
|
|
56
|
+
pytest --unit -q
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
Run only end-to-end tests:
|
|
60
|
+
```bash
|
|
61
|
+
pytest --e2e -q
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
You can also use pytest markers:
|
|
65
|
+
```bash
|
|
66
|
+
pytest -m unit -q
|
|
67
|
+
pytest -m e2e -q
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
Or specify the test directory directly:
|
|
71
|
+
```bash
|
|
72
|
+
pytest tests/unit -q
|
|
73
|
+
pytest tests/e2e -q
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
### Test Coverage
|
|
77
|
+
|
|
78
|
+
The test suite is configured to require at least 85% code coverage. Coverage reports are generated in both terminal and XML formats.
|
|
79
|
+
|
|
80
|
+
## Backend (FastAPI) quickstart
|
|
81
|
+
|
|
82
|
+
Install server-only dependencies (kept out of the core library) with uv:
|
|
83
|
+
```bash
|
|
84
|
+
uv sync --group server
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
Run the API:
|
|
88
|
+
```bash
|
|
89
|
+
uv run --group server uvicorn server.main:app --reload
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
Smoke-test the server endpoints:
|
|
93
|
+
```bash
|
|
94
|
+
uv run --group server pytest tests/server/test_api.py --cov=server --cov-fail-under=0
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
### SAE API usage
|
|
98
|
+
|
|
99
|
+
- Configure artifact location (optional): `export SERVER_ARTIFACT_BASE_PATH=/path/to/amber_artifacts` (defaults to `~/.cache/amber_server`)
|
|
100
|
+
- Load a model: `curl -X POST http://localhost:8000/models/load -H "Content-Type: application/json" -d '{"model_id":"bielik"}'`
|
|
101
|
+
- Save activations from dataset (stored in `LocalStore` under `activations/<model>/<run_id>`):
|
|
102
|
+
- HF dataset: `{"dataset":{"type":"hf","name":"ag_news","split":"train","text_field":"text"}}`
|
|
103
|
+
- Local files: `{"dataset":{"type":"local","paths":["/path/to/file.txt"]}}`
|
|
104
|
+
- Example: `curl -X POST http://localhost:8000/sae/activations/save -H "Content-Type: application/json" -d '{"model_id":"bielik","layers":["dummy_root"],"dataset":{"type":"local","paths":["/tmp/data.txt"]},"sample_limit":100,"batch_size":4,"shard_size":64}'` → returns a manifest path, run_id, token counts, and batch metadata.
|
|
105
|
+
- List activation runs: `curl "http://localhost:8000/sae/activations?model_id=bielik"`
|
|
106
|
+
- Start SAE training (async job, uses `SaeTrainer`): `curl -X POST http://localhost:8000/sae/train -H "Content-Type: application/json" -d '{"model_id":"bielik","activations_path":"/path/to/manifest.json","layer":"<layer_name>","sae_class":"TopKSae","hyperparams":{"epochs":1,"batch_size":256}}'` → returns `job_id`
|
|
107
|
+
- Check job status: `curl http://localhost:8000/sae/train/status/<job_id>` (returns `sae_id`, `sae_path`, `metadata_path`, progress, and logs)
|
|
108
|
+
- Cancel a job (best-effort): `curl -X POST http://localhost:8000/sae/train/cancel/<job_id>`
|
|
109
|
+
- Load an SAE: `curl -X POST http://localhost:8000/sae/load -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_path":"/path/to/sae.json"}'`
|
|
110
|
+
- List SAEs: `curl "http://localhost:8000/sae/saes?model_id=bielik"`
|
|
111
|
+
- Run SAE inference (optionally save top texts and apply concept config): `curl -X POST http://localhost:8000/sae/infer -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_id":"<sae_id>","save_top_texts":true,"top_k_neurons":5,"concept_config_path":"/path/to/concepts.json","inputs":[{"prompt":"hi"}]}'` → returns outputs, top neuron summary, sae metadata, and saved top-texts path when requested.
|
|
112
|
+
- Per-token latents: add `"return_token_latents": true` (default off) to include top-k neuron activations per token.
|
|
113
|
+
- List concepts: `curl "http://localhost:8000/sae/concepts?model_id=bielik&sae_id=<sae_id>"`
|
|
114
|
+
- Load concepts from a file (validated against SAE latents): `curl -X POST http://localhost:8000/sae/concepts/load -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_id":"<sae_id>","source_path":"/path/to/concepts.json"}'`
|
|
115
|
+
- Manipulate concepts (saves a config file for inference-time scaling): `curl -X POST http://localhost:8000/sae/concepts/manipulate -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_id":"<sae_id>","edits":{"0":1.2}}'`
|
|
116
|
+
- List concept configs: `curl "http://localhost:8000/sae/concepts/configs?model_id=bielik&sae_id=<sae_id>"`
|
|
117
|
+
- Preview concept config (validate without saving): `curl -X POST http://localhost:8000/sae/concepts/preview -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_id":"<sae_id>","edits":{"0":1.2}}'`
|
|
118
|
+
- Delete activation run or SAE (requires API key if set): `curl -X DELETE "http://localhost:8000/sae/activations/<run_id>?model_id=bielik" -H "X-API-Key: <key>"` and `curl -X DELETE "http://localhost:8000/sae/saes/<sae_id>?model_id=bielik" -H "X-API-Key: <key>"`
|
|
119
|
+
- Health/metrics summary: `curl http://localhost:8000/health/metrics` (in-memory job counts; no persistence, no auth)
|
|
120
|
+
|
|
121
|
+
Notes:
|
|
122
|
+
- Job manager is in-memory/lightweight: jobs disappear on process restart; idempotency is best-effort via payload key.
|
|
123
|
+
- Training/inference currently run in-process threads; add your own resource guards when running heavy models.
|
|
124
|
+
- Optional API key protection: set `SERVER_API_KEY=<value>` to require `X-API-Key` on protected endpoints (delete).
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
amber/__init__.py,sha256=5nh0D8qvFgOhBEQj00Rm06T1iY5VcSiifAg9SoY1LLA,483
|
|
2
|
+
amber/utils.py,sha256=oER2LA_alUjaIk_xCAyP2V54ywjqsg00I4KvitYnJPc,1547
|
|
3
|
+
amber/datasets/__init__.py,sha256=zhqgbm5zMBsRbmPNfjlYNJwGWOLuCNf5jEj0P8aopRU,341
|
|
4
|
+
amber/datasets/base_dataset.py,sha256=X2wt3GdjgAOY24_vOqrD5gVFxGplSRMCb69CoQtj0xw,22508
|
|
5
|
+
amber/datasets/classification_dataset.py,sha256=x_ZQ4dMzoY3Nn8V1I01xvzJK_IcHcDcm8dIxYeXzV5g,21700
|
|
6
|
+
amber/datasets/loading_strategy.py,sha256=17VM3Td8lqDllGIx9DHI6WiXmSKKQHDHbfe4ZeM8ATA,1206
|
|
7
|
+
amber/datasets/text_dataset.py,sha256=ly0GHCS28Rg5ZluaafjavhcbvSD9-6ryovd_Y1ZIMms,16775
|
|
8
|
+
amber/hooks/__init__.py,sha256=9H08ZVoTK6TzYJXjEP2aqdHfoyLfdXvg6eOv3K1zNps,679
|
|
9
|
+
amber/hooks/controller.py,sha256=hc8FrrDosFYLrEGsEZmx1KsJ77F4p_gMKcF2WzHiURY,6057
|
|
10
|
+
amber/hooks/detector.py,sha256=5drJFrdrjseVjRNT-cq-U8XCt8AXV04YY2YrkQz4eFk,3110
|
|
11
|
+
amber/hooks/hook.py,sha256=-Qi-GJqRuIskXMuHUzp9_ESbbZi5tLSAFMWoBmrP3io,7540
|
|
12
|
+
amber/hooks/utils.py,sha256=wtsrjsMt-bXR3NshkwyZmfLre3IE3S4E5EoKppQrYOo,2022
|
|
13
|
+
amber/hooks/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
amber/hooks/implementations/function_controller.py,sha256=66FFx_7sU7b0_FFQFFkAOqQm3aGsiyIRuZ64hIv-0w8,3088
|
|
15
|
+
amber/hooks/implementations/layer_activation_detector.py,sha256=bzoW6V8NNDNgRASs1YN_1TjEXaK3ahoNWiZ-ODfjB6I,3161
|
|
16
|
+
amber/hooks/implementations/model_input_detector.py,sha256=cYRVfyBEHi-1qg6F-4Q0vKEae6gYtq_3g1j3rOOCQdA,10074
|
|
17
|
+
amber/hooks/implementations/model_output_detector.py,sha256=iN-twt7Chc9ODmj-iei7_Ah7GqvE-knTVWi4C9kNye4,4879
|
|
18
|
+
amber/language_model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
|
+
amber/language_model/activations.py,sha256=7VRDlCM-IctF1ee0H68_Dd7xeXHyUSQe1SOPb28Ej18,16971
|
|
20
|
+
amber/language_model/context.py,sha256=koslpikCcu9Svsopboa1wd7Fv0R2-4sI2whOCvVWvT8,1118
|
|
21
|
+
amber/language_model/contracts.py,sha256=6ij7rzJcpSAKgYx-fiefg0Fi8TsFugaM5IlihEm_2VM,268
|
|
22
|
+
amber/language_model/hook_metadata.py,sha256=9Xyfiu4ekCZj79zG4gZfLk-850AO2iKDE24FDXe7q7s,1392
|
|
23
|
+
amber/language_model/inference.py,sha256=l8BASS8E9B4VWJHucEqF_G_zqOzlKeG4KEvGameBbMw,19068
|
|
24
|
+
amber/language_model/initialization.py,sha256=hfrKdI_fsmaxk0p9q4wN7EFxq_lSXs8BXGlxwKJ21Qw,3698
|
|
25
|
+
amber/language_model/language_model.py,sha256=MXoaXYbNBUxHw4sxtWkVFLbXiw9tFEx9GI78sCiESuQ,13943
|
|
26
|
+
amber/language_model/layers.py,sha256=Ob7QZl8i236ALLklY9o_xtjDZSt6FD8sqdmFy_YLgN0,15906
|
|
27
|
+
amber/language_model/persistence.py,sha256=i2ibDH1OABM5-ZNNLh7h4rOYWPsg3aaeYhmB_xWYDZw,5867
|
|
28
|
+
amber/language_model/tokenizer.py,sha256=9eKNOHvUjIJhJbj7M-tN7jWU5lWhOeCY_cssa4exQ1g,7377
|
|
29
|
+
amber/language_model/utils.py,sha256=5Y7scRvvudUjKDV8QPhC3HAc2S-dCuqbm6xEjRr0fRM,2630
|
|
30
|
+
amber/mechanistic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
+
amber/mechanistic/sae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
|
+
amber/mechanistic/sae/autoencoder_context.py,sha256=cn0mv9COqT3jNcvXBfce70ankVEmw9kNE3Mu-knugoc,945
|
|
33
|
+
amber/mechanistic/sae/sae.py,sha256=ha6rXGsOXE59E_ohTH0vJh6M4rQh3Xw0GfmCkSgeYS4,6035
|
|
34
|
+
amber/mechanistic/sae/sae_trainer.py,sha256=GMrPz9SpSuANA0tJt3IkyIOOqVr2k7apE0w4CqL92gM,26311
|
|
35
|
+
amber/mechanistic/sae/concepts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
+
amber/mechanistic/sae/concepts/autoencoder_concepts.py,sha256=dTB9v5zK5lcKc4_zbYtWoCy8fA6ycj1-FJid4-7hQ04,14371
|
|
37
|
+
amber/mechanistic/sae/concepts/concept_dictionary.py,sha256=9px845gODuufW2koym-5426q0ijdpxvS7avH1vu8-Ls,7671
|
|
38
|
+
amber/mechanistic/sae/concepts/concept_models.py,sha256=HGyPoMSmj8CAg9joIa6fV_FfWAY2XN4LFPIkNQiJWbY,133
|
|
39
|
+
amber/mechanistic/sae/concepts/input_tracker.py,sha256=81FrOv9AAC7ejhryOWDTZ7Hlt3B2WoANx-wiO0KLr24,1886
|
|
40
|
+
amber/mechanistic/sae/modules/__init__.py,sha256=xpoz0HtPWoJD4dPj1qHaxtXDr7J0ERn30CX3m1dz21s,239
|
|
41
|
+
amber/mechanistic/sae/modules/l1_sae.py,sha256=_BebvpB9iUCTDjSiYNzDBM4l9sU_wadAdypwXaSb4ww,15352
|
|
42
|
+
amber/mechanistic/sae/modules/topk_sae.py,sha256=GQA8hYb6Fw7U1e5ExZjzZBjAZRhGk-VwqqYREVCQ_u8,17275
|
|
43
|
+
amber/mechanistic/sae/training/wandb_logger.py,sha256=d3vVBIQrnsJurX5HNVu7OYW4DqNgv18UZDpV8ddfN9k,8554
|
|
44
|
+
amber/store/__init__.py,sha256=UW4Hqyu-_qgnZ-gN_mk97OaWSrlPERcNi5YjnXMKeOU,119
|
|
45
|
+
amber/store/local_store.py,sha256=1pJbizZKrzNt_IQFnCFYjApPXs9ot-G1H8adeR7Qi50,17214
|
|
46
|
+
amber/store/store.py,sha256=VuDe9Git0glND3TTHh0zhDJNxdQY3dCp0cURhApYQbU,9334
|
|
47
|
+
amber/store/store_dataloader.py,sha256=QyYHSgOos8e-yzaEE_rySSVlGKaRNybURSDCgNrTIVM,4337
|
|
48
|
+
mi_crow-0.1.1.post12.dist-info/METADATA,sha256=1_SPBcnK8j_scOIWwnPqQGjlUpNnXxGGtvkYRZHeDQ8,6580
|
|
49
|
+
mi_crow-0.1.1.post12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
50
|
+
mi_crow-0.1.1.post12.dist-info/top_level.txt,sha256=FNP1x_ePvcW9Jsr7J9gCBARdDC-gqxIYtWF6HGNxtnI,6
|
|
51
|
+
mi_crow-0.1.1.post12.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
amber
|