mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
amber/store/store.py ADDED
@@ -0,0 +1,276 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from pathlib import Path
5
+ from typing import Dict, Any, List, Iterator
6
+
7
+ import torch
8
+
9
+
10
+ TensorMetadata = Dict[str, Dict[str, torch.Tensor]]
11
+
12
+
13
+ class Store(abc.ABC):
14
+ """Abstract store optimized for tensor batches grouped by run_id.
15
+
16
+ This interface intentionally excludes generic bytes/JSON APIs.
17
+ Implementations should focus on efficient safetensors-backed IO.
18
+
19
+ The store organizes data hierarchically:
20
+ - Runs: Top-level grouping by run_id
21
+ - Batches: Within each run, data is organized by batch_index
22
+ - Layers: Within each batch, tensors are organized by layer_signature
23
+ - Keys: Within each layer, tensors are identified by key (e.g., "activations")
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ base_path: Path | str = "",
29
+ runs_prefix: str = "runs",
30
+ dataset_prefix: str = "datasets",
31
+ model_prefix: str = "models",
32
+ ):
33
+ """Initialize Store.
34
+
35
+ Args:
36
+ base_path: Base directory path for the store
37
+ runs_prefix: Prefix for runs directory (default: "runs")
38
+ dataset_prefix: Prefix for datasets directory (default: "datasets")
39
+ model_prefix: Prefix for models directory (default: "models")
40
+ """
41
+ self.runs_prefix = runs_prefix
42
+ self.dataset_prefix = dataset_prefix
43
+ self.model_prefix = model_prefix
44
+ self.base_path = Path(base_path)
45
+
46
+ def _run_key(self, run_id: str) -> Path:
47
+ """Get path for a run directory.
48
+
49
+ Args:
50
+ run_id: Run identifier
51
+
52
+ Returns:
53
+ Path to run directory
54
+ """
55
+ return self.base_path / self.runs_prefix / run_id
56
+
57
+ def _run_batch_key(self, run_id: str, batch_index: int) -> Path:
58
+ """Get path for a batch directory within a run.
59
+
60
+ Args:
61
+ run_id: Run identifier
62
+ batch_index: Batch index
63
+
64
+ Returns:
65
+ Path to batch directory
66
+ """
67
+ return self._run_key(run_id) / f"batch_{batch_index}"
68
+
69
+ def _run_metadata_key(self, run_id: str) -> Path:
70
+ """Get path for run metadata file.
71
+
72
+ Args:
73
+ run_id: Run identifier
74
+
75
+ Returns:
76
+ Path to metadata JSON file
77
+ """
78
+ return self._run_key(run_id) / "meta.json"
79
+
80
+ @abc.abstractmethod
81
+ def put_run_batch(self, run_id: str, batch_index: int,
82
+ tensors: List[torch.Tensor] | Dict[str, torch.Tensor]) -> str:
83
+ raise NotImplementedError
84
+
85
+ @abc.abstractmethod
86
+ def get_run_batch(self, run_id: str, batch_index: int) -> List[torch.Tensor] | Dict[
87
+ str, torch.Tensor]:
88
+ raise NotImplementedError
89
+
90
+ @abc.abstractmethod
91
+ def list_run_batches(self, run_id: str) -> List[int]:
92
+ raise NotImplementedError
93
+
94
+ def iter_run_batches(self, run_id: str) -> Iterator[List[torch.Tensor] | Dict[str, torch.Tensor]]:
95
+ for idx in self.list_run_batches(run_id):
96
+ yield self.get_run_batch(run_id, idx)
97
+
98
+ def iter_run_batch_range(
99
+ self,
100
+ run_id: str,
101
+ *,
102
+ start: int = 0,
103
+ stop: int | None = None,
104
+ step: int = 1,
105
+ ) -> Iterator[List[torch.Tensor] | Dict[str, torch.Tensor]]:
106
+ """Iterate run batches for indices in range(start, stop, step).
107
+
108
+ If stop is None, it will be set to max(list_run_batches(run_id)) + 1 (or 0 if none).
109
+ Raises ValueError if step == 0 or start < 0.
110
+ """
111
+ if step == 0:
112
+ raise ValueError("step must not be 0")
113
+ if start < 0:
114
+ raise ValueError("start must be >= 0")
115
+ indices = self.list_run_batches(run_id)
116
+ if not indices:
117
+ return
118
+ max_idx = max(indices)
119
+ if stop is None:
120
+ stop = max_idx + 1
121
+ for idx in range(start, stop, step):
122
+ try:
123
+ yield self.get_run_batch(run_id, idx)
124
+ except FileNotFoundError:
125
+ continue
126
+
127
+ @abc.abstractmethod
128
+ def delete_run(self, run_id: str) -> None:
129
+ raise NotImplementedError
130
+
131
+ @abc.abstractmethod
132
+ def put_run_metadata(self, run_id: str, meta: Dict[str, Any]) -> str:
133
+ """Persist metadata for a run (e.g., dataset/model identifiers).
134
+
135
+ Args:
136
+ run_id: Run identifier
137
+ meta: Metadata dictionary to save (must be JSON-serializable)
138
+
139
+ Returns:
140
+ String path/key where metadata was stored (e.g., "runs/{run_id}/meta.json")
141
+
142
+ Raises:
143
+ ValueError: If run_id is invalid or meta is not JSON-serializable
144
+ OSError: If file system operations fail
145
+
146
+ Note:
147
+ Implementations should store JSON at a stable location, e.g., runs/{run_id}/meta.json.
148
+ """
149
+ raise NotImplementedError
150
+
151
+ @abc.abstractmethod
152
+ def get_run_metadata(self, run_id: str) -> Dict[str, Any]:
153
+ """Load metadata for a run.
154
+
155
+ Args:
156
+ run_id: Run identifier
157
+
158
+ Returns:
159
+ Metadata dictionary, or empty dict if not found
160
+
161
+ Raises:
162
+ ValueError: If run_id is invalid
163
+ json.JSONDecodeError: If metadata file exists but contains invalid JSON
164
+ """
165
+ raise NotImplementedError
166
+
167
+ @abc.abstractmethod
168
+ def put_detector_metadata(
169
+ self,
170
+ run_id: str,
171
+ batch_index: int,
172
+ metadata: Dict[str, Any],
173
+ tensor_metadata: TensorMetadata
174
+ ) -> str:
175
+ """Save detector metadata with separate JSON and tensor store.
176
+
177
+ Args:
178
+ run_id: Run identifier
179
+ batch_index: Batch index (must be non-negative)
180
+ metadata: JSON-serializable metadata dictionary (aggregated from all detectors)
181
+ tensor_metadata: Dictionary mapping layer_signature to dict of tensor_key -> tensor
182
+ (from all detectors)
183
+
184
+ Returns:
185
+ Full path key used for store (e.g., "runs/{run_id}/batch_{batch_index}")
186
+
187
+ Raises:
188
+ ValueError: If parameters are invalid or metadata is not JSON-serializable
189
+ OSError: If file system operations fail
190
+ """
191
+ raise NotImplementedError
192
+
193
+ @abc.abstractmethod
194
+ def get_detector_metadata(
195
+ self,
196
+ run_id: str,
197
+ batch_index: int
198
+ ) -> tuple[Dict[str, Any], TensorMetadata]:
199
+ """Load detector metadata with separate JSON and tensor store.
200
+
201
+ Args:
202
+ run_id: Run identifier
203
+ batch_index: Batch index
204
+
205
+ Returns:
206
+ Tuple of (metadata dict, tensor_metadata dict). Returns empty dicts if not found.
207
+
208
+ Raises:
209
+ ValueError: If parameters are invalid or metadata format is invalid
210
+ json.JSONDecodeError: If metadata file exists but contains invalid JSON
211
+ OSError: If tensor files exist but cannot be loaded
212
+ """
213
+ raise NotImplementedError
214
+
215
+ @abc.abstractmethod
216
+ def get_detector_metadata_by_layer_by_key(
217
+ self,
218
+ run_id: str,
219
+ batch_index: int,
220
+ layer: str,
221
+ key: str
222
+ ) -> torch.Tensor:
223
+ """Get a specific tensor from detector metadata by layer and key.
224
+
225
+ Args:
226
+ run_id: Run identifier
227
+ batch_index: Batch index
228
+ layer: Layer signature
229
+ key: Tensor key (e.g., "activations")
230
+
231
+ Returns:
232
+ The requested tensor
233
+
234
+ Raises:
235
+ ValueError: If parameters are invalid
236
+ FileNotFoundError: If the tensor doesn't exist
237
+ OSError: If tensor file exists but cannot be loaded
238
+ """
239
+ raise NotImplementedError
240
+
241
+ # --- Unified detector metadata for whole runs ---
242
+ @abc.abstractmethod
243
+ def put_run_detector_metadata(
244
+ self,
245
+ run_id: str,
246
+ metadata: Dict[str, Any],
247
+ tensor_metadata: TensorMetadata,
248
+ ) -> str:
249
+ """
250
+ Save detector metadata for a whole run in a unified location.
251
+
252
+ This differs from ``put_detector_metadata`` which organises data
253
+ per-batch under ``runs/{run_id}/batch_{batch_index}``.
254
+
255
+ ``put_run_detector_metadata`` instead stores everything under
256
+ ``runs/{run_id}/detectors``. Implementations are expected to
257
+ support being called multiple times for the same ``run_id`` and
258
+ append / aggregate new metadata rather than overwrite it.
259
+
260
+ Args:
261
+ run_id: Run identifier
262
+ metadata: JSON-serialisable metadata dictionary aggregated
263
+ from all detectors for the current chunk / batch.
264
+ tensor_metadata: Dictionary mapping layer_signature to dict
265
+ of tensor_key -> tensor (from all detectors).
266
+
267
+ Returns:
268
+ String path/key where metadata was stored
269
+ (e.g. ``runs/{run_id}/detectors``).
270
+
271
+ Raises:
272
+ ValueError: If parameters are invalid or metadata is not
273
+ JSON‑serialisable.
274
+ OSError: If file system operations fail.
275
+ """
276
+ raise NotImplementedError
@@ -0,0 +1,124 @@
1
+ from typing import Optional, Iterator, TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ from amber.utils import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+ if TYPE_CHECKING:
10
+ from amber.store.store import Store
11
+
12
+
13
+ class StoreDataloader:
14
+ """
15
+ A reusable DataLoader-like class that can be iterated multiple times.
16
+
17
+ This is needed because overcomplete's train_sae iterates over the dataloader
18
+ once per epoch, so we need a dataloader that can be iterated multiple times.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ store: "Store",
24
+ run_id: str,
25
+ layer: str,
26
+ key: str = "activations",
27
+ batch_size: int = 32,
28
+ dtype: Optional[torch.dtype] = None,
29
+ device: Optional[torch.device] = None,
30
+ max_batches: Optional[int] = None,
31
+ logger_instance=None
32
+ ):
33
+ """
34
+ Initialize StoreDataloader.
35
+
36
+ Args:
37
+ store: Store instance containing activations
38
+ run_id: Run ID to iterate over
39
+ layer: Layer signature to load activations from
40
+ key: Tensor key to load (default: "activations")
41
+ batch_size: Mini-batch size
42
+ dtype: Optional dtype to cast activations to
43
+ device: Optional device to move tensors to
44
+ max_batches: Optional limit on number of batches per epoch
45
+ logger_instance: Optional logger instance for debug messages
46
+ """
47
+ self.store = store
48
+ self.run_id = run_id
49
+ self.layer = layer
50
+ self.key = key
51
+ self.batch_size = batch_size
52
+ self.dtype = dtype
53
+ self.device = device
54
+ self.max_batches = max_batches
55
+ self.logger = logger_instance or logger
56
+
57
+ def __iter__(self) -> Iterator[torch.Tensor]:
58
+ """
59
+ Create a new iterator for each epoch.
60
+
61
+ This allows the dataloader to be iterated multiple times,
62
+ which is required for multiple epochs.
63
+ """
64
+ batches_yielded = 0
65
+
66
+ # Get list of batch indices
67
+ batch_indices = self.store.list_run_batches(self.run_id)
68
+
69
+ for batch_index in batch_indices:
70
+ if self.max_batches is not None and batches_yielded >= self.max_batches:
71
+ break
72
+
73
+ acts = None
74
+ try:
75
+ # Try to load from detector metadata first
76
+ acts = self.store.get_detector_metadata_by_layer_by_key(
77
+ self.run_id,
78
+ batch_index,
79
+ self.layer,
80
+ self.key
81
+ )
82
+ except FileNotFoundError:
83
+ # Fall back to traditional batch files
84
+ try:
85
+ batch = self.store.get_run_batch(self.run_id, batch_index)
86
+ if isinstance(batch, dict) and self.key in batch:
87
+ acts = batch[self.key]
88
+ elif isinstance(batch, dict) and "activations" in batch:
89
+ # For backward compatibility, use "activations" if key not found
90
+ acts = batch["activations"]
91
+ except Exception:
92
+ pass
93
+
94
+ if acts is None:
95
+ if self.logger.isEnabledFor(self.logger.level):
96
+ self.logger.debug(
97
+ f"Skipping batch {batch_index}: tensor not found "
98
+ f"(run_id={self.run_id}, layer={self.layer}, key={self.key})"
99
+ )
100
+ continue
101
+
102
+ # Ensure 2D [N, D]
103
+ if acts.dim() > 2:
104
+ d = acts.shape[-1]
105
+ acts = acts.view(-1, d)
106
+ elif acts.dim() == 1:
107
+ acts = acts.view(1, -1)
108
+
109
+ # dtype handling
110
+ if self.dtype is not None:
111
+ acts = acts.to(self.dtype)
112
+
113
+ # device handling
114
+ if self.device is not None:
115
+ acts = acts.to(self.device)
116
+
117
+ # Yield mini-batches
118
+ bs = max(1, int(self.batch_size))
119
+ n = acts.shape[0]
120
+ for start in range(0, n, bs):
121
+ if self.max_batches is not None and batches_yielded >= self.max_batches:
122
+ return
123
+ yield acts[start:start + bs]
124
+ batches_yielded += 1
amber/utils.py ADDED
@@ -0,0 +1,46 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import random
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def set_seed(seed: int, deterministic: bool = True) -> None:
13
+ """Set seeds for python, numpy, and torch.
14
+
15
+ Args:
16
+ seed: Seed value.
17
+ deterministic: If True, tries to make torch deterministic where possible.
18
+ """
19
+ random.seed(seed)
20
+ os.environ["PYTHONHASHSEED"] = str(seed)
21
+ np.random.seed(seed)
22
+ torch.manual_seed(seed)
23
+ torch.cuda.manual_seed_all(seed)
24
+ if deterministic:
25
+ torch.use_deterministic_algorithms(True, warn_only=True)
26
+ torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
27
+ torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
28
+
29
+
30
+ def get_logger(name: str = "amber", level: int | str = logging.INFO) -> logging.Logger:
31
+ """Get a configured logger with a simple format. Idempotent."""
32
+ logger = logging.getLogger(name)
33
+ if isinstance(level, str):
34
+ level = logging.getLevelName(level)
35
+ logger.setLevel(level)
36
+ if not logger.handlers:
37
+ handler = logging.StreamHandler()
38
+ fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
39
+ handler.setFormatter(fmt)
40
+ logger.addHandler(handler)
41
+ # Allow propagation so pytest's caplog can capture logs
42
+ logger.propagate = True
43
+ else:
44
+ # Even if a handler exists (e.g., configured by the app), ensure propagation is enabled
45
+ logger.propagate = True
46
+ return logger
@@ -0,0 +1,124 @@
1
+ Metadata-Version: 2.4
2
+ Name: mi-crow
3
+ Version: 0.1.1.post12
4
+ Summary: Engineer Thesis: Explaining and modifying LLM responses using SAE and concepts.
5
+ Author-email: Hubert Kowalski <your.email@example.com>, Adam Kaniasty <adam.kaniasty@gmail.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: accelerate>=1.10.1
9
+ Requires-Dist: datasets>=4.0.0
10
+ Requires-Dist: notebook>=7.4.7
11
+ Requires-Dist: overcomplete
12
+ Requires-Dist: pre-commit>=4.3.0
13
+ Requires-Dist: setuptools-scm>=8
14
+ Requires-Dist: torch>=2.8.0
15
+ Requires-Dist: tqdm>=4.67.1
16
+ Requires-Dist: transformers>=4.56.1
17
+ Requires-Dist: safetensors>=0.4.5
18
+ Requires-Dist: wandb>=0.22.1
19
+ Requires-Dist: pytest>=8.4.2
20
+ Requires-Dist: pytest-xdist>=3.8.0
21
+ Requires-Dist: seaborn>=0.13.2
22
+ Provides-Extra: dev
23
+ Requires-Dist: pre-commit>=4.3.0; extra == "dev"
24
+ Requires-Dist: ruff>=0.13.2; extra == "dev"
25
+ Requires-Dist: setuptools-scm>=8; extra == "dev"
26
+ Requires-Dist: wheel>=0.45.1; extra == "dev"
27
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
28
+ Requires-Dist: pytest-cov>=5.0.0; extra == "dev"
29
+ Provides-Extra: docs
30
+ Requires-Dist: mkdocs>=1.6; extra == "docs"
31
+ Requires-Dist: mkdocs-material>=9.5; extra == "docs"
32
+ Requires-Dist: mkdocstrings[python]>=0.26; extra == "docs"
33
+ Requires-Dist: mkdocs-section-index>=0.3; extra == "docs"
34
+ Requires-Dist: mkdocs-redirects>=1.2; extra == "docs"
35
+ Requires-Dist: mkdocs-literate-nav>=0.6; extra == "docs"
36
+ Requires-Dist: mkdocs-gen-files>=0.5; extra == "docs"
37
+ Requires-Dist: mike>=2.1; extra == "docs"
38
+
39
+ ![CI](https://github.com/AdamKaniasty/Inzynierka/actions/workflows/tests.yml/badge.svg?branch=main)
40
+ [![Docs](https://img.shields.io/badge/docs-gh--pages-blue)](https://adamkaniasty.github.io/Inzynierka/)
41
+
42
+ ## Running Tests
43
+
44
+ The project uses pytest for testing. Tests are organized into unit tests and end-to-end tests.
45
+
46
+ ### Running All Tests
47
+
48
+ ```bash
49
+ pytest
50
+ ```
51
+
52
+ ### Running Specific Test Suites
53
+
54
+ Run only unit tests:
55
+ ```bash
56
+ pytest --unit -q
57
+ ```
58
+
59
+ Run only end-to-end tests:
60
+ ```bash
61
+ pytest --e2e -q
62
+ ```
63
+
64
+ You can also use pytest markers:
65
+ ```bash
66
+ pytest -m unit -q
67
+ pytest -m e2e -q
68
+ ```
69
+
70
+ Or specify the test directory directly:
71
+ ```bash
72
+ pytest tests/unit -q
73
+ pytest tests/e2e -q
74
+ ```
75
+
76
+ ### Test Coverage
77
+
78
+ The test suite is configured to require at least 85% code coverage. Coverage reports are generated in both terminal and XML formats.
79
+
80
+ ## Backend (FastAPI) quickstart
81
+
82
+ Install server-only dependencies (kept out of the core library) with uv:
83
+ ```bash
84
+ uv sync --group server
85
+ ```
86
+
87
+ Run the API:
88
+ ```bash
89
+ uv run --group server uvicorn server.main:app --reload
90
+ ```
91
+
92
+ Smoke-test the server endpoints:
93
+ ```bash
94
+ uv run --group server pytest tests/server/test_api.py --cov=server --cov-fail-under=0
95
+ ```
96
+
97
+ ### SAE API usage
98
+
99
+ - Configure artifact location (optional): `export SERVER_ARTIFACT_BASE_PATH=/path/to/amber_artifacts` (defaults to `~/.cache/amber_server`)
100
+ - Load a model: `curl -X POST http://localhost:8000/models/load -H "Content-Type: application/json" -d '{"model_id":"bielik"}'`
101
+ - Save activations from dataset (stored in `LocalStore` under `activations/<model>/<run_id>`):
102
+ - HF dataset: `{"dataset":{"type":"hf","name":"ag_news","split":"train","text_field":"text"}}`
103
+ - Local files: `{"dataset":{"type":"local","paths":["/path/to/file.txt"]}}`
104
+ - Example: `curl -X POST http://localhost:8000/sae/activations/save -H "Content-Type: application/json" -d '{"model_id":"bielik","layers":["dummy_root"],"dataset":{"type":"local","paths":["/tmp/data.txt"]},"sample_limit":100,"batch_size":4,"shard_size":64}'` → returns a manifest path, run_id, token counts, and batch metadata.
105
+ - List activation runs: `curl "http://localhost:8000/sae/activations?model_id=bielik"`
106
+ - Start SAE training (async job, uses `SaeTrainer`): `curl -X POST http://localhost:8000/sae/train -H "Content-Type: application/json" -d '{"model_id":"bielik","activations_path":"/path/to/manifest.json","layer":"<layer_name>","sae_class":"TopKSae","hyperparams":{"epochs":1,"batch_size":256}}'` → returns `job_id`
107
+ - Check job status: `curl http://localhost:8000/sae/train/status/<job_id>` (returns `sae_id`, `sae_path`, `metadata_path`, progress, and logs)
108
+ - Cancel a job (best-effort): `curl -X POST http://localhost:8000/sae/train/cancel/<job_id>`
109
+ - Load an SAE: `curl -X POST http://localhost:8000/sae/load -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_path":"/path/to/sae.json"}'`
110
+ - List SAEs: `curl "http://localhost:8000/sae/saes?model_id=bielik"`
111
+ - Run SAE inference (optionally save top texts and apply concept config): `curl -X POST http://localhost:8000/sae/infer -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_id":"<sae_id>","save_top_texts":true,"top_k_neurons":5,"concept_config_path":"/path/to/concepts.json","inputs":[{"prompt":"hi"}]}'` → returns outputs, top neuron summary, sae metadata, and saved top-texts path when requested.
112
+ - Per-token latents: add `"return_token_latents": true` (default off) to include top-k neuron activations per token.
113
+ - List concepts: `curl "http://localhost:8000/sae/concepts?model_id=bielik&sae_id=<sae_id>"`
114
+ - Load concepts from a file (validated against SAE latents): `curl -X POST http://localhost:8000/sae/concepts/load -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_id":"<sae_id>","source_path":"/path/to/concepts.json"}'`
115
+ - Manipulate concepts (saves a config file for inference-time scaling): `curl -X POST http://localhost:8000/sae/concepts/manipulate -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_id":"<sae_id>","edits":{"0":1.2}}'`
116
+ - List concept configs: `curl "http://localhost:8000/sae/concepts/configs?model_id=bielik&sae_id=<sae_id>"`
117
+ - Preview concept config (validate without saving): `curl -X POST http://localhost:8000/sae/concepts/preview -H "Content-Type: application/json" -d '{"model_id":"bielik","sae_id":"<sae_id>","edits":{"0":1.2}}'`
118
+ - Delete activation run or SAE (requires API key if set): `curl -X DELETE "http://localhost:8000/sae/activations/<run_id>?model_id=bielik" -H "X-API-Key: <key>"` and `curl -X DELETE "http://localhost:8000/sae/saes/<sae_id>?model_id=bielik" -H "X-API-Key: <key>"`
119
+ - Health/metrics summary: `curl http://localhost:8000/health/metrics` (in-memory job counts; no persistence, no auth)
120
+
121
+ Notes:
122
+ - Job manager is in-memory/lightweight: jobs disappear on process restart; idempotency is best-effort via payload key.
123
+ - Training/inference currently run in-process threads; add your own resource guards when running heavy models.
124
+ - Optional API key protection: set `SERVER_API_KEY=<value>` to require `X-API-Key` on protected endpoints (delete).
@@ -0,0 +1,51 @@
1
+ amber/__init__.py,sha256=5nh0D8qvFgOhBEQj00Rm06T1iY5VcSiifAg9SoY1LLA,483
2
+ amber/utils.py,sha256=oER2LA_alUjaIk_xCAyP2V54ywjqsg00I4KvitYnJPc,1547
3
+ amber/datasets/__init__.py,sha256=zhqgbm5zMBsRbmPNfjlYNJwGWOLuCNf5jEj0P8aopRU,341
4
+ amber/datasets/base_dataset.py,sha256=X2wt3GdjgAOY24_vOqrD5gVFxGplSRMCb69CoQtj0xw,22508
5
+ amber/datasets/classification_dataset.py,sha256=x_ZQ4dMzoY3Nn8V1I01xvzJK_IcHcDcm8dIxYeXzV5g,21700
6
+ amber/datasets/loading_strategy.py,sha256=17VM3Td8lqDllGIx9DHI6WiXmSKKQHDHbfe4ZeM8ATA,1206
7
+ amber/datasets/text_dataset.py,sha256=ly0GHCS28Rg5ZluaafjavhcbvSD9-6ryovd_Y1ZIMms,16775
8
+ amber/hooks/__init__.py,sha256=9H08ZVoTK6TzYJXjEP2aqdHfoyLfdXvg6eOv3K1zNps,679
9
+ amber/hooks/controller.py,sha256=hc8FrrDosFYLrEGsEZmx1KsJ77F4p_gMKcF2WzHiURY,6057
10
+ amber/hooks/detector.py,sha256=5drJFrdrjseVjRNT-cq-U8XCt8AXV04YY2YrkQz4eFk,3110
11
+ amber/hooks/hook.py,sha256=-Qi-GJqRuIskXMuHUzp9_ESbbZi5tLSAFMWoBmrP3io,7540
12
+ amber/hooks/utils.py,sha256=wtsrjsMt-bXR3NshkwyZmfLre3IE3S4E5EoKppQrYOo,2022
13
+ amber/hooks/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ amber/hooks/implementations/function_controller.py,sha256=66FFx_7sU7b0_FFQFFkAOqQm3aGsiyIRuZ64hIv-0w8,3088
15
+ amber/hooks/implementations/layer_activation_detector.py,sha256=bzoW6V8NNDNgRASs1YN_1TjEXaK3ahoNWiZ-ODfjB6I,3161
16
+ amber/hooks/implementations/model_input_detector.py,sha256=cYRVfyBEHi-1qg6F-4Q0vKEae6gYtq_3g1j3rOOCQdA,10074
17
+ amber/hooks/implementations/model_output_detector.py,sha256=iN-twt7Chc9ODmj-iei7_Ah7GqvE-knTVWi4C9kNye4,4879
18
+ amber/language_model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
+ amber/language_model/activations.py,sha256=7VRDlCM-IctF1ee0H68_Dd7xeXHyUSQe1SOPb28Ej18,16971
20
+ amber/language_model/context.py,sha256=koslpikCcu9Svsopboa1wd7Fv0R2-4sI2whOCvVWvT8,1118
21
+ amber/language_model/contracts.py,sha256=6ij7rzJcpSAKgYx-fiefg0Fi8TsFugaM5IlihEm_2VM,268
22
+ amber/language_model/hook_metadata.py,sha256=9Xyfiu4ekCZj79zG4gZfLk-850AO2iKDE24FDXe7q7s,1392
23
+ amber/language_model/inference.py,sha256=l8BASS8E9B4VWJHucEqF_G_zqOzlKeG4KEvGameBbMw,19068
24
+ amber/language_model/initialization.py,sha256=hfrKdI_fsmaxk0p9q4wN7EFxq_lSXs8BXGlxwKJ21Qw,3698
25
+ amber/language_model/language_model.py,sha256=MXoaXYbNBUxHw4sxtWkVFLbXiw9tFEx9GI78sCiESuQ,13943
26
+ amber/language_model/layers.py,sha256=Ob7QZl8i236ALLklY9o_xtjDZSt6FD8sqdmFy_YLgN0,15906
27
+ amber/language_model/persistence.py,sha256=i2ibDH1OABM5-ZNNLh7h4rOYWPsg3aaeYhmB_xWYDZw,5867
28
+ amber/language_model/tokenizer.py,sha256=9eKNOHvUjIJhJbj7M-tN7jWU5lWhOeCY_cssa4exQ1g,7377
29
+ amber/language_model/utils.py,sha256=5Y7scRvvudUjKDV8QPhC3HAc2S-dCuqbm6xEjRr0fRM,2630
30
+ amber/mechanistic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ amber/mechanistic/sae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
+ amber/mechanistic/sae/autoencoder_context.py,sha256=cn0mv9COqT3jNcvXBfce70ankVEmw9kNE3Mu-knugoc,945
33
+ amber/mechanistic/sae/sae.py,sha256=ha6rXGsOXE59E_ohTH0vJh6M4rQh3Xw0GfmCkSgeYS4,6035
34
+ amber/mechanistic/sae/sae_trainer.py,sha256=GMrPz9SpSuANA0tJt3IkyIOOqVr2k7apE0w4CqL92gM,26311
35
+ amber/mechanistic/sae/concepts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
+ amber/mechanistic/sae/concepts/autoencoder_concepts.py,sha256=dTB9v5zK5lcKc4_zbYtWoCy8fA6ycj1-FJid4-7hQ04,14371
37
+ amber/mechanistic/sae/concepts/concept_dictionary.py,sha256=9px845gODuufW2koym-5426q0ijdpxvS7avH1vu8-Ls,7671
38
+ amber/mechanistic/sae/concepts/concept_models.py,sha256=HGyPoMSmj8CAg9joIa6fV_FfWAY2XN4LFPIkNQiJWbY,133
39
+ amber/mechanistic/sae/concepts/input_tracker.py,sha256=81FrOv9AAC7ejhryOWDTZ7Hlt3B2WoANx-wiO0KLr24,1886
40
+ amber/mechanistic/sae/modules/__init__.py,sha256=xpoz0HtPWoJD4dPj1qHaxtXDr7J0ERn30CX3m1dz21s,239
41
+ amber/mechanistic/sae/modules/l1_sae.py,sha256=_BebvpB9iUCTDjSiYNzDBM4l9sU_wadAdypwXaSb4ww,15352
42
+ amber/mechanistic/sae/modules/topk_sae.py,sha256=GQA8hYb6Fw7U1e5ExZjzZBjAZRhGk-VwqqYREVCQ_u8,17275
43
+ amber/mechanistic/sae/training/wandb_logger.py,sha256=d3vVBIQrnsJurX5HNVu7OYW4DqNgv18UZDpV8ddfN9k,8554
44
+ amber/store/__init__.py,sha256=UW4Hqyu-_qgnZ-gN_mk97OaWSrlPERcNi5YjnXMKeOU,119
45
+ amber/store/local_store.py,sha256=1pJbizZKrzNt_IQFnCFYjApPXs9ot-G1H8adeR7Qi50,17214
46
+ amber/store/store.py,sha256=VuDe9Git0glND3TTHh0zhDJNxdQY3dCp0cURhApYQbU,9334
47
+ amber/store/store_dataloader.py,sha256=QyYHSgOos8e-yzaEE_rySSVlGKaRNybURSDCgNrTIVM,4337
48
+ mi_crow-0.1.1.post12.dist-info/METADATA,sha256=1_SPBcnK8j_scOIWwnPqQGjlUpNnXxGGtvkYRZHeDQ8,6580
49
+ mi_crow-0.1.1.post12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
50
+ mi_crow-0.1.1.post12.dist-info/top_level.txt,sha256=FNP1x_ePvcW9Jsr7J9gCBARdDC-gqxIYtWF6HGNxtnI,6
51
+ mi_crow-0.1.1.post12.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ amber