archscope 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- archscope/__init__.py +30 -0
- archscope/_utils.py +113 -0
- archscope/attribute.py +201 -0
- archscope/backends.py +236 -0
- archscope/bench.py +262 -0
- archscope/circuits.py +255 -0
- archscope/cli.py +120 -0
- archscope/diff.py +212 -0
- archscope/kazdov_backend.py +141 -0
- archscope/lens.py +304 -0
- archscope/neurons.py +118 -0
- archscope/probes.py +160 -0
- archscope/sae.py +127 -0
- archscope/transfer.py +188 -0
- archscope-0.2.2.dist-info/METADATA +324 -0
- archscope-0.2.2.dist-info/RECORD +20 -0
- archscope-0.2.2.dist-info/WHEEL +5 -0
- archscope-0.2.2.dist-info/entry_points.txt +2 -0
- archscope-0.2.2.dist-info/licenses/LICENSE +17 -0
- archscope-0.2.2.dist-info/top_level.txt +1 -0
archscope/diff.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Model Diff — compare base model vs fine-tuned version, find what changed.
|
|
2
|
+
|
|
3
|
+
This is the practical question every fine-tuner asks:
|
|
4
|
+
"I fine-tuned X on my data. What did it change in the model's brain?"
|
|
5
|
+
|
|
6
|
+
Given (base_model, fine_tuned_model) with the same architecture, archscope.diff
|
|
7
|
+
returns a structured ModelDiff with:
|
|
8
|
+
- Per-layer activation drift (how much each layer's residual stream moved)
|
|
9
|
+
- Top-K shifted neurons per layer
|
|
10
|
+
- Circuit score deltas (induction, copy, etc.)
|
|
11
|
+
- Optional: probe direction drift (do trained probes still work?)
|
|
12
|
+
|
|
13
|
+
Requirements:
|
|
14
|
+
- base and fine_tuned must share architecture + tokenizer
|
|
15
|
+
- A small calibration_texts list (16-100 short texts) to measure drift on
|
|
16
|
+
|
|
17
|
+
Usage:
|
|
18
|
+
from archscope.diff import compare
|
|
19
|
+
result = compare(base, finetuned, tokenizer, calibration_texts, backend_hint="transformer")
|
|
20
|
+
print(result.to_markdown())
|
|
21
|
+
"""
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
import torch
|
|
25
|
+
|
|
26
|
+
from .backends import Backend
|
|
27
|
+
from . import circuits
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class LayerDrift:
|
|
32
|
+
"""How much one layer's residual stream changed under fine-tuning."""
|
|
33
|
+
layer: int
|
|
34
|
+
layer_name: str
|
|
35
|
+
mean_l2_delta: float # mean L2 norm of (a_ft - a_base) per token
|
|
36
|
+
relative_drift: float # mean_l2_delta / mean ||a_base||
|
|
37
|
+
cosine_similarity: float # mean cos(a_base, a_ft) — closer to 1 = less change
|
|
38
|
+
top_shifted_neurons: list[tuple[int, float]] = field(default_factory=list)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class CircuitDelta:
|
|
43
|
+
"""How a single circuit score changed."""
|
|
44
|
+
name: str
|
|
45
|
+
base_score: float
|
|
46
|
+
fine_tuned_score: float
|
|
47
|
+
delta: float # ft - base
|
|
48
|
+
relative_change: float # delta / |base| (capped at large values)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class ModelDiff:
|
|
53
|
+
"""Full diff between base and fine-tuned model."""
|
|
54
|
+
arch_family: str
|
|
55
|
+
n_layers: int
|
|
56
|
+
n_calibration_texts: int
|
|
57
|
+
layer_drift: list[LayerDrift] = field(default_factory=list)
|
|
58
|
+
circuit_deltas: list[CircuitDelta] = field(default_factory=list)
|
|
59
|
+
notes: list[str] = field(default_factory=list)
|
|
60
|
+
|
|
61
|
+
def top_changed_layers(self, k: int = 3) -> list[LayerDrift]:
|
|
62
|
+
"""Return the k layers with highest relative drift."""
|
|
63
|
+
return sorted(self.layer_drift, key=lambda d: -d.relative_drift)[:k]
|
|
64
|
+
|
|
65
|
+
def to_markdown(self) -> str:
|
|
66
|
+
lines = [f"# Model Diff — {self.arch_family} ({self.n_layers} layers)\n"]
|
|
67
|
+
lines.append(f"Calibration texts: {self.n_calibration_texts}\n")
|
|
68
|
+
|
|
69
|
+
# Layer drift table
|
|
70
|
+
lines.append("## Per-layer residual drift\n")
|
|
71
|
+
lines.append("| Layer | mean ‖Δa‖ | relative | cosine sim |")
|
|
72
|
+
lines.append("|------:|---------:|---------:|----------:|")
|
|
73
|
+
for d in self.layer_drift:
|
|
74
|
+
lines.append(
|
|
75
|
+
f"| {d.layer:>5} | {d.mean_l2_delta:>8.3f} | {d.relative_drift:>7.3%} | {d.cosine_similarity:>9.3f} |"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Top changed layers
|
|
79
|
+
lines.append("\n## Top-3 most-changed layers\n")
|
|
80
|
+
top = self.top_changed_layers(3)
|
|
81
|
+
for d in top:
|
|
82
|
+
lines.append(f"- **layer {d.layer}**: relative drift {d.relative_drift:.1%}, cosine {d.cosine_similarity:.3f}")
|
|
83
|
+
if d.top_shifted_neurons:
|
|
84
|
+
idxs = ", ".join(f"#{i}({delta:+.2f})" for i, delta in d.top_shifted_neurons[:5])
|
|
85
|
+
lines.append(f" - top-5 neurons shifted: {idxs}")
|
|
86
|
+
|
|
87
|
+
# Circuit deltas
|
|
88
|
+
if self.circuit_deltas:
|
|
89
|
+
lines.append("\n## Circuit deltas (induction, copy, concentration)\n")
|
|
90
|
+
lines.append("| Circuit | base | fine-tuned | Δ | Δ% |")
|
|
91
|
+
lines.append("|---------|-----:|-----------:|--:|---:|")
|
|
92
|
+
for c in self.circuit_deltas:
|
|
93
|
+
pct = f"{c.relative_change:+.1%}" if abs(c.base_score) > 1e-9 else "—"
|
|
94
|
+
lines.append(
|
|
95
|
+
f"| {c.name} | {c.base_score:.3f} | {c.fine_tuned_score:.3f} | {c.delta:+.3f} | {pct} |"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
if self.notes:
|
|
99
|
+
lines.append("\n## Notes\n")
|
|
100
|
+
for n in self.notes:
|
|
101
|
+
lines.append(f"- {n}")
|
|
102
|
+
return "\n".join(lines)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _activation_drift(
|
|
106
|
+
base_acts: torch.Tensor,
|
|
107
|
+
ft_acts: torch.Tensor,
|
|
108
|
+
top_k_neurons: int = 10,
|
|
109
|
+
) -> tuple[float, float, float, list[tuple[int, float]]]:
|
|
110
|
+
"""Compute drift stats between two activation tensors of shape (B, T, H) or (B, H).
|
|
111
|
+
|
|
112
|
+
Returns (mean_l2_delta, relative_drift, cosine_sim, top_k_shifted_neurons).
|
|
113
|
+
"""
|
|
114
|
+
if base_acts.shape != ft_acts.shape:
|
|
115
|
+
raise ValueError(f"Shape mismatch: {tuple(base_acts.shape)} vs {tuple(ft_acts.shape)}")
|
|
116
|
+
# Flatten leading dims; keep hidden dim
|
|
117
|
+
H = base_acts.shape[-1]
|
|
118
|
+
a = base_acts.reshape(-1, H)
|
|
119
|
+
b = ft_acts.reshape(-1, H)
|
|
120
|
+
diff = (b - a).float()
|
|
121
|
+
l2_per_token = diff.norm(dim=-1)
|
|
122
|
+
mean_l2 = float(l2_per_token.mean().item())
|
|
123
|
+
base_l2 = float(a.float().norm(dim=-1).mean().item()) + 1e-9
|
|
124
|
+
relative = mean_l2 / base_l2
|
|
125
|
+
# Cosine sim (per-token, then averaged)
|
|
126
|
+
cos = torch.nn.functional.cosine_similarity(a.float(), b.float(), dim=-1)
|
|
127
|
+
cosine_sim = float(cos.mean().item())
|
|
128
|
+
# Per-neuron drift: mean absolute delta per channel
|
|
129
|
+
per_neuron_delta = diff.abs().mean(dim=0) # (H,)
|
|
130
|
+
topk = torch.topk(per_neuron_delta, k=min(top_k_neurons, H))
|
|
131
|
+
top_shifted = [(int(i), float(v)) for i, v in zip(topk.indices.tolist(), topk.values.tolist())]
|
|
132
|
+
return mean_l2, relative, cosine_sim, top_shifted
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def compare(
|
|
136
|
+
base_model,
|
|
137
|
+
fine_tuned_model,
|
|
138
|
+
tokenizer,
|
|
139
|
+
calibration_texts: list[str],
|
|
140
|
+
backend_hint: str | None = None,
|
|
141
|
+
run_circuits: bool = True,
|
|
142
|
+
max_length: int = 32,
|
|
143
|
+
device: str = "cpu",
|
|
144
|
+
) -> ModelDiff:
|
|
145
|
+
"""Compare base vs fine-tuned model on calibration texts.
|
|
146
|
+
|
|
147
|
+
Both models must share the same architecture and tokenizer.
|
|
148
|
+
"""
|
|
149
|
+
base_backend = Backend.for_model(base_model, hint=backend_hint)
|
|
150
|
+
ft_backend = Backend.for_model(fine_tuned_model, hint=backend_hint)
|
|
151
|
+
|
|
152
|
+
layer_names = [n for n in base_backend.layer_names() if ".residual" in n]
|
|
153
|
+
ft_layer_names = [n for n in ft_backend.layer_names() if ".residual" in n]
|
|
154
|
+
if layer_names != ft_layer_names:
|
|
155
|
+
raise ValueError("base and fine_tuned have different layer structure — "
|
|
156
|
+
"they must share architecture")
|
|
157
|
+
|
|
158
|
+
# Tokenize calibration
|
|
159
|
+
enc = tokenizer(calibration_texts, return_tensors="pt", padding=True,
|
|
160
|
+
truncation=True, max_length=max_length)
|
|
161
|
+
inputs = {"input_ids": enc["input_ids"].to(device)}
|
|
162
|
+
if "attention_mask" in enc:
|
|
163
|
+
inputs["attention_mask"] = enc["attention_mask"].to(device)
|
|
164
|
+
|
|
165
|
+
# Need attention_mask as bool for kazdov backend
|
|
166
|
+
if backend_hint == "kazdov" and "attention_mask" in inputs:
|
|
167
|
+
inputs["attention_mask"] = inputs["attention_mask"].bool()
|
|
168
|
+
|
|
169
|
+
with torch.no_grad():
|
|
170
|
+
base_records = base_backend.extract(inputs, layers=layer_names)
|
|
171
|
+
ft_records = ft_backend.extract(inputs, layers=layer_names)
|
|
172
|
+
|
|
173
|
+
diff = ModelDiff(
|
|
174
|
+
arch_family=backend_hint or "auto",
|
|
175
|
+
n_layers=len(layer_names),
|
|
176
|
+
n_calibration_texts=len(calibration_texts),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
for base_rec, ft_rec in zip(base_records, ft_records):
|
|
180
|
+
mean_l2, rel, cos_sim, top_neurons = _activation_drift(
|
|
181
|
+
base_rec.activations, ft_rec.activations,
|
|
182
|
+
)
|
|
183
|
+
idx = int(base_rec.layer_name.split("_")[1].split(".")[0])
|
|
184
|
+
diff.layer_drift.append(LayerDrift(
|
|
185
|
+
layer=idx, layer_name=base_rec.layer_name,
|
|
186
|
+
mean_l2_delta=mean_l2,
|
|
187
|
+
relative_drift=rel,
|
|
188
|
+
cosine_similarity=cos_sim,
|
|
189
|
+
top_shifted_neurons=top_neurons,
|
|
190
|
+
))
|
|
191
|
+
|
|
192
|
+
# Circuit deltas
|
|
193
|
+
if run_circuits:
|
|
194
|
+
try:
|
|
195
|
+
base_circs = circuits.run_all_circuits(base_model, tokenizer=tokenizer, device=device)
|
|
196
|
+
ft_circs = circuits.run_all_circuits(fine_tuned_model, tokenizer=tokenizer, device=device)
|
|
197
|
+
for name in base_circs:
|
|
198
|
+
if name not in ft_circs:
|
|
199
|
+
continue
|
|
200
|
+
bs = base_circs[name].score
|
|
201
|
+
fs = ft_circs[name].score
|
|
202
|
+
delta = fs - bs
|
|
203
|
+
rel = delta / (abs(bs) + 1e-9)
|
|
204
|
+
diff.circuit_deltas.append(CircuitDelta(
|
|
205
|
+
name=name,
|
|
206
|
+
base_score=bs, fine_tuned_score=fs,
|
|
207
|
+
delta=delta, relative_change=rel,
|
|
208
|
+
))
|
|
209
|
+
except Exception as e:
|
|
210
|
+
diff.notes.append(f"circuit comparison error: {str(e)[:80]}")
|
|
211
|
+
|
|
212
|
+
return diff
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Backend for kazdov-α (and related Kazdov family models).
|
|
2
|
+
|
|
3
|
+
Kazdov-α is a transformer-style decoder LM with hybrid attention (MoBE-BCN
|
|
4
|
+
mixture of bilinear experts + standard MHA in parallel). Architecturally
|
|
5
|
+
closer to standard transformer than to pure RNN/SSM — but the BCN attention
|
|
6
|
+
branch makes it a distinct architecture family for cross-arch interp.
|
|
7
|
+
|
|
8
|
+
Differences from HF transformer:
|
|
9
|
+
- No HF AutoModelForCausalLM interface (custom forward signature)
|
|
10
|
+
- Layers exposed as `model.blocks` (ModuleList)
|
|
11
|
+
- No `output_hidden_states=True` argument — we capture via forward hooks
|
|
12
|
+
- Forward signature: (input_ids, attention_mask=None, labels=None)
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
import sys
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from .backends import Backend, ActivationRecord
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
KAZDOV_REPO = Path.home() / "code" / "OriginalKazdov" / "kazdov"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _ensure_kazdov_importable():
|
|
26
|
+
"""Add kazdov repo to sys.path so we can import KazdovLM."""
|
|
27
|
+
p = str(KAZDOV_REPO)
|
|
28
|
+
if p not in sys.path:
|
|
29
|
+
sys.path.insert(0, p)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def load_kazdov_checkpoint(checkpoint_path: str | Path, device: str = "cpu"):
|
|
33
|
+
"""Load kazdov-α from a checkpoint directory.
|
|
34
|
+
|
|
35
|
+
Expects: config.json + final.pt (or latest.pt) in the directory.
|
|
36
|
+
Returns: (model in eval mode, tokenizer wrapper).
|
|
37
|
+
"""
|
|
38
|
+
_ensure_kazdov_importable()
|
|
39
|
+
from kazdov.kazdov_lm import KazdovLM
|
|
40
|
+
import json
|
|
41
|
+
|
|
42
|
+
ckpt_dir = Path(checkpoint_path)
|
|
43
|
+
config = json.loads((ckpt_dir / "config.json").read_text())
|
|
44
|
+
model_cfg = config["model_cfg"]
|
|
45
|
+
|
|
46
|
+
model = KazdovLM(
|
|
47
|
+
vocab_size=model_cfg["vocab_size"],
|
|
48
|
+
d_model=model_cfg["d_model"],
|
|
49
|
+
n_layers=model_cfg["n_layers"],
|
|
50
|
+
n_heads=model_cfg["n_heads"],
|
|
51
|
+
rank=model_cfg["rank"],
|
|
52
|
+
mlp_dim=model_cfg.get("mlp_dim"),
|
|
53
|
+
max_len=model_cfg.get("max_len", 256),
|
|
54
|
+
use_trilinear=model_cfg.get("use_trilinear", False),
|
|
55
|
+
use_bi_bcn=model_cfg.get("use_bi_bcn", False),
|
|
56
|
+
use_hybrid_mha=model_cfg.get("use_hybrid_mha", True),
|
|
57
|
+
use_mobe=model_cfg.get("use_mobe", False),
|
|
58
|
+
n_experts=model_cfg.get("n_experts", 1),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Try final.pt then latest.pt
|
|
62
|
+
for fname in ("final.pt", "latest.pt"):
|
|
63
|
+
f = ckpt_dir / fname
|
|
64
|
+
if f.exists():
|
|
65
|
+
state = torch.load(f, map_location=device, weights_only=False)
|
|
66
|
+
if isinstance(state, dict) and "model" in state:
|
|
67
|
+
state = state["model"]
|
|
68
|
+
model.load_state_dict(state, strict=False)
|
|
69
|
+
break
|
|
70
|
+
else:
|
|
71
|
+
raise FileNotFoundError(f"No final.pt or latest.pt in {ckpt_dir}")
|
|
72
|
+
|
|
73
|
+
model.to(device).eval()
|
|
74
|
+
|
|
75
|
+
# Tokenizer: kazdov used GPT-2 tokenizer per memory
|
|
76
|
+
from transformers import GPT2Tokenizer
|
|
77
|
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
78
|
+
if tokenizer.pad_token is None:
|
|
79
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
80
|
+
|
|
81
|
+
return model, tokenizer
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@Backend.register("kazdov")
|
|
85
|
+
class KazdovBackend(Backend):
|
|
86
|
+
"""Backend for kazdov-family models (KazdovLM, MoBE-BCN variants).
|
|
87
|
+
|
|
88
|
+
Uses forward hooks to capture residual stream after each KazdovBlock,
|
|
89
|
+
since the model doesn't expose output_hidden_states.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def layer_names(self) -> list[str]:
|
|
93
|
+
n_layers = len(self.model.blocks)
|
|
94
|
+
return [f"layer_{i}.residual" for i in range(n_layers)]
|
|
95
|
+
|
|
96
|
+
def extract(self, inputs, layers=None):
|
|
97
|
+
layers = layers or self.layer_names()
|
|
98
|
+
captures: dict[str, torch.Tensor] = {}
|
|
99
|
+
|
|
100
|
+
# Register a forward hook on each requested block.
|
|
101
|
+
hooks = []
|
|
102
|
+
for layer_name in layers:
|
|
103
|
+
idx = int(layer_name.split("_")[1].split(".")[0])
|
|
104
|
+
if idx >= len(self.model.blocks):
|
|
105
|
+
continue
|
|
106
|
+
block = self.model.blocks[idx]
|
|
107
|
+
|
|
108
|
+
def make_hook(name):
|
|
109
|
+
def hook(module, inp, out):
|
|
110
|
+
tensor = out if isinstance(out, torch.Tensor) else out[0]
|
|
111
|
+
captures[name] = tensor.detach()
|
|
112
|
+
return hook
|
|
113
|
+
hooks.append(block.register_forward_hook(make_hook(layer_name)))
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
# Kazdov forward signature: model(input_ids, attention_mask=None)
|
|
117
|
+
with torch.no_grad():
|
|
118
|
+
if isinstance(inputs, dict):
|
|
119
|
+
input_ids = inputs["input_ids"]
|
|
120
|
+
attn = inputs.get("attention_mask")
|
|
121
|
+
else:
|
|
122
|
+
input_ids = inputs
|
|
123
|
+
attn = None
|
|
124
|
+
self.model(input_ids, attention_mask=attn)
|
|
125
|
+
finally:
|
|
126
|
+
for h in hooks:
|
|
127
|
+
h.remove()
|
|
128
|
+
|
|
129
|
+
records = []
|
|
130
|
+
for layer_name in layers:
|
|
131
|
+
if layer_name not in captures:
|
|
132
|
+
continue
|
|
133
|
+
records.append(ActivationRecord(
|
|
134
|
+
layer_name=layer_name,
|
|
135
|
+
activations=captures[layer_name],
|
|
136
|
+
meta={"kind": "residual", "arch": "kazdov-mobe-bcn"},
|
|
137
|
+
))
|
|
138
|
+
return records
|
|
139
|
+
|
|
140
|
+
def hidden_dim(self, layer_name: str) -> int:
|
|
141
|
+
return self.model.d_model
|
archscope/lens.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""Logit lens and Tuned lens — project intermediate residual stream into vocab space.
|
|
2
|
+
|
|
3
|
+
The logit lens (Nostalgebraist, 2020) applies the model's own final norm + unembedding
|
|
4
|
+
to every layer's residual stream, revealing "what the model would predict if forced
|
|
5
|
+
to commit at this layer." It tells you which layers do the work of forming the final
|
|
6
|
+
prediction.
|
|
7
|
+
|
|
8
|
+
The tuned lens (Belrose et al, 2023) learns per-layer affine transformations that
|
|
9
|
+
correct for representation drift, producing a more faithful intermediate decoding.
|
|
10
|
+
|
|
11
|
+
Both work on any architecture exposing residual stream via Backend.extract — i.e.
|
|
12
|
+
transformer, mamba, kazdov, custom recurrent.
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
|
|
19
|
+
from .backends import Backend
|
|
20
|
+
from ._utils import resolve_unembedding, resolve_final_norm
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class LayerPrediction:
|
|
25
|
+
"""What the lens says at one layer."""
|
|
26
|
+
layer: int
|
|
27
|
+
layer_name: str
|
|
28
|
+
top_tokens: list[tuple[int, str, float]] # (token_id, token_str, prob)
|
|
29
|
+
target_prob: float | None = None # if target_token provided, prob at this layer
|
|
30
|
+
target_rank: int | None = None # rank of target token in this layer's distribution
|
|
31
|
+
entropy: float = 0.0 # confidence (low = peaked, high = uncertain)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class LensResult:
|
|
36
|
+
"""Output of a lens pass over multiple layers."""
|
|
37
|
+
prompt: str
|
|
38
|
+
target_token: str | None = None
|
|
39
|
+
target_token_id: int | None = None
|
|
40
|
+
layers: list[LayerPrediction] = field(default_factory=list)
|
|
41
|
+
method: str = "logit_lens"
|
|
42
|
+
|
|
43
|
+
def to_markdown(self) -> str:
|
|
44
|
+
"""Quick formatted display."""
|
|
45
|
+
lines = [f"### {self.method} on `{self.prompt}`"]
|
|
46
|
+
if self.target_token:
|
|
47
|
+
lines.append(f"Target: `{self.target_token}` (id={self.target_token_id})")
|
|
48
|
+
lines.append("")
|
|
49
|
+
lines.append("| Layer | top-1 token | top-1 prob | target prob | rank | entropy |")
|
|
50
|
+
lines.append("|-------|-------------|-----------:|------------:|-----:|--------:|")
|
|
51
|
+
for lp in self.layers:
|
|
52
|
+
if lp.top_tokens:
|
|
53
|
+
top_id, top_str, top_p = lp.top_tokens[0]
|
|
54
|
+
else:
|
|
55
|
+
top_str, top_p = "?", float("nan")
|
|
56
|
+
tgt_p = f"{lp.target_prob:.3f}" if lp.target_prob is not None else "—"
|
|
57
|
+
tgt_r = f"{lp.target_rank}" if lp.target_rank is not None else "—"
|
|
58
|
+
lines.append(
|
|
59
|
+
f"| {lp.layer:>2} | `{top_str[:20]}` | {top_p:.3f} | {tgt_p} | {tgt_r} | {lp.entropy:.3f} |"
|
|
60
|
+
)
|
|
61
|
+
return "\n".join(lines)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _logits_to_pred(
|
|
65
|
+
logits: torch.Tensor,
|
|
66
|
+
tokenizer,
|
|
67
|
+
target_id: int | None = None,
|
|
68
|
+
top_k: int = 5,
|
|
69
|
+
) -> tuple[list[tuple[int, str, float]], float, int | None, float]:
|
|
70
|
+
"""Convert logit vector → (top_k_tokens, target_prob, target_rank, entropy)."""
|
|
71
|
+
probs = torch.softmax(logits.float(), dim=-1)
|
|
72
|
+
top_p, top_i = probs.topk(top_k)
|
|
73
|
+
top_tokens = []
|
|
74
|
+
for p, i in zip(top_p.tolist(), top_i.tolist()):
|
|
75
|
+
try:
|
|
76
|
+
tok_str = tokenizer.decode([i])
|
|
77
|
+
except Exception:
|
|
78
|
+
tok_str = f"<id={i}>"
|
|
79
|
+
top_tokens.append((i, tok_str, p))
|
|
80
|
+
if target_id is not None:
|
|
81
|
+
target_prob = float(probs[target_id].item())
|
|
82
|
+
target_rank = int((probs > probs[target_id]).sum().item())
|
|
83
|
+
else:
|
|
84
|
+
target_prob = None
|
|
85
|
+
target_rank = None
|
|
86
|
+
# Shannon entropy in nats
|
|
87
|
+
entropy = float(-(probs.clamp(min=1e-12) * probs.clamp(min=1e-12).log()).sum().item())
|
|
88
|
+
return top_tokens, target_prob, target_rank, entropy
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def logit_lens(
|
|
92
|
+
model,
|
|
93
|
+
tokenizer,
|
|
94
|
+
prompt: str,
|
|
95
|
+
target_token: str | None = None,
|
|
96
|
+
layers: list[int] | None = None,
|
|
97
|
+
backend_hint: str | None = None,
|
|
98
|
+
top_k: int = 5,
|
|
99
|
+
device: str = "cpu",
|
|
100
|
+
) -> LensResult:
|
|
101
|
+
"""Apply model's own final-norm + unembedding to each layer's residual stream.
|
|
102
|
+
|
|
103
|
+
Returns per-layer predictions: top-k tokens + entropy + (optional) target rank/prob.
|
|
104
|
+
|
|
105
|
+
The classic question this answers: "at which layer does the model commit to the
|
|
106
|
+
final token?" Often you'll see entropy decrease and target rank climb across layers.
|
|
107
|
+
"""
|
|
108
|
+
backend = Backend.for_model(model, hint=backend_hint)
|
|
109
|
+
norm = resolve_final_norm(model)
|
|
110
|
+
unembed = resolve_unembedding(model)
|
|
111
|
+
if unembed is None:
|
|
112
|
+
raise ValueError("Could not locate model's lm_head / unembedding module.")
|
|
113
|
+
|
|
114
|
+
# Tokenize prompt (handle both HF tokenizer + kazdov-style inputs)
|
|
115
|
+
enc = tokenizer(prompt, return_tensors="pt")
|
|
116
|
+
if hasattr(enc, "input_ids"):
|
|
117
|
+
inputs = {"input_ids": enc.input_ids.to(device)}
|
|
118
|
+
else:
|
|
119
|
+
inputs = {"input_ids": enc["input_ids"].to(device)}
|
|
120
|
+
|
|
121
|
+
target_id = None
|
|
122
|
+
if target_token is not None:
|
|
123
|
+
target_ids = tokenizer(target_token, add_special_tokens=False).input_ids
|
|
124
|
+
target_id = target_ids[0] if target_ids else None
|
|
125
|
+
|
|
126
|
+
# Pick layers
|
|
127
|
+
all_layer_names = [n for n in backend.layer_names() if ".residual" in n]
|
|
128
|
+
if layers is None:
|
|
129
|
+
layer_names = all_layer_names
|
|
130
|
+
else:
|
|
131
|
+
layer_names = [f"layer_{i}.residual" for i in layers if f"layer_{i}.residual" in all_layer_names]
|
|
132
|
+
|
|
133
|
+
# Extract residual streams
|
|
134
|
+
with torch.no_grad():
|
|
135
|
+
records = backend.extract(inputs, layers=layer_names)
|
|
136
|
+
|
|
137
|
+
result = LensResult(
|
|
138
|
+
prompt=prompt,
|
|
139
|
+
target_token=target_token,
|
|
140
|
+
target_token_id=target_id,
|
|
141
|
+
method="logit_lens",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
for rec in records:
|
|
145
|
+
# rec.activations shape: (B, T, hidden) — take last-token residual
|
|
146
|
+
last = rec.activations[:, -1, :] # (B, hidden)
|
|
147
|
+
# Apply final norm if present (transformer family). For kazdov it's `model.ln_f`.
|
|
148
|
+
if norm is not None:
|
|
149
|
+
try:
|
|
150
|
+
last = norm(last)
|
|
151
|
+
except Exception:
|
|
152
|
+
pass
|
|
153
|
+
# Project to vocab
|
|
154
|
+
logits = unembed(last)[0] # (vocab,)
|
|
155
|
+
top_tokens, tgt_p, tgt_r, ent = _logits_to_pred(logits, tokenizer, target_id, top_k)
|
|
156
|
+
idx = int(rec.layer_name.split("_")[1].split(".")[0])
|
|
157
|
+
result.layers.append(LayerPrediction(
|
|
158
|
+
layer=idx, layer_name=rec.layer_name,
|
|
159
|
+
top_tokens=top_tokens,
|
|
160
|
+
target_prob=tgt_p, target_rank=tgt_r,
|
|
161
|
+
entropy=ent,
|
|
162
|
+
))
|
|
163
|
+
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# -------------------- TUNED LENS --------------------
|
|
168
|
+
|
|
169
|
+
class TunedLens(nn.Module):
|
|
170
|
+
"""Learned per-layer affine transformations into the unembedding space.
|
|
171
|
+
|
|
172
|
+
For each layer, learn a transformation: residual → adjusted_residual that, when
|
|
173
|
+
fed through the model's final_norm + unembedding, better matches the final-layer
|
|
174
|
+
distribution. Following Belrose et al 2023.
|
|
175
|
+
|
|
176
|
+
Usage:
|
|
177
|
+
tl = TunedLens.fit(model, tokenizer, calibration_texts, backend_hint="transformer")
|
|
178
|
+
result = tl.predict(model, tokenizer, prompt="...", backend_hint="transformer")
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(self, model, n_layers: int, hidden_dim: int, backend_hint: str | None = None):
|
|
182
|
+
super().__init__()
|
|
183
|
+
self.n_layers = n_layers
|
|
184
|
+
self.hidden_dim = hidden_dim
|
|
185
|
+
self.backend_hint = backend_hint
|
|
186
|
+
# One Linear per layer; identity initialization (so untrained tuned lens ≈ logit lens)
|
|
187
|
+
self.translators = nn.ModuleList([
|
|
188
|
+
nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
189
|
+
for _ in range(n_layers)
|
|
190
|
+
])
|
|
191
|
+
for t in self.translators:
|
|
192
|
+
nn.init.eye_(t.weight)
|
|
193
|
+
nn.init.zeros_(t.bias)
|
|
194
|
+
self._model_ref = [model] # avoid registering as submodule
|
|
195
|
+
|
|
196
|
+
@classmethod
|
|
197
|
+
def fit(
|
|
198
|
+
cls,
|
|
199
|
+
model,
|
|
200
|
+
tokenizer,
|
|
201
|
+
calibration_texts: list[str],
|
|
202
|
+
backend_hint: str | None = None,
|
|
203
|
+
epochs: int = 30,
|
|
204
|
+
lr: float = 1e-3,
|
|
205
|
+
max_len: int = 64,
|
|
206
|
+
device: str = "cpu",
|
|
207
|
+
) -> "TunedLens":
|
|
208
|
+
"""Train per-layer translators to match the model's own final distribution."""
|
|
209
|
+
backend = Backend.for_model(model, hint=backend_hint)
|
|
210
|
+
layer_names = [n for n in backend.layer_names() if ".residual" in n]
|
|
211
|
+
n_layers = len(layer_names)
|
|
212
|
+
# Get hidden dim from first layer
|
|
213
|
+
hidden_dim = backend.hidden_dim(layer_names[0])
|
|
214
|
+
|
|
215
|
+
tl = cls(model, n_layers=n_layers, hidden_dim=hidden_dim, backend_hint=backend_hint).to(device)
|
|
216
|
+
norm = resolve_final_norm(model)
|
|
217
|
+
unembed = resolve_unembedding(model)
|
|
218
|
+
|
|
219
|
+
opt = torch.optim.AdamW(tl.translators.parameters(), lr=lr)
|
|
220
|
+
|
|
221
|
+
# Pre-extract all activations + target logits once
|
|
222
|
+
enc = tokenizer(calibration_texts, return_tensors="pt", padding=True,
|
|
223
|
+
truncation=True, max_length=max_len)
|
|
224
|
+
inputs = {"input_ids": enc["input_ids"].to(device)}
|
|
225
|
+
with torch.no_grad():
|
|
226
|
+
records = backend.extract(inputs, layers=layer_names)
|
|
227
|
+
# Target: model's actual final logits at last position
|
|
228
|
+
final_residual = records[-1].activations[:, -1, :]
|
|
229
|
+
if norm is not None:
|
|
230
|
+
final_residual = norm(final_residual)
|
|
231
|
+
target_logits = unembed(final_residual).detach() # (B, vocab)
|
|
232
|
+
target_log_probs = torch.log_softmax(target_logits.float(), dim=-1)
|
|
233
|
+
|
|
234
|
+
for epoch in range(epochs):
|
|
235
|
+
opt.zero_grad()
|
|
236
|
+
total_loss = 0.0
|
|
237
|
+
for i, rec in enumerate(records):
|
|
238
|
+
last = rec.activations[:, -1, :].detach()
|
|
239
|
+
translated = tl.translators[i](last)
|
|
240
|
+
if norm is not None:
|
|
241
|
+
translated = norm(translated)
|
|
242
|
+
pred_logits = unembed(translated)
|
|
243
|
+
pred_log_probs = torch.log_softmax(pred_logits.float(), dim=-1)
|
|
244
|
+
# KL divergence target_log_probs || pred_log_probs
|
|
245
|
+
loss = torch.nn.functional.kl_div(pred_log_probs, target_log_probs,
|
|
246
|
+
log_target=True, reduction="batchmean")
|
|
247
|
+
total_loss = total_loss + loss
|
|
248
|
+
total_loss.backward()
|
|
249
|
+
opt.step()
|
|
250
|
+
|
|
251
|
+
tl.last_loss = float(total_loss.item() / max(n_layers, 1))
|
|
252
|
+
return tl
|
|
253
|
+
|
|
254
|
+
def predict(
|
|
255
|
+
self,
|
|
256
|
+
model,
|
|
257
|
+
tokenizer,
|
|
258
|
+
prompt: str,
|
|
259
|
+
target_token: str | None = None,
|
|
260
|
+
backend_hint: str | None = None,
|
|
261
|
+
layers: list[int] | None = None,
|
|
262
|
+
top_k: int = 5,
|
|
263
|
+
device: str = "cpu",
|
|
264
|
+
) -> LensResult:
|
|
265
|
+
"""Apply the trained tuned lens to a new prompt."""
|
|
266
|
+
backend = Backend.for_model(model, hint=backend_hint or self.backend_hint)
|
|
267
|
+
norm = resolve_final_norm(model)
|
|
268
|
+
unembed = resolve_unembedding(model)
|
|
269
|
+
|
|
270
|
+
enc = tokenizer(prompt, return_tensors="pt")
|
|
271
|
+
inputs = {"input_ids": enc["input_ids"].to(device)}
|
|
272
|
+
target_id = None
|
|
273
|
+
if target_token is not None:
|
|
274
|
+
ids = tokenizer(target_token, add_special_tokens=False).input_ids
|
|
275
|
+
target_id = ids[0] if ids else None
|
|
276
|
+
|
|
277
|
+
all_layer_names = [n for n in backend.layer_names() if ".residual" in n]
|
|
278
|
+
if layers is None:
|
|
279
|
+
layer_names = all_layer_names
|
|
280
|
+
else:
|
|
281
|
+
layer_names = [f"layer_{i}.residual" for i in layers if f"layer_{i}.residual" in all_layer_names]
|
|
282
|
+
|
|
283
|
+
with torch.no_grad():
|
|
284
|
+
records = backend.extract(inputs, layers=layer_names)
|
|
285
|
+
|
|
286
|
+
result = LensResult(prompt=prompt, target_token=target_token,
|
|
287
|
+
target_token_id=target_id, method="tuned_lens")
|
|
288
|
+
|
|
289
|
+
for rec in records:
|
|
290
|
+
idx = int(rec.layer_name.split("_")[1].split(".")[0])
|
|
291
|
+
last = rec.activations[:, -1, :]
|
|
292
|
+
with torch.no_grad():
|
|
293
|
+
translated = self.translators[idx](last)
|
|
294
|
+
if norm is not None:
|
|
295
|
+
translated = norm(translated)
|
|
296
|
+
logits = unembed(translated)[0]
|
|
297
|
+
top_tokens, tgt_p, tgt_r, ent = _logits_to_pred(logits, tokenizer, target_id, top_k)
|
|
298
|
+
result.layers.append(LayerPrediction(
|
|
299
|
+
layer=idx, layer_name=rec.layer_name,
|
|
300
|
+
top_tokens=top_tokens,
|
|
301
|
+
target_prob=tgt_p, target_rank=tgt_r,
|
|
302
|
+
entropy=ent,
|
|
303
|
+
))
|
|
304
|
+
return result
|