mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset, IterableDataset, load_dataset
|
|
7
|
+
|
|
8
|
+
from amber.datasets.base_dataset import BaseDataset
|
|
9
|
+
from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
|
|
10
|
+
from amber.store.store import Store
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TextDataset(BaseDataset):
|
|
14
|
+
"""
|
|
15
|
+
Text-only dataset with support for multiple sources and loading strategies.
|
|
16
|
+
Each item is a string (text snippet).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
ds: Dataset | IterableDataset,
|
|
22
|
+
store: Store,
|
|
23
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
24
|
+
text_field: str = "text",
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize text dataset.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
ds: HuggingFace Dataset or IterableDataset
|
|
31
|
+
store: Store instance
|
|
32
|
+
loading_strategy: Loading strategy
|
|
33
|
+
text_field: Name of the column containing text
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If text_field is empty or not found in dataset
|
|
37
|
+
"""
|
|
38
|
+
self._validate_text_field(text_field)
|
|
39
|
+
|
|
40
|
+
# Validate and prepare dataset
|
|
41
|
+
is_iterable = isinstance(ds, IterableDataset)
|
|
42
|
+
if not is_iterable:
|
|
43
|
+
if text_field not in ds.column_names:
|
|
44
|
+
raise ValueError(f"Dataset must have a '{text_field}' column; got columns: {ds.column_names}")
|
|
45
|
+
# Keep only text column for memory efficiency
|
|
46
|
+
columns_to_remove = [c for c in ds.column_names if c != text_field]
|
|
47
|
+
if columns_to_remove:
|
|
48
|
+
ds = ds.remove_columns(columns_to_remove)
|
|
49
|
+
if text_field != "text":
|
|
50
|
+
ds = ds.rename_column(text_field, "text")
|
|
51
|
+
ds.set_format("python", columns=["text"])
|
|
52
|
+
|
|
53
|
+
self._text_field = text_field
|
|
54
|
+
super().__init__(ds, store=store, loading_strategy=loading_strategy)
|
|
55
|
+
|
|
56
|
+
def _validate_text_field(self, text_field: str) -> None:
|
|
57
|
+
"""Validate text_field parameter.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
text_field: Text field name to validate
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If text_field is empty or not a string
|
|
64
|
+
"""
|
|
65
|
+
if not text_field or not isinstance(text_field, str) or not text_field.strip():
|
|
66
|
+
raise ValueError(f"text_field must be a non-empty string, got: {text_field!r}")
|
|
67
|
+
|
|
68
|
+
def _extract_text_from_row(self, row: Dict[str, Any]) -> Optional[str]:
|
|
69
|
+
"""Extract text from a dataset row.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
row: Dataset row dictionary
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Text string from the row
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If text field is not found in row
|
|
79
|
+
"""
|
|
80
|
+
if self._text_field in row:
|
|
81
|
+
text = row[self._text_field]
|
|
82
|
+
elif "text" in row:
|
|
83
|
+
text = row["text"]
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Text field '{self._text_field}' or 'text' not found in dataset row. "
|
|
87
|
+
f"Available fields: {list(row.keys())}"
|
|
88
|
+
)
|
|
89
|
+
return text
|
|
90
|
+
|
|
91
|
+
def __len__(self) -> int:
|
|
92
|
+
"""
|
|
93
|
+
Return the number of items in the dataset.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
97
|
+
"""
|
|
98
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
99
|
+
raise NotImplementedError("len() not supported for STREAMING datasets")
|
|
100
|
+
return self._ds.num_rows
|
|
101
|
+
|
|
102
|
+
def __getitem__(self, idx: IndexLike) -> Union[Optional[str], List[Optional[str]]]:
|
|
103
|
+
"""
|
|
104
|
+
Get text item(s) by index.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
idx: Index (int), slice, or sequence of indices
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Single text string or list of text strings
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
114
|
+
IndexError: If index is out of bounds
|
|
115
|
+
ValueError: If dataset is empty
|
|
116
|
+
"""
|
|
117
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
118
|
+
raise NotImplementedError(
|
|
119
|
+
"Indexing not supported for STREAMING datasets. Use iter_items or iter_batches."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
dataset_len = len(self)
|
|
123
|
+
if dataset_len == 0:
|
|
124
|
+
raise ValueError("Cannot index into empty dataset")
|
|
125
|
+
|
|
126
|
+
if isinstance(idx, int):
|
|
127
|
+
if idx < 0:
|
|
128
|
+
idx = dataset_len + idx
|
|
129
|
+
if idx < 0 or idx >= dataset_len:
|
|
130
|
+
raise IndexError(f"Index {idx} out of bounds for dataset of length {dataset_len}")
|
|
131
|
+
return self._ds[idx]["text"]
|
|
132
|
+
|
|
133
|
+
if isinstance(idx, slice):
|
|
134
|
+
start, stop, step = idx.indices(dataset_len)
|
|
135
|
+
if step != 1:
|
|
136
|
+
indices = list(range(start, stop, step))
|
|
137
|
+
out = self._ds.select(indices)["text"]
|
|
138
|
+
else:
|
|
139
|
+
out = self._ds.select(range(start, stop))["text"]
|
|
140
|
+
return list(out)
|
|
141
|
+
|
|
142
|
+
if isinstance(idx, Sequence):
|
|
143
|
+
# Validate all indices are in bounds
|
|
144
|
+
invalid_indices = [i for i in idx if not (0 <= i < dataset_len)]
|
|
145
|
+
if invalid_indices:
|
|
146
|
+
raise IndexError(f"Indices out of bounds: {invalid_indices} (dataset length: {dataset_len})")
|
|
147
|
+
out = self._ds.select(list(idx))["text"]
|
|
148
|
+
return list(out)
|
|
149
|
+
|
|
150
|
+
raise TypeError(f"Invalid index type: {type(idx)}")
|
|
151
|
+
|
|
152
|
+
def iter_items(self) -> Iterator[Optional[str]]:
|
|
153
|
+
"""
|
|
154
|
+
Iterate over text items one by one.
|
|
155
|
+
|
|
156
|
+
Yields:
|
|
157
|
+
Text strings from the dataset
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
ValueError: If text field is not found in any row
|
|
161
|
+
"""
|
|
162
|
+
for row in self._ds:
|
|
163
|
+
yield self._extract_text_from_row(row)
|
|
164
|
+
|
|
165
|
+
def iter_batches(self, batch_size: int) -> Iterator[List[Optional[str]]]:
|
|
166
|
+
"""
|
|
167
|
+
Iterate over text items in batches.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
batch_size: Number of items per batch
|
|
171
|
+
|
|
172
|
+
Yields:
|
|
173
|
+
Lists of text strings (batches)
|
|
174
|
+
|
|
175
|
+
Raises:
|
|
176
|
+
ValueError: If batch_size <= 0 or text field is not found in any row
|
|
177
|
+
"""
|
|
178
|
+
if batch_size <= 0:
|
|
179
|
+
raise ValueError(f"batch_size must be > 0, got: {batch_size}")
|
|
180
|
+
|
|
181
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
182
|
+
batch = []
|
|
183
|
+
for row in self._ds:
|
|
184
|
+
batch.append(self._extract_text_from_row(row))
|
|
185
|
+
if len(batch) >= batch_size:
|
|
186
|
+
yield batch
|
|
187
|
+
batch = []
|
|
188
|
+
if batch:
|
|
189
|
+
yield batch
|
|
190
|
+
else:
|
|
191
|
+
for batch in self._ds.iter(batch_size=batch_size):
|
|
192
|
+
yield list(batch["text"])
|
|
193
|
+
|
|
194
|
+
def extract_texts_from_batch(self, batch: List[Optional[str]]) -> List[Optional[str]]:
|
|
195
|
+
"""Extract text strings from a batch.
|
|
196
|
+
|
|
197
|
+
For TextDataset, batch items are already strings, so return as-is.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
batch: List of text strings
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
List of text strings (same as input)
|
|
204
|
+
"""
|
|
205
|
+
return batch
|
|
206
|
+
|
|
207
|
+
def get_all_texts(self) -> List[Optional[str]]:
|
|
208
|
+
"""Get all texts from the dataset.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
List of all text strings
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
215
|
+
"""
|
|
216
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
217
|
+
return list(self.iter_items())
|
|
218
|
+
return list(self._ds["text"])
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def from_huggingface(
|
|
222
|
+
cls,
|
|
223
|
+
repo_id: str,
|
|
224
|
+
store: Store,
|
|
225
|
+
*,
|
|
226
|
+
split: str = "train",
|
|
227
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
228
|
+
revision: Optional[str] = None,
|
|
229
|
+
text_field: str = "text",
|
|
230
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
231
|
+
limit: Optional[int] = None,
|
|
232
|
+
stratify_by: Optional[str] = None,
|
|
233
|
+
stratify_seed: Optional[int] = None,
|
|
234
|
+
streaming: Optional[bool] = None,
|
|
235
|
+
drop_na: bool = False,
|
|
236
|
+
**kwargs,
|
|
237
|
+
) -> "TextDataset":
|
|
238
|
+
"""
|
|
239
|
+
Load text dataset from HuggingFace Hub.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
repo_id: HuggingFace dataset repository ID
|
|
243
|
+
store: Store instance
|
|
244
|
+
split: Dataset split
|
|
245
|
+
loading_strategy: Loading strategy
|
|
246
|
+
revision: Optional git revision
|
|
247
|
+
text_field: Name of the column containing text
|
|
248
|
+
filters: Optional filters to apply (dict of column: value)
|
|
249
|
+
limit: Optional limit on number of rows
|
|
250
|
+
stratify_by: Optional column used for stratified sampling (non-streaming only)
|
|
251
|
+
stratify_seed: Optional RNG seed for deterministic stratification
|
|
252
|
+
streaming: Optional override for streaming
|
|
253
|
+
drop_na: Whether to drop rows with None/empty text
|
|
254
|
+
**kwargs: Additional arguments for load_dataset
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
TextDataset instance
|
|
258
|
+
|
|
259
|
+
Raises:
|
|
260
|
+
ValueError: If parameters are invalid
|
|
261
|
+
RuntimeError: If dataset loading fails
|
|
262
|
+
"""
|
|
263
|
+
use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)
|
|
264
|
+
|
|
265
|
+
if (stratify_by or drop_na) and use_streaming:
|
|
266
|
+
raise NotImplementedError(
|
|
267
|
+
"Stratification and drop_na are not supported for streaming datasets. Use MEMORY or DISK."
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
ds = load_dataset(
|
|
272
|
+
path=repo_id,
|
|
273
|
+
split=split,
|
|
274
|
+
revision=revision,
|
|
275
|
+
streaming=use_streaming,
|
|
276
|
+
**kwargs,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if use_streaming:
|
|
280
|
+
if filters or limit:
|
|
281
|
+
raise NotImplementedError(
|
|
282
|
+
"filters and limit are not supported when streaming datasets. Choose MEMORY or DISK."
|
|
283
|
+
)
|
|
284
|
+
else:
|
|
285
|
+
drop_na_columns = [text_field] if drop_na else None
|
|
286
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
287
|
+
ds,
|
|
288
|
+
filters=filters,
|
|
289
|
+
limit=limit,
|
|
290
|
+
stratify_by=stratify_by,
|
|
291
|
+
stratify_seed=stratify_seed,
|
|
292
|
+
drop_na_columns=drop_na_columns,
|
|
293
|
+
)
|
|
294
|
+
except Exception as e:
|
|
295
|
+
raise RuntimeError(
|
|
296
|
+
f"Failed to load text dataset from HuggingFace Hub: "
|
|
297
|
+
f"repo_id={repo_id!r}, split={split!r}, text_field={text_field!r}. "
|
|
298
|
+
f"Error: {e}"
|
|
299
|
+
) from e
|
|
300
|
+
|
|
301
|
+
return cls(ds, store=store, loading_strategy=loading_strategy, text_field=text_field)
|
|
302
|
+
|
|
303
|
+
@classmethod
|
|
304
|
+
def from_csv(
|
|
305
|
+
cls,
|
|
306
|
+
source: Union[str, Path],
|
|
307
|
+
store: Store,
|
|
308
|
+
*,
|
|
309
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
310
|
+
text_field: str = "text",
|
|
311
|
+
delimiter: str = ",",
|
|
312
|
+
stratify_by: Optional[str] = None,
|
|
313
|
+
stratify_seed: Optional[int] = None,
|
|
314
|
+
drop_na: bool = False,
|
|
315
|
+
**kwargs,
|
|
316
|
+
) -> "TextDataset":
|
|
317
|
+
"""
|
|
318
|
+
Load text dataset from CSV file.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
source: Path to CSV file
|
|
322
|
+
store: Store instance
|
|
323
|
+
loading_strategy: Loading strategy
|
|
324
|
+
text_field: Name of the column containing text
|
|
325
|
+
delimiter: CSV delimiter (default: comma)
|
|
326
|
+
stratify_by: Optional column to use for stratified sampling
|
|
327
|
+
stratify_seed: Optional RNG seed for stratified sampling
|
|
328
|
+
drop_na: Whether to drop rows with None/empty text
|
|
329
|
+
**kwargs: Additional arguments for load_dataset
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
TextDataset instance
|
|
333
|
+
|
|
334
|
+
Raises:
|
|
335
|
+
FileNotFoundError: If CSV file doesn't exist
|
|
336
|
+
RuntimeError: If dataset loading fails
|
|
337
|
+
"""
|
|
338
|
+
drop_na_columns = [text_field] if drop_na else None
|
|
339
|
+
dataset = super().from_csv(
|
|
340
|
+
source,
|
|
341
|
+
store=store,
|
|
342
|
+
loading_strategy=loading_strategy,
|
|
343
|
+
text_field=text_field,
|
|
344
|
+
delimiter=delimiter,
|
|
345
|
+
stratify_by=stratify_by,
|
|
346
|
+
stratify_seed=stratify_seed,
|
|
347
|
+
drop_na_columns=drop_na_columns,
|
|
348
|
+
**kwargs,
|
|
349
|
+
)
|
|
350
|
+
return cls(
|
|
351
|
+
dataset._ds,
|
|
352
|
+
store=store,
|
|
353
|
+
loading_strategy=loading_strategy,
|
|
354
|
+
text_field=text_field,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def from_json(
|
|
359
|
+
cls,
|
|
360
|
+
source: Union[str, Path],
|
|
361
|
+
store: Store,
|
|
362
|
+
*,
|
|
363
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
364
|
+
text_field: str = "text",
|
|
365
|
+
stratify_by: Optional[str] = None,
|
|
366
|
+
stratify_seed: Optional[int] = None,
|
|
367
|
+
drop_na: bool = False,
|
|
368
|
+
**kwargs,
|
|
369
|
+
) -> "TextDataset":
|
|
370
|
+
"""
|
|
371
|
+
Load text dataset from JSON/JSONL file.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
source: Path to JSON or JSONL file
|
|
375
|
+
store: Store instance
|
|
376
|
+
loading_strategy: Loading strategy
|
|
377
|
+
text_field: Name of the field containing text
|
|
378
|
+
stratify_by: Optional column to use for stratified sampling
|
|
379
|
+
stratify_seed: Optional RNG seed for stratified sampling
|
|
380
|
+
drop_na: Whether to drop rows with None/empty text
|
|
381
|
+
**kwargs: Additional arguments for load_dataset
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
TextDataset instance
|
|
385
|
+
|
|
386
|
+
Raises:
|
|
387
|
+
FileNotFoundError: If JSON file doesn't exist
|
|
388
|
+
RuntimeError: If dataset loading fails
|
|
389
|
+
"""
|
|
390
|
+
drop_na_columns = [text_field] if drop_na else None
|
|
391
|
+
dataset = super().from_json(
|
|
392
|
+
source,
|
|
393
|
+
store=store,
|
|
394
|
+
loading_strategy=loading_strategy,
|
|
395
|
+
text_field=text_field,
|
|
396
|
+
stratify_by=stratify_by,
|
|
397
|
+
stratify_seed=stratify_seed,
|
|
398
|
+
drop_na_columns=drop_na_columns,
|
|
399
|
+
**kwargs,
|
|
400
|
+
)
|
|
401
|
+
# Re-initialize with text_field
|
|
402
|
+
return cls(
|
|
403
|
+
dataset._ds,
|
|
404
|
+
store=store,
|
|
405
|
+
loading_strategy=loading_strategy,
|
|
406
|
+
text_field=text_field,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
@classmethod
|
|
410
|
+
def from_local(
|
|
411
|
+
cls,
|
|
412
|
+
source: Union[str, Path],
|
|
413
|
+
store: Store,
|
|
414
|
+
*,
|
|
415
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
416
|
+
text_field: str = "text",
|
|
417
|
+
recursive: bool = True,
|
|
418
|
+
) -> "TextDataset":
|
|
419
|
+
"""
|
|
420
|
+
Load from a local directory or file(s).
|
|
421
|
+
|
|
422
|
+
Supported:
|
|
423
|
+
- Directory of .txt files (each file becomes one example)
|
|
424
|
+
- JSONL/JSON/CSV/TSV files with a text column
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
source: Path to directory or file
|
|
428
|
+
store: Store instance
|
|
429
|
+
loading_strategy: Loading strategy
|
|
430
|
+
text_field: Name of the column/field containing text
|
|
431
|
+
recursive: Whether to recursively search directories for .txt files
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
TextDataset instance
|
|
435
|
+
|
|
436
|
+
Raises:
|
|
437
|
+
FileNotFoundError: If source path doesn't exist
|
|
438
|
+
ValueError: If source is invalid or unsupported file type
|
|
439
|
+
RuntimeError: If file operations fail
|
|
440
|
+
"""
|
|
441
|
+
p = Path(source)
|
|
442
|
+
if not p.exists():
|
|
443
|
+
raise FileNotFoundError(f"Source path does not exist: {source}")
|
|
444
|
+
|
|
445
|
+
if p.is_dir():
|
|
446
|
+
txts: List[str] = []
|
|
447
|
+
pattern = "**/*.txt" if recursive else "*.txt"
|
|
448
|
+
try:
|
|
449
|
+
for fp in sorted(p.glob(pattern)):
|
|
450
|
+
txts.append(fp.read_text(encoding="utf-8", errors="ignore"))
|
|
451
|
+
except OSError as e:
|
|
452
|
+
raise RuntimeError(f"Failed to read text files from directory {source}. Error: {e}") from e
|
|
453
|
+
|
|
454
|
+
if not txts:
|
|
455
|
+
raise ValueError(f"No .txt files found in directory: {source} (recursive={recursive})")
|
|
456
|
+
|
|
457
|
+
ds = Dataset.from_dict({"text": txts})
|
|
458
|
+
else:
|
|
459
|
+
suffix = p.suffix.lower()
|
|
460
|
+
if suffix in {".jsonl", ".json"}:
|
|
461
|
+
return cls.from_json(
|
|
462
|
+
source,
|
|
463
|
+
store=store,
|
|
464
|
+
loading_strategy=loading_strategy,
|
|
465
|
+
text_field=text_field,
|
|
466
|
+
)
|
|
467
|
+
elif suffix in {".csv"}:
|
|
468
|
+
return cls.from_csv(
|
|
469
|
+
source,
|
|
470
|
+
store=store,
|
|
471
|
+
loading_strategy=loading_strategy,
|
|
472
|
+
text_field=text_field,
|
|
473
|
+
)
|
|
474
|
+
elif suffix in {".tsv"}:
|
|
475
|
+
return cls.from_csv(
|
|
476
|
+
source,
|
|
477
|
+
store=store,
|
|
478
|
+
loading_strategy=loading_strategy,
|
|
479
|
+
text_field=text_field,
|
|
480
|
+
delimiter="\t",
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
raise ValueError(
|
|
484
|
+
f"Unsupported file type: {suffix} for source: {source}. "
|
|
485
|
+
f"Use directory of .txt, or JSON/JSONL/CSV/TSV."
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
return cls(ds, store=store, loading_strategy=loading_strategy, text_field=text_field)
|
amber/hooks/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from amber.hooks.hook import Hook, HookType, HookError
|
|
2
|
+
from amber.hooks.detector import Detector
|
|
3
|
+
from amber.hooks.controller import Controller
|
|
4
|
+
from amber.hooks.implementations.layer_activation_detector import LayerActivationDetector
|
|
5
|
+
from amber.hooks.implementations.model_input_detector import ModelInputDetector
|
|
6
|
+
from amber.hooks.implementations.model_output_detector import ModelOutputDetector
|
|
7
|
+
from amber.hooks.implementations.function_controller import FunctionController
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Hook",
|
|
11
|
+
"HookType",
|
|
12
|
+
"HookError",
|
|
13
|
+
"Detector",
|
|
14
|
+
"Controller",
|
|
15
|
+
"LayerActivationDetector",
|
|
16
|
+
"ModelInputDetector",
|
|
17
|
+
"ModelOutputDetector",
|
|
18
|
+
"FunctionController",
|
|
19
|
+
]
|
|
20
|
+
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from amber.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
10
|
+
from amber.hooks.utils import extract_tensor_from_input, extract_tensor_from_output
|
|
11
|
+
from amber.utils import get_logger
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Controller(Hook):
|
|
20
|
+
"""
|
|
21
|
+
Abstract base class for controller hooks that modify activations during inference.
|
|
22
|
+
|
|
23
|
+
Controllers can modify inputs (pre_forward) or outputs (forward) of layers.
|
|
24
|
+
They are designed to actively change the behavior of the model during inference.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
hook_type: HookType | str = HookType.FORWARD,
|
|
30
|
+
hook_id: str | None = None,
|
|
31
|
+
layer_signature: str | int | None = None
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Initialize a controller hook.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
|
|
38
|
+
hook_id: Unique identifier
|
|
39
|
+
layer_signature: Layer to attach to (optional, for compatibility)
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(layer_signature=layer_signature, hook_type=hook_type, hook_id=hook_id)
|
|
42
|
+
|
|
43
|
+
def _handle_pre_forward(
|
|
44
|
+
self,
|
|
45
|
+
module: torch.nn.Module,
|
|
46
|
+
input: HOOK_FUNCTION_INPUT
|
|
47
|
+
) -> HOOK_FUNCTION_INPUT | None:
|
|
48
|
+
"""Handle pre-forward hook execution.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
module: The PyTorch module being hooked
|
|
52
|
+
input: Tuple of input tensors to the module
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Modified input tuple or None to keep original
|
|
56
|
+
"""
|
|
57
|
+
input_tensor = extract_tensor_from_input(input)
|
|
58
|
+
|
|
59
|
+
if input_tensor is None:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
modified_tensor = self.modify_activations(module, input_tensor, input_tensor)
|
|
63
|
+
|
|
64
|
+
if modified_tensor is not None and isinstance(modified_tensor, torch.Tensor):
|
|
65
|
+
result = list(input)
|
|
66
|
+
if len(result) > 0:
|
|
67
|
+
result[0] = modified_tensor
|
|
68
|
+
return tuple(result)
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
def _handle_forward(
|
|
72
|
+
self,
|
|
73
|
+
module: torch.nn.Module,
|
|
74
|
+
input: HOOK_FUNCTION_INPUT,
|
|
75
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
76
|
+
) -> None:
|
|
77
|
+
"""Handle forward hook execution.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
module: The PyTorch module being hooked
|
|
81
|
+
input: Tuple of input tensors to the module
|
|
82
|
+
output: Output tensor(s) from the module
|
|
83
|
+
"""
|
|
84
|
+
output_tensor = extract_tensor_from_output(output)
|
|
85
|
+
|
|
86
|
+
if output_tensor is None:
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
# Extract input tensor if available for modify_activations
|
|
90
|
+
input_tensor = extract_tensor_from_input(input)
|
|
91
|
+
|
|
92
|
+
# Note: forward hooks can't modify output in PyTorch, but we call modify_activations
|
|
93
|
+
# for consistency. The actual modification happens via the hook mechanism.
|
|
94
|
+
# We still call it so controllers can capture/process activations.
|
|
95
|
+
self.modify_activations(module, input_tensor, output_tensor)
|
|
96
|
+
|
|
97
|
+
def _hook_fn(
|
|
98
|
+
self,
|
|
99
|
+
module: torch.nn.Module,
|
|
100
|
+
input: HOOK_FUNCTION_INPUT,
|
|
101
|
+
output: HOOK_FUNCTION_OUTPUT
|
|
102
|
+
) -> None | HOOK_FUNCTION_INPUT:
|
|
103
|
+
"""
|
|
104
|
+
Internal hook function that modifies activations.
|
|
105
|
+
|
|
106
|
+
If the instance also inherits from Detector, first processes activations
|
|
107
|
+
as a Detector (saves metadata), then modifies activations as a Controller.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
module: The PyTorch module being hooked
|
|
111
|
+
input: Tuple of input tensors to the module
|
|
112
|
+
output: Output tensor(s) from the module
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
For pre_forward hooks: modified inputs (tuple) or None to keep original
|
|
116
|
+
For forward hooks: None (forward hooks cannot modify output in PyTorch)
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
RuntimeError: If modify_activations raises an exception
|
|
120
|
+
"""
|
|
121
|
+
if not self._enabled:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
# Check if this instance also inherits from Detector
|
|
125
|
+
if self._is_both_controller_and_detector():
|
|
126
|
+
# First, process activations as a Detector (save metadata)
|
|
127
|
+
try:
|
|
128
|
+
self.process_activations(module, input, output)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.warning(
|
|
131
|
+
f"Error in {self.__class__.__name__} detector process_activations: {e}",
|
|
132
|
+
exc_info=True
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
if self.hook_type == HookType.PRE_FORWARD:
|
|
137
|
+
return self._handle_pre_forward(module, input)
|
|
138
|
+
else:
|
|
139
|
+
self._handle_forward(module, input, output)
|
|
140
|
+
return None
|
|
141
|
+
except Exception as e:
|
|
142
|
+
raise RuntimeError(
|
|
143
|
+
f"Error in controller {self.id} modify_activations: {e}"
|
|
144
|
+
) from e
|
|
145
|
+
|
|
146
|
+
@abc.abstractmethod
|
|
147
|
+
def modify_activations(
|
|
148
|
+
self,
|
|
149
|
+
module: nn.Module,
|
|
150
|
+
inputs: torch.Tensor | None,
|
|
151
|
+
output: torch.Tensor | None
|
|
152
|
+
) -> torch.Tensor | None:
|
|
153
|
+
"""
|
|
154
|
+
Modify activations from the hooked layer.
|
|
155
|
+
|
|
156
|
+
For pre_forward hooks: receives input tensor, should return modified input tensor.
|
|
157
|
+
For forward hooks: receives input and output tensors, should return modified output tensor.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
module: The PyTorch module being hooked
|
|
161
|
+
inputs: Input tensor (None for forward hooks if not available)
|
|
162
|
+
output: Output tensor (None for pre_forward hooks)
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Modified input tensor (for pre_forward) or modified output tensor (for forward).
|
|
166
|
+
Return None to keep original tensor unchanged.
|
|
167
|
+
|
|
168
|
+
Raises:
|
|
169
|
+
Exception: Subclasses may raise exceptions for invalid inputs or modification errors
|
|
170
|
+
"""
|
|
171
|
+
raise NotImplementedError("modify_activations must be implemented by subclasses")
|