archscope 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
archscope/neurons.py ADDED
@@ -0,0 +1,118 @@
1
+ """Targeted Neuron Modulation via Contrastive Pair Search (2605.12290).
2
+
3
+ 5-step algorithm reimplemented from paper:
4
+ 1. Run harmful + benign prompts through model
5
+ 2. Record MLP activations at final token position
6
+ 3. Compute per-neuron mean activation difference
7
+ 4. Select top k (default 0.1%) neurons by |Δ|
8
+ 5. At inference: multiply selected neuron activations by scalar m
9
+ """
10
+ from __future__ import annotations
11
+ from dataclasses import dataclass
12
+ import torch
13
+ from .backends import Backend
14
+ from ._utils import resolve_layer_module
15
+
16
+
17
+ @dataclass
18
+ class NeuronEditConfig:
19
+ top_frac: float = 0.001 # top 0.1% by default
20
+ layer_filter: str | None = None # e.g., "mlp" to restrict to MLP neurons
21
+ mode: str = "scalar" # "scalar" (multiply by m) or "ablate" (m=0)
22
+
23
+
24
+ @dataclass
25
+ class NeuronEdit:
26
+ """Result of contrastive search — which neurons to edit + how."""
27
+ layer_to_indices: dict[str, torch.Tensor] # layer_name → neuron indices
28
+ layer_to_deltas: dict[str, torch.Tensor] # layer_name → mean diff values
29
+ config: NeuronEditConfig
30
+ multiplier: float = 0.0 # 0 = ablate; >1 = amplify; <1 = dampen
31
+
32
+ def apply_hook(self, model, backend: Backend | None = None):
33
+ """Register forward hooks that modulate the selected neurons.
34
+
35
+ Returns a context manager that auto-removes hooks on exit.
36
+ """
37
+ hooks = []
38
+ for layer_name, indices in self.layer_to_indices.items():
39
+ module = resolve_layer_module(model, layer_name)
40
+ if module is None:
41
+ continue
42
+
43
+ indices_local = indices
44
+
45
+ def hook(module, input, output, idxs=indices_local, m=self.multiplier):
46
+ # Return MODIFIED output (don't rely on in-place — some HF layers
47
+ # produce fresh tensors that won't propagate in-place changes).
48
+ if isinstance(output, tuple):
49
+ h = output[0].clone()
50
+ if h.dim() >= 2:
51
+ h[..., idxs] = h[..., idxs] * m
52
+ return (h,) + output[1:]
53
+ h = output.clone()
54
+ if h.dim() >= 2:
55
+ h[..., idxs] = h[..., idxs] * m
56
+ return h
57
+ hooks.append(module.register_forward_hook(hook))
58
+
59
+ return _HookContext(hooks)
60
+
61
+
62
+ class _HookContext:
63
+ """Auto-removes forward hooks when used as a `with` block."""
64
+
65
+ def __init__(self, hooks):
66
+ self.hooks = hooks
67
+
68
+ def __enter__(self):
69
+ return self
70
+
71
+ def __exit__(self, *args):
72
+ for h in self.hooks:
73
+ h.remove()
74
+
75
+
76
+ def find_neurons(
77
+ model,
78
+ inputs_harmful: list,
79
+ inputs_benign: list,
80
+ config: NeuronEditConfig | None = None,
81
+ backend_hint: str | None = None,
82
+ ) -> NeuronEdit:
83
+ """Algorithm 1 of Targeted Neuron Modulation paper.
84
+
85
+ Returns a NeuronEdit that can be applied as a hook during inference.
86
+ """
87
+ config = config or NeuronEditConfig()
88
+ backend = Backend.for_model(model, hint=backend_hint)
89
+
90
+ # Get all layers (will filter to MLP later if requested)
91
+ all_layers = backend.layer_names()
92
+
93
+ # Forward both classes, collect final-token activations
94
+ harm_acts = backend.extract(inputs_harmful, layers=all_layers)
95
+ ben_acts = backend.extract(inputs_benign, layers=all_layers)
96
+
97
+ layer_to_indices = {}
98
+ layer_to_deltas = {}
99
+ for h_rec, b_rec in zip(harm_acts, ben_acts):
100
+ # Final token: (batch, -1, hidden_dim) -> (batch, hidden_dim)
101
+ h_final = h_rec.activations[:, -1, :]
102
+ b_final = b_rec.activations[:, -1, :]
103
+ # Per-neuron mean diff
104
+ delta = h_final.mean(dim=0) - b_final.mean(dim=0)
105
+ # Top k by absolute value
106
+ k = max(1, int(config.top_frac * len(delta)))
107
+ topk = torch.topk(delta.abs(), k=k)
108
+ layer_to_indices[h_rec.layer_name] = topk.indices
109
+ layer_to_deltas[h_rec.layer_name] = delta[topk.indices]
110
+
111
+ return NeuronEdit(
112
+ layer_to_indices=layer_to_indices,
113
+ layer_to_deltas=layer_to_deltas,
114
+ config=config,
115
+ multiplier=0.0,
116
+ )
117
+
118
+
archscope/probes.py ADDED
@@ -0,0 +1,160 @@
1
+ """Probe-Filtered RL — linear/MLP probes over hidden states (Drop the Act, 2605.11467).
2
+
3
+ Core idea: train a frozen-base probe over residual/hidden states that predicts
4
+ some property (faithfulness, refusal, deception). Use probe scores to filter
5
+ trajectories during RL training, or as inference-time signals.
6
+
7
+ Reimplemented from paper (code anonymized for review).
8
+ """
9
+ from __future__ import annotations
10
+ from dataclasses import dataclass
11
+ import torch
12
+ import torch.nn as nn
13
+ from .backends import Backend
14
+
15
+
16
+ @dataclass
17
+ class ProbeConfig:
18
+ """Configuration for a single probe."""
19
+ layer_name: str
20
+ probe_type: str = "linear" # "linear" or "mlp"
21
+ hidden_dim: int = 64 # only used if mlp
22
+ target: str = "performativity"
23
+
24
+
25
+ class Probe(nn.Module):
26
+ """Linear or MLP probe over a layer's hidden states.
27
+
28
+ Trained to predict a scalar property from activations at one layer.
29
+ Frozen base model: probe is the only trainable component.
30
+ """
31
+
32
+ def __init__(self, input_dim: int, config: ProbeConfig):
33
+ super().__init__()
34
+ self.config = config
35
+ if config.probe_type == "linear":
36
+ self.net = nn.Linear(input_dim, 1)
37
+ elif config.probe_type == "mlp":
38
+ self.net = nn.Sequential(
39
+ nn.Linear(input_dim, config.hidden_dim),
40
+ nn.GELU(),
41
+ nn.Linear(config.hidden_dim, 1),
42
+ )
43
+ else:
44
+ raise ValueError(f"Unknown probe_type: {config.probe_type}")
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ """Apply probe. Accepts (N, hidden_dim) or (N, seq, hidden_dim); returns
48
+ same leading dims with the final hidden_dim collapsed to scalar logits."""
49
+ return self.net(x).squeeze(-1)
50
+
51
+
52
+ class ProbeFit:
53
+ """Fit a probe given (activations, labels) pairs."""
54
+
55
+ def __init__(self, config: ProbeConfig, input_dim: int, device: str = "cpu"):
56
+ self.config = config
57
+ self.probe = Probe(input_dim, config).to(device)
58
+ self.device = device
59
+
60
+ def train(
61
+ self,
62
+ activations: torch.Tensor, # (N, hidden_dim) or (N, seq_len, hidden_dim) pooled
63
+ labels: torch.Tensor, # (N,)
64
+ epochs: int = 50,
65
+ lr: float = 1e-3,
66
+ batch_size: int = 64,
67
+ val_split: float = 0.2,
68
+ ) -> dict:
69
+ """Standard supervised fit. Returns train/val AUROC + final loss."""
70
+ if activations.dim() == 3:
71
+ # Pool seq dim by mean (could be configurable)
72
+ activations = activations.mean(dim=1)
73
+ activations = activations.to(self.device)
74
+ labels = labels.float().to(self.device)
75
+
76
+ n = len(activations)
77
+ idx = torch.randperm(n)
78
+ n_val = int(n * val_split)
79
+ train_idx, val_idx = idx[n_val:], idx[:n_val]
80
+
81
+ opt = torch.optim.AdamW(self.probe.parameters(), lr=lr)
82
+ loss_fn = nn.BCEWithLogitsLoss()
83
+
84
+ for _ in range(epochs):
85
+ self.probe.train()
86
+ perm = train_idx[torch.randperm(len(train_idx))]
87
+ for b in range(0, len(perm), batch_size):
88
+ batch = perm[b:b+batch_size]
89
+ logits = self.probe(activations[batch])
90
+ loss = loss_fn(logits, labels[batch])
91
+ opt.zero_grad()
92
+ loss.backward()
93
+ opt.step()
94
+
95
+ self.probe.eval()
96
+ with torch.no_grad():
97
+ train_logits = self.probe(activations[train_idx])
98
+ val_logits = self.probe(activations[val_idx])
99
+ return {
100
+ "train_auroc": _auroc(train_logits, labels[train_idx]),
101
+ "val_auroc": _auroc(val_logits, labels[val_idx]),
102
+ "train_loss": loss_fn(train_logits, labels[train_idx]).item(),
103
+ }
104
+
105
+ def score(self, activations: torch.Tensor) -> torch.Tensor:
106
+ """Apply probe — returns per-token (or per-example if pooled) scores."""
107
+ self.probe.eval()
108
+ with torch.no_grad():
109
+ return torch.sigmoid(self.probe(activations.to(self.device)))
110
+
111
+
112
+ def _auroc(logits: torch.Tensor, labels: torch.Tensor) -> float:
113
+ """Simple AUROC from logits + binary labels."""
114
+ from sklearn.metrics import roc_auc_score
115
+ scores = torch.sigmoid(logits).cpu().numpy()
116
+ y = labels.cpu().numpy()
117
+ try:
118
+ return float(roc_auc_score(y, scores))
119
+ except ValueError:
120
+ return float("nan") # happens when only one class present
121
+
122
+
123
+ # High-level API matching paper
124
+
125
+ def fit_probe(
126
+ model,
127
+ inputs_pos: list, # examples where target=1 (e.g., faithful)
128
+ inputs_neg: list, # examples where target=0 (e.g., reasoning theater)
129
+ layer_name: str,
130
+ backend_hint: str | None = None,
131
+ config: ProbeConfig | None = None,
132
+ device: str = "cpu",
133
+ ) -> ProbeFit:
134
+ """End-to-end: extract activations from model, fit probe."""
135
+ backend = Backend.for_model(model, hint=backend_hint)
136
+ config = config or ProbeConfig(layer_name=layer_name)
137
+
138
+ # Extract activations for both classes — detach from autograd graph
139
+ # (probe is trained separately; base model is frozen)
140
+ with torch.no_grad():
141
+ acts_pos = backend.extract(inputs_pos, layers=[layer_name])[0].activations.detach()
142
+ acts_neg = backend.extract(inputs_neg, layers=[layer_name])[0].activations.detach()
143
+
144
+ # Pool across sequence dim — handles independent padding across batches
145
+ if acts_pos.dim() == 3:
146
+ acts_pos = acts_pos.mean(dim=1) # (batch, hidden)
147
+ if acts_neg.dim() == 3:
148
+ acts_neg = acts_neg.mean(dim=1)
149
+
150
+ labels_pos = torch.ones(len(acts_pos))
151
+ labels_neg = torch.zeros(len(acts_neg))
152
+
153
+ all_acts = torch.cat([acts_pos, acts_neg], dim=0)
154
+ all_labels = torch.cat([labels_pos, labels_neg])
155
+
156
+ hidden_dim = backend.hidden_dim(layer_name)
157
+ probe_fit = ProbeFit(config, hidden_dim, device=device)
158
+ metrics = probe_fit.train(all_acts, all_labels)
159
+ probe_fit.metrics = metrics
160
+ return probe_fit
archscope/sae.py ADDED
@@ -0,0 +1,127 @@
1
+ """Sparse Autoencoders for hidden states (WriteSAE, 2605.12770).
2
+
3
+ Two SAE variants:
4
+ - DenseSAE: standard SAE with L1 sparsity penalty
5
+ - Rank1FactoredSAE: WriteSAE's contribution — atoms = v_i w_i^T (rank-1 outer
6
+ product), applied to recurrent cache writes specifically.
7
+
8
+ Both work on transformer residual streams AND recurrent hidden states.
9
+ """
10
+ from __future__ import annotations
11
+ from dataclasses import dataclass
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ @dataclass
18
+ class SAEConfig:
19
+ input_dim: int
20
+ n_features: int # dictionary size
21
+ sparsity: float = 1e-3 # L1 coefficient
22
+ sae_type: str = "dense" # "dense" or "rank1"
23
+ learning_rate: float = 1e-3
24
+ batch_size: int = 64
25
+
26
+
27
+ class DenseSAE(nn.Module):
28
+ """Standard SAE: encoder + decoder, L1 sparsity on hidden code."""
29
+
30
+ def __init__(self, config: SAEConfig):
31
+ super().__init__()
32
+ self.config = config
33
+ self.encoder = nn.Linear(config.input_dim, config.n_features)
34
+ self.decoder = nn.Linear(config.n_features, config.input_dim, bias=False)
35
+ # Initialize decoder weights to encoder.T transpose (helpful for SAEs)
36
+ with torch.no_grad():
37
+ self.decoder.weight.data = self.encoder.weight.data.T.clone()
38
+
39
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
40
+ return F.relu(self.encoder(x))
41
+
42
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
43
+ return self.decoder(z)
44
+
45
+ def forward(self, x: torch.Tensor):
46
+ z = self.encode(x)
47
+ x_hat = self.decode(z)
48
+ return x_hat, z
49
+
50
+ def loss(self, x: torch.Tensor):
51
+ x_hat, z = self.forward(x)
52
+ recon = F.mse_loss(x_hat, x)
53
+ l1 = z.abs().mean() * self.config.sparsity
54
+ return recon + l1, {"recon": recon.item(), "l1": l1.item(), "n_active": (z > 0).float().mean().item()}
55
+
56
+
57
+ class Rank1FactoredSAE(nn.Module):
58
+ """WriteSAE rank-1 factored atoms: each feature i = v_i w_i^T outer product.
59
+
60
+ Designed for recurrent CACHE WRITES specifically — atoms substitute for
61
+ native write contributions at matched Frobenius norm.
62
+
63
+ Reference: Eq. 3-factor closed form Δℓ ≈ G·⟨w_i,q_t⟩·⟨v_i,W_u[tok]⟩
64
+ """
65
+
66
+ def __init__(self, config: SAEConfig):
67
+ super().__init__()
68
+ assert config.sae_type == "rank1"
69
+ self.config = config
70
+ # Atoms parameterized as v (output) and w (input) vectors
71
+ self.v = nn.Parameter(torch.randn(config.n_features, config.input_dim) * 0.02)
72
+ self.w = nn.Parameter(torch.randn(config.n_features, config.input_dim) * 0.02)
73
+
74
+ def fire(self, x: torch.Tensor) -> torch.Tensor:
75
+ """Firing strengths per atom: <w_i, x>, kept positive."""
76
+ # x: (..., input_dim), w: (n_features, input_dim)
77
+ return F.relu(x @ self.w.T)
78
+
79
+ def reconstruct(self, x: torch.Tensor) -> torch.Tensor:
80
+ fires = self.fire(x)
81
+ # Output: sum_i fires_i * v_i
82
+ return fires @ self.v
83
+
84
+ def forward(self, x: torch.Tensor):
85
+ x_hat = self.reconstruct(x)
86
+ z = self.fire(x)
87
+ return x_hat, z
88
+
89
+ def loss(self, x: torch.Tensor):
90
+ x_hat, z = self.forward(x)
91
+ recon = F.mse_loss(x_hat, x)
92
+ l1 = z.abs().mean() * self.config.sparsity
93
+ return recon + l1, {"recon": recon.item(), "l1": l1.item(), "n_active": (z > 0).float().mean().item()}
94
+
95
+
96
+ def build_sae(config: SAEConfig):
97
+ if config.sae_type == "dense":
98
+ return DenseSAE(config)
99
+ if config.sae_type == "rank1":
100
+ return Rank1FactoredSAE(config)
101
+ raise ValueError(f"Unknown sae_type: {config.sae_type}")
102
+
103
+
104
+ def fit_sae(
105
+ activations: torch.Tensor, # (N, input_dim) flattened
106
+ config: SAEConfig,
107
+ epochs: int = 100,
108
+ device: str = "cpu",
109
+ ) -> nn.Module:
110
+ """Train SAE on flattened activations."""
111
+ sae = build_sae(config).to(device)
112
+ opt = torch.optim.AdamW(sae.parameters(), lr=config.learning_rate)
113
+ activations = activations.to(device)
114
+
115
+ n = len(activations)
116
+ last_metrics: dict = {}
117
+ for _ in range(epochs):
118
+ perm = torch.randperm(n)
119
+ for b in range(0, n, config.batch_size):
120
+ batch = activations[perm[b:b + config.batch_size]]
121
+ loss, metrics = sae.loss(batch)
122
+ opt.zero_grad()
123
+ loss.backward()
124
+ opt.step()
125
+ last_metrics = metrics
126
+ sae.last_metrics = last_metrics
127
+ return sae
archscope/transfer.py ADDED
@@ -0,0 +1,188 @@
1
+ """Cross-architecture probe transfer via paired-activation linear alignment.
2
+
3
+ Setup:
4
+ - Two models with different hidden dims (e.g., Pythia 768 vs Kazdov 512)
5
+ - Train a linear projection M: kazdov_space (512) → pythia_space (768) using
6
+ paired (text → activation pair) data
7
+ - A Pythia probe with weights w_py ∈ R^768 transferred to kazdov-space becomes
8
+ w_kz = M^T · w_py ∈ R^512
9
+
10
+ This tests: do probe directions from one architecture transfer to another?
11
+ """
12
+ from __future__ import annotations
13
+ import torch
14
+ import torch.nn as nn
15
+ from dataclasses import dataclass
16
+
17
+
18
+ @dataclass
19
+ class TransferResult:
20
+ """One transfer experiment's metrics."""
21
+ source_arch: str # e.g., "pythia"
22
+ target_arch: str # e.g., "kazdov"
23
+ source_layer: str
24
+ target_layer: str
25
+ n_align_pairs: int
26
+
27
+ baseline_source_auroc: float # source probe on source data (no transfer)
28
+ baseline_target_auroc: float # target probe on target data (in-arch reference)
29
+ transfer_auroc: float # source probe via M applied to target data
30
+ transfer_drop: float # baseline_target - transfer (how much we lose)
31
+
32
+
33
+ def learn_alignment(
34
+ src_acts: torch.Tensor, # (N, d_src)
35
+ tgt_acts: torch.Tensor, # (N, d_tgt)
36
+ ridge: float = 1e-3,
37
+ ) -> torch.Tensor:
38
+ """Ridge regression: learn M such that M @ tgt_acts.T ≈ src_acts.T.
39
+
40
+ M has shape (d_src, d_tgt). After fit, src_act ≈ M @ tgt_act.
41
+
42
+ Ridge term prevents overfitting on small N.
43
+ """
44
+ assert src_acts.shape[0] == tgt_acts.shape[0]
45
+ src_acts = src_acts.float()
46
+ tgt_acts = tgt_acts.float()
47
+ # M = src @ tgt.T @ (tgt @ tgt.T + λI)^-1
48
+ # But the standard formulation: we want X @ M.T = Y where X is target, Y is source
49
+ # So M.T = (X.T X + λI)^-1 X.T Y
50
+ # Then M = ((X.T X + λI)^-1 X.T Y).T
51
+ X = tgt_acts # (N, d_tgt)
52
+ Y = src_acts # (N, d_src)
53
+ XtX = X.T @ X
54
+ reg = torch.eye(XtX.shape[0], device=XtX.device) * ridge
55
+ M_T = torch.linalg.solve(XtX + reg, X.T @ Y)
56
+ M = M_T.T # (d_src, d_tgt)
57
+ return M
58
+
59
+
60
+ def transfer_probe(
61
+ probe_weights_source: torch.Tensor, # (d_src,) — Pythia probe direction
62
+ probe_bias_source: torch.Tensor, # scalar
63
+ alignment: torch.Tensor, # M: (d_src, d_tgt)
64
+ ) -> tuple[torch.Tensor, torch.Tensor]:
65
+ """Transport a probe from source arch to target arch via alignment matrix.
66
+
67
+ Given Pythia probe f(x_py) = w_py · x_py + b, and alignment x_py ≈ M @ x_kz,
68
+ the kazdov-space probe is w_kz = M.T @ w_py, with same bias.
69
+ """
70
+ w_target = alignment.T @ probe_weights_source
71
+ return w_target, probe_bias_source
72
+
73
+
74
+ def auroc_from_scores(scores: torch.Tensor, labels: torch.Tensor) -> float:
75
+ """Quick AUROC."""
76
+ from sklearn.metrics import roc_auc_score
77
+ try:
78
+ return float(roc_auc_score(labels.cpu().numpy(), scores.cpu().numpy()))
79
+ except ValueError:
80
+ return float("nan")
81
+
82
+
83
+ def evaluate_transfer(
84
+ source_model, target_model,
85
+ source_backend, target_backend,
86
+ source_tokenize, target_tokenize,
87
+ align_texts: list[str], # texts for learning M (no labels needed)
88
+ train_pos: list[str], train_neg: list[str],
89
+ test_pos: list[str], test_neg: list[str],
90
+ source_layer: str,
91
+ target_layer: str,
92
+ source_arch_name: str = "source",
93
+ target_arch_name: str = "target",
94
+ ) -> TransferResult:
95
+ """Full transfer experiment.
96
+
97
+ 1. Get paired source/target activations on `align_texts` → learn alignment M
98
+ 2. Train source probe on source activations of train_pos/neg
99
+ 3. Train target probe on target activations of train_pos/neg (in-arch baseline)
100
+ 4. Transfer source probe via M, apply to target activations of test_pos/neg
101
+ 5. Compare AUROCs.
102
+ """
103
+ def _extract_pooled(backend, tokenize, texts: list[str], layer: str) -> torch.Tensor:
104
+ """Extract activations and pool to (N, hidden_dim) by averaging seq dim."""
105
+ with torch.no_grad():
106
+ rec = backend.extract(tokenize(texts), layers=[layer])[0]
107
+ acts = rec.activations.detach()
108
+ return acts.mean(dim=1) if acts.dim() == 3 else acts
109
+
110
+ # -- 1. Learn alignment from paired activations on the same texts.
111
+ src_align = _extract_pooled(source_backend, source_tokenize, align_texts, source_layer)
112
+ tgt_align = _extract_pooled(target_backend, target_tokenize, align_texts, target_layer)
113
+ M = learn_alignment(src_align, tgt_align)
114
+
115
+ # -- 2. Train source probe (Pythia-style) on source activations
116
+ from . import probes
117
+ src_train_inputs_pos = source_tokenize(train_pos)
118
+ src_train_inputs_neg = source_tokenize(train_neg)
119
+ pf_src = probes.fit_probe(
120
+ source_model, src_train_inputs_pos, src_train_inputs_neg,
121
+ layer_name=source_layer, backend_hint=source_arch_name,
122
+ )
123
+
124
+ # Extract source probe weights (linear case)
125
+ src_probe_net = pf_src.probe.net
126
+ if not isinstance(src_probe_net, nn.Linear):
127
+ raise NotImplementedError("transfer only works for linear probes for now")
128
+ w_src = src_probe_net.weight.data.squeeze(0)
129
+ b_src = src_probe_net.bias.data
130
+
131
+ # -- 3. Train target probe (in-arch baseline)
132
+ tgt_train_inputs_pos = target_tokenize(train_pos)
133
+ tgt_train_inputs_neg = target_tokenize(train_neg)
134
+ pf_tgt = probes.fit_probe(
135
+ target_model, tgt_train_inputs_pos, tgt_train_inputs_neg,
136
+ layer_name=target_layer, backend_hint=target_arch_name,
137
+ )
138
+
139
+ # -- 4. Transfer: project w_src into target space via M
140
+ w_transferred, b_transferred = transfer_probe(w_src, b_src, M)
141
+
142
+ # -- 5. Evaluate on the test set.
143
+ with torch.no_grad():
144
+ # Source baseline: source probe on source test data.
145
+ src_test_acts_pos = _extract_pooled(source_backend, source_tokenize, test_pos, source_layer)
146
+ src_test_acts_neg = _extract_pooled(source_backend, source_tokenize, test_neg, source_layer)
147
+ src_scores = torch.cat([
148
+ src_probe_net(src_test_acts_pos).squeeze(-1),
149
+ src_probe_net(src_test_acts_neg).squeeze(-1),
150
+ ])
151
+ src_labels = torch.cat([
152
+ torch.ones(len(src_test_acts_pos)),
153
+ torch.zeros(len(src_test_acts_neg)),
154
+ ])
155
+ baseline_source = auroc_from_scores(src_scores, src_labels)
156
+
157
+ # Target baseline: target probe on target test data.
158
+ tgt_test_acts_pos = _extract_pooled(target_backend, target_tokenize, test_pos, target_layer)
159
+ tgt_test_acts_neg = _extract_pooled(target_backend, target_tokenize, test_neg, target_layer)
160
+ tgt_probe_net = pf_tgt.probe.net
161
+ tgt_scores = torch.cat([
162
+ tgt_probe_net(tgt_test_acts_pos).squeeze(-1),
163
+ tgt_probe_net(tgt_test_acts_neg).squeeze(-1),
164
+ ])
165
+ tgt_labels = torch.cat([
166
+ torch.ones(len(tgt_test_acts_pos)),
167
+ torch.zeros(len(tgt_test_acts_neg)),
168
+ ])
169
+ baseline_target = auroc_from_scores(tgt_scores, tgt_labels)
170
+
171
+ # Transfer: transferred probe on target activations
172
+ transfer_scores = torch.cat([
173
+ (tgt_test_acts_pos @ w_transferred + b_transferred),
174
+ (tgt_test_acts_neg @ w_transferred + b_transferred),
175
+ ])
176
+ transfer_auroc = auroc_from_scores(transfer_scores, tgt_labels)
177
+
178
+ return TransferResult(
179
+ source_arch=source_arch_name,
180
+ target_arch=target_arch_name,
181
+ source_layer=source_layer,
182
+ target_layer=target_layer,
183
+ n_align_pairs=len(align_texts),
184
+ baseline_source_auroc=baseline_source,
185
+ baseline_target_auroc=baseline_target,
186
+ transfer_auroc=transfer_auroc,
187
+ transfer_drop=baseline_target - transfer_auroc,
188
+ )