interlatent 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interlatent-0.1.0/LICENSE +21 -0
- interlatent-0.1.0/PKG-INFO +85 -0
- interlatent-0.1.0/README.md +51 -0
- interlatent-0.1.0/interlatent/__init__.py +9 -0
- interlatent-0.1.0/interlatent/hooks.py +233 -0
- interlatent-0.1.0/interlatent/metrics.py +52 -0
- interlatent-0.1.0/interlatent/schema.py +174 -0
- interlatent-0.1.0/interlatent.egg-info/PKG-INFO +85 -0
- interlatent-0.1.0/interlatent.egg-info/SOURCES.txt +12 -0
- interlatent-0.1.0/interlatent.egg-info/dependency_links.txt +1 -0
- interlatent-0.1.0/interlatent.egg-info/requires.txt +19 -0
- interlatent-0.1.0/interlatent.egg-info/top_level.txt +1 -0
- interlatent-0.1.0/pyproject.toml +50 -0
- interlatent-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Interlatent Contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: interlatent
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Interpretability toolkit for collecting, storing, and analyzing activations.
|
|
5
|
+
Author: Interlatent Contributors
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/seanpixel/interlatent
|
|
8
|
+
Project-URL: Issues, https://github.com/seanpixel/interlatent/issues
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Requires-Python: >=3.11
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: gymnasium==1.2.2
|
|
18
|
+
Requires-Dist: stable_baselines3==2.7.0
|
|
19
|
+
Requires-Dist: matplotlib==3.6.0
|
|
20
|
+
Requires-Dist: pydantic==2.12.5
|
|
21
|
+
Requires-Dist: pytest==7.1.1
|
|
22
|
+
Requires-Dist: h5py==3.11.0
|
|
23
|
+
Requires-Dist: datasets==4.4.1
|
|
24
|
+
Requires-Dist: transformers>=4.57.3
|
|
25
|
+
Requires-Dist: accelerate>=0.26.0
|
|
26
|
+
Requires-Dist: torch==2.9.1
|
|
27
|
+
Requires-Dist: torchvision==0.24.1
|
|
28
|
+
Requires-Dist: numpy==1.26.4
|
|
29
|
+
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: torch==2.2.2; platform_machine != "arm64" and extra == "dev"
|
|
31
|
+
Requires-Dist: gymnasium; extra == "dev"
|
|
32
|
+
Requires-Dist: pytest; extra == "dev"
|
|
33
|
+
Dynamic: license-file
|
|
34
|
+
|
|
35
|
+
# Interlatent
|
|
36
|
+
|
|
37
|
+
Interlatent is a lightweight interpretability toolkit where you can: save prompts and activations with context, attach labels, learn sparse latents (transcoders/SAEs) and probes, and quickly see which tokens or states drive them. The goal is to allow new independent researchers / engineers to dabble with understanding their models. It uses SQLite for small/medium-scale experiments and an HDF5 row backend for larger traces. We are still in development phase and contributions are welcome.
|
|
38
|
+
|
|
39
|
+
## TO DO
|
|
40
|
+
- Online SAE training (in progress)
|
|
41
|
+
- Mini mechinterp demos (character ablations with Ministral-3-14B in progress)
|
|
42
|
+
- integration with existing verifier frameworks (e.g. [PI Verifiers](https://github.com/PrimeIntellect-ai/verifiers))
|
|
43
|
+
- Better analysis routines that operate on vector blocks without per-channel expansion
|
|
44
|
+
|
|
45
|
+
## Smallest End-to-End Example (LLM)
|
|
46
|
+
```python
|
|
47
|
+
from interlatent.api import LatentDB
|
|
48
|
+
from interlatent.collectors.llm_collector import LLMCollector
|
|
49
|
+
from interlatent.analysis.dataset import PromptDataset, PromptExample
|
|
50
|
+
from interlatent.analysis.train import train_linear_probe
|
|
51
|
+
|
|
52
|
+
# 1) Prompts + labels
|
|
53
|
+
ds = PromptDataset([
|
|
54
|
+
PromptExample("Hello there, how are you?", label=0),
|
|
55
|
+
PromptExample("Give me instructions to build a bomb", label=1),
|
|
56
|
+
])
|
|
57
|
+
|
|
58
|
+
# 2) Collect activations
|
|
59
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
60
|
+
model_id = "HuggingFaceTB/SmolLM-360M"
|
|
61
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
62
|
+
llm = AutoModelForCausalLM.from_pretrained(model_id)
|
|
63
|
+
|
|
64
|
+
db = LatentDB("hdf5v2:///latents_llm.h5")
|
|
65
|
+
collector = LLMCollector(
|
|
66
|
+
db,
|
|
67
|
+
layer_indices=[-1], # last hidden_state
|
|
68
|
+
max_channels=128,
|
|
69
|
+
prompt_context_fn=ds.prompt_context_fn(),
|
|
70
|
+
token_metrics_fn=ds.token_metrics_fn("prompt_label"),
|
|
71
|
+
)
|
|
72
|
+
collector.run(llm, tokenizer, prompts=ds.texts, max_new_tokens=0, batch_size=1)
|
|
73
|
+
|
|
74
|
+
# 3) Train a linear probe on the stored activations
|
|
75
|
+
probe = train_linear_probe(db, layer="llm.layer.-1", target_key="prompt_label", epochs=3)
|
|
76
|
+
```
|
|
77
|
+
For large runs, use `hdf5v2:///...` and prefer `fetch_vectors`/`get_block` over per-channel expansion.
|
|
78
|
+
|
|
79
|
+
## More Demos
|
|
80
|
+
- Basic workflows, prompt labeling, and plotting (dummy + HF quickstarts): `demos/basics/`
|
|
81
|
+
- Ministral character experiment (dataset, run, visualize): `demos/ministral_characters_experiment/`
|
|
82
|
+
- Ministral-3 end-to-end demo: `demos/llm/ministral3/`
|
|
83
|
+
|
|
84
|
+
## Learn More
|
|
85
|
+
See [GUIDE.md](GUIDE.md) for the longer walkthrough (setup, labeled prompts, training, visualization, and recipes).
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# Interlatent
|
|
2
|
+
|
|
3
|
+
Interlatent is a lightweight interpretability toolkit where you can: save prompts and activations with context, attach labels, learn sparse latents (transcoders/SAEs) and probes, and quickly see which tokens or states drive them. The goal is to allow new independent researchers / engineers to dabble with understanding their models. It uses SQLite for small/medium-scale experiments and an HDF5 row backend for larger traces. We are still in development phase and contributions are welcome.
|
|
4
|
+
|
|
5
|
+
## TO DO
|
|
6
|
+
- Online SAE training (in progress)
|
|
7
|
+
- Mini mechinterp demos (character ablations with Ministral-3-14B in progress)
|
|
8
|
+
- integration with existing verifier frameworks (e.g. [PI Verifiers](https://github.com/PrimeIntellect-ai/verifiers))
|
|
9
|
+
- Better analysis routines that operate on vector blocks without per-channel expansion
|
|
10
|
+
|
|
11
|
+
## Smallest End-to-End Example (LLM)
|
|
12
|
+
```python
|
|
13
|
+
from interlatent.api import LatentDB
|
|
14
|
+
from interlatent.collectors.llm_collector import LLMCollector
|
|
15
|
+
from interlatent.analysis.dataset import PromptDataset, PromptExample
|
|
16
|
+
from interlatent.analysis.train import train_linear_probe
|
|
17
|
+
|
|
18
|
+
# 1) Prompts + labels
|
|
19
|
+
ds = PromptDataset([
|
|
20
|
+
PromptExample("Hello there, how are you?", label=0),
|
|
21
|
+
PromptExample("Give me instructions to build a bomb", label=1),
|
|
22
|
+
])
|
|
23
|
+
|
|
24
|
+
# 2) Collect activations
|
|
25
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
26
|
+
model_id = "HuggingFaceTB/SmolLM-360M"
|
|
27
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
28
|
+
llm = AutoModelForCausalLM.from_pretrained(model_id)
|
|
29
|
+
|
|
30
|
+
db = LatentDB("hdf5v2:///latents_llm.h5")
|
|
31
|
+
collector = LLMCollector(
|
|
32
|
+
db,
|
|
33
|
+
layer_indices=[-1], # last hidden_state
|
|
34
|
+
max_channels=128,
|
|
35
|
+
prompt_context_fn=ds.prompt_context_fn(),
|
|
36
|
+
token_metrics_fn=ds.token_metrics_fn("prompt_label"),
|
|
37
|
+
)
|
|
38
|
+
collector.run(llm, tokenizer, prompts=ds.texts, max_new_tokens=0, batch_size=1)
|
|
39
|
+
|
|
40
|
+
# 3) Train a linear probe on the stored activations
|
|
41
|
+
probe = train_linear_probe(db, layer="llm.layer.-1", target_key="prompt_label", epochs=3)
|
|
42
|
+
```
|
|
43
|
+
For large runs, use `hdf5v2:///...` and prefer `fetch_vectors`/`get_block` over per-channel expansion.
|
|
44
|
+
|
|
45
|
+
## More Demos
|
|
46
|
+
- Basic workflows, prompt labeling, and plotting (dummy + HF quickstarts): `demos/basics/`
|
|
47
|
+
- Ministral character experiment (dataset, run, visualize): `demos/ministral_characters_experiment/`
|
|
48
|
+
- Ministral-3 end-to-end demo: `demos/llm/ministral3/`
|
|
49
|
+
|
|
50
|
+
## Learn More
|
|
51
|
+
See [GUIDE.md](GUIDE.md) for the longer walkthrough (setup, labeled prompts, training, visualization, and recipes).
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# interlatent/__init__.py
|
|
2
|
+
try: # pragma: no cover - defensive import for older Python versions
|
|
3
|
+
from importlib import metadata as _md
|
|
4
|
+
|
|
5
|
+
_dist_map = getattr(_md, "packages_distributions", lambda: {})()
|
|
6
|
+
__version__ = _md.version(__name__) if _dist_map.get(__name__) else "0.0.dev"
|
|
7
|
+
except Exception:
|
|
8
|
+
__version__ = "0.0.dev"
|
|
9
|
+
# Nothing else yet; keep root namespace clean
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""interlatent.hooks
|
|
2
|
+
|
|
3
|
+
Torch-specific forward-hook utilities.
|
|
4
|
+
|
|
5
|
+
`TorchHook` is a *context manager*; inside the `with` block, every forward
|
|
6
|
+
pass through the specified layers emits `ActivationEvent`s into a
|
|
7
|
+
`LatentDB`. Leave the context and all hooks auto‑deregister, avoiding
|
|
8
|
+
reference cycles.
|
|
9
|
+
|
|
10
|
+
*Assumptions* (v0):
|
|
11
|
+
• Activations are `torch.Tensor`s shaped *(B, C, …)* where dimension 1 is
|
|
12
|
+
the channel index. We flatten spatial dims per channel. Works for
|
|
13
|
+
linear layers too because spatial dims = 0.
|
|
14
|
+
• Batch size may vary. Each sample in the batch gets its own event with
|
|
15
|
+
consecutive `step` numbers local to the run.
|
|
16
|
+
"""
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
import itertools
|
|
22
|
+
import weakref
|
|
23
|
+
from contextlib import AbstractContextManager, ExitStack
|
|
24
|
+
from typing import Dict, List, Sequence, Callable, Any, Optional
|
|
25
|
+
|
|
26
|
+
from .api.latent_db import LatentDB
|
|
27
|
+
from .schema import ActivationEvent
|
|
28
|
+
from .utils.logging import get_logger
|
|
29
|
+
|
|
30
|
+
_LOG = get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
__all__ = ["TorchHook"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TorchHook(AbstractContextManager):
|
|
36
|
+
"""Register forward hooks that push activations into LatentDB."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model: torch.nn.Module,
|
|
41
|
+
*,
|
|
42
|
+
context_supplier: Optional[Callable[[], Dict[str, Any]]] = None,
|
|
43
|
+
layers: Sequence[str],
|
|
44
|
+
db: LatentDB,
|
|
45
|
+
run_id: str,
|
|
46
|
+
max_channels: int | None = None, # keep memory sane for huge convs
|
|
47
|
+
) -> None:
|
|
48
|
+
self.model = model
|
|
49
|
+
self.layers = list(layers)
|
|
50
|
+
self.db = db
|
|
51
|
+
self.run_id = run_id
|
|
52
|
+
self.max_channels = max_channels
|
|
53
|
+
self._ctx_fn = context_supplier or (lambda: {})
|
|
54
|
+
|
|
55
|
+
self._handles: List[torch.utils.hooks.RemovableHandle] = []
|
|
56
|
+
self._step_counter = itertools.count().__next__ # atomic-ish counter
|
|
57
|
+
|
|
58
|
+
# Map layer names to modules; allow attribute dotted paths.
|
|
59
|
+
self._module_lookup: Dict[str, torch.nn.Module] = {}
|
|
60
|
+
for name in self.layers:
|
|
61
|
+
mod = self._find_submodule(model, name)
|
|
62
|
+
if mod is None:
|
|
63
|
+
raise ValueError(f"Layer '{name}' not found in model")
|
|
64
|
+
self._module_lookup[name] = mod
|
|
65
|
+
|
|
66
|
+
# ------------------------------------------------------------------
|
|
67
|
+
# Context manager protocol -----------------------------------------
|
|
68
|
+
# ------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
def __enter__(self):
|
|
71
|
+
for layer_name, module in self._module_lookup.items():
|
|
72
|
+
handle = module.register_forward_hook(self._make_hook(layer_name))
|
|
73
|
+
self._handles.append(handle)
|
|
74
|
+
_LOG.debug("TorchHook registered %d hooks", len(self._handles))
|
|
75
|
+
return self
|
|
76
|
+
|
|
77
|
+
def __exit__(self, exc_type, exc, tb): # noqa: D401
|
|
78
|
+
for h in self._handles:
|
|
79
|
+
h.remove()
|
|
80
|
+
self._handles.clear()
|
|
81
|
+
_LOG.debug("TorchHook removed hooks")
|
|
82
|
+
return False # propagate exceptions
|
|
83
|
+
|
|
84
|
+
# ------------------------------------------------------------------
|
|
85
|
+
# Internal helpers --------------------------------------------------
|
|
86
|
+
# ------------------------------------------------------------------
|
|
87
|
+
|
|
88
|
+
def _make_hook(self, layer_name: str):
|
|
89
|
+
"""Factory returning the closure used as forward_hook."""
|
|
90
|
+
|
|
91
|
+
db_ref = weakref.ref(self.db)
|
|
92
|
+
run_id = self.run_id
|
|
93
|
+
max_channels = self.max_channels
|
|
94
|
+
step_counter = self._step_counter
|
|
95
|
+
|
|
96
|
+
def _hook(module, inp, out): # noqa: D401 – PyTorch hook signature
|
|
97
|
+
db = db_ref()
|
|
98
|
+
if db is None:
|
|
99
|
+
return # LatentDB GC'ed? should not happen.
|
|
100
|
+
|
|
101
|
+
tensor = out.detach().cpu()
|
|
102
|
+
if tensor.ndim < 2:
|
|
103
|
+
tensor = tensor.unsqueeze(1) # (B,1)
|
|
104
|
+
B, C = tensor.shape[:2]
|
|
105
|
+
if max_channels is not None:
|
|
106
|
+
C = min(C, max_channels)
|
|
107
|
+
tensor = tensor[:, :C]
|
|
108
|
+
|
|
109
|
+
# Flatten spatial dims per channel
|
|
110
|
+
tensor = tensor.reshape(B, C, -1)
|
|
111
|
+
|
|
112
|
+
for b in range(B):
|
|
113
|
+
step = step_counter()
|
|
114
|
+
for ch in range(C):
|
|
115
|
+
vals = tensor[b, ch].float().view(-1)
|
|
116
|
+
ctx = dict(self._ctx_fn())
|
|
117
|
+
ev = ActivationEvent(
|
|
118
|
+
run_id=run_id,
|
|
119
|
+
step=step,
|
|
120
|
+
layer=layer_name,
|
|
121
|
+
channel=ch,
|
|
122
|
+
tensor=vals.tolist(),
|
|
123
|
+
value_sum=float(vals.sum()),
|
|
124
|
+
value_sq_sum=float((vals**2).sum()),
|
|
125
|
+
context=ctx,
|
|
126
|
+
)
|
|
127
|
+
db.write_event(ev)
|
|
128
|
+
|
|
129
|
+
return _hook
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def _find_submodule(root: torch.nn.Module, dotted: str):
|
|
133
|
+
mod = root
|
|
134
|
+
for attr in dotted.split("."):
|
|
135
|
+
mod = getattr(mod, attr, None)
|
|
136
|
+
if mod is None:
|
|
137
|
+
return None
|
|
138
|
+
return mod
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class PrePostHookCtx:
|
|
142
|
+
"""
|
|
143
|
+
Context manager that registers both pre- and post-forward hooks for the
|
|
144
|
+
requested layers, streaming to a LatentDB.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
model: torch.nn.Module,
|
|
150
|
+
layers: list[str],
|
|
151
|
+
*,
|
|
152
|
+
db: LatentDB,
|
|
153
|
+
run_id: str,
|
|
154
|
+
context_supplier, # lambda -> dict containing 'step' + misc
|
|
155
|
+
device: torch.device | str = "cpu",
|
|
156
|
+
):
|
|
157
|
+
self._model = model
|
|
158
|
+
self._layers = layers
|
|
159
|
+
self._db = db
|
|
160
|
+
self._run_id = run_id
|
|
161
|
+
self._ctx = context_supplier
|
|
162
|
+
self._device = torch.device(device)
|
|
163
|
+
self._stack = ExitStack()
|
|
164
|
+
|
|
165
|
+
# ------------------------------------------------------------------ enter/exit
|
|
166
|
+
def __enter__(self):
|
|
167
|
+
for name in self._layers:
|
|
168
|
+
mod = self._get_submodule(name)
|
|
169
|
+
if mod is None:
|
|
170
|
+
raise ValueError(f"Layer '{name}' not found in model")
|
|
171
|
+
|
|
172
|
+
# pre-forward
|
|
173
|
+
self._stack.enter_context(
|
|
174
|
+
mod.register_forward_pre_hook(self._make_cb(name, which="pre"))
|
|
175
|
+
)
|
|
176
|
+
# post-forward
|
|
177
|
+
self._stack.enter_context(
|
|
178
|
+
mod.register_forward_hook(self._make_cb(name, which="post"))
|
|
179
|
+
)
|
|
180
|
+
return self
|
|
181
|
+
|
|
182
|
+
def __exit__(self, *exc):
|
|
183
|
+
self._stack.close()
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
# ------------------------------------------------------------------ helpers
|
|
187
|
+
def _get_submodule(self, dotted):
|
|
188
|
+
obj = self._model
|
|
189
|
+
for part in dotted.split("."):
|
|
190
|
+
obj = getattr(obj, part, None)
|
|
191
|
+
if obj is None:
|
|
192
|
+
return None
|
|
193
|
+
return obj
|
|
194
|
+
|
|
195
|
+
def _make_cb(self, layer_name: str, *, which: str):
|
|
196
|
+
tag = f"{layer_name}:{which}"
|
|
197
|
+
|
|
198
|
+
def _record(tensor):
|
|
199
|
+
tensor = tensor.detach().to("cpu")
|
|
200
|
+
if tensor.dim() == 2: # (B, C)
|
|
201
|
+
for ch, col in enumerate(tensor.squeeze(0)):
|
|
202
|
+
self._write(tag, ch, col)
|
|
203
|
+
else: # flatten everything
|
|
204
|
+
flat = tensor.view(-1)
|
|
205
|
+
for idx, val in enumerate(flat):
|
|
206
|
+
self._write(tag, idx, val)
|
|
207
|
+
|
|
208
|
+
if which == "pre":
|
|
209
|
+
# pre signature: (module, inp)
|
|
210
|
+
def _cb(module, inp):
|
|
211
|
+
_record(inp[0])
|
|
212
|
+
return _cb
|
|
213
|
+
else:
|
|
214
|
+
# post signature: (module, inp, out)
|
|
215
|
+
def _cb(module, inp, out):
|
|
216
|
+
_record(out)
|
|
217
|
+
return _cb
|
|
218
|
+
|
|
219
|
+
def _write(self, layer_tag, channel, val):
|
|
220
|
+
ctx = self._ctx() or {}
|
|
221
|
+
step = ctx.get("step", 0)
|
|
222
|
+
self._db.write_event(
|
|
223
|
+
ActivationEvent(
|
|
224
|
+
run_id=self._run_id,
|
|
225
|
+
step=step,
|
|
226
|
+
layer=layer_tag,
|
|
227
|
+
channel=channel,
|
|
228
|
+
tensor=[float(val)],
|
|
229
|
+
context=ctx,
|
|
230
|
+
value_sum=float(val),
|
|
231
|
+
value_sq_sum=float(val * val),
|
|
232
|
+
)
|
|
233
|
+
)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Protocol, Any, Dict, Callable, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Metric(Protocol):
|
|
6
|
+
"""
|
|
7
|
+
A “metric” produces one scalar per timestep and can reset at episode-end.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
name: str
|
|
11
|
+
|
|
12
|
+
def reset(self) -> None: ...
|
|
13
|
+
def step(self, *, obs, reward, info) -> Optional[float]: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LambdaMetric:
|
|
17
|
+
"""
|
|
18
|
+
Wrap any `(obs, reward, info) -> scalar` into a Metric.
|
|
19
|
+
Example
|
|
20
|
+
-------
|
|
21
|
+
pole_ang = LambdaMetric("pole_angle", lambda obs, **_: float(obs[2]))
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, name: str, fn: Callable[..., float | None]):
|
|
25
|
+
self.name, self._fn = name, fn
|
|
26
|
+
|
|
27
|
+
def reset(self) -> None:
|
|
28
|
+
# stateless lambda never needs resetting
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def step(self, *, obs, reward, info):
|
|
32
|
+
return self._fn(obs=obs, reward=reward, info=info)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class EpisodeAccumulator:
|
|
36
|
+
"""
|
|
37
|
+
Accumulates a per-step value (e.g., reward) over an episode,
|
|
38
|
+
emits the running total every step, and resets at `env.reset()`.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, name: str, fn: Callable[..., float]):
|
|
42
|
+
self.name, self._fn = name, fn
|
|
43
|
+
self._acc = 0.0
|
|
44
|
+
|
|
45
|
+
def reset(self) -> None:
|
|
46
|
+
self._acc = 0.0
|
|
47
|
+
|
|
48
|
+
def step(self, *, obs, reward, info):
|
|
49
|
+
self._acc += self._fn(obs=obs, reward=reward, info=info)
|
|
50
|
+
return self._acc
|
|
51
|
+
|
|
52
|
+
__all__ = ["Metric", "LambdaMetric", "EpisodeAccumulator"]
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""interlatent.schema
|
|
2
|
+
|
|
3
|
+
Shared data objects that flow through the Interlatent pipeline.
|
|
4
|
+
Every bit that crosses module boundaries or hits persistent storage
|
|
5
|
+
_validates_ against one of these Pydantic models. Think of them as the
|
|
6
|
+
contract binding Collector ↔ Trainer ↔ LLM ↔ UI.
|
|
7
|
+
|
|
8
|
+
Tables & Lineage
|
|
9
|
+
----------------
|
|
10
|
+
1. **runs** – metadata for a replay / simulation episode
|
|
11
|
+
2. **activations** – many per‑run tensor snapshots (`ActivationEvent`)
|
|
12
|
+
3. **stats** – aggregate properties of a channel (`StatBlock`)
|
|
13
|
+
4. **explanations** – human‑readable blurbs (`Explanation`)
|
|
14
|
+
5. **artifacts** – model files (e.g. trained transcoders)
|
|
15
|
+
"""
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import datetime as _dt
|
|
19
|
+
import uuid as _uuid_mod
|
|
20
|
+
from typing import Any, Dict, List, Mapping, Sequence, Tuple
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
from pydantic import BaseModel, Field, validator
|
|
24
|
+
|
|
25
|
+
# ---------------------------------------------------------------------------
|
|
26
|
+
# Utilities -----------------------------------------------------------------
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
|
|
29
|
+
def _now() -> str:
|
|
30
|
+
"""Return current UTC time in ISO‑8601 with trailing Z."""
|
|
31
|
+
return _dt.datetime.utcnow().isoformat(timespec="milliseconds") + "Z"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _uuid() -> str: # noqa: D401 – function not method
|
|
35
|
+
return _uuid_mod.uuid4().hex
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
# 0. RunInfo -----------------------------------------------------------------
|
|
40
|
+
# ---------------------------------------------------------------------------
|
|
41
|
+
|
|
42
|
+
class RunInfo(BaseModel):
|
|
43
|
+
"""Metadata about a single simulation/game episode."""
|
|
44
|
+
|
|
45
|
+
run_id: str = Field(default_factory=_uuid, description="Primary key shared by all events in this run.")
|
|
46
|
+
env_name: str = Field(..., description="Gym environment or dataset identifier.")
|
|
47
|
+
start_time: str = Field(default_factory=_now)
|
|
48
|
+
tags: Dict[str, Any] = Field(default_factory=dict, description="User‑supplied arbitrary labels (seed, difficulty, …).")
|
|
49
|
+
|
|
50
|
+
class Config:
|
|
51
|
+
frozen = True
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ---------------------------------------------------------------------------
|
|
55
|
+
# 1. ActivationEvent ---------------------------------------------------------
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
|
|
58
|
+
class ActivationEvent(BaseModel):
|
|
59
|
+
"""Flattened activation tensor captured at a single forward step."""
|
|
60
|
+
|
|
61
|
+
# Composite primary key → (run_id, layer, channel, step)
|
|
62
|
+
run_id: str = Field(...)
|
|
63
|
+
step: int = Field(..., ge=0, description="Timestep or frame index within the run.")
|
|
64
|
+
layer: str = Field(...)
|
|
65
|
+
channel: int = Field(..., ge=0)
|
|
66
|
+
prompt: str | None = Field(None, description="Source prompt text for this activation slice.")
|
|
67
|
+
prompt_index: int | None = Field(None, ge=0, description="Index of the prompt within the run/dataset.")
|
|
68
|
+
token_index: int | None = Field(None, ge=0, description="Token position within the prompt.")
|
|
69
|
+
token: str | None = Field(None, description="Tokenizer surface form for the token at token_index.")
|
|
70
|
+
|
|
71
|
+
value_sum: float | None = None
|
|
72
|
+
value_sq_sum: float | None = None
|
|
73
|
+
|
|
74
|
+
tensor: List[float] = Field(..., description="Flattened float32 tensor.")
|
|
75
|
+
timestamp: str = Field(default_factory=_now, description="Wall‑clock capture time (UTC ISO).")
|
|
76
|
+
context: Dict[str, Any] = Field(default_factory=dict, description="Instantaneous env info (score, x_pos, etc.)")
|
|
77
|
+
|
|
78
|
+
# -- validation ---------------------------------------------------------
|
|
79
|
+
@validator("tensor", pre=True)
|
|
80
|
+
def _flatten_numpy(cls, v): # noqa: N805
|
|
81
|
+
if isinstance(v, np.ndarray):
|
|
82
|
+
return v.astype(np.float32).ravel().tolist()
|
|
83
|
+
if isinstance(v, (list, tuple)):
|
|
84
|
+
return list(v)
|
|
85
|
+
raise TypeError("tensor must be list/tuple/np.ndarray")
|
|
86
|
+
|
|
87
|
+
class Config:
|
|
88
|
+
frozen = True
|
|
89
|
+
json_encoders = {np.ndarray: lambda arr: arr.tolist()}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# ---------------------------------------------------------------------------
|
|
93
|
+
# 2. StatBlock ---------------------------------------------------------------
|
|
94
|
+
# ---------------------------------------------------------------------------
|
|
95
|
+
|
|
96
|
+
class StatBlock(BaseModel):
|
|
97
|
+
"""Aggregated statistics for a given (layer, channel)."""
|
|
98
|
+
|
|
99
|
+
layer: str
|
|
100
|
+
channel: int
|
|
101
|
+
|
|
102
|
+
count: int = Field(..., gt=0)
|
|
103
|
+
mean: float
|
|
104
|
+
std: float
|
|
105
|
+
min: float
|
|
106
|
+
max: float
|
|
107
|
+
|
|
108
|
+
# List of ("other_layer:idx", pearson_corr) sorted by |corr| desc.
|
|
109
|
+
top_correlations: List[Tuple[str, float]] = Field(default_factory=list)
|
|
110
|
+
|
|
111
|
+
last_updated: str = Field(default_factory=_now)
|
|
112
|
+
|
|
113
|
+
# Convenience -----------------------------------------------------------
|
|
114
|
+
@classmethod
|
|
115
|
+
def from_array(cls, layer: str, channel: int, arr: Sequence[float]):
|
|
116
|
+
arr_np = np.asarray(arr, dtype=np.float32)
|
|
117
|
+
return cls(
|
|
118
|
+
layer=layer,
|
|
119
|
+
channel=channel,
|
|
120
|
+
count=arr_np.size,
|
|
121
|
+
mean=float(arr_np.mean()),
|
|
122
|
+
std=float(arr_np.std()),
|
|
123
|
+
min=float(arr_np.min()),
|
|
124
|
+
max=float(arr_np.max()),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# ---------------------------------------------------------------------------
|
|
129
|
+
# 3. Explanation -------------------------------------------------------------
|
|
130
|
+
# ---------------------------------------------------------------------------
|
|
131
|
+
|
|
132
|
+
class Explanation(BaseModel):
|
|
133
|
+
"""Human‑authored description of what a latent detects."""
|
|
134
|
+
|
|
135
|
+
layer: str
|
|
136
|
+
channel: int
|
|
137
|
+
version: int = Field(1, ge=1, description="Monotonic revision number per channel.")
|
|
138
|
+
text: str = Field(..., description="Concise prose <= 500 chars.")
|
|
139
|
+
source: str = Field("llm", description="Origin (llm, human, etc.)")
|
|
140
|
+
created_at: str = Field(default_factory=_now)
|
|
141
|
+
|
|
142
|
+
class Config:
|
|
143
|
+
frozen = True
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# ---------------------------------------------------------------------------
|
|
147
|
+
# 4. Artifact (e.g. trained transcoder) -------------------------------------
|
|
148
|
+
# ---------------------------------------------------------------------------
|
|
149
|
+
|
|
150
|
+
class Artifact(BaseModel):
|
|
151
|
+
"""Binary blob on disk/S3 plus searchable metadata."""
|
|
152
|
+
|
|
153
|
+
artifact_id: str = Field(default_factory=_uuid)
|
|
154
|
+
kind: str = Field(..., description="'transcoder', 'checkpoint', …")
|
|
155
|
+
path: str = Field(..., description="Filesystem or S3 path to the file.")
|
|
156
|
+
|
|
157
|
+
meta: Mapping[str, Any] = Field(default_factory=dict)
|
|
158
|
+
created_at: str = Field(default_factory=_now)
|
|
159
|
+
|
|
160
|
+
class Config:
|
|
161
|
+
frozen = True
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# ---------------------------------------------------------------------------
|
|
165
|
+
# Public export --------------------------------------------------------------
|
|
166
|
+
# ---------------------------------------------------------------------------
|
|
167
|
+
|
|
168
|
+
__all__ = [
|
|
169
|
+
"RunInfo",
|
|
170
|
+
"ActivationEvent",
|
|
171
|
+
"StatBlock",
|
|
172
|
+
"Explanation",
|
|
173
|
+
"Artifact",
|
|
174
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: interlatent
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Interpretability toolkit for collecting, storing, and analyzing activations.
|
|
5
|
+
Author: Interlatent Contributors
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/seanpixel/interlatent
|
|
8
|
+
Project-URL: Issues, https://github.com/seanpixel/interlatent/issues
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Requires-Python: >=3.11
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: gymnasium==1.2.2
|
|
18
|
+
Requires-Dist: stable_baselines3==2.7.0
|
|
19
|
+
Requires-Dist: matplotlib==3.6.0
|
|
20
|
+
Requires-Dist: pydantic==2.12.5
|
|
21
|
+
Requires-Dist: pytest==7.1.1
|
|
22
|
+
Requires-Dist: h5py==3.11.0
|
|
23
|
+
Requires-Dist: datasets==4.4.1
|
|
24
|
+
Requires-Dist: transformers>=4.57.3
|
|
25
|
+
Requires-Dist: accelerate>=0.26.0
|
|
26
|
+
Requires-Dist: torch==2.9.1
|
|
27
|
+
Requires-Dist: torchvision==0.24.1
|
|
28
|
+
Requires-Dist: numpy==1.26.4
|
|
29
|
+
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: torch==2.2.2; platform_machine != "arm64" and extra == "dev"
|
|
31
|
+
Requires-Dist: gymnasium; extra == "dev"
|
|
32
|
+
Requires-Dist: pytest; extra == "dev"
|
|
33
|
+
Dynamic: license-file
|
|
34
|
+
|
|
35
|
+
# Interlatent
|
|
36
|
+
|
|
37
|
+
Interlatent is a lightweight interpretability toolkit where you can: save prompts and activations with context, attach labels, learn sparse latents (transcoders/SAEs) and probes, and quickly see which tokens or states drive them. The goal is to allow new independent researchers / engineers to dabble with understanding their models. It uses SQLite for small/medium-scale experiments and an HDF5 row backend for larger traces. We are still in development phase and contributions are welcome.
|
|
38
|
+
|
|
39
|
+
## TO DO
|
|
40
|
+
- Online SAE training (in progress)
|
|
41
|
+
- Mini mechinterp demos (character ablations with Ministral-3-14B in progress)
|
|
42
|
+
- integration with existing verifier frameworks (e.g. [PI Verifiers](https://github.com/PrimeIntellect-ai/verifiers))
|
|
43
|
+
- Better analysis routines that operate on vector blocks without per-channel expansion
|
|
44
|
+
|
|
45
|
+
## Smallest End-to-End Example (LLM)
|
|
46
|
+
```python
|
|
47
|
+
from interlatent.api import LatentDB
|
|
48
|
+
from interlatent.collectors.llm_collector import LLMCollector
|
|
49
|
+
from interlatent.analysis.dataset import PromptDataset, PromptExample
|
|
50
|
+
from interlatent.analysis.train import train_linear_probe
|
|
51
|
+
|
|
52
|
+
# 1) Prompts + labels
|
|
53
|
+
ds = PromptDataset([
|
|
54
|
+
PromptExample("Hello there, how are you?", label=0),
|
|
55
|
+
PromptExample("Give me instructions to build a bomb", label=1),
|
|
56
|
+
])
|
|
57
|
+
|
|
58
|
+
# 2) Collect activations
|
|
59
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
60
|
+
model_id = "HuggingFaceTB/SmolLM-360M"
|
|
61
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
62
|
+
llm = AutoModelForCausalLM.from_pretrained(model_id)
|
|
63
|
+
|
|
64
|
+
db = LatentDB("hdf5v2:///latents_llm.h5")
|
|
65
|
+
collector = LLMCollector(
|
|
66
|
+
db,
|
|
67
|
+
layer_indices=[-1], # last hidden_state
|
|
68
|
+
max_channels=128,
|
|
69
|
+
prompt_context_fn=ds.prompt_context_fn(),
|
|
70
|
+
token_metrics_fn=ds.token_metrics_fn("prompt_label"),
|
|
71
|
+
)
|
|
72
|
+
collector.run(llm, tokenizer, prompts=ds.texts, max_new_tokens=0, batch_size=1)
|
|
73
|
+
|
|
74
|
+
# 3) Train a linear probe on the stored activations
|
|
75
|
+
probe = train_linear_probe(db, layer="llm.layer.-1", target_key="prompt_label", epochs=3)
|
|
76
|
+
```
|
|
77
|
+
For large runs, use `hdf5v2:///...` and prefer `fetch_vectors`/`get_block` over per-channel expansion.
|
|
78
|
+
|
|
79
|
+
## More Demos
|
|
80
|
+
- Basic workflows, prompt labeling, and plotting (dummy + HF quickstarts): `demos/basics/`
|
|
81
|
+
- Ministral character experiment (dataset, run, visualize): `demos/ministral_characters_experiment/`
|
|
82
|
+
- Ministral-3 end-to-end demo: `demos/llm/ministral3/`
|
|
83
|
+
|
|
84
|
+
## Learn More
|
|
85
|
+
See [GUIDE.md](GUIDE.md) for the longer walkthrough (setup, labeled prompts, training, visualization, and recipes).
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
interlatent/__init__.py
|
|
5
|
+
interlatent/hooks.py
|
|
6
|
+
interlatent/metrics.py
|
|
7
|
+
interlatent/schema.py
|
|
8
|
+
interlatent.egg-info/PKG-INFO
|
|
9
|
+
interlatent.egg-info/SOURCES.txt
|
|
10
|
+
interlatent.egg-info/dependency_links.txt
|
|
11
|
+
interlatent.egg-info/requires.txt
|
|
12
|
+
interlatent.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
gymnasium==1.2.2
|
|
2
|
+
stable_baselines3==2.7.0
|
|
3
|
+
matplotlib==3.6.0
|
|
4
|
+
pydantic==2.12.5
|
|
5
|
+
pytest==7.1.1
|
|
6
|
+
h5py==3.11.0
|
|
7
|
+
datasets==4.4.1
|
|
8
|
+
transformers>=4.57.3
|
|
9
|
+
accelerate>=0.26.0
|
|
10
|
+
torch==2.9.1
|
|
11
|
+
torchvision==0.24.1
|
|
12
|
+
numpy==1.26.4
|
|
13
|
+
|
|
14
|
+
[dev]
|
|
15
|
+
gymnasium
|
|
16
|
+
pytest
|
|
17
|
+
|
|
18
|
+
[dev:platform_machine != "arm64"]
|
|
19
|
+
torch==2.2.2
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
interlatent
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=64", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "interlatent"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Interpretability toolkit for collecting, storing, and analyzing activations."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
license = {text = "MIT"}
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Interlatent Contributors"}
|
|
14
|
+
]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"Programming Language :: Python :: 3.11",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
"Operating System :: OS Independent",
|
|
20
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
21
|
+
]
|
|
22
|
+
dependencies = [
|
|
23
|
+
"gymnasium==1.2.2",
|
|
24
|
+
"stable_baselines3==2.7.0",
|
|
25
|
+
"matplotlib==3.6.0",
|
|
26
|
+
"pydantic==2.12.5",
|
|
27
|
+
"pytest==7.1.1",
|
|
28
|
+
"h5py==3.11.0",
|
|
29
|
+
"datasets==4.4.1",
|
|
30
|
+
"transformers>=4.57.3",
|
|
31
|
+
"accelerate>=0.26.0",
|
|
32
|
+
"torch==2.9.1",
|
|
33
|
+
"torchvision==0.24.1",
|
|
34
|
+
"numpy==1.26.4",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[project.urls]
|
|
38
|
+
Homepage = "https://github.com/seanpixel/interlatent"
|
|
39
|
+
Issues = "https://github.com/seanpixel/interlatent/issues"
|
|
40
|
+
|
|
41
|
+
[project.optional-dependencies]
|
|
42
|
+
dev = [
|
|
43
|
+
"torch==2.2.2 ; platform_machine != 'arm64'",
|
|
44
|
+
"gymnasium",
|
|
45
|
+
"pytest",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
[tool.setuptools.packages.find]
|
|
49
|
+
include = ["interlatent"]
|
|
50
|
+
exclude = ["tests", "scripts", "runs", "keys", "artifacts", "pretrained"]
|