archscope 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
archscope/diff.py ADDED
@@ -0,0 +1,212 @@
1
+ """Model Diff — compare base model vs fine-tuned version, find what changed.
2
+
3
+ This is the practical question every fine-tuner asks:
4
+ "I fine-tuned X on my data. What did it change in the model's brain?"
5
+
6
+ Given (base_model, fine_tuned_model) with the same architecture, archscope.diff
7
+ returns a structured ModelDiff with:
8
+ - Per-layer activation drift (how much each layer's residual stream moved)
9
+ - Top-K shifted neurons per layer
10
+ - Circuit score deltas (induction, copy, etc.)
11
+ - Optional: probe direction drift (do trained probes still work?)
12
+
13
+ Requirements:
14
+ - base and fine_tuned must share architecture + tokenizer
15
+ - A small calibration_texts list (16-100 short texts) to measure drift on
16
+
17
+ Usage:
18
+ from archscope.diff import compare
19
+ result = compare(base, finetuned, tokenizer, calibration_texts, backend_hint="transformer")
20
+ print(result.to_markdown())
21
+ """
22
+ from __future__ import annotations
23
+ from dataclasses import dataclass, field
24
+ import torch
25
+
26
+ from .backends import Backend
27
+ from . import circuits
28
+
29
+
30
+ @dataclass
31
+ class LayerDrift:
32
+ """How much one layer's residual stream changed under fine-tuning."""
33
+ layer: int
34
+ layer_name: str
35
+ mean_l2_delta: float # mean L2 norm of (a_ft - a_base) per token
36
+ relative_drift: float # mean_l2_delta / mean ||a_base||
37
+ cosine_similarity: float # mean cos(a_base, a_ft) — closer to 1 = less change
38
+ top_shifted_neurons: list[tuple[int, float]] = field(default_factory=list)
39
+
40
+
41
+ @dataclass
42
+ class CircuitDelta:
43
+ """How a single circuit score changed."""
44
+ name: str
45
+ base_score: float
46
+ fine_tuned_score: float
47
+ delta: float # ft - base
48
+ relative_change: float # delta / |base| (capped at large values)
49
+
50
+
51
+ @dataclass
52
+ class ModelDiff:
53
+ """Full diff between base and fine-tuned model."""
54
+ arch_family: str
55
+ n_layers: int
56
+ n_calibration_texts: int
57
+ layer_drift: list[LayerDrift] = field(default_factory=list)
58
+ circuit_deltas: list[CircuitDelta] = field(default_factory=list)
59
+ notes: list[str] = field(default_factory=list)
60
+
61
+ def top_changed_layers(self, k: int = 3) -> list[LayerDrift]:
62
+ """Return the k layers with highest relative drift."""
63
+ return sorted(self.layer_drift, key=lambda d: -d.relative_drift)[:k]
64
+
65
+ def to_markdown(self) -> str:
66
+ lines = [f"# Model Diff — {self.arch_family} ({self.n_layers} layers)\n"]
67
+ lines.append(f"Calibration texts: {self.n_calibration_texts}\n")
68
+
69
+ # Layer drift table
70
+ lines.append("## Per-layer residual drift\n")
71
+ lines.append("| Layer | mean ‖Δa‖ | relative | cosine sim |")
72
+ lines.append("|------:|---------:|---------:|----------:|")
73
+ for d in self.layer_drift:
74
+ lines.append(
75
+ f"| {d.layer:>5} | {d.mean_l2_delta:>8.3f} | {d.relative_drift:>7.3%} | {d.cosine_similarity:>9.3f} |"
76
+ )
77
+
78
+ # Top changed layers
79
+ lines.append("\n## Top-3 most-changed layers\n")
80
+ top = self.top_changed_layers(3)
81
+ for d in top:
82
+ lines.append(f"- **layer {d.layer}**: relative drift {d.relative_drift:.1%}, cosine {d.cosine_similarity:.3f}")
83
+ if d.top_shifted_neurons:
84
+ idxs = ", ".join(f"#{i}({delta:+.2f})" for i, delta in d.top_shifted_neurons[:5])
85
+ lines.append(f" - top-5 neurons shifted: {idxs}")
86
+
87
+ # Circuit deltas
88
+ if self.circuit_deltas:
89
+ lines.append("\n## Circuit deltas (induction, copy, concentration)\n")
90
+ lines.append("| Circuit | base | fine-tuned | Δ | Δ% |")
91
+ lines.append("|---------|-----:|-----------:|--:|---:|")
92
+ for c in self.circuit_deltas:
93
+ pct = f"{c.relative_change:+.1%}" if abs(c.base_score) > 1e-9 else "—"
94
+ lines.append(
95
+ f"| {c.name} | {c.base_score:.3f} | {c.fine_tuned_score:.3f} | {c.delta:+.3f} | {pct} |"
96
+ )
97
+
98
+ if self.notes:
99
+ lines.append("\n## Notes\n")
100
+ for n in self.notes:
101
+ lines.append(f"- {n}")
102
+ return "\n".join(lines)
103
+
104
+
105
+ def _activation_drift(
106
+ base_acts: torch.Tensor,
107
+ ft_acts: torch.Tensor,
108
+ top_k_neurons: int = 10,
109
+ ) -> tuple[float, float, float, list[tuple[int, float]]]:
110
+ """Compute drift stats between two activation tensors of shape (B, T, H) or (B, H).
111
+
112
+ Returns (mean_l2_delta, relative_drift, cosine_sim, top_k_shifted_neurons).
113
+ """
114
+ if base_acts.shape != ft_acts.shape:
115
+ raise ValueError(f"Shape mismatch: {tuple(base_acts.shape)} vs {tuple(ft_acts.shape)}")
116
+ # Flatten leading dims; keep hidden dim
117
+ H = base_acts.shape[-1]
118
+ a = base_acts.reshape(-1, H)
119
+ b = ft_acts.reshape(-1, H)
120
+ diff = (b - a).float()
121
+ l2_per_token = diff.norm(dim=-1)
122
+ mean_l2 = float(l2_per_token.mean().item())
123
+ base_l2 = float(a.float().norm(dim=-1).mean().item()) + 1e-9
124
+ relative = mean_l2 / base_l2
125
+ # Cosine sim (per-token, then averaged)
126
+ cos = torch.nn.functional.cosine_similarity(a.float(), b.float(), dim=-1)
127
+ cosine_sim = float(cos.mean().item())
128
+ # Per-neuron drift: mean absolute delta per channel
129
+ per_neuron_delta = diff.abs().mean(dim=0) # (H,)
130
+ topk = torch.topk(per_neuron_delta, k=min(top_k_neurons, H))
131
+ top_shifted = [(int(i), float(v)) for i, v in zip(topk.indices.tolist(), topk.values.tolist())]
132
+ return mean_l2, relative, cosine_sim, top_shifted
133
+
134
+
135
+ def compare(
136
+ base_model,
137
+ fine_tuned_model,
138
+ tokenizer,
139
+ calibration_texts: list[str],
140
+ backend_hint: str | None = None,
141
+ run_circuits: bool = True,
142
+ max_length: int = 32,
143
+ device: str = "cpu",
144
+ ) -> ModelDiff:
145
+ """Compare base vs fine-tuned model on calibration texts.
146
+
147
+ Both models must share the same architecture and tokenizer.
148
+ """
149
+ base_backend = Backend.for_model(base_model, hint=backend_hint)
150
+ ft_backend = Backend.for_model(fine_tuned_model, hint=backend_hint)
151
+
152
+ layer_names = [n for n in base_backend.layer_names() if ".residual" in n]
153
+ ft_layer_names = [n for n in ft_backend.layer_names() if ".residual" in n]
154
+ if layer_names != ft_layer_names:
155
+ raise ValueError("base and fine_tuned have different layer structure — "
156
+ "they must share architecture")
157
+
158
+ # Tokenize calibration
159
+ enc = tokenizer(calibration_texts, return_tensors="pt", padding=True,
160
+ truncation=True, max_length=max_length)
161
+ inputs = {"input_ids": enc["input_ids"].to(device)}
162
+ if "attention_mask" in enc:
163
+ inputs["attention_mask"] = enc["attention_mask"].to(device)
164
+
165
+ # Need attention_mask as bool for kazdov backend
166
+ if backend_hint == "kazdov" and "attention_mask" in inputs:
167
+ inputs["attention_mask"] = inputs["attention_mask"].bool()
168
+
169
+ with torch.no_grad():
170
+ base_records = base_backend.extract(inputs, layers=layer_names)
171
+ ft_records = ft_backend.extract(inputs, layers=layer_names)
172
+
173
+ diff = ModelDiff(
174
+ arch_family=backend_hint or "auto",
175
+ n_layers=len(layer_names),
176
+ n_calibration_texts=len(calibration_texts),
177
+ )
178
+
179
+ for base_rec, ft_rec in zip(base_records, ft_records):
180
+ mean_l2, rel, cos_sim, top_neurons = _activation_drift(
181
+ base_rec.activations, ft_rec.activations,
182
+ )
183
+ idx = int(base_rec.layer_name.split("_")[1].split(".")[0])
184
+ diff.layer_drift.append(LayerDrift(
185
+ layer=idx, layer_name=base_rec.layer_name,
186
+ mean_l2_delta=mean_l2,
187
+ relative_drift=rel,
188
+ cosine_similarity=cos_sim,
189
+ top_shifted_neurons=top_neurons,
190
+ ))
191
+
192
+ # Circuit deltas
193
+ if run_circuits:
194
+ try:
195
+ base_circs = circuits.run_all_circuits(base_model, tokenizer=tokenizer, device=device)
196
+ ft_circs = circuits.run_all_circuits(fine_tuned_model, tokenizer=tokenizer, device=device)
197
+ for name in base_circs:
198
+ if name not in ft_circs:
199
+ continue
200
+ bs = base_circs[name].score
201
+ fs = ft_circs[name].score
202
+ delta = fs - bs
203
+ rel = delta / (abs(bs) + 1e-9)
204
+ diff.circuit_deltas.append(CircuitDelta(
205
+ name=name,
206
+ base_score=bs, fine_tuned_score=fs,
207
+ delta=delta, relative_change=rel,
208
+ ))
209
+ except Exception as e:
210
+ diff.notes.append(f"circuit comparison error: {str(e)[:80]}")
211
+
212
+ return diff
@@ -0,0 +1,141 @@
1
+ """Backend for kazdov-α (and related Kazdov family models).
2
+
3
+ Kazdov-α is a transformer-style decoder LM with hybrid attention (MoBE-BCN
4
+ mixture of bilinear experts + standard MHA in parallel). Architecturally
5
+ closer to standard transformer than to pure RNN/SSM — but the BCN attention
6
+ branch makes it a distinct architecture family for cross-arch interp.
7
+
8
+ Differences from HF transformer:
9
+ - No HF AutoModelForCausalLM interface (custom forward signature)
10
+ - Layers exposed as `model.blocks` (ModuleList)
11
+ - No `output_hidden_states=True` argument — we capture via forward hooks
12
+ - Forward signature: (input_ids, attention_mask=None, labels=None)
13
+ """
14
+ from __future__ import annotations
15
+ import sys
16
+ from pathlib import Path
17
+ import torch
18
+
19
+ from .backends import Backend, ActivationRecord
20
+
21
+
22
+ KAZDOV_REPO = Path.home() / "code" / "OriginalKazdov" / "kazdov"
23
+
24
+
25
+ def _ensure_kazdov_importable():
26
+ """Add kazdov repo to sys.path so we can import KazdovLM."""
27
+ p = str(KAZDOV_REPO)
28
+ if p not in sys.path:
29
+ sys.path.insert(0, p)
30
+
31
+
32
+ def load_kazdov_checkpoint(checkpoint_path: str | Path, device: str = "cpu"):
33
+ """Load kazdov-α from a checkpoint directory.
34
+
35
+ Expects: config.json + final.pt (or latest.pt) in the directory.
36
+ Returns: (model in eval mode, tokenizer wrapper).
37
+ """
38
+ _ensure_kazdov_importable()
39
+ from kazdov.kazdov_lm import KazdovLM
40
+ import json
41
+
42
+ ckpt_dir = Path(checkpoint_path)
43
+ config = json.loads((ckpt_dir / "config.json").read_text())
44
+ model_cfg = config["model_cfg"]
45
+
46
+ model = KazdovLM(
47
+ vocab_size=model_cfg["vocab_size"],
48
+ d_model=model_cfg["d_model"],
49
+ n_layers=model_cfg["n_layers"],
50
+ n_heads=model_cfg["n_heads"],
51
+ rank=model_cfg["rank"],
52
+ mlp_dim=model_cfg.get("mlp_dim"),
53
+ max_len=model_cfg.get("max_len", 256),
54
+ use_trilinear=model_cfg.get("use_trilinear", False),
55
+ use_bi_bcn=model_cfg.get("use_bi_bcn", False),
56
+ use_hybrid_mha=model_cfg.get("use_hybrid_mha", True),
57
+ use_mobe=model_cfg.get("use_mobe", False),
58
+ n_experts=model_cfg.get("n_experts", 1),
59
+ )
60
+
61
+ # Try final.pt then latest.pt
62
+ for fname in ("final.pt", "latest.pt"):
63
+ f = ckpt_dir / fname
64
+ if f.exists():
65
+ state = torch.load(f, map_location=device, weights_only=False)
66
+ if isinstance(state, dict) and "model" in state:
67
+ state = state["model"]
68
+ model.load_state_dict(state, strict=False)
69
+ break
70
+ else:
71
+ raise FileNotFoundError(f"No final.pt or latest.pt in {ckpt_dir}")
72
+
73
+ model.to(device).eval()
74
+
75
+ # Tokenizer: kazdov used GPT-2 tokenizer per memory
76
+ from transformers import GPT2Tokenizer
77
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
78
+ if tokenizer.pad_token is None:
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+
81
+ return model, tokenizer
82
+
83
+
84
+ @Backend.register("kazdov")
85
+ class KazdovBackend(Backend):
86
+ """Backend for kazdov-family models (KazdovLM, MoBE-BCN variants).
87
+
88
+ Uses forward hooks to capture residual stream after each KazdovBlock,
89
+ since the model doesn't expose output_hidden_states.
90
+ """
91
+
92
+ def layer_names(self) -> list[str]:
93
+ n_layers = len(self.model.blocks)
94
+ return [f"layer_{i}.residual" for i in range(n_layers)]
95
+
96
+ def extract(self, inputs, layers=None):
97
+ layers = layers or self.layer_names()
98
+ captures: dict[str, torch.Tensor] = {}
99
+
100
+ # Register a forward hook on each requested block.
101
+ hooks = []
102
+ for layer_name in layers:
103
+ idx = int(layer_name.split("_")[1].split(".")[0])
104
+ if idx >= len(self.model.blocks):
105
+ continue
106
+ block = self.model.blocks[idx]
107
+
108
+ def make_hook(name):
109
+ def hook(module, inp, out):
110
+ tensor = out if isinstance(out, torch.Tensor) else out[0]
111
+ captures[name] = tensor.detach()
112
+ return hook
113
+ hooks.append(block.register_forward_hook(make_hook(layer_name)))
114
+
115
+ try:
116
+ # Kazdov forward signature: model(input_ids, attention_mask=None)
117
+ with torch.no_grad():
118
+ if isinstance(inputs, dict):
119
+ input_ids = inputs["input_ids"]
120
+ attn = inputs.get("attention_mask")
121
+ else:
122
+ input_ids = inputs
123
+ attn = None
124
+ self.model(input_ids, attention_mask=attn)
125
+ finally:
126
+ for h in hooks:
127
+ h.remove()
128
+
129
+ records = []
130
+ for layer_name in layers:
131
+ if layer_name not in captures:
132
+ continue
133
+ records.append(ActivationRecord(
134
+ layer_name=layer_name,
135
+ activations=captures[layer_name],
136
+ meta={"kind": "residual", "arch": "kazdov-mobe-bcn"},
137
+ ))
138
+ return records
139
+
140
+ def hidden_dim(self, layer_name: str) -> int:
141
+ return self.model.d_model
archscope/lens.py ADDED
@@ -0,0 +1,304 @@
1
+ """Logit lens and Tuned lens — project intermediate residual stream into vocab space.
2
+
3
+ The logit lens (Nostalgebraist, 2020) applies the model's own final norm + unembedding
4
+ to every layer's residual stream, revealing "what the model would predict if forced
5
+ to commit at this layer." It tells you which layers do the work of forming the final
6
+ prediction.
7
+
8
+ The tuned lens (Belrose et al, 2023) learns per-layer affine transformations that
9
+ correct for representation drift, producing a more faithful intermediate decoding.
10
+
11
+ Both work on any architecture exposing residual stream via Backend.extract — i.e.
12
+ transformer, mamba, kazdov, custom recurrent.
13
+ """
14
+ from __future__ import annotations
15
+ from dataclasses import dataclass, field
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from .backends import Backend
20
+ from ._utils import resolve_unembedding, resolve_final_norm
21
+
22
+
23
+ @dataclass
24
+ class LayerPrediction:
25
+ """What the lens says at one layer."""
26
+ layer: int
27
+ layer_name: str
28
+ top_tokens: list[tuple[int, str, float]] # (token_id, token_str, prob)
29
+ target_prob: float | None = None # if target_token provided, prob at this layer
30
+ target_rank: int | None = None # rank of target token in this layer's distribution
31
+ entropy: float = 0.0 # confidence (low = peaked, high = uncertain)
32
+
33
+
34
+ @dataclass
35
+ class LensResult:
36
+ """Output of a lens pass over multiple layers."""
37
+ prompt: str
38
+ target_token: str | None = None
39
+ target_token_id: int | None = None
40
+ layers: list[LayerPrediction] = field(default_factory=list)
41
+ method: str = "logit_lens"
42
+
43
+ def to_markdown(self) -> str:
44
+ """Quick formatted display."""
45
+ lines = [f"### {self.method} on `{self.prompt}`"]
46
+ if self.target_token:
47
+ lines.append(f"Target: `{self.target_token}` (id={self.target_token_id})")
48
+ lines.append("")
49
+ lines.append("| Layer | top-1 token | top-1 prob | target prob | rank | entropy |")
50
+ lines.append("|-------|-------------|-----------:|------------:|-----:|--------:|")
51
+ for lp in self.layers:
52
+ if lp.top_tokens:
53
+ top_id, top_str, top_p = lp.top_tokens[0]
54
+ else:
55
+ top_str, top_p = "?", float("nan")
56
+ tgt_p = f"{lp.target_prob:.3f}" if lp.target_prob is not None else "—"
57
+ tgt_r = f"{lp.target_rank}" if lp.target_rank is not None else "—"
58
+ lines.append(
59
+ f"| {lp.layer:>2} | `{top_str[:20]}` | {top_p:.3f} | {tgt_p} | {tgt_r} | {lp.entropy:.3f} |"
60
+ )
61
+ return "\n".join(lines)
62
+
63
+
64
+ def _logits_to_pred(
65
+ logits: torch.Tensor,
66
+ tokenizer,
67
+ target_id: int | None = None,
68
+ top_k: int = 5,
69
+ ) -> tuple[list[tuple[int, str, float]], float, int | None, float]:
70
+ """Convert logit vector → (top_k_tokens, target_prob, target_rank, entropy)."""
71
+ probs = torch.softmax(logits.float(), dim=-1)
72
+ top_p, top_i = probs.topk(top_k)
73
+ top_tokens = []
74
+ for p, i in zip(top_p.tolist(), top_i.tolist()):
75
+ try:
76
+ tok_str = tokenizer.decode([i])
77
+ except Exception:
78
+ tok_str = f"<id={i}>"
79
+ top_tokens.append((i, tok_str, p))
80
+ if target_id is not None:
81
+ target_prob = float(probs[target_id].item())
82
+ target_rank = int((probs > probs[target_id]).sum().item())
83
+ else:
84
+ target_prob = None
85
+ target_rank = None
86
+ # Shannon entropy in nats
87
+ entropy = float(-(probs.clamp(min=1e-12) * probs.clamp(min=1e-12).log()).sum().item())
88
+ return top_tokens, target_prob, target_rank, entropy
89
+
90
+
91
+ def logit_lens(
92
+ model,
93
+ tokenizer,
94
+ prompt: str,
95
+ target_token: str | None = None,
96
+ layers: list[int] | None = None,
97
+ backend_hint: str | None = None,
98
+ top_k: int = 5,
99
+ device: str = "cpu",
100
+ ) -> LensResult:
101
+ """Apply model's own final-norm + unembedding to each layer's residual stream.
102
+
103
+ Returns per-layer predictions: top-k tokens + entropy + (optional) target rank/prob.
104
+
105
+ The classic question this answers: "at which layer does the model commit to the
106
+ final token?" Often you'll see entropy decrease and target rank climb across layers.
107
+ """
108
+ backend = Backend.for_model(model, hint=backend_hint)
109
+ norm = resolve_final_norm(model)
110
+ unembed = resolve_unembedding(model)
111
+ if unembed is None:
112
+ raise ValueError("Could not locate model's lm_head / unembedding module.")
113
+
114
+ # Tokenize prompt (handle both HF tokenizer + kazdov-style inputs)
115
+ enc = tokenizer(prompt, return_tensors="pt")
116
+ if hasattr(enc, "input_ids"):
117
+ inputs = {"input_ids": enc.input_ids.to(device)}
118
+ else:
119
+ inputs = {"input_ids": enc["input_ids"].to(device)}
120
+
121
+ target_id = None
122
+ if target_token is not None:
123
+ target_ids = tokenizer(target_token, add_special_tokens=False).input_ids
124
+ target_id = target_ids[0] if target_ids else None
125
+
126
+ # Pick layers
127
+ all_layer_names = [n for n in backend.layer_names() if ".residual" in n]
128
+ if layers is None:
129
+ layer_names = all_layer_names
130
+ else:
131
+ layer_names = [f"layer_{i}.residual" for i in layers if f"layer_{i}.residual" in all_layer_names]
132
+
133
+ # Extract residual streams
134
+ with torch.no_grad():
135
+ records = backend.extract(inputs, layers=layer_names)
136
+
137
+ result = LensResult(
138
+ prompt=prompt,
139
+ target_token=target_token,
140
+ target_token_id=target_id,
141
+ method="logit_lens",
142
+ )
143
+
144
+ for rec in records:
145
+ # rec.activations shape: (B, T, hidden) — take last-token residual
146
+ last = rec.activations[:, -1, :] # (B, hidden)
147
+ # Apply final norm if present (transformer family). For kazdov it's `model.ln_f`.
148
+ if norm is not None:
149
+ try:
150
+ last = norm(last)
151
+ except Exception:
152
+ pass
153
+ # Project to vocab
154
+ logits = unembed(last)[0] # (vocab,)
155
+ top_tokens, tgt_p, tgt_r, ent = _logits_to_pred(logits, tokenizer, target_id, top_k)
156
+ idx = int(rec.layer_name.split("_")[1].split(".")[0])
157
+ result.layers.append(LayerPrediction(
158
+ layer=idx, layer_name=rec.layer_name,
159
+ top_tokens=top_tokens,
160
+ target_prob=tgt_p, target_rank=tgt_r,
161
+ entropy=ent,
162
+ ))
163
+
164
+ return result
165
+
166
+
167
+ # -------------------- TUNED LENS --------------------
168
+
169
+ class TunedLens(nn.Module):
170
+ """Learned per-layer affine transformations into the unembedding space.
171
+
172
+ For each layer, learn a transformation: residual → adjusted_residual that, when
173
+ fed through the model's final_norm + unembedding, better matches the final-layer
174
+ distribution. Following Belrose et al 2023.
175
+
176
+ Usage:
177
+ tl = TunedLens.fit(model, tokenizer, calibration_texts, backend_hint="transformer")
178
+ result = tl.predict(model, tokenizer, prompt="...", backend_hint="transformer")
179
+ """
180
+
181
+ def __init__(self, model, n_layers: int, hidden_dim: int, backend_hint: str | None = None):
182
+ super().__init__()
183
+ self.n_layers = n_layers
184
+ self.hidden_dim = hidden_dim
185
+ self.backend_hint = backend_hint
186
+ # One Linear per layer; identity initialization (so untrained tuned lens ≈ logit lens)
187
+ self.translators = nn.ModuleList([
188
+ nn.Linear(hidden_dim, hidden_dim, bias=True)
189
+ for _ in range(n_layers)
190
+ ])
191
+ for t in self.translators:
192
+ nn.init.eye_(t.weight)
193
+ nn.init.zeros_(t.bias)
194
+ self._model_ref = [model] # avoid registering as submodule
195
+
196
+ @classmethod
197
+ def fit(
198
+ cls,
199
+ model,
200
+ tokenizer,
201
+ calibration_texts: list[str],
202
+ backend_hint: str | None = None,
203
+ epochs: int = 30,
204
+ lr: float = 1e-3,
205
+ max_len: int = 64,
206
+ device: str = "cpu",
207
+ ) -> "TunedLens":
208
+ """Train per-layer translators to match the model's own final distribution."""
209
+ backend = Backend.for_model(model, hint=backend_hint)
210
+ layer_names = [n for n in backend.layer_names() if ".residual" in n]
211
+ n_layers = len(layer_names)
212
+ # Get hidden dim from first layer
213
+ hidden_dim = backend.hidden_dim(layer_names[0])
214
+
215
+ tl = cls(model, n_layers=n_layers, hidden_dim=hidden_dim, backend_hint=backend_hint).to(device)
216
+ norm = resolve_final_norm(model)
217
+ unembed = resolve_unembedding(model)
218
+
219
+ opt = torch.optim.AdamW(tl.translators.parameters(), lr=lr)
220
+
221
+ # Pre-extract all activations + target logits once
222
+ enc = tokenizer(calibration_texts, return_tensors="pt", padding=True,
223
+ truncation=True, max_length=max_len)
224
+ inputs = {"input_ids": enc["input_ids"].to(device)}
225
+ with torch.no_grad():
226
+ records = backend.extract(inputs, layers=layer_names)
227
+ # Target: model's actual final logits at last position
228
+ final_residual = records[-1].activations[:, -1, :]
229
+ if norm is not None:
230
+ final_residual = norm(final_residual)
231
+ target_logits = unembed(final_residual).detach() # (B, vocab)
232
+ target_log_probs = torch.log_softmax(target_logits.float(), dim=-1)
233
+
234
+ for epoch in range(epochs):
235
+ opt.zero_grad()
236
+ total_loss = 0.0
237
+ for i, rec in enumerate(records):
238
+ last = rec.activations[:, -1, :].detach()
239
+ translated = tl.translators[i](last)
240
+ if norm is not None:
241
+ translated = norm(translated)
242
+ pred_logits = unembed(translated)
243
+ pred_log_probs = torch.log_softmax(pred_logits.float(), dim=-1)
244
+ # KL divergence target_log_probs || pred_log_probs
245
+ loss = torch.nn.functional.kl_div(pred_log_probs, target_log_probs,
246
+ log_target=True, reduction="batchmean")
247
+ total_loss = total_loss + loss
248
+ total_loss.backward()
249
+ opt.step()
250
+
251
+ tl.last_loss = float(total_loss.item() / max(n_layers, 1))
252
+ return tl
253
+
254
+ def predict(
255
+ self,
256
+ model,
257
+ tokenizer,
258
+ prompt: str,
259
+ target_token: str | None = None,
260
+ backend_hint: str | None = None,
261
+ layers: list[int] | None = None,
262
+ top_k: int = 5,
263
+ device: str = "cpu",
264
+ ) -> LensResult:
265
+ """Apply the trained tuned lens to a new prompt."""
266
+ backend = Backend.for_model(model, hint=backend_hint or self.backend_hint)
267
+ norm = resolve_final_norm(model)
268
+ unembed = resolve_unembedding(model)
269
+
270
+ enc = tokenizer(prompt, return_tensors="pt")
271
+ inputs = {"input_ids": enc["input_ids"].to(device)}
272
+ target_id = None
273
+ if target_token is not None:
274
+ ids = tokenizer(target_token, add_special_tokens=False).input_ids
275
+ target_id = ids[0] if ids else None
276
+
277
+ all_layer_names = [n for n in backend.layer_names() if ".residual" in n]
278
+ if layers is None:
279
+ layer_names = all_layer_names
280
+ else:
281
+ layer_names = [f"layer_{i}.residual" for i in layers if f"layer_{i}.residual" in all_layer_names]
282
+
283
+ with torch.no_grad():
284
+ records = backend.extract(inputs, layers=layer_names)
285
+
286
+ result = LensResult(prompt=prompt, target_token=target_token,
287
+ target_token_id=target_id, method="tuned_lens")
288
+
289
+ for rec in records:
290
+ idx = int(rec.layer_name.split("_")[1].split(".")[0])
291
+ last = rec.activations[:, -1, :]
292
+ with torch.no_grad():
293
+ translated = self.translators[idx](last)
294
+ if norm is not None:
295
+ translated = norm(translated)
296
+ logits = unembed(translated)[0]
297
+ top_tokens, tgt_p, tgt_r, ent = _logits_to_pred(logits, tokenizer, target_id, top_k)
298
+ result.layers.append(LayerPrediction(
299
+ layer=idx, layer_name=rec.layer_name,
300
+ top_tokens=top_tokens,
301
+ target_prob=tgt_p, target_rank=tgt_r,
302
+ entropy=ent,
303
+ ))
304
+ return result