mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,488 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
5
+
6
+ from datasets import Dataset, IterableDataset, load_dataset
7
+
8
+ from amber.datasets.base_dataset import BaseDataset
9
+ from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
10
+ from amber.store.store import Store
11
+
12
+
13
+ class TextDataset(BaseDataset):
14
+ """
15
+ Text-only dataset with support for multiple sources and loading strategies.
16
+ Each item is a string (text snippet).
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ ds: Dataset | IterableDataset,
22
+ store: Store,
23
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
24
+ text_field: str = "text",
25
+ ):
26
+ """
27
+ Initialize text dataset.
28
+
29
+ Args:
30
+ ds: HuggingFace Dataset or IterableDataset
31
+ store: Store instance
32
+ loading_strategy: Loading strategy
33
+ text_field: Name of the column containing text
34
+
35
+ Raises:
36
+ ValueError: If text_field is empty or not found in dataset
37
+ """
38
+ self._validate_text_field(text_field)
39
+
40
+ # Validate and prepare dataset
41
+ is_iterable = isinstance(ds, IterableDataset)
42
+ if not is_iterable:
43
+ if text_field not in ds.column_names:
44
+ raise ValueError(f"Dataset must have a '{text_field}' column; got columns: {ds.column_names}")
45
+ # Keep only text column for memory efficiency
46
+ columns_to_remove = [c for c in ds.column_names if c != text_field]
47
+ if columns_to_remove:
48
+ ds = ds.remove_columns(columns_to_remove)
49
+ if text_field != "text":
50
+ ds = ds.rename_column(text_field, "text")
51
+ ds.set_format("python", columns=["text"])
52
+
53
+ self._text_field = text_field
54
+ super().__init__(ds, store=store, loading_strategy=loading_strategy)
55
+
56
+ def _validate_text_field(self, text_field: str) -> None:
57
+ """Validate text_field parameter.
58
+
59
+ Args:
60
+ text_field: Text field name to validate
61
+
62
+ Raises:
63
+ ValueError: If text_field is empty or not a string
64
+ """
65
+ if not text_field or not isinstance(text_field, str) or not text_field.strip():
66
+ raise ValueError(f"text_field must be a non-empty string, got: {text_field!r}")
67
+
68
+ def _extract_text_from_row(self, row: Dict[str, Any]) -> Optional[str]:
69
+ """Extract text from a dataset row.
70
+
71
+ Args:
72
+ row: Dataset row dictionary
73
+
74
+ Returns:
75
+ Text string from the row
76
+
77
+ Raises:
78
+ ValueError: If text field is not found in row
79
+ """
80
+ if self._text_field in row:
81
+ text = row[self._text_field]
82
+ elif "text" in row:
83
+ text = row["text"]
84
+ else:
85
+ raise ValueError(
86
+ f"Text field '{self._text_field}' or 'text' not found in dataset row. "
87
+ f"Available fields: {list(row.keys())}"
88
+ )
89
+ return text
90
+
91
+ def __len__(self) -> int:
92
+ """
93
+ Return the number of items in the dataset.
94
+
95
+ Raises:
96
+ NotImplementedError: If loading_strategy is STREAMING
97
+ """
98
+ if self._loading_strategy == LoadingStrategy.STREAMING:
99
+ raise NotImplementedError("len() not supported for STREAMING datasets")
100
+ return self._ds.num_rows
101
+
102
+ def __getitem__(self, idx: IndexLike) -> Union[Optional[str], List[Optional[str]]]:
103
+ """
104
+ Get text item(s) by index.
105
+
106
+ Args:
107
+ idx: Index (int), slice, or sequence of indices
108
+
109
+ Returns:
110
+ Single text string or list of text strings
111
+
112
+ Raises:
113
+ NotImplementedError: If loading_strategy is STREAMING
114
+ IndexError: If index is out of bounds
115
+ ValueError: If dataset is empty
116
+ """
117
+ if self._loading_strategy == LoadingStrategy.STREAMING:
118
+ raise NotImplementedError(
119
+ "Indexing not supported for STREAMING datasets. Use iter_items or iter_batches."
120
+ )
121
+
122
+ dataset_len = len(self)
123
+ if dataset_len == 0:
124
+ raise ValueError("Cannot index into empty dataset")
125
+
126
+ if isinstance(idx, int):
127
+ if idx < 0:
128
+ idx = dataset_len + idx
129
+ if idx < 0 or idx >= dataset_len:
130
+ raise IndexError(f"Index {idx} out of bounds for dataset of length {dataset_len}")
131
+ return self._ds[idx]["text"]
132
+
133
+ if isinstance(idx, slice):
134
+ start, stop, step = idx.indices(dataset_len)
135
+ if step != 1:
136
+ indices = list(range(start, stop, step))
137
+ out = self._ds.select(indices)["text"]
138
+ else:
139
+ out = self._ds.select(range(start, stop))["text"]
140
+ return list(out)
141
+
142
+ if isinstance(idx, Sequence):
143
+ # Validate all indices are in bounds
144
+ invalid_indices = [i for i in idx if not (0 <= i < dataset_len)]
145
+ if invalid_indices:
146
+ raise IndexError(f"Indices out of bounds: {invalid_indices} (dataset length: {dataset_len})")
147
+ out = self._ds.select(list(idx))["text"]
148
+ return list(out)
149
+
150
+ raise TypeError(f"Invalid index type: {type(idx)}")
151
+
152
+ def iter_items(self) -> Iterator[Optional[str]]:
153
+ """
154
+ Iterate over text items one by one.
155
+
156
+ Yields:
157
+ Text strings from the dataset
158
+
159
+ Raises:
160
+ ValueError: If text field is not found in any row
161
+ """
162
+ for row in self._ds:
163
+ yield self._extract_text_from_row(row)
164
+
165
+ def iter_batches(self, batch_size: int) -> Iterator[List[Optional[str]]]:
166
+ """
167
+ Iterate over text items in batches.
168
+
169
+ Args:
170
+ batch_size: Number of items per batch
171
+
172
+ Yields:
173
+ Lists of text strings (batches)
174
+
175
+ Raises:
176
+ ValueError: If batch_size <= 0 or text field is not found in any row
177
+ """
178
+ if batch_size <= 0:
179
+ raise ValueError(f"batch_size must be > 0, got: {batch_size}")
180
+
181
+ if self._loading_strategy == LoadingStrategy.STREAMING:
182
+ batch = []
183
+ for row in self._ds:
184
+ batch.append(self._extract_text_from_row(row))
185
+ if len(batch) >= batch_size:
186
+ yield batch
187
+ batch = []
188
+ if batch:
189
+ yield batch
190
+ else:
191
+ for batch in self._ds.iter(batch_size=batch_size):
192
+ yield list(batch["text"])
193
+
194
+ def extract_texts_from_batch(self, batch: List[Optional[str]]) -> List[Optional[str]]:
195
+ """Extract text strings from a batch.
196
+
197
+ For TextDataset, batch items are already strings, so return as-is.
198
+
199
+ Args:
200
+ batch: List of text strings
201
+
202
+ Returns:
203
+ List of text strings (same as input)
204
+ """
205
+ return batch
206
+
207
+ def get_all_texts(self) -> List[Optional[str]]:
208
+ """Get all texts from the dataset.
209
+
210
+ Returns:
211
+ List of all text strings
212
+
213
+ Raises:
214
+ NotImplementedError: If loading_strategy is STREAMING
215
+ """
216
+ if self._loading_strategy == LoadingStrategy.STREAMING:
217
+ return list(self.iter_items())
218
+ return list(self._ds["text"])
219
+
220
+ @classmethod
221
+ def from_huggingface(
222
+ cls,
223
+ repo_id: str,
224
+ store: Store,
225
+ *,
226
+ split: str = "train",
227
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
228
+ revision: Optional[str] = None,
229
+ text_field: str = "text",
230
+ filters: Optional[Dict[str, Any]] = None,
231
+ limit: Optional[int] = None,
232
+ stratify_by: Optional[str] = None,
233
+ stratify_seed: Optional[int] = None,
234
+ streaming: Optional[bool] = None,
235
+ drop_na: bool = False,
236
+ **kwargs,
237
+ ) -> "TextDataset":
238
+ """
239
+ Load text dataset from HuggingFace Hub.
240
+
241
+ Args:
242
+ repo_id: HuggingFace dataset repository ID
243
+ store: Store instance
244
+ split: Dataset split
245
+ loading_strategy: Loading strategy
246
+ revision: Optional git revision
247
+ text_field: Name of the column containing text
248
+ filters: Optional filters to apply (dict of column: value)
249
+ limit: Optional limit on number of rows
250
+ stratify_by: Optional column used for stratified sampling (non-streaming only)
251
+ stratify_seed: Optional RNG seed for deterministic stratification
252
+ streaming: Optional override for streaming
253
+ drop_na: Whether to drop rows with None/empty text
254
+ **kwargs: Additional arguments for load_dataset
255
+
256
+ Returns:
257
+ TextDataset instance
258
+
259
+ Raises:
260
+ ValueError: If parameters are invalid
261
+ RuntimeError: If dataset loading fails
262
+ """
263
+ use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)
264
+
265
+ if (stratify_by or drop_na) and use_streaming:
266
+ raise NotImplementedError(
267
+ "Stratification and drop_na are not supported for streaming datasets. Use MEMORY or DISK."
268
+ )
269
+
270
+ try:
271
+ ds = load_dataset(
272
+ path=repo_id,
273
+ split=split,
274
+ revision=revision,
275
+ streaming=use_streaming,
276
+ **kwargs,
277
+ )
278
+
279
+ if use_streaming:
280
+ if filters or limit:
281
+ raise NotImplementedError(
282
+ "filters and limit are not supported when streaming datasets. Choose MEMORY or DISK."
283
+ )
284
+ else:
285
+ drop_na_columns = [text_field] if drop_na else None
286
+ ds = cls._postprocess_non_streaming_dataset(
287
+ ds,
288
+ filters=filters,
289
+ limit=limit,
290
+ stratify_by=stratify_by,
291
+ stratify_seed=stratify_seed,
292
+ drop_na_columns=drop_na_columns,
293
+ )
294
+ except Exception as e:
295
+ raise RuntimeError(
296
+ f"Failed to load text dataset from HuggingFace Hub: "
297
+ f"repo_id={repo_id!r}, split={split!r}, text_field={text_field!r}. "
298
+ f"Error: {e}"
299
+ ) from e
300
+
301
+ return cls(ds, store=store, loading_strategy=loading_strategy, text_field=text_field)
302
+
303
+ @classmethod
304
+ def from_csv(
305
+ cls,
306
+ source: Union[str, Path],
307
+ store: Store,
308
+ *,
309
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
310
+ text_field: str = "text",
311
+ delimiter: str = ",",
312
+ stratify_by: Optional[str] = None,
313
+ stratify_seed: Optional[int] = None,
314
+ drop_na: bool = False,
315
+ **kwargs,
316
+ ) -> "TextDataset":
317
+ """
318
+ Load text dataset from CSV file.
319
+
320
+ Args:
321
+ source: Path to CSV file
322
+ store: Store instance
323
+ loading_strategy: Loading strategy
324
+ text_field: Name of the column containing text
325
+ delimiter: CSV delimiter (default: comma)
326
+ stratify_by: Optional column to use for stratified sampling
327
+ stratify_seed: Optional RNG seed for stratified sampling
328
+ drop_na: Whether to drop rows with None/empty text
329
+ **kwargs: Additional arguments for load_dataset
330
+
331
+ Returns:
332
+ TextDataset instance
333
+
334
+ Raises:
335
+ FileNotFoundError: If CSV file doesn't exist
336
+ RuntimeError: If dataset loading fails
337
+ """
338
+ drop_na_columns = [text_field] if drop_na else None
339
+ dataset = super().from_csv(
340
+ source,
341
+ store=store,
342
+ loading_strategy=loading_strategy,
343
+ text_field=text_field,
344
+ delimiter=delimiter,
345
+ stratify_by=stratify_by,
346
+ stratify_seed=stratify_seed,
347
+ drop_na_columns=drop_na_columns,
348
+ **kwargs,
349
+ )
350
+ return cls(
351
+ dataset._ds,
352
+ store=store,
353
+ loading_strategy=loading_strategy,
354
+ text_field=text_field,
355
+ )
356
+
357
+ @classmethod
358
+ def from_json(
359
+ cls,
360
+ source: Union[str, Path],
361
+ store: Store,
362
+ *,
363
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
364
+ text_field: str = "text",
365
+ stratify_by: Optional[str] = None,
366
+ stratify_seed: Optional[int] = None,
367
+ drop_na: bool = False,
368
+ **kwargs,
369
+ ) -> "TextDataset":
370
+ """
371
+ Load text dataset from JSON/JSONL file.
372
+
373
+ Args:
374
+ source: Path to JSON or JSONL file
375
+ store: Store instance
376
+ loading_strategy: Loading strategy
377
+ text_field: Name of the field containing text
378
+ stratify_by: Optional column to use for stratified sampling
379
+ stratify_seed: Optional RNG seed for stratified sampling
380
+ drop_na: Whether to drop rows with None/empty text
381
+ **kwargs: Additional arguments for load_dataset
382
+
383
+ Returns:
384
+ TextDataset instance
385
+
386
+ Raises:
387
+ FileNotFoundError: If JSON file doesn't exist
388
+ RuntimeError: If dataset loading fails
389
+ """
390
+ drop_na_columns = [text_field] if drop_na else None
391
+ dataset = super().from_json(
392
+ source,
393
+ store=store,
394
+ loading_strategy=loading_strategy,
395
+ text_field=text_field,
396
+ stratify_by=stratify_by,
397
+ stratify_seed=stratify_seed,
398
+ drop_na_columns=drop_na_columns,
399
+ **kwargs,
400
+ )
401
+ # Re-initialize with text_field
402
+ return cls(
403
+ dataset._ds,
404
+ store=store,
405
+ loading_strategy=loading_strategy,
406
+ text_field=text_field,
407
+ )
408
+
409
+ @classmethod
410
+ def from_local(
411
+ cls,
412
+ source: Union[str, Path],
413
+ store: Store,
414
+ *,
415
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
416
+ text_field: str = "text",
417
+ recursive: bool = True,
418
+ ) -> "TextDataset":
419
+ """
420
+ Load from a local directory or file(s).
421
+
422
+ Supported:
423
+ - Directory of .txt files (each file becomes one example)
424
+ - JSONL/JSON/CSV/TSV files with a text column
425
+
426
+ Args:
427
+ source: Path to directory or file
428
+ store: Store instance
429
+ loading_strategy: Loading strategy
430
+ text_field: Name of the column/field containing text
431
+ recursive: Whether to recursively search directories for .txt files
432
+
433
+ Returns:
434
+ TextDataset instance
435
+
436
+ Raises:
437
+ FileNotFoundError: If source path doesn't exist
438
+ ValueError: If source is invalid or unsupported file type
439
+ RuntimeError: If file operations fail
440
+ """
441
+ p = Path(source)
442
+ if not p.exists():
443
+ raise FileNotFoundError(f"Source path does not exist: {source}")
444
+
445
+ if p.is_dir():
446
+ txts: List[str] = []
447
+ pattern = "**/*.txt" if recursive else "*.txt"
448
+ try:
449
+ for fp in sorted(p.glob(pattern)):
450
+ txts.append(fp.read_text(encoding="utf-8", errors="ignore"))
451
+ except OSError as e:
452
+ raise RuntimeError(f"Failed to read text files from directory {source}. Error: {e}") from e
453
+
454
+ if not txts:
455
+ raise ValueError(f"No .txt files found in directory: {source} (recursive={recursive})")
456
+
457
+ ds = Dataset.from_dict({"text": txts})
458
+ else:
459
+ suffix = p.suffix.lower()
460
+ if suffix in {".jsonl", ".json"}:
461
+ return cls.from_json(
462
+ source,
463
+ store=store,
464
+ loading_strategy=loading_strategy,
465
+ text_field=text_field,
466
+ )
467
+ elif suffix in {".csv"}:
468
+ return cls.from_csv(
469
+ source,
470
+ store=store,
471
+ loading_strategy=loading_strategy,
472
+ text_field=text_field,
473
+ )
474
+ elif suffix in {".tsv"}:
475
+ return cls.from_csv(
476
+ source,
477
+ store=store,
478
+ loading_strategy=loading_strategy,
479
+ text_field=text_field,
480
+ delimiter="\t",
481
+ )
482
+ else:
483
+ raise ValueError(
484
+ f"Unsupported file type: {suffix} for source: {source}. "
485
+ f"Use directory of .txt, or JSON/JSONL/CSV/TSV."
486
+ )
487
+
488
+ return cls(ds, store=store, loading_strategy=loading_strategy, text_field=text_field)
@@ -0,0 +1,20 @@
1
+ from amber.hooks.hook import Hook, HookType, HookError
2
+ from amber.hooks.detector import Detector
3
+ from amber.hooks.controller import Controller
4
+ from amber.hooks.implementations.layer_activation_detector import LayerActivationDetector
5
+ from amber.hooks.implementations.model_input_detector import ModelInputDetector
6
+ from amber.hooks.implementations.model_output_detector import ModelOutputDetector
7
+ from amber.hooks.implementations.function_controller import FunctionController
8
+
9
+ __all__ = [
10
+ "Hook",
11
+ "HookType",
12
+ "HookError",
13
+ "Detector",
14
+ "Controller",
15
+ "LayerActivationDetector",
16
+ "ModelInputDetector",
17
+ "ModelOutputDetector",
18
+ "FunctionController",
19
+ ]
20
+
@@ -0,0 +1,171 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from amber.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
10
+ from amber.hooks.utils import extract_tensor_from_input, extract_tensor_from_output
11
+ from amber.utils import get_logger
12
+
13
+ if TYPE_CHECKING:
14
+ pass
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class Controller(Hook):
20
+ """
21
+ Abstract base class for controller hooks that modify activations during inference.
22
+
23
+ Controllers can modify inputs (pre_forward) or outputs (forward) of layers.
24
+ They are designed to actively change the behavior of the model during inference.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ hook_type: HookType | str = HookType.FORWARD,
30
+ hook_id: str | None = None,
31
+ layer_signature: str | int | None = None
32
+ ):
33
+ """
34
+ Initialize a controller hook.
35
+
36
+ Args:
37
+ hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
38
+ hook_id: Unique identifier
39
+ layer_signature: Layer to attach to (optional, for compatibility)
40
+ """
41
+ super().__init__(layer_signature=layer_signature, hook_type=hook_type, hook_id=hook_id)
42
+
43
+ def _handle_pre_forward(
44
+ self,
45
+ module: torch.nn.Module,
46
+ input: HOOK_FUNCTION_INPUT
47
+ ) -> HOOK_FUNCTION_INPUT | None:
48
+ """Handle pre-forward hook execution.
49
+
50
+ Args:
51
+ module: The PyTorch module being hooked
52
+ input: Tuple of input tensors to the module
53
+
54
+ Returns:
55
+ Modified input tuple or None to keep original
56
+ """
57
+ input_tensor = extract_tensor_from_input(input)
58
+
59
+ if input_tensor is None:
60
+ return None
61
+
62
+ modified_tensor = self.modify_activations(module, input_tensor, input_tensor)
63
+
64
+ if modified_tensor is not None and isinstance(modified_tensor, torch.Tensor):
65
+ result = list(input)
66
+ if len(result) > 0:
67
+ result[0] = modified_tensor
68
+ return tuple(result)
69
+ return None
70
+
71
+ def _handle_forward(
72
+ self,
73
+ module: torch.nn.Module,
74
+ input: HOOK_FUNCTION_INPUT,
75
+ output: HOOK_FUNCTION_OUTPUT
76
+ ) -> None:
77
+ """Handle forward hook execution.
78
+
79
+ Args:
80
+ module: The PyTorch module being hooked
81
+ input: Tuple of input tensors to the module
82
+ output: Output tensor(s) from the module
83
+ """
84
+ output_tensor = extract_tensor_from_output(output)
85
+
86
+ if output_tensor is None:
87
+ return
88
+
89
+ # Extract input tensor if available for modify_activations
90
+ input_tensor = extract_tensor_from_input(input)
91
+
92
+ # Note: forward hooks can't modify output in PyTorch, but we call modify_activations
93
+ # for consistency. The actual modification happens via the hook mechanism.
94
+ # We still call it so controllers can capture/process activations.
95
+ self.modify_activations(module, input_tensor, output_tensor)
96
+
97
+ def _hook_fn(
98
+ self,
99
+ module: torch.nn.Module,
100
+ input: HOOK_FUNCTION_INPUT,
101
+ output: HOOK_FUNCTION_OUTPUT
102
+ ) -> None | HOOK_FUNCTION_INPUT:
103
+ """
104
+ Internal hook function that modifies activations.
105
+
106
+ If the instance also inherits from Detector, first processes activations
107
+ as a Detector (saves metadata), then modifies activations as a Controller.
108
+
109
+ Args:
110
+ module: The PyTorch module being hooked
111
+ input: Tuple of input tensors to the module
112
+ output: Output tensor(s) from the module
113
+
114
+ Returns:
115
+ For pre_forward hooks: modified inputs (tuple) or None to keep original
116
+ For forward hooks: None (forward hooks cannot modify output in PyTorch)
117
+
118
+ Raises:
119
+ RuntimeError: If modify_activations raises an exception
120
+ """
121
+ if not self._enabled:
122
+ return None
123
+
124
+ # Check if this instance also inherits from Detector
125
+ if self._is_both_controller_and_detector():
126
+ # First, process activations as a Detector (save metadata)
127
+ try:
128
+ self.process_activations(module, input, output)
129
+ except Exception as e:
130
+ logger.warning(
131
+ f"Error in {self.__class__.__name__} detector process_activations: {e}",
132
+ exc_info=True
133
+ )
134
+
135
+ try:
136
+ if self.hook_type == HookType.PRE_FORWARD:
137
+ return self._handle_pre_forward(module, input)
138
+ else:
139
+ self._handle_forward(module, input, output)
140
+ return None
141
+ except Exception as e:
142
+ raise RuntimeError(
143
+ f"Error in controller {self.id} modify_activations: {e}"
144
+ ) from e
145
+
146
+ @abc.abstractmethod
147
+ def modify_activations(
148
+ self,
149
+ module: nn.Module,
150
+ inputs: torch.Tensor | None,
151
+ output: torch.Tensor | None
152
+ ) -> torch.Tensor | None:
153
+ """
154
+ Modify activations from the hooked layer.
155
+
156
+ For pre_forward hooks: receives input tensor, should return modified input tensor.
157
+ For forward hooks: receives input and output tensors, should return modified output tensor.
158
+
159
+ Args:
160
+ module: The PyTorch module being hooked
161
+ inputs: Input tensor (None for forward hooks if not available)
162
+ output: Output tensor (None for pre_forward hooks)
163
+
164
+ Returns:
165
+ Modified input tensor (for pre_forward) or modified output tensor (for forward).
166
+ Return None to keep original tensor unchanged.
167
+
168
+ Raises:
169
+ Exception: Subclasses may raise exceptions for invalid inputs or modification errors
170
+ """
171
+ raise NotImplementedError("modify_activations must be implemented by subclasses")