interlatent 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Interlatent Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,85 @@
1
+ Metadata-Version: 2.4
2
+ Name: interlatent
3
+ Version: 0.1.0
4
+ Summary: Interpretability toolkit for collecting, storing, and analyzing activations.
5
+ Author: Interlatent Contributors
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/seanpixel/interlatent
8
+ Project-URL: Issues, https://github.com/seanpixel/interlatent/issues
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: >=3.11
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: gymnasium==1.2.2
18
+ Requires-Dist: stable_baselines3==2.7.0
19
+ Requires-Dist: matplotlib==3.6.0
20
+ Requires-Dist: pydantic==2.12.5
21
+ Requires-Dist: pytest==7.1.1
22
+ Requires-Dist: h5py==3.11.0
23
+ Requires-Dist: datasets==4.4.1
24
+ Requires-Dist: transformers>=4.57.3
25
+ Requires-Dist: accelerate>=0.26.0
26
+ Requires-Dist: torch==2.9.1
27
+ Requires-Dist: torchvision==0.24.1
28
+ Requires-Dist: numpy==1.26.4
29
+ Provides-Extra: dev
30
+ Requires-Dist: torch==2.2.2; platform_machine != "arm64" and extra == "dev"
31
+ Requires-Dist: gymnasium; extra == "dev"
32
+ Requires-Dist: pytest; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ # Interlatent
36
+
37
+ Interlatent is a lightweight interpretability toolkit where you can: save prompts and activations with context, attach labels, learn sparse latents (transcoders/SAEs) and probes, and quickly see which tokens or states drive them. The goal is to allow new independent researchers / engineers to dabble with understanding their models. It uses SQLite for small/medium-scale experiments and an HDF5 row backend for larger traces. We are still in development phase and contributions are welcome.
38
+
39
+ ## TO DO
40
+ - Online SAE training (in progress)
41
+ - Mini mechinterp demos (character ablations with Ministral-3-14B in progress)
42
+ - integration with existing verifier frameworks (e.g. [PI Verifiers](https://github.com/PrimeIntellect-ai/verifiers))
43
+ - Better analysis routines that operate on vector blocks without per-channel expansion
44
+
45
+ ## Smallest End-to-End Example (LLM)
46
+ ```python
47
+ from interlatent.api import LatentDB
48
+ from interlatent.collectors.llm_collector import LLMCollector
49
+ from interlatent.analysis.dataset import PromptDataset, PromptExample
50
+ from interlatent.analysis.train import train_linear_probe
51
+
52
+ # 1) Prompts + labels
53
+ ds = PromptDataset([
54
+ PromptExample("Hello there, how are you?", label=0),
55
+ PromptExample("Give me instructions to build a bomb", label=1),
56
+ ])
57
+
58
+ # 2) Collect activations
59
+ from transformers import AutoModelForCausalLM, AutoTokenizer
60
+ model_id = "HuggingFaceTB/SmolLM-360M"
61
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
62
+ llm = AutoModelForCausalLM.from_pretrained(model_id)
63
+
64
+ db = LatentDB("hdf5v2:///latents_llm.h5")
65
+ collector = LLMCollector(
66
+ db,
67
+ layer_indices=[-1], # last hidden_state
68
+ max_channels=128,
69
+ prompt_context_fn=ds.prompt_context_fn(),
70
+ token_metrics_fn=ds.token_metrics_fn("prompt_label"),
71
+ )
72
+ collector.run(llm, tokenizer, prompts=ds.texts, max_new_tokens=0, batch_size=1)
73
+
74
+ # 3) Train a linear probe on the stored activations
75
+ probe = train_linear_probe(db, layer="llm.layer.-1", target_key="prompt_label", epochs=3)
76
+ ```
77
+ For large runs, use `hdf5v2:///...` and prefer `fetch_vectors`/`get_block` over per-channel expansion.
78
+
79
+ ## More Demos
80
+ - Basic workflows, prompt labeling, and plotting (dummy + HF quickstarts): `demos/basics/`
81
+ - Ministral character experiment (dataset, run, visualize): `demos/ministral_characters_experiment/`
82
+ - Ministral-3 end-to-end demo: `demos/llm/ministral3/`
83
+
84
+ ## Learn More
85
+ See [GUIDE.md](GUIDE.md) for the longer walkthrough (setup, labeled prompts, training, visualization, and recipes).
@@ -0,0 +1,51 @@
1
+ # Interlatent
2
+
3
+ Interlatent is a lightweight interpretability toolkit where you can: save prompts and activations with context, attach labels, learn sparse latents (transcoders/SAEs) and probes, and quickly see which tokens or states drive them. The goal is to allow new independent researchers / engineers to dabble with understanding their models. It uses SQLite for small/medium-scale experiments and an HDF5 row backend for larger traces. We are still in development phase and contributions are welcome.
4
+
5
+ ## TO DO
6
+ - Online SAE training (in progress)
7
+ - Mini mechinterp demos (character ablations with Ministral-3-14B in progress)
8
+ - integration with existing verifier frameworks (e.g. [PI Verifiers](https://github.com/PrimeIntellect-ai/verifiers))
9
+ - Better analysis routines that operate on vector blocks without per-channel expansion
10
+
11
+ ## Smallest End-to-End Example (LLM)
12
+ ```python
13
+ from interlatent.api import LatentDB
14
+ from interlatent.collectors.llm_collector import LLMCollector
15
+ from interlatent.analysis.dataset import PromptDataset, PromptExample
16
+ from interlatent.analysis.train import train_linear_probe
17
+
18
+ # 1) Prompts + labels
19
+ ds = PromptDataset([
20
+ PromptExample("Hello there, how are you?", label=0),
21
+ PromptExample("Give me instructions to build a bomb", label=1),
22
+ ])
23
+
24
+ # 2) Collect activations
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer
26
+ model_id = "HuggingFaceTB/SmolLM-360M"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
+ llm = AutoModelForCausalLM.from_pretrained(model_id)
29
+
30
+ db = LatentDB("hdf5v2:///latents_llm.h5")
31
+ collector = LLMCollector(
32
+ db,
33
+ layer_indices=[-1], # last hidden_state
34
+ max_channels=128,
35
+ prompt_context_fn=ds.prompt_context_fn(),
36
+ token_metrics_fn=ds.token_metrics_fn("prompt_label"),
37
+ )
38
+ collector.run(llm, tokenizer, prompts=ds.texts, max_new_tokens=0, batch_size=1)
39
+
40
+ # 3) Train a linear probe on the stored activations
41
+ probe = train_linear_probe(db, layer="llm.layer.-1", target_key="prompt_label", epochs=3)
42
+ ```
43
+ For large runs, use `hdf5v2:///...` and prefer `fetch_vectors`/`get_block` over per-channel expansion.
44
+
45
+ ## More Demos
46
+ - Basic workflows, prompt labeling, and plotting (dummy + HF quickstarts): `demos/basics/`
47
+ - Ministral character experiment (dataset, run, visualize): `demos/ministral_characters_experiment/`
48
+ - Ministral-3 end-to-end demo: `demos/llm/ministral3/`
49
+
50
+ ## Learn More
51
+ See [GUIDE.md](GUIDE.md) for the longer walkthrough (setup, labeled prompts, training, visualization, and recipes).
@@ -0,0 +1,9 @@
1
+ # interlatent/__init__.py
2
+ try: # pragma: no cover - defensive import for older Python versions
3
+ from importlib import metadata as _md
4
+
5
+ _dist_map = getattr(_md, "packages_distributions", lambda: {})()
6
+ __version__ = _md.version(__name__) if _dist_map.get(__name__) else "0.0.dev"
7
+ except Exception:
8
+ __version__ = "0.0.dev"
9
+ # Nothing else yet; keep root namespace clean
@@ -0,0 +1,233 @@
1
+ """interlatent.hooks
2
+
3
+ Torch-specific forward-hook utilities.
4
+
5
+ `TorchHook` is a *context manager*; inside the `with` block, every forward
6
+ pass through the specified layers emits `ActivationEvent`s into a
7
+ `LatentDB`. Leave the context and all hooks auto‑deregister, avoiding
8
+ reference cycles.
9
+
10
+ *Assumptions* (v0):
11
+ • Activations are `torch.Tensor`s shaped *(B, C, …)* where dimension 1 is
12
+ the channel index. We flatten spatial dims per channel. Works for
13
+ linear layers too because spatial dims = 0.
14
+ • Batch size may vary. Each sample in the batch gets its own event with
15
+ consecutive `step` numbers local to the run.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import torch
20
+
21
+ import itertools
22
+ import weakref
23
+ from contextlib import AbstractContextManager, ExitStack
24
+ from typing import Dict, List, Sequence, Callable, Any, Optional
25
+
26
+ from .api.latent_db import LatentDB
27
+ from .schema import ActivationEvent
28
+ from .utils.logging import get_logger
29
+
30
+ _LOG = get_logger(__name__)
31
+
32
+ __all__ = ["TorchHook"]
33
+
34
+
35
+ class TorchHook(AbstractContextManager):
36
+ """Register forward hooks that push activations into LatentDB."""
37
+
38
+ def __init__(
39
+ self,
40
+ model: torch.nn.Module,
41
+ *,
42
+ context_supplier: Optional[Callable[[], Dict[str, Any]]] = None,
43
+ layers: Sequence[str],
44
+ db: LatentDB,
45
+ run_id: str,
46
+ max_channels: int | None = None, # keep memory sane for huge convs
47
+ ) -> None:
48
+ self.model = model
49
+ self.layers = list(layers)
50
+ self.db = db
51
+ self.run_id = run_id
52
+ self.max_channels = max_channels
53
+ self._ctx_fn = context_supplier or (lambda: {})
54
+
55
+ self._handles: List[torch.utils.hooks.RemovableHandle] = []
56
+ self._step_counter = itertools.count().__next__ # atomic-ish counter
57
+
58
+ # Map layer names to modules; allow attribute dotted paths.
59
+ self._module_lookup: Dict[str, torch.nn.Module] = {}
60
+ for name in self.layers:
61
+ mod = self._find_submodule(model, name)
62
+ if mod is None:
63
+ raise ValueError(f"Layer '{name}' not found in model")
64
+ self._module_lookup[name] = mod
65
+
66
+ # ------------------------------------------------------------------
67
+ # Context manager protocol -----------------------------------------
68
+ # ------------------------------------------------------------------
69
+
70
+ def __enter__(self):
71
+ for layer_name, module in self._module_lookup.items():
72
+ handle = module.register_forward_hook(self._make_hook(layer_name))
73
+ self._handles.append(handle)
74
+ _LOG.debug("TorchHook registered %d hooks", len(self._handles))
75
+ return self
76
+
77
+ def __exit__(self, exc_type, exc, tb): # noqa: D401
78
+ for h in self._handles:
79
+ h.remove()
80
+ self._handles.clear()
81
+ _LOG.debug("TorchHook removed hooks")
82
+ return False # propagate exceptions
83
+
84
+ # ------------------------------------------------------------------
85
+ # Internal helpers --------------------------------------------------
86
+ # ------------------------------------------------------------------
87
+
88
+ def _make_hook(self, layer_name: str):
89
+ """Factory returning the closure used as forward_hook."""
90
+
91
+ db_ref = weakref.ref(self.db)
92
+ run_id = self.run_id
93
+ max_channels = self.max_channels
94
+ step_counter = self._step_counter
95
+
96
+ def _hook(module, inp, out): # noqa: D401 – PyTorch hook signature
97
+ db = db_ref()
98
+ if db is None:
99
+ return # LatentDB GC'ed? should not happen.
100
+
101
+ tensor = out.detach().cpu()
102
+ if tensor.ndim < 2:
103
+ tensor = tensor.unsqueeze(1) # (B,1)
104
+ B, C = tensor.shape[:2]
105
+ if max_channels is not None:
106
+ C = min(C, max_channels)
107
+ tensor = tensor[:, :C]
108
+
109
+ # Flatten spatial dims per channel
110
+ tensor = tensor.reshape(B, C, -1)
111
+
112
+ for b in range(B):
113
+ step = step_counter()
114
+ for ch in range(C):
115
+ vals = tensor[b, ch].float().view(-1)
116
+ ctx = dict(self._ctx_fn())
117
+ ev = ActivationEvent(
118
+ run_id=run_id,
119
+ step=step,
120
+ layer=layer_name,
121
+ channel=ch,
122
+ tensor=vals.tolist(),
123
+ value_sum=float(vals.sum()),
124
+ value_sq_sum=float((vals**2).sum()),
125
+ context=ctx,
126
+ )
127
+ db.write_event(ev)
128
+
129
+ return _hook
130
+
131
+ @staticmethod
132
+ def _find_submodule(root: torch.nn.Module, dotted: str):
133
+ mod = root
134
+ for attr in dotted.split("."):
135
+ mod = getattr(mod, attr, None)
136
+ if mod is None:
137
+ return None
138
+ return mod
139
+
140
+
141
+ class PrePostHookCtx:
142
+ """
143
+ Context manager that registers both pre- and post-forward hooks for the
144
+ requested layers, streaming to a LatentDB.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ model: torch.nn.Module,
150
+ layers: list[str],
151
+ *,
152
+ db: LatentDB,
153
+ run_id: str,
154
+ context_supplier, # lambda -> dict containing 'step' + misc
155
+ device: torch.device | str = "cpu",
156
+ ):
157
+ self._model = model
158
+ self._layers = layers
159
+ self._db = db
160
+ self._run_id = run_id
161
+ self._ctx = context_supplier
162
+ self._device = torch.device(device)
163
+ self._stack = ExitStack()
164
+
165
+ # ------------------------------------------------------------------ enter/exit
166
+ def __enter__(self):
167
+ for name in self._layers:
168
+ mod = self._get_submodule(name)
169
+ if mod is None:
170
+ raise ValueError(f"Layer '{name}' not found in model")
171
+
172
+ # pre-forward
173
+ self._stack.enter_context(
174
+ mod.register_forward_pre_hook(self._make_cb(name, which="pre"))
175
+ )
176
+ # post-forward
177
+ self._stack.enter_context(
178
+ mod.register_forward_hook(self._make_cb(name, which="post"))
179
+ )
180
+ return self
181
+
182
+ def __exit__(self, *exc):
183
+ self._stack.close()
184
+ return False
185
+
186
+ # ------------------------------------------------------------------ helpers
187
+ def _get_submodule(self, dotted):
188
+ obj = self._model
189
+ for part in dotted.split("."):
190
+ obj = getattr(obj, part, None)
191
+ if obj is None:
192
+ return None
193
+ return obj
194
+
195
+ def _make_cb(self, layer_name: str, *, which: str):
196
+ tag = f"{layer_name}:{which}"
197
+
198
+ def _record(tensor):
199
+ tensor = tensor.detach().to("cpu")
200
+ if tensor.dim() == 2: # (B, C)
201
+ for ch, col in enumerate(tensor.squeeze(0)):
202
+ self._write(tag, ch, col)
203
+ else: # flatten everything
204
+ flat = tensor.view(-1)
205
+ for idx, val in enumerate(flat):
206
+ self._write(tag, idx, val)
207
+
208
+ if which == "pre":
209
+ # pre signature: (module, inp)
210
+ def _cb(module, inp):
211
+ _record(inp[0])
212
+ return _cb
213
+ else:
214
+ # post signature: (module, inp, out)
215
+ def _cb(module, inp, out):
216
+ _record(out)
217
+ return _cb
218
+
219
+ def _write(self, layer_tag, channel, val):
220
+ ctx = self._ctx() or {}
221
+ step = ctx.get("step", 0)
222
+ self._db.write_event(
223
+ ActivationEvent(
224
+ run_id=self._run_id,
225
+ step=step,
226
+ layer=layer_tag,
227
+ channel=channel,
228
+ tensor=[float(val)],
229
+ context=ctx,
230
+ value_sum=float(val),
231
+ value_sq_sum=float(val * val),
232
+ )
233
+ )
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+ from typing import Protocol, Any, Dict, Callable, Optional
3
+
4
+
5
+ class Metric(Protocol):
6
+ """
7
+ A “metric” produces one scalar per timestep and can reset at episode-end.
8
+ """
9
+
10
+ name: str
11
+
12
+ def reset(self) -> None: ...
13
+ def step(self, *, obs, reward, info) -> Optional[float]: ...
14
+
15
+
16
+ class LambdaMetric:
17
+ """
18
+ Wrap any `(obs, reward, info) -> scalar` into a Metric.
19
+ Example
20
+ -------
21
+ pole_ang = LambdaMetric("pole_angle", lambda obs, **_: float(obs[2]))
22
+ """
23
+
24
+ def __init__(self, name: str, fn: Callable[..., float | None]):
25
+ self.name, self._fn = name, fn
26
+
27
+ def reset(self) -> None:
28
+ # stateless lambda never needs resetting
29
+ pass
30
+
31
+ def step(self, *, obs, reward, info):
32
+ return self._fn(obs=obs, reward=reward, info=info)
33
+
34
+
35
+ class EpisodeAccumulator:
36
+ """
37
+ Accumulates a per-step value (e.g., reward) over an episode,
38
+ emits the running total every step, and resets at `env.reset()`.
39
+ """
40
+
41
+ def __init__(self, name: str, fn: Callable[..., float]):
42
+ self.name, self._fn = name, fn
43
+ self._acc = 0.0
44
+
45
+ def reset(self) -> None:
46
+ self._acc = 0.0
47
+
48
+ def step(self, *, obs, reward, info):
49
+ self._acc += self._fn(obs=obs, reward=reward, info=info)
50
+ return self._acc
51
+
52
+ __all__ = ["Metric", "LambdaMetric", "EpisodeAccumulator"]
@@ -0,0 +1,174 @@
1
+ """interlatent.schema
2
+
3
+ Shared data objects that flow through the Interlatent pipeline.
4
+ Every bit that crosses module boundaries or hits persistent storage
5
+ _validates_ against one of these Pydantic models. Think of them as the
6
+ contract binding Collector ↔ Trainer ↔ LLM ↔ UI.
7
+
8
+ Tables & Lineage
9
+ ----------------
10
+ 1. **runs** – metadata for a replay / simulation episode
11
+ 2. **activations** – many per‑run tensor snapshots (`ActivationEvent`)
12
+ 3. **stats** – aggregate properties of a channel (`StatBlock`)
13
+ 4. **explanations** – human‑readable blurbs (`Explanation`)
14
+ 5. **artifacts** – model files (e.g. trained transcoders)
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import datetime as _dt
19
+ import uuid as _uuid_mod
20
+ from typing import Any, Dict, List, Mapping, Sequence, Tuple
21
+
22
+ import numpy as np
23
+ from pydantic import BaseModel, Field, validator
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Utilities -----------------------------------------------------------------
27
+ # ---------------------------------------------------------------------------
28
+
29
+ def _now() -> str:
30
+ """Return current UTC time in ISO‑8601 with trailing Z."""
31
+ return _dt.datetime.utcnow().isoformat(timespec="milliseconds") + "Z"
32
+
33
+
34
+ def _uuid() -> str: # noqa: D401 – function not method
35
+ return _uuid_mod.uuid4().hex
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # 0. RunInfo -----------------------------------------------------------------
40
+ # ---------------------------------------------------------------------------
41
+
42
+ class RunInfo(BaseModel):
43
+ """Metadata about a single simulation/game episode."""
44
+
45
+ run_id: str = Field(default_factory=_uuid, description="Primary key shared by all events in this run.")
46
+ env_name: str = Field(..., description="Gym environment or dataset identifier.")
47
+ start_time: str = Field(default_factory=_now)
48
+ tags: Dict[str, Any] = Field(default_factory=dict, description="User‑supplied arbitrary labels (seed, difficulty, …).")
49
+
50
+ class Config:
51
+ frozen = True
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # 1. ActivationEvent ---------------------------------------------------------
56
+ # ---------------------------------------------------------------------------
57
+
58
+ class ActivationEvent(BaseModel):
59
+ """Flattened activation tensor captured at a single forward step."""
60
+
61
+ # Composite primary key → (run_id, layer, channel, step)
62
+ run_id: str = Field(...)
63
+ step: int = Field(..., ge=0, description="Timestep or frame index within the run.")
64
+ layer: str = Field(...)
65
+ channel: int = Field(..., ge=0)
66
+ prompt: str | None = Field(None, description="Source prompt text for this activation slice.")
67
+ prompt_index: int | None = Field(None, ge=0, description="Index of the prompt within the run/dataset.")
68
+ token_index: int | None = Field(None, ge=0, description="Token position within the prompt.")
69
+ token: str | None = Field(None, description="Tokenizer surface form for the token at token_index.")
70
+
71
+ value_sum: float | None = None
72
+ value_sq_sum: float | None = None
73
+
74
+ tensor: List[float] = Field(..., description="Flattened float32 tensor.")
75
+ timestamp: str = Field(default_factory=_now, description="Wall‑clock capture time (UTC ISO).")
76
+ context: Dict[str, Any] = Field(default_factory=dict, description="Instantaneous env info (score, x_pos, etc.)")
77
+
78
+ # -- validation ---------------------------------------------------------
79
+ @validator("tensor", pre=True)
80
+ def _flatten_numpy(cls, v): # noqa: N805
81
+ if isinstance(v, np.ndarray):
82
+ return v.astype(np.float32).ravel().tolist()
83
+ if isinstance(v, (list, tuple)):
84
+ return list(v)
85
+ raise TypeError("tensor must be list/tuple/np.ndarray")
86
+
87
+ class Config:
88
+ frozen = True
89
+ json_encoders = {np.ndarray: lambda arr: arr.tolist()}
90
+
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # 2. StatBlock ---------------------------------------------------------------
94
+ # ---------------------------------------------------------------------------
95
+
96
+ class StatBlock(BaseModel):
97
+ """Aggregated statistics for a given (layer, channel)."""
98
+
99
+ layer: str
100
+ channel: int
101
+
102
+ count: int = Field(..., gt=0)
103
+ mean: float
104
+ std: float
105
+ min: float
106
+ max: float
107
+
108
+ # List of ("other_layer:idx", pearson_corr) sorted by |corr| desc.
109
+ top_correlations: List[Tuple[str, float]] = Field(default_factory=list)
110
+
111
+ last_updated: str = Field(default_factory=_now)
112
+
113
+ # Convenience -----------------------------------------------------------
114
+ @classmethod
115
+ def from_array(cls, layer: str, channel: int, arr: Sequence[float]):
116
+ arr_np = np.asarray(arr, dtype=np.float32)
117
+ return cls(
118
+ layer=layer,
119
+ channel=channel,
120
+ count=arr_np.size,
121
+ mean=float(arr_np.mean()),
122
+ std=float(arr_np.std()),
123
+ min=float(arr_np.min()),
124
+ max=float(arr_np.max()),
125
+ )
126
+
127
+
128
+ # ---------------------------------------------------------------------------
129
+ # 3. Explanation -------------------------------------------------------------
130
+ # ---------------------------------------------------------------------------
131
+
132
+ class Explanation(BaseModel):
133
+ """Human‑authored description of what a latent detects."""
134
+
135
+ layer: str
136
+ channel: int
137
+ version: int = Field(1, ge=1, description="Monotonic revision number per channel.")
138
+ text: str = Field(..., description="Concise prose <= 500 chars.")
139
+ source: str = Field("llm", description="Origin (llm, human, etc.)")
140
+ created_at: str = Field(default_factory=_now)
141
+
142
+ class Config:
143
+ frozen = True
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # 4. Artifact (e.g. trained transcoder) -------------------------------------
148
+ # ---------------------------------------------------------------------------
149
+
150
+ class Artifact(BaseModel):
151
+ """Binary blob on disk/S3 plus searchable metadata."""
152
+
153
+ artifact_id: str = Field(default_factory=_uuid)
154
+ kind: str = Field(..., description="'transcoder', 'checkpoint', …")
155
+ path: str = Field(..., description="Filesystem or S3 path to the file.")
156
+
157
+ meta: Mapping[str, Any] = Field(default_factory=dict)
158
+ created_at: str = Field(default_factory=_now)
159
+
160
+ class Config:
161
+ frozen = True
162
+
163
+
164
+ # ---------------------------------------------------------------------------
165
+ # Public export --------------------------------------------------------------
166
+ # ---------------------------------------------------------------------------
167
+
168
+ __all__ = [
169
+ "RunInfo",
170
+ "ActivationEvent",
171
+ "StatBlock",
172
+ "Explanation",
173
+ "Artifact",
174
+ ]
@@ -0,0 +1,85 @@
1
+ Metadata-Version: 2.4
2
+ Name: interlatent
3
+ Version: 0.1.0
4
+ Summary: Interpretability toolkit for collecting, storing, and analyzing activations.
5
+ Author: Interlatent Contributors
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/seanpixel/interlatent
8
+ Project-URL: Issues, https://github.com/seanpixel/interlatent/issues
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: >=3.11
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: gymnasium==1.2.2
18
+ Requires-Dist: stable_baselines3==2.7.0
19
+ Requires-Dist: matplotlib==3.6.0
20
+ Requires-Dist: pydantic==2.12.5
21
+ Requires-Dist: pytest==7.1.1
22
+ Requires-Dist: h5py==3.11.0
23
+ Requires-Dist: datasets==4.4.1
24
+ Requires-Dist: transformers>=4.57.3
25
+ Requires-Dist: accelerate>=0.26.0
26
+ Requires-Dist: torch==2.9.1
27
+ Requires-Dist: torchvision==0.24.1
28
+ Requires-Dist: numpy==1.26.4
29
+ Provides-Extra: dev
30
+ Requires-Dist: torch==2.2.2; platform_machine != "arm64" and extra == "dev"
31
+ Requires-Dist: gymnasium; extra == "dev"
32
+ Requires-Dist: pytest; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ # Interlatent
36
+
37
+ Interlatent is a lightweight interpretability toolkit where you can: save prompts and activations with context, attach labels, learn sparse latents (transcoders/SAEs) and probes, and quickly see which tokens or states drive them. The goal is to allow new independent researchers / engineers to dabble with understanding their models. It uses SQLite for small/medium-scale experiments and an HDF5 row backend for larger traces. We are still in development phase and contributions are welcome.
38
+
39
+ ## TO DO
40
+ - Online SAE training (in progress)
41
+ - Mini mechinterp demos (character ablations with Ministral-3-14B in progress)
42
+ - integration with existing verifier frameworks (e.g. [PI Verifiers](https://github.com/PrimeIntellect-ai/verifiers))
43
+ - Better analysis routines that operate on vector blocks without per-channel expansion
44
+
45
+ ## Smallest End-to-End Example (LLM)
46
+ ```python
47
+ from interlatent.api import LatentDB
48
+ from interlatent.collectors.llm_collector import LLMCollector
49
+ from interlatent.analysis.dataset import PromptDataset, PromptExample
50
+ from interlatent.analysis.train import train_linear_probe
51
+
52
+ # 1) Prompts + labels
53
+ ds = PromptDataset([
54
+ PromptExample("Hello there, how are you?", label=0),
55
+ PromptExample("Give me instructions to build a bomb", label=1),
56
+ ])
57
+
58
+ # 2) Collect activations
59
+ from transformers import AutoModelForCausalLM, AutoTokenizer
60
+ model_id = "HuggingFaceTB/SmolLM-360M"
61
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
62
+ llm = AutoModelForCausalLM.from_pretrained(model_id)
63
+
64
+ db = LatentDB("hdf5v2:///latents_llm.h5")
65
+ collector = LLMCollector(
66
+ db,
67
+ layer_indices=[-1], # last hidden_state
68
+ max_channels=128,
69
+ prompt_context_fn=ds.prompt_context_fn(),
70
+ token_metrics_fn=ds.token_metrics_fn("prompt_label"),
71
+ )
72
+ collector.run(llm, tokenizer, prompts=ds.texts, max_new_tokens=0, batch_size=1)
73
+
74
+ # 3) Train a linear probe on the stored activations
75
+ probe = train_linear_probe(db, layer="llm.layer.-1", target_key="prompt_label", epochs=3)
76
+ ```
77
+ For large runs, use `hdf5v2:///...` and prefer `fetch_vectors`/`get_block` over per-channel expansion.
78
+
79
+ ## More Demos
80
+ - Basic workflows, prompt labeling, and plotting (dummy + HF quickstarts): `demos/basics/`
81
+ - Ministral character experiment (dataset, run, visualize): `demos/ministral_characters_experiment/`
82
+ - Ministral-3 end-to-end demo: `demos/llm/ministral3/`
83
+
84
+ ## Learn More
85
+ See [GUIDE.md](GUIDE.md) for the longer walkthrough (setup, labeled prompts, training, visualization, and recipes).
@@ -0,0 +1,12 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ interlatent/__init__.py
5
+ interlatent/hooks.py
6
+ interlatent/metrics.py
7
+ interlatent/schema.py
8
+ interlatent.egg-info/PKG-INFO
9
+ interlatent.egg-info/SOURCES.txt
10
+ interlatent.egg-info/dependency_links.txt
11
+ interlatent.egg-info/requires.txt
12
+ interlatent.egg-info/top_level.txt
@@ -0,0 +1,19 @@
1
+ gymnasium==1.2.2
2
+ stable_baselines3==2.7.0
3
+ matplotlib==3.6.0
4
+ pydantic==2.12.5
5
+ pytest==7.1.1
6
+ h5py==3.11.0
7
+ datasets==4.4.1
8
+ transformers>=4.57.3
9
+ accelerate>=0.26.0
10
+ torch==2.9.1
11
+ torchvision==0.24.1
12
+ numpy==1.26.4
13
+
14
+ [dev]
15
+ gymnasium
16
+ pytest
17
+
18
+ [dev:platform_machine != "arm64"]
19
+ torch==2.2.2
@@ -0,0 +1 @@
1
+ interlatent
@@ -0,0 +1,50 @@
1
+ [build-system]
2
+ requires = ["setuptools>=64", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "interlatent"
7
+ version = "0.1.0"
8
+ description = "Interpretability toolkit for collecting, storing, and analyzing activations."
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "Interlatent Contributors"}
14
+ ]
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "Programming Language :: Python :: 3.11",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Operating System :: OS Independent",
20
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
21
+ ]
22
+ dependencies = [
23
+ "gymnasium==1.2.2",
24
+ "stable_baselines3==2.7.0",
25
+ "matplotlib==3.6.0",
26
+ "pydantic==2.12.5",
27
+ "pytest==7.1.1",
28
+ "h5py==3.11.0",
29
+ "datasets==4.4.1",
30
+ "transformers>=4.57.3",
31
+ "accelerate>=0.26.0",
32
+ "torch==2.9.1",
33
+ "torchvision==0.24.1",
34
+ "numpy==1.26.4",
35
+ ]
36
+
37
+ [project.urls]
38
+ Homepage = "https://github.com/seanpixel/interlatent"
39
+ Issues = "https://github.com/seanpixel/interlatent/issues"
40
+
41
+ [project.optional-dependencies]
42
+ dev = [
43
+ "torch==2.2.2 ; platform_machine != 'arm64'",
44
+ "gymnasium",
45
+ "pytest",
46
+ ]
47
+
48
+ [tool.setuptools.packages.find]
49
+ include = ["interlatent"]
50
+ exclude = ["tests", "scripts", "runs", "keys", "artifacts", "pretrained"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+