archscope 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- archscope/__init__.py +30 -0
- archscope/_utils.py +113 -0
- archscope/attribute.py +201 -0
- archscope/backends.py +236 -0
- archscope/bench.py +262 -0
- archscope/circuits.py +255 -0
- archscope/cli.py +120 -0
- archscope/diff.py +212 -0
- archscope/kazdov_backend.py +141 -0
- archscope/lens.py +304 -0
- archscope/neurons.py +118 -0
- archscope/probes.py +160 -0
- archscope/sae.py +127 -0
- archscope/transfer.py +188 -0
- archscope-0.2.2.dist-info/METADATA +324 -0
- archscope-0.2.2.dist-info/RECORD +20 -0
- archscope-0.2.2.dist-info/WHEEL +5 -0
- archscope-0.2.2.dist-info/entry_points.txt +2 -0
- archscope-0.2.2.dist-info/licenses/LICENSE +17 -0
- archscope-0.2.2.dist-info/top_level.txt +1 -0
archscope/neurons.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Targeted Neuron Modulation via Contrastive Pair Search (2605.12290).
|
|
2
|
+
|
|
3
|
+
5-step algorithm reimplemented from paper:
|
|
4
|
+
1. Run harmful + benign prompts through model
|
|
5
|
+
2. Record MLP activations at final token position
|
|
6
|
+
3. Compute per-neuron mean activation difference
|
|
7
|
+
4. Select top k (default 0.1%) neurons by |Δ|
|
|
8
|
+
5. At inference: multiply selected neuron activations by scalar m
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
import torch
|
|
13
|
+
from .backends import Backend
|
|
14
|
+
from ._utils import resolve_layer_module
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class NeuronEditConfig:
|
|
19
|
+
top_frac: float = 0.001 # top 0.1% by default
|
|
20
|
+
layer_filter: str | None = None # e.g., "mlp" to restrict to MLP neurons
|
|
21
|
+
mode: str = "scalar" # "scalar" (multiply by m) or "ablate" (m=0)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class NeuronEdit:
|
|
26
|
+
"""Result of contrastive search — which neurons to edit + how."""
|
|
27
|
+
layer_to_indices: dict[str, torch.Tensor] # layer_name → neuron indices
|
|
28
|
+
layer_to_deltas: dict[str, torch.Tensor] # layer_name → mean diff values
|
|
29
|
+
config: NeuronEditConfig
|
|
30
|
+
multiplier: float = 0.0 # 0 = ablate; >1 = amplify; <1 = dampen
|
|
31
|
+
|
|
32
|
+
def apply_hook(self, model, backend: Backend | None = None):
|
|
33
|
+
"""Register forward hooks that modulate the selected neurons.
|
|
34
|
+
|
|
35
|
+
Returns a context manager that auto-removes hooks on exit.
|
|
36
|
+
"""
|
|
37
|
+
hooks = []
|
|
38
|
+
for layer_name, indices in self.layer_to_indices.items():
|
|
39
|
+
module = resolve_layer_module(model, layer_name)
|
|
40
|
+
if module is None:
|
|
41
|
+
continue
|
|
42
|
+
|
|
43
|
+
indices_local = indices
|
|
44
|
+
|
|
45
|
+
def hook(module, input, output, idxs=indices_local, m=self.multiplier):
|
|
46
|
+
# Return MODIFIED output (don't rely on in-place — some HF layers
|
|
47
|
+
# produce fresh tensors that won't propagate in-place changes).
|
|
48
|
+
if isinstance(output, tuple):
|
|
49
|
+
h = output[0].clone()
|
|
50
|
+
if h.dim() >= 2:
|
|
51
|
+
h[..., idxs] = h[..., idxs] * m
|
|
52
|
+
return (h,) + output[1:]
|
|
53
|
+
h = output.clone()
|
|
54
|
+
if h.dim() >= 2:
|
|
55
|
+
h[..., idxs] = h[..., idxs] * m
|
|
56
|
+
return h
|
|
57
|
+
hooks.append(module.register_forward_hook(hook))
|
|
58
|
+
|
|
59
|
+
return _HookContext(hooks)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _HookContext:
|
|
63
|
+
"""Auto-removes forward hooks when used as a `with` block."""
|
|
64
|
+
|
|
65
|
+
def __init__(self, hooks):
|
|
66
|
+
self.hooks = hooks
|
|
67
|
+
|
|
68
|
+
def __enter__(self):
|
|
69
|
+
return self
|
|
70
|
+
|
|
71
|
+
def __exit__(self, *args):
|
|
72
|
+
for h in self.hooks:
|
|
73
|
+
h.remove()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def find_neurons(
|
|
77
|
+
model,
|
|
78
|
+
inputs_harmful: list,
|
|
79
|
+
inputs_benign: list,
|
|
80
|
+
config: NeuronEditConfig | None = None,
|
|
81
|
+
backend_hint: str | None = None,
|
|
82
|
+
) -> NeuronEdit:
|
|
83
|
+
"""Algorithm 1 of Targeted Neuron Modulation paper.
|
|
84
|
+
|
|
85
|
+
Returns a NeuronEdit that can be applied as a hook during inference.
|
|
86
|
+
"""
|
|
87
|
+
config = config or NeuronEditConfig()
|
|
88
|
+
backend = Backend.for_model(model, hint=backend_hint)
|
|
89
|
+
|
|
90
|
+
# Get all layers (will filter to MLP later if requested)
|
|
91
|
+
all_layers = backend.layer_names()
|
|
92
|
+
|
|
93
|
+
# Forward both classes, collect final-token activations
|
|
94
|
+
harm_acts = backend.extract(inputs_harmful, layers=all_layers)
|
|
95
|
+
ben_acts = backend.extract(inputs_benign, layers=all_layers)
|
|
96
|
+
|
|
97
|
+
layer_to_indices = {}
|
|
98
|
+
layer_to_deltas = {}
|
|
99
|
+
for h_rec, b_rec in zip(harm_acts, ben_acts):
|
|
100
|
+
# Final token: (batch, -1, hidden_dim) -> (batch, hidden_dim)
|
|
101
|
+
h_final = h_rec.activations[:, -1, :]
|
|
102
|
+
b_final = b_rec.activations[:, -1, :]
|
|
103
|
+
# Per-neuron mean diff
|
|
104
|
+
delta = h_final.mean(dim=0) - b_final.mean(dim=0)
|
|
105
|
+
# Top k by absolute value
|
|
106
|
+
k = max(1, int(config.top_frac * len(delta)))
|
|
107
|
+
topk = torch.topk(delta.abs(), k=k)
|
|
108
|
+
layer_to_indices[h_rec.layer_name] = topk.indices
|
|
109
|
+
layer_to_deltas[h_rec.layer_name] = delta[topk.indices]
|
|
110
|
+
|
|
111
|
+
return NeuronEdit(
|
|
112
|
+
layer_to_indices=layer_to_indices,
|
|
113
|
+
layer_to_deltas=layer_to_deltas,
|
|
114
|
+
config=config,
|
|
115
|
+
multiplier=0.0,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
archscope/probes.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Probe-Filtered RL — linear/MLP probes over hidden states (Drop the Act, 2605.11467).
|
|
2
|
+
|
|
3
|
+
Core idea: train a frozen-base probe over residual/hidden states that predicts
|
|
4
|
+
some property (faithfulness, refusal, deception). Use probe scores to filter
|
|
5
|
+
trajectories during RL training, or as inference-time signals.
|
|
6
|
+
|
|
7
|
+
Reimplemented from paper (code anonymized for review).
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from .backends import Backend
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ProbeConfig:
|
|
18
|
+
"""Configuration for a single probe."""
|
|
19
|
+
layer_name: str
|
|
20
|
+
probe_type: str = "linear" # "linear" or "mlp"
|
|
21
|
+
hidden_dim: int = 64 # only used if mlp
|
|
22
|
+
target: str = "performativity"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Probe(nn.Module):
|
|
26
|
+
"""Linear or MLP probe over a layer's hidden states.
|
|
27
|
+
|
|
28
|
+
Trained to predict a scalar property from activations at one layer.
|
|
29
|
+
Frozen base model: probe is the only trainable component.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, input_dim: int, config: ProbeConfig):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.config = config
|
|
35
|
+
if config.probe_type == "linear":
|
|
36
|
+
self.net = nn.Linear(input_dim, 1)
|
|
37
|
+
elif config.probe_type == "mlp":
|
|
38
|
+
self.net = nn.Sequential(
|
|
39
|
+
nn.Linear(input_dim, config.hidden_dim),
|
|
40
|
+
nn.GELU(),
|
|
41
|
+
nn.Linear(config.hidden_dim, 1),
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError(f"Unknown probe_type: {config.probe_type}")
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
"""Apply probe. Accepts (N, hidden_dim) or (N, seq, hidden_dim); returns
|
|
48
|
+
same leading dims with the final hidden_dim collapsed to scalar logits."""
|
|
49
|
+
return self.net(x).squeeze(-1)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ProbeFit:
|
|
53
|
+
"""Fit a probe given (activations, labels) pairs."""
|
|
54
|
+
|
|
55
|
+
def __init__(self, config: ProbeConfig, input_dim: int, device: str = "cpu"):
|
|
56
|
+
self.config = config
|
|
57
|
+
self.probe = Probe(input_dim, config).to(device)
|
|
58
|
+
self.device = device
|
|
59
|
+
|
|
60
|
+
def train(
|
|
61
|
+
self,
|
|
62
|
+
activations: torch.Tensor, # (N, hidden_dim) or (N, seq_len, hidden_dim) pooled
|
|
63
|
+
labels: torch.Tensor, # (N,)
|
|
64
|
+
epochs: int = 50,
|
|
65
|
+
lr: float = 1e-3,
|
|
66
|
+
batch_size: int = 64,
|
|
67
|
+
val_split: float = 0.2,
|
|
68
|
+
) -> dict:
|
|
69
|
+
"""Standard supervised fit. Returns train/val AUROC + final loss."""
|
|
70
|
+
if activations.dim() == 3:
|
|
71
|
+
# Pool seq dim by mean (could be configurable)
|
|
72
|
+
activations = activations.mean(dim=1)
|
|
73
|
+
activations = activations.to(self.device)
|
|
74
|
+
labels = labels.float().to(self.device)
|
|
75
|
+
|
|
76
|
+
n = len(activations)
|
|
77
|
+
idx = torch.randperm(n)
|
|
78
|
+
n_val = int(n * val_split)
|
|
79
|
+
train_idx, val_idx = idx[n_val:], idx[:n_val]
|
|
80
|
+
|
|
81
|
+
opt = torch.optim.AdamW(self.probe.parameters(), lr=lr)
|
|
82
|
+
loss_fn = nn.BCEWithLogitsLoss()
|
|
83
|
+
|
|
84
|
+
for _ in range(epochs):
|
|
85
|
+
self.probe.train()
|
|
86
|
+
perm = train_idx[torch.randperm(len(train_idx))]
|
|
87
|
+
for b in range(0, len(perm), batch_size):
|
|
88
|
+
batch = perm[b:b+batch_size]
|
|
89
|
+
logits = self.probe(activations[batch])
|
|
90
|
+
loss = loss_fn(logits, labels[batch])
|
|
91
|
+
opt.zero_grad()
|
|
92
|
+
loss.backward()
|
|
93
|
+
opt.step()
|
|
94
|
+
|
|
95
|
+
self.probe.eval()
|
|
96
|
+
with torch.no_grad():
|
|
97
|
+
train_logits = self.probe(activations[train_idx])
|
|
98
|
+
val_logits = self.probe(activations[val_idx])
|
|
99
|
+
return {
|
|
100
|
+
"train_auroc": _auroc(train_logits, labels[train_idx]),
|
|
101
|
+
"val_auroc": _auroc(val_logits, labels[val_idx]),
|
|
102
|
+
"train_loss": loss_fn(train_logits, labels[train_idx]).item(),
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
def score(self, activations: torch.Tensor) -> torch.Tensor:
|
|
106
|
+
"""Apply probe — returns per-token (or per-example if pooled) scores."""
|
|
107
|
+
self.probe.eval()
|
|
108
|
+
with torch.no_grad():
|
|
109
|
+
return torch.sigmoid(self.probe(activations.to(self.device)))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _auroc(logits: torch.Tensor, labels: torch.Tensor) -> float:
|
|
113
|
+
"""Simple AUROC from logits + binary labels."""
|
|
114
|
+
from sklearn.metrics import roc_auc_score
|
|
115
|
+
scores = torch.sigmoid(logits).cpu().numpy()
|
|
116
|
+
y = labels.cpu().numpy()
|
|
117
|
+
try:
|
|
118
|
+
return float(roc_auc_score(y, scores))
|
|
119
|
+
except ValueError:
|
|
120
|
+
return float("nan") # happens when only one class present
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# High-level API matching paper
|
|
124
|
+
|
|
125
|
+
def fit_probe(
|
|
126
|
+
model,
|
|
127
|
+
inputs_pos: list, # examples where target=1 (e.g., faithful)
|
|
128
|
+
inputs_neg: list, # examples where target=0 (e.g., reasoning theater)
|
|
129
|
+
layer_name: str,
|
|
130
|
+
backend_hint: str | None = None,
|
|
131
|
+
config: ProbeConfig | None = None,
|
|
132
|
+
device: str = "cpu",
|
|
133
|
+
) -> ProbeFit:
|
|
134
|
+
"""End-to-end: extract activations from model, fit probe."""
|
|
135
|
+
backend = Backend.for_model(model, hint=backend_hint)
|
|
136
|
+
config = config or ProbeConfig(layer_name=layer_name)
|
|
137
|
+
|
|
138
|
+
# Extract activations for both classes — detach from autograd graph
|
|
139
|
+
# (probe is trained separately; base model is frozen)
|
|
140
|
+
with torch.no_grad():
|
|
141
|
+
acts_pos = backend.extract(inputs_pos, layers=[layer_name])[0].activations.detach()
|
|
142
|
+
acts_neg = backend.extract(inputs_neg, layers=[layer_name])[0].activations.detach()
|
|
143
|
+
|
|
144
|
+
# Pool across sequence dim — handles independent padding across batches
|
|
145
|
+
if acts_pos.dim() == 3:
|
|
146
|
+
acts_pos = acts_pos.mean(dim=1) # (batch, hidden)
|
|
147
|
+
if acts_neg.dim() == 3:
|
|
148
|
+
acts_neg = acts_neg.mean(dim=1)
|
|
149
|
+
|
|
150
|
+
labels_pos = torch.ones(len(acts_pos))
|
|
151
|
+
labels_neg = torch.zeros(len(acts_neg))
|
|
152
|
+
|
|
153
|
+
all_acts = torch.cat([acts_pos, acts_neg], dim=0)
|
|
154
|
+
all_labels = torch.cat([labels_pos, labels_neg])
|
|
155
|
+
|
|
156
|
+
hidden_dim = backend.hidden_dim(layer_name)
|
|
157
|
+
probe_fit = ProbeFit(config, hidden_dim, device=device)
|
|
158
|
+
metrics = probe_fit.train(all_acts, all_labels)
|
|
159
|
+
probe_fit.metrics = metrics
|
|
160
|
+
return probe_fit
|
archscope/sae.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Sparse Autoencoders for hidden states (WriteSAE, 2605.12770).
|
|
2
|
+
|
|
3
|
+
Two SAE variants:
|
|
4
|
+
- DenseSAE: standard SAE with L1 sparsity penalty
|
|
5
|
+
- Rank1FactoredSAE: WriteSAE's contribution — atoms = v_i w_i^T (rank-1 outer
|
|
6
|
+
product), applied to recurrent cache writes specifically.
|
|
7
|
+
|
|
8
|
+
Both work on transformer residual streams AND recurrent hidden states.
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SAEConfig:
|
|
19
|
+
input_dim: int
|
|
20
|
+
n_features: int # dictionary size
|
|
21
|
+
sparsity: float = 1e-3 # L1 coefficient
|
|
22
|
+
sae_type: str = "dense" # "dense" or "rank1"
|
|
23
|
+
learning_rate: float = 1e-3
|
|
24
|
+
batch_size: int = 64
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DenseSAE(nn.Module):
|
|
28
|
+
"""Standard SAE: encoder + decoder, L1 sparsity on hidden code."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, config: SAEConfig):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.config = config
|
|
33
|
+
self.encoder = nn.Linear(config.input_dim, config.n_features)
|
|
34
|
+
self.decoder = nn.Linear(config.n_features, config.input_dim, bias=False)
|
|
35
|
+
# Initialize decoder weights to encoder.T transpose (helpful for SAEs)
|
|
36
|
+
with torch.no_grad():
|
|
37
|
+
self.decoder.weight.data = self.encoder.weight.data.T.clone()
|
|
38
|
+
|
|
39
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
40
|
+
return F.relu(self.encoder(x))
|
|
41
|
+
|
|
42
|
+
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
return self.decoder(z)
|
|
44
|
+
|
|
45
|
+
def forward(self, x: torch.Tensor):
|
|
46
|
+
z = self.encode(x)
|
|
47
|
+
x_hat = self.decode(z)
|
|
48
|
+
return x_hat, z
|
|
49
|
+
|
|
50
|
+
def loss(self, x: torch.Tensor):
|
|
51
|
+
x_hat, z = self.forward(x)
|
|
52
|
+
recon = F.mse_loss(x_hat, x)
|
|
53
|
+
l1 = z.abs().mean() * self.config.sparsity
|
|
54
|
+
return recon + l1, {"recon": recon.item(), "l1": l1.item(), "n_active": (z > 0).float().mean().item()}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Rank1FactoredSAE(nn.Module):
|
|
58
|
+
"""WriteSAE rank-1 factored atoms: each feature i = v_i w_i^T outer product.
|
|
59
|
+
|
|
60
|
+
Designed for recurrent CACHE WRITES specifically — atoms substitute for
|
|
61
|
+
native write contributions at matched Frobenius norm.
|
|
62
|
+
|
|
63
|
+
Reference: Eq. 3-factor closed form Δℓ ≈ G·⟨w_i,q_t⟩·⟨v_i,W_u[tok]⟩
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, config: SAEConfig):
|
|
67
|
+
super().__init__()
|
|
68
|
+
assert config.sae_type == "rank1"
|
|
69
|
+
self.config = config
|
|
70
|
+
# Atoms parameterized as v (output) and w (input) vectors
|
|
71
|
+
self.v = nn.Parameter(torch.randn(config.n_features, config.input_dim) * 0.02)
|
|
72
|
+
self.w = nn.Parameter(torch.randn(config.n_features, config.input_dim) * 0.02)
|
|
73
|
+
|
|
74
|
+
def fire(self, x: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
"""Firing strengths per atom: <w_i, x>, kept positive."""
|
|
76
|
+
# x: (..., input_dim), w: (n_features, input_dim)
|
|
77
|
+
return F.relu(x @ self.w.T)
|
|
78
|
+
|
|
79
|
+
def reconstruct(self, x: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
fires = self.fire(x)
|
|
81
|
+
# Output: sum_i fires_i * v_i
|
|
82
|
+
return fires @ self.v
|
|
83
|
+
|
|
84
|
+
def forward(self, x: torch.Tensor):
|
|
85
|
+
x_hat = self.reconstruct(x)
|
|
86
|
+
z = self.fire(x)
|
|
87
|
+
return x_hat, z
|
|
88
|
+
|
|
89
|
+
def loss(self, x: torch.Tensor):
|
|
90
|
+
x_hat, z = self.forward(x)
|
|
91
|
+
recon = F.mse_loss(x_hat, x)
|
|
92
|
+
l1 = z.abs().mean() * self.config.sparsity
|
|
93
|
+
return recon + l1, {"recon": recon.item(), "l1": l1.item(), "n_active": (z > 0).float().mean().item()}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def build_sae(config: SAEConfig):
|
|
97
|
+
if config.sae_type == "dense":
|
|
98
|
+
return DenseSAE(config)
|
|
99
|
+
if config.sae_type == "rank1":
|
|
100
|
+
return Rank1FactoredSAE(config)
|
|
101
|
+
raise ValueError(f"Unknown sae_type: {config.sae_type}")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def fit_sae(
|
|
105
|
+
activations: torch.Tensor, # (N, input_dim) flattened
|
|
106
|
+
config: SAEConfig,
|
|
107
|
+
epochs: int = 100,
|
|
108
|
+
device: str = "cpu",
|
|
109
|
+
) -> nn.Module:
|
|
110
|
+
"""Train SAE on flattened activations."""
|
|
111
|
+
sae = build_sae(config).to(device)
|
|
112
|
+
opt = torch.optim.AdamW(sae.parameters(), lr=config.learning_rate)
|
|
113
|
+
activations = activations.to(device)
|
|
114
|
+
|
|
115
|
+
n = len(activations)
|
|
116
|
+
last_metrics: dict = {}
|
|
117
|
+
for _ in range(epochs):
|
|
118
|
+
perm = torch.randperm(n)
|
|
119
|
+
for b in range(0, n, config.batch_size):
|
|
120
|
+
batch = activations[perm[b:b + config.batch_size]]
|
|
121
|
+
loss, metrics = sae.loss(batch)
|
|
122
|
+
opt.zero_grad()
|
|
123
|
+
loss.backward()
|
|
124
|
+
opt.step()
|
|
125
|
+
last_metrics = metrics
|
|
126
|
+
sae.last_metrics = last_metrics
|
|
127
|
+
return sae
|
archscope/transfer.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Cross-architecture probe transfer via paired-activation linear alignment.
|
|
2
|
+
|
|
3
|
+
Setup:
|
|
4
|
+
- Two models with different hidden dims (e.g., Pythia 768 vs Kazdov 512)
|
|
5
|
+
- Train a linear projection M: kazdov_space (512) → pythia_space (768) using
|
|
6
|
+
paired (text → activation pair) data
|
|
7
|
+
- A Pythia probe with weights w_py ∈ R^768 transferred to kazdov-space becomes
|
|
8
|
+
w_kz = M^T · w_py ∈ R^512
|
|
9
|
+
|
|
10
|
+
This tests: do probe directions from one architecture transfer to another?
|
|
11
|
+
"""
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class TransferResult:
|
|
20
|
+
"""One transfer experiment's metrics."""
|
|
21
|
+
source_arch: str # e.g., "pythia"
|
|
22
|
+
target_arch: str # e.g., "kazdov"
|
|
23
|
+
source_layer: str
|
|
24
|
+
target_layer: str
|
|
25
|
+
n_align_pairs: int
|
|
26
|
+
|
|
27
|
+
baseline_source_auroc: float # source probe on source data (no transfer)
|
|
28
|
+
baseline_target_auroc: float # target probe on target data (in-arch reference)
|
|
29
|
+
transfer_auroc: float # source probe via M applied to target data
|
|
30
|
+
transfer_drop: float # baseline_target - transfer (how much we lose)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def learn_alignment(
|
|
34
|
+
src_acts: torch.Tensor, # (N, d_src)
|
|
35
|
+
tgt_acts: torch.Tensor, # (N, d_tgt)
|
|
36
|
+
ridge: float = 1e-3,
|
|
37
|
+
) -> torch.Tensor:
|
|
38
|
+
"""Ridge regression: learn M such that M @ tgt_acts.T ≈ src_acts.T.
|
|
39
|
+
|
|
40
|
+
M has shape (d_src, d_tgt). After fit, src_act ≈ M @ tgt_act.
|
|
41
|
+
|
|
42
|
+
Ridge term prevents overfitting on small N.
|
|
43
|
+
"""
|
|
44
|
+
assert src_acts.shape[0] == tgt_acts.shape[0]
|
|
45
|
+
src_acts = src_acts.float()
|
|
46
|
+
tgt_acts = tgt_acts.float()
|
|
47
|
+
# M = src @ tgt.T @ (tgt @ tgt.T + λI)^-1
|
|
48
|
+
# But the standard formulation: we want X @ M.T = Y where X is target, Y is source
|
|
49
|
+
# So M.T = (X.T X + λI)^-1 X.T Y
|
|
50
|
+
# Then M = ((X.T X + λI)^-1 X.T Y).T
|
|
51
|
+
X = tgt_acts # (N, d_tgt)
|
|
52
|
+
Y = src_acts # (N, d_src)
|
|
53
|
+
XtX = X.T @ X
|
|
54
|
+
reg = torch.eye(XtX.shape[0], device=XtX.device) * ridge
|
|
55
|
+
M_T = torch.linalg.solve(XtX + reg, X.T @ Y)
|
|
56
|
+
M = M_T.T # (d_src, d_tgt)
|
|
57
|
+
return M
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def transfer_probe(
|
|
61
|
+
probe_weights_source: torch.Tensor, # (d_src,) — Pythia probe direction
|
|
62
|
+
probe_bias_source: torch.Tensor, # scalar
|
|
63
|
+
alignment: torch.Tensor, # M: (d_src, d_tgt)
|
|
64
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
65
|
+
"""Transport a probe from source arch to target arch via alignment matrix.
|
|
66
|
+
|
|
67
|
+
Given Pythia probe f(x_py) = w_py · x_py + b, and alignment x_py ≈ M @ x_kz,
|
|
68
|
+
the kazdov-space probe is w_kz = M.T @ w_py, with same bias.
|
|
69
|
+
"""
|
|
70
|
+
w_target = alignment.T @ probe_weights_source
|
|
71
|
+
return w_target, probe_bias_source
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def auroc_from_scores(scores: torch.Tensor, labels: torch.Tensor) -> float:
|
|
75
|
+
"""Quick AUROC."""
|
|
76
|
+
from sklearn.metrics import roc_auc_score
|
|
77
|
+
try:
|
|
78
|
+
return float(roc_auc_score(labels.cpu().numpy(), scores.cpu().numpy()))
|
|
79
|
+
except ValueError:
|
|
80
|
+
return float("nan")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def evaluate_transfer(
|
|
84
|
+
source_model, target_model,
|
|
85
|
+
source_backend, target_backend,
|
|
86
|
+
source_tokenize, target_tokenize,
|
|
87
|
+
align_texts: list[str], # texts for learning M (no labels needed)
|
|
88
|
+
train_pos: list[str], train_neg: list[str],
|
|
89
|
+
test_pos: list[str], test_neg: list[str],
|
|
90
|
+
source_layer: str,
|
|
91
|
+
target_layer: str,
|
|
92
|
+
source_arch_name: str = "source",
|
|
93
|
+
target_arch_name: str = "target",
|
|
94
|
+
) -> TransferResult:
|
|
95
|
+
"""Full transfer experiment.
|
|
96
|
+
|
|
97
|
+
1. Get paired source/target activations on `align_texts` → learn alignment M
|
|
98
|
+
2. Train source probe on source activations of train_pos/neg
|
|
99
|
+
3. Train target probe on target activations of train_pos/neg (in-arch baseline)
|
|
100
|
+
4. Transfer source probe via M, apply to target activations of test_pos/neg
|
|
101
|
+
5. Compare AUROCs.
|
|
102
|
+
"""
|
|
103
|
+
def _extract_pooled(backend, tokenize, texts: list[str], layer: str) -> torch.Tensor:
|
|
104
|
+
"""Extract activations and pool to (N, hidden_dim) by averaging seq dim."""
|
|
105
|
+
with torch.no_grad():
|
|
106
|
+
rec = backend.extract(tokenize(texts), layers=[layer])[0]
|
|
107
|
+
acts = rec.activations.detach()
|
|
108
|
+
return acts.mean(dim=1) if acts.dim() == 3 else acts
|
|
109
|
+
|
|
110
|
+
# -- 1. Learn alignment from paired activations on the same texts.
|
|
111
|
+
src_align = _extract_pooled(source_backend, source_tokenize, align_texts, source_layer)
|
|
112
|
+
tgt_align = _extract_pooled(target_backend, target_tokenize, align_texts, target_layer)
|
|
113
|
+
M = learn_alignment(src_align, tgt_align)
|
|
114
|
+
|
|
115
|
+
# -- 2. Train source probe (Pythia-style) on source activations
|
|
116
|
+
from . import probes
|
|
117
|
+
src_train_inputs_pos = source_tokenize(train_pos)
|
|
118
|
+
src_train_inputs_neg = source_tokenize(train_neg)
|
|
119
|
+
pf_src = probes.fit_probe(
|
|
120
|
+
source_model, src_train_inputs_pos, src_train_inputs_neg,
|
|
121
|
+
layer_name=source_layer, backend_hint=source_arch_name,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Extract source probe weights (linear case)
|
|
125
|
+
src_probe_net = pf_src.probe.net
|
|
126
|
+
if not isinstance(src_probe_net, nn.Linear):
|
|
127
|
+
raise NotImplementedError("transfer only works for linear probes for now")
|
|
128
|
+
w_src = src_probe_net.weight.data.squeeze(0)
|
|
129
|
+
b_src = src_probe_net.bias.data
|
|
130
|
+
|
|
131
|
+
# -- 3. Train target probe (in-arch baseline)
|
|
132
|
+
tgt_train_inputs_pos = target_tokenize(train_pos)
|
|
133
|
+
tgt_train_inputs_neg = target_tokenize(train_neg)
|
|
134
|
+
pf_tgt = probes.fit_probe(
|
|
135
|
+
target_model, tgt_train_inputs_pos, tgt_train_inputs_neg,
|
|
136
|
+
layer_name=target_layer, backend_hint=target_arch_name,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# -- 4. Transfer: project w_src into target space via M
|
|
140
|
+
w_transferred, b_transferred = transfer_probe(w_src, b_src, M)
|
|
141
|
+
|
|
142
|
+
# -- 5. Evaluate on the test set.
|
|
143
|
+
with torch.no_grad():
|
|
144
|
+
# Source baseline: source probe on source test data.
|
|
145
|
+
src_test_acts_pos = _extract_pooled(source_backend, source_tokenize, test_pos, source_layer)
|
|
146
|
+
src_test_acts_neg = _extract_pooled(source_backend, source_tokenize, test_neg, source_layer)
|
|
147
|
+
src_scores = torch.cat([
|
|
148
|
+
src_probe_net(src_test_acts_pos).squeeze(-1),
|
|
149
|
+
src_probe_net(src_test_acts_neg).squeeze(-1),
|
|
150
|
+
])
|
|
151
|
+
src_labels = torch.cat([
|
|
152
|
+
torch.ones(len(src_test_acts_pos)),
|
|
153
|
+
torch.zeros(len(src_test_acts_neg)),
|
|
154
|
+
])
|
|
155
|
+
baseline_source = auroc_from_scores(src_scores, src_labels)
|
|
156
|
+
|
|
157
|
+
# Target baseline: target probe on target test data.
|
|
158
|
+
tgt_test_acts_pos = _extract_pooled(target_backend, target_tokenize, test_pos, target_layer)
|
|
159
|
+
tgt_test_acts_neg = _extract_pooled(target_backend, target_tokenize, test_neg, target_layer)
|
|
160
|
+
tgt_probe_net = pf_tgt.probe.net
|
|
161
|
+
tgt_scores = torch.cat([
|
|
162
|
+
tgt_probe_net(tgt_test_acts_pos).squeeze(-1),
|
|
163
|
+
tgt_probe_net(tgt_test_acts_neg).squeeze(-1),
|
|
164
|
+
])
|
|
165
|
+
tgt_labels = torch.cat([
|
|
166
|
+
torch.ones(len(tgt_test_acts_pos)),
|
|
167
|
+
torch.zeros(len(tgt_test_acts_neg)),
|
|
168
|
+
])
|
|
169
|
+
baseline_target = auroc_from_scores(tgt_scores, tgt_labels)
|
|
170
|
+
|
|
171
|
+
# Transfer: transferred probe on target activations
|
|
172
|
+
transfer_scores = torch.cat([
|
|
173
|
+
(tgt_test_acts_pos @ w_transferred + b_transferred),
|
|
174
|
+
(tgt_test_acts_neg @ w_transferred + b_transferred),
|
|
175
|
+
])
|
|
176
|
+
transfer_auroc = auroc_from_scores(transfer_scores, tgt_labels)
|
|
177
|
+
|
|
178
|
+
return TransferResult(
|
|
179
|
+
source_arch=source_arch_name,
|
|
180
|
+
target_arch=target_arch_name,
|
|
181
|
+
source_layer=source_layer,
|
|
182
|
+
target_layer=target_layer,
|
|
183
|
+
n_align_pairs=len(align_texts),
|
|
184
|
+
baseline_source_auroc=baseline_source,
|
|
185
|
+
baseline_target_auroc=baseline_target,
|
|
186
|
+
transfer_auroc=transfer_auroc,
|
|
187
|
+
transfer_drop=baseline_target - transfer_auroc,
|
|
188
|
+
)
|