archscope 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- archscope/__init__.py +30 -0
- archscope/_utils.py +113 -0
- archscope/attribute.py +201 -0
- archscope/backends.py +236 -0
- archscope/bench.py +262 -0
- archscope/circuits.py +255 -0
- archscope/cli.py +120 -0
- archscope/diff.py +212 -0
- archscope/kazdov_backend.py +141 -0
- archscope/lens.py +304 -0
- archscope/neurons.py +118 -0
- archscope/probes.py +160 -0
- archscope/sae.py +127 -0
- archscope/transfer.py +188 -0
- archscope-0.2.2.dist-info/METADATA +324 -0
- archscope-0.2.2.dist-info/RECORD +20 -0
- archscope-0.2.2.dist-info/WHEEL +5 -0
- archscope-0.2.2.dist-info/entry_points.txt +2 -0
- archscope-0.2.2.dist-info/licenses/LICENSE +17 -0
- archscope-0.2.2.dist-info/top_level.txt +1 -0
archscope/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""archscope: unified mech interp toolkit across small + RNN + transformer.
|
|
2
|
+
|
|
3
|
+
Four core methods unified under a single API:
|
|
4
|
+
- probes: linear/MLP probes over hidden states (Drop the Act inspired)
|
|
5
|
+
- sae: sparse autoencoders for residual + recurrent state (WriteSAE)
|
|
6
|
+
- neurons: targeted neuron modulation via contrastive search (Nous Research)
|
|
7
|
+
- attribute: activation patching + DIM decomposition (Multi-Agent Sycophancy)
|
|
8
|
+
|
|
9
|
+
Each method exposes the same architecture-agnostic API:
|
|
10
|
+
- .extract(model, inputs) -> hidden states / activations
|
|
11
|
+
- .fit(activations, labels) -> learned tool
|
|
12
|
+
- .apply(model, inputs) -> modified outputs / scores / explanations
|
|
13
|
+
|
|
14
|
+
Designed for cross-architecture comparison: transformer, Mamba/SSM, custom RNN.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
__version__ = "0.2.2"
|
|
18
|
+
|
|
19
|
+
from . import probes, sae, neurons, attribute, backends, circuits, transfer, bench, lens, diff
|
|
20
|
+
|
|
21
|
+
# Kazdov backend registers itself on import — optional, only if kazdov repo present
|
|
22
|
+
try:
|
|
23
|
+
from . import kazdov_backend # noqa: F401
|
|
24
|
+
except ImportError:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"probes", "sae", "neurons", "attribute", "backends",
|
|
29
|
+
"circuits", "transfer", "bench", "lens", "diff", "__version__",
|
|
30
|
+
]
|
archscope/_utils.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Internal utilities — not part of the public API.
|
|
2
|
+
|
|
3
|
+
Shared layer-resolution logic used by neurons.py and attribute.py to find
|
|
4
|
+
the actual nn.Module corresponding to a layer_name string like
|
|
5
|
+
"layer_5.residual" across different HuggingFace model architectures.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Common HF architecture paths to layer ModuleList.
|
|
12
|
+
# Order matters: first match wins. Add new architectures here.
|
|
13
|
+
_LAYER_PATHS: list[tuple[str, str | None]] = [
|
|
14
|
+
("model", "layers"), # Llama / Mistral / Qwen / GPT-NeoX-style
|
|
15
|
+
("transformer", "h"), # GPT-2 / Falcon
|
|
16
|
+
("transformer", "blocks"), # MPT
|
|
17
|
+
("gpt_neox", "layers"), # Pythia
|
|
18
|
+
("backbone", "layers"), # Mamba / Mamba-2
|
|
19
|
+
("layers", None), # Direct .layers (some custom models, e.g. kazdov)
|
|
20
|
+
("h", None), # Direct .h
|
|
21
|
+
("blocks", None), # Direct .blocks (kazdov)
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _parse_layer_index(layer_name: str) -> int | None:
|
|
26
|
+
"""Extract the integer index from a name like 'layer_5.residual'."""
|
|
27
|
+
try:
|
|
28
|
+
idx_part = layer_name.split("_")[1].split(".")[0]
|
|
29
|
+
return int(idx_part)
|
|
30
|
+
except (IndexError, ValueError):
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def resolve_layer_module(model: Any, layer_name: str):
|
|
35
|
+
"""Return the nn.Module corresponding to a layer_name across HF naming conventions.
|
|
36
|
+
|
|
37
|
+
Handles: Llama, Mistral, Qwen, GPT-2, Falcon, MPT, Pythia, Mamba, custom .blocks.
|
|
38
|
+
|
|
39
|
+
Returns None if the layer name cannot be parsed or no path matches.
|
|
40
|
+
"""
|
|
41
|
+
idx = _parse_layer_index(layer_name)
|
|
42
|
+
if idx is None:
|
|
43
|
+
return None
|
|
44
|
+
for parent_attr, child_attr in _LAYER_PATHS:
|
|
45
|
+
parent_obj = getattr(model, parent_attr, None)
|
|
46
|
+
if parent_obj is None:
|
|
47
|
+
continue
|
|
48
|
+
layers = parent_obj if child_attr is None else getattr(parent_obj, child_attr, None)
|
|
49
|
+
if layers is None:
|
|
50
|
+
continue
|
|
51
|
+
try:
|
|
52
|
+
return layers[idx]
|
|
53
|
+
except (IndexError, TypeError):
|
|
54
|
+
continue
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
_UNEMBED_PATHS = [
|
|
59
|
+
"lm_head", # Llama, Pythia, Mistral, Mamba, kazdov, most HF CausalLMs
|
|
60
|
+
"embed_out", # some HF models
|
|
61
|
+
"output_layer", # some custom models
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
_FINAL_NORM_PATHS: list[tuple[str, str | None]] = [
|
|
65
|
+
("model", "norm"), # Llama / Mistral
|
|
66
|
+
("gpt_neox", "final_layer_norm"), # Pythia
|
|
67
|
+
("transformer", "ln_f"), # GPT-2 / Falcon
|
|
68
|
+
("backbone", "norm_f"), # Mamba
|
|
69
|
+
("ln_f", None), # kazdov (top-level)
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def resolve_unembedding(model: Any):
|
|
74
|
+
"""Find the model's unembedding / lm_head module. Returns nn.Module or None."""
|
|
75
|
+
for path in _UNEMBED_PATHS:
|
|
76
|
+
m = getattr(model, path, None)
|
|
77
|
+
if m is not None:
|
|
78
|
+
return m
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def resolve_final_norm(model: Any):
|
|
83
|
+
"""Find the model's final pre-unembedding layer norm. Returns module or None."""
|
|
84
|
+
for parent_attr, child_attr in _FINAL_NORM_PATHS:
|
|
85
|
+
parent_obj = getattr(model, parent_attr, None)
|
|
86
|
+
if parent_obj is None:
|
|
87
|
+
continue
|
|
88
|
+
norm = parent_obj if child_attr is None else getattr(parent_obj, child_attr, None)
|
|
89
|
+
if norm is not None:
|
|
90
|
+
return norm
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def resolve_subcomponent_module(model: Any, idx: int, component: str):
|
|
95
|
+
"""Find an attention or MLP submodule inside a layer at index `idx`.
|
|
96
|
+
|
|
97
|
+
component: "attention" or "mlp".
|
|
98
|
+
Returns None if the component isn't found.
|
|
99
|
+
"""
|
|
100
|
+
layer = resolve_layer_module(model, f"layer_{idx}.residual")
|
|
101
|
+
if layer is None:
|
|
102
|
+
return None
|
|
103
|
+
if component == "attention":
|
|
104
|
+
for attr in ("self_attn", "attn", "attention"):
|
|
105
|
+
sub = getattr(layer, attr, None)
|
|
106
|
+
if sub is not None:
|
|
107
|
+
return sub
|
|
108
|
+
elif component == "mlp":
|
|
109
|
+
for attr in ("mlp", "feed_forward", "ffn"):
|
|
110
|
+
sub = getattr(layer, attr, None)
|
|
111
|
+
if sub is not None:
|
|
112
|
+
return sub
|
|
113
|
+
return None
|
archscope/attribute.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Activation patching + DIM decomposition (Multi-Agent Sycophancy, 2605.12991).
|
|
2
|
+
|
|
3
|
+
Two methods:
|
|
4
|
+
- `activation_patch`: replace activations from one prompt with another at
|
|
5
|
+
specified layer range. Measures how much of the behavioral gap is "restored"
|
|
6
|
+
by the patch.
|
|
7
|
+
- `dim_decompose`: difference-in-means decomposition of attribution per
|
|
8
|
+
component (MLP vs attention).
|
|
9
|
+
|
|
10
|
+
Use cases:
|
|
11
|
+
- Localize behavior to specific layers (e.g., "L14-L18 restores 96.8% of gap")
|
|
12
|
+
- Separate attention vs MLP contribution
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from .backends import Backend
|
|
20
|
+
from ._utils import resolve_layer_module, resolve_subcomponent_module
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class PatchResult:
|
|
25
|
+
"""Outcome of a single activation_patch experiment."""
|
|
26
|
+
layer_range: tuple[int, int]
|
|
27
|
+
gap_restored: float # fraction of behavioral gap closed by patching
|
|
28
|
+
target_metric: str # what we measured (e.g., "logit_diff")
|
|
29
|
+
baseline_metric: float
|
|
30
|
+
patched_metric: float
|
|
31
|
+
clean_metric: float
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class DIMResult:
|
|
36
|
+
"""Difference-in-means attribution per component (attention vs MLP)."""
|
|
37
|
+
components: dict[str, float] # e.g., {"attention": 0.45, "mlp": 0.02}
|
|
38
|
+
total: float
|
|
39
|
+
layer_range: tuple[int, int]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _remove_hooks(hooks):
|
|
43
|
+
"""Cleanly detach a list of forward-hook handles."""
|
|
44
|
+
for h in hooks:
|
|
45
|
+
h.remove()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def activation_patch(
|
|
49
|
+
model,
|
|
50
|
+
prompt_source,
|
|
51
|
+
prompt_target,
|
|
52
|
+
layer_indices: list[int],
|
|
53
|
+
metric_fn,
|
|
54
|
+
backend_hint: str | None = None,
|
|
55
|
+
) -> PatchResult:
|
|
56
|
+
"""Replace activations at chosen layers with those from a source prompt.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
model: any model exposing `model(**inputs, return_dict=True).logits`
|
|
60
|
+
prompt_source: tokenized inputs to extract clean activations from
|
|
61
|
+
prompt_target: tokenized inputs where activations will be replaced
|
|
62
|
+
layer_indices: which layer indices to patch
|
|
63
|
+
metric_fn: function mapping `model_outputs → scalar` (e.g., logit diff)
|
|
64
|
+
backend_hint: backend name for extraction
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
PatchResult with the fraction of behavioral gap closed by patching.
|
|
68
|
+
"""
|
|
69
|
+
backend = Backend.for_model(model, hint=backend_hint)
|
|
70
|
+
layer_names = [f"layer_{i}.residual" for i in layer_indices]
|
|
71
|
+
|
|
72
|
+
# 1. Clean source: extract activations to patch in.
|
|
73
|
+
src_acts = backend.extract(prompt_source, layers=layer_names)
|
|
74
|
+
|
|
75
|
+
# 2. Clean target: baseline metric.
|
|
76
|
+
with torch.no_grad():
|
|
77
|
+
target_clean_out = model(**prompt_target, output_hidden_states=False, return_dict=True)
|
|
78
|
+
clean_metric = metric_fn(target_clean_out)
|
|
79
|
+
|
|
80
|
+
# 3. Patched target: hook in source activations.
|
|
81
|
+
hooks = []
|
|
82
|
+
for layer_name, src_rec in zip(layer_names, src_acts):
|
|
83
|
+
idx = int(layer_name.split("_")[1].split(".")[0])
|
|
84
|
+
module = resolve_layer_module(model, f"layer_{idx}.residual")
|
|
85
|
+
if module is None:
|
|
86
|
+
continue
|
|
87
|
+
src_h = src_rec.activations
|
|
88
|
+
|
|
89
|
+
def hook(mod, inp, out, replacement=src_h):
|
|
90
|
+
if isinstance(out, tuple):
|
|
91
|
+
return (replacement,) + out[1:]
|
|
92
|
+
return replacement
|
|
93
|
+
hooks.append(module.register_forward_hook(hook))
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
with torch.no_grad():
|
|
97
|
+
patched_out = model(**prompt_target, output_hidden_states=False, return_dict=True)
|
|
98
|
+
patched_metric = metric_fn(patched_out)
|
|
99
|
+
finally:
|
|
100
|
+
_remove_hooks(hooks)
|
|
101
|
+
|
|
102
|
+
# Source baseline (no hooks).
|
|
103
|
+
with torch.no_grad():
|
|
104
|
+
src_out = model(**prompt_source, output_hidden_states=False, return_dict=True)
|
|
105
|
+
source_metric = metric_fn(src_out)
|
|
106
|
+
|
|
107
|
+
gap = source_metric - clean_metric
|
|
108
|
+
gap_restored = 0.0 if abs(gap) < 1e-9 else (patched_metric - clean_metric) / gap
|
|
109
|
+
|
|
110
|
+
return PatchResult(
|
|
111
|
+
layer_range=(min(layer_indices), max(layer_indices)),
|
|
112
|
+
gap_restored=float(gap_restored),
|
|
113
|
+
target_metric="custom",
|
|
114
|
+
baseline_metric=float(source_metric),
|
|
115
|
+
patched_metric=float(patched_metric),
|
|
116
|
+
clean_metric=float(clean_metric),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def dim_decompose(
|
|
121
|
+
model,
|
|
122
|
+
prompt_a,
|
|
123
|
+
prompt_b,
|
|
124
|
+
layer_indices: list[int],
|
|
125
|
+
metric_fn,
|
|
126
|
+
components: tuple[str, ...] = ("attention", "mlp"),
|
|
127
|
+
backend_hint: str | None = None, # kept for symmetry with activation_patch
|
|
128
|
+
) -> DIMResult:
|
|
129
|
+
"""Decompose a behavioral difference into per-component contributions.
|
|
130
|
+
|
|
131
|
+
For each component (default: attention, mlp), captures its output during
|
|
132
|
+
`prompt_a`, then patches that output into the forward pass on `prompt_b`,
|
|
133
|
+
and measures the fraction of the metric gap that the patch closes.
|
|
134
|
+
|
|
135
|
+
`backend_hint` is accepted but unused (this function uses module hooks
|
|
136
|
+
directly via `resolve_subcomponent_module`).
|
|
137
|
+
"""
|
|
138
|
+
del backend_hint # unused; kept for API symmetry
|
|
139
|
+
|
|
140
|
+
with torch.no_grad():
|
|
141
|
+
out_a = model(**prompt_a, return_dict=True)
|
|
142
|
+
out_b = model(**prompt_b, return_dict=True)
|
|
143
|
+
metric_a = metric_fn(out_a)
|
|
144
|
+
metric_b = metric_fn(out_b)
|
|
145
|
+
total_gap = metric_a - metric_b
|
|
146
|
+
|
|
147
|
+
contributions: dict[str, float] = {}
|
|
148
|
+
for comp in components:
|
|
149
|
+
# 1) Capture component outputs during prompt_a.
|
|
150
|
+
capture_hooks = []
|
|
151
|
+
src_acts_by_layer: dict[int, list] = {}
|
|
152
|
+
for idx in layer_indices:
|
|
153
|
+
module = resolve_subcomponent_module(model, idx, comp)
|
|
154
|
+
if module is None:
|
|
155
|
+
continue
|
|
156
|
+
captured: list = []
|
|
157
|
+
|
|
158
|
+
def capture(mod, inp, out, store=captured):
|
|
159
|
+
store.append(out[0] if isinstance(out, tuple) else out)
|
|
160
|
+
capture_hooks.append(module.register_forward_hook(capture))
|
|
161
|
+
src_acts_by_layer[idx] = captured
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
model(**prompt_a, return_dict=True)
|
|
166
|
+
finally:
|
|
167
|
+
_remove_hooks(capture_hooks)
|
|
168
|
+
|
|
169
|
+
# 2) Patch captured outputs into prompt_b's forward pass.
|
|
170
|
+
patch_hooks = []
|
|
171
|
+
for idx in layer_indices:
|
|
172
|
+
if idx not in src_acts_by_layer:
|
|
173
|
+
continue
|
|
174
|
+
module = resolve_subcomponent_module(model, idx, comp)
|
|
175
|
+
if module is None:
|
|
176
|
+
continue
|
|
177
|
+
stored = src_acts_by_layer[idx]
|
|
178
|
+
if not stored:
|
|
179
|
+
continue
|
|
180
|
+
captured_out = stored[0]
|
|
181
|
+
|
|
182
|
+
def patch(mod, inp, out, repl=captured_out):
|
|
183
|
+
if isinstance(out, tuple):
|
|
184
|
+
return (repl,) + out[1:]
|
|
185
|
+
return repl
|
|
186
|
+
patch_hooks.append(module.register_forward_hook(patch))
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
with torch.no_grad():
|
|
190
|
+
patched_out = model(**prompt_b, return_dict=True)
|
|
191
|
+
patched_metric = metric_fn(patched_out)
|
|
192
|
+
finally:
|
|
193
|
+
_remove_hooks(patch_hooks)
|
|
194
|
+
|
|
195
|
+
contributions[comp] = float((patched_metric - metric_b) / (total_gap + 1e-9))
|
|
196
|
+
|
|
197
|
+
return DIMResult(
|
|
198
|
+
components=contributions,
|
|
199
|
+
total=float(total_gap),
|
|
200
|
+
layer_range=(min(layer_indices), max(layer_indices)),
|
|
201
|
+
)
|
archscope/backends.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Architecture-agnostic activation extraction.
|
|
2
|
+
|
|
3
|
+
The core abstraction: a `Backend` knows how to hook into a model and pull out
|
|
4
|
+
hidden states at named layers, regardless of underlying framework
|
|
5
|
+
(PyTorch/JAX/custom).
|
|
6
|
+
|
|
7
|
+
Three backends implemented:
|
|
8
|
+
- TransformerBackend: HuggingFace transformers (residual stream per layer)
|
|
9
|
+
- MambaBackend: state-space models (hidden state + ssm state per layer)
|
|
10
|
+
- RecurrentBackend: generic RNN-like (extracts hidden state per timestep)
|
|
11
|
+
|
|
12
|
+
Custom architectures (e.g., kazdov MoBE-BCN) register via Backend.register().
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
import abc
|
|
16
|
+
import torch
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class ActivationRecord:
|
|
23
|
+
"""Captured activations from a single forward pass.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
layer_name: identifier of the layer (e.g., "blocks.5.residual")
|
|
27
|
+
activations: tensor of shape (batch, seq_len, hidden_dim) typically
|
|
28
|
+
meta: arch-specific metadata (e.g., {'kind': 'residual'} or {'kind': 'ssm_state'})
|
|
29
|
+
"""
|
|
30
|
+
layer_name: str
|
|
31
|
+
activations: Any # torch.Tensor or jax.Array
|
|
32
|
+
meta: dict
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Backend(abc.ABC):
|
|
36
|
+
"""Abstract interface — extract activations from any model architecture."""
|
|
37
|
+
|
|
38
|
+
_registry: dict[str, type["Backend"]] = {}
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def register(cls, name: str):
|
|
42
|
+
def deco(klass):
|
|
43
|
+
cls._registry[name] = klass
|
|
44
|
+
return klass
|
|
45
|
+
return deco
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def for_model(cls, model: Any, hint: str | None = None) -> "Backend":
|
|
49
|
+
"""Auto-detect or use hint to select backend."""
|
|
50
|
+
if hint and hint in cls._registry:
|
|
51
|
+
return cls._registry[hint](model)
|
|
52
|
+
# Auto-detect via attribute introspection
|
|
53
|
+
if hasattr(model, "config") and getattr(model.config, "model_type", None) in ("llama", "gpt2", "qwen2", "qwen3"):
|
|
54
|
+
return cls._registry["transformer"](model)
|
|
55
|
+
if hasattr(model, "config") and getattr(model.config, "model_type", "") in ("mamba", "mamba2"):
|
|
56
|
+
return cls._registry["mamba"](model)
|
|
57
|
+
# Default fallback
|
|
58
|
+
if "recurrent" in cls._registry:
|
|
59
|
+
return cls._registry["recurrent"](model)
|
|
60
|
+
raise ValueError(f"No backend matches model {type(model).__name__}. Register via Backend.register('name').")
|
|
61
|
+
|
|
62
|
+
def __init__(self, model: Any):
|
|
63
|
+
self.model = model
|
|
64
|
+
|
|
65
|
+
@abc.abstractmethod
|
|
66
|
+
def layer_names(self) -> list[str]:
|
|
67
|
+
"""Return list of layer identifiers we can hook."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
@abc.abstractmethod
|
|
71
|
+
def extract(self, inputs: Any, layers: list[str] | None = None) -> list[ActivationRecord]:
|
|
72
|
+
"""Run forward pass, return activations at requested layers (all if None)."""
|
|
73
|
+
...
|
|
74
|
+
|
|
75
|
+
@abc.abstractmethod
|
|
76
|
+
def hidden_dim(self, layer_name: str) -> int:
|
|
77
|
+
"""Dimensionality of activations at a given layer."""
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@Backend.register("transformer")
|
|
82
|
+
class TransformerBackend(Backend):
|
|
83
|
+
"""HuggingFace transformers backend — extracts residual stream per layer."""
|
|
84
|
+
|
|
85
|
+
def layer_names(self) -> list[str]:
|
|
86
|
+
# Standard HF: model.model.layers[i] for decoder transformers
|
|
87
|
+
n_layers = getattr(self.model.config, "num_hidden_layers", 0)
|
|
88
|
+
return [f"layer_{i}.residual" for i in range(n_layers)]
|
|
89
|
+
|
|
90
|
+
def extract(self, inputs, layers=None):
|
|
91
|
+
layers = layers or self.layer_names()
|
|
92
|
+
# Use HF's output_hidden_states=True for clean extraction.
|
|
93
|
+
# Wrap in no_grad: extraction shouldn't build a backward graph.
|
|
94
|
+
with torch.no_grad():
|
|
95
|
+
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
|
96
|
+
records = []
|
|
97
|
+
hidden_states = outputs.hidden_states # tuple of (n_layers+1) tensors
|
|
98
|
+
for layer_name in layers:
|
|
99
|
+
# Parse "layer_N.residual" → N+1 (since [0] is embedding output)
|
|
100
|
+
idx = int(layer_name.split("_")[1].split(".")[0]) + 1
|
|
101
|
+
records.append(ActivationRecord(
|
|
102
|
+
layer_name=layer_name,
|
|
103
|
+
activations=hidden_states[idx],
|
|
104
|
+
meta={"kind": "residual", "arch": "transformer"},
|
|
105
|
+
))
|
|
106
|
+
return records
|
|
107
|
+
|
|
108
|
+
def hidden_dim(self, layer_name: str) -> int:
|
|
109
|
+
return self.model.config.hidden_size
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@Backend.register("mamba")
|
|
113
|
+
class MambaBackend(Backend):
|
|
114
|
+
"""Mamba/Mamba-2 backend — extracts residual stream AND SSM recurrent state.
|
|
115
|
+
|
|
116
|
+
Works with HuggingFace MambaForCausalLM (and Mamba2ForCausalLM).
|
|
117
|
+
|
|
118
|
+
Two flavors of activations exposed:
|
|
119
|
+
- `layer_N.residual` → residual stream after block N (B, T, hidden_size)
|
|
120
|
+
- `layer_N.ssm_state` → final SSM recurrent state after processing the
|
|
121
|
+
sequence at block N: shape (B, intermediate_size, ssm_state_size).
|
|
122
|
+
This exposes the recurrent state used by Mamba-style models —
|
|
123
|
+
useful when experiments need access to memory-like state rather
|
|
124
|
+
than residual activations alone.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def layer_names(self) -> list[str]:
|
|
128
|
+
n_layers = getattr(self.model.config, "n_layer", 0) or getattr(self.model.config, "num_hidden_layers", 0)
|
|
129
|
+
out = []
|
|
130
|
+
for i in range(n_layers):
|
|
131
|
+
out.append(f"layer_{i}.residual")
|
|
132
|
+
out.append(f"layer_{i}.ssm_state")
|
|
133
|
+
return out
|
|
134
|
+
|
|
135
|
+
def extract(self, inputs, layers=None):
|
|
136
|
+
layers = layers or self.layer_names()
|
|
137
|
+
|
|
138
|
+
need_residual = any(".residual" in ln for ln in layers)
|
|
139
|
+
need_ssm = any(".ssm_state" in ln for ln in layers)
|
|
140
|
+
|
|
141
|
+
with torch.no_grad():
|
|
142
|
+
if need_ssm:
|
|
143
|
+
# Pass a DynamicCache so Mamba writes final SSM states to it
|
|
144
|
+
from transformers.models.mamba.modeling_mamba import DynamicCache
|
|
145
|
+
cache = DynamicCache(config=self.model.config)
|
|
146
|
+
outputs = self.model(
|
|
147
|
+
**inputs,
|
|
148
|
+
cache_params=cache,
|
|
149
|
+
use_cache=True,
|
|
150
|
+
output_hidden_states=need_residual,
|
|
151
|
+
return_dict=True,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
outputs = self.model(
|
|
155
|
+
**inputs, output_hidden_states=True, return_dict=True
|
|
156
|
+
)
|
|
157
|
+
cache = None
|
|
158
|
+
|
|
159
|
+
records = []
|
|
160
|
+
for layer_name in layers:
|
|
161
|
+
idx = int(layer_name.split("_")[1].split(".")[0])
|
|
162
|
+
if ".residual" in layer_name:
|
|
163
|
+
records.append(ActivationRecord(
|
|
164
|
+
layer_name=layer_name,
|
|
165
|
+
activations=outputs.hidden_states[idx + 1].detach(), # +1 because [0] is embedding output
|
|
166
|
+
meta={"kind": "residual", "arch": "mamba", "shape_meaning": "(B, T, hidden_size)"},
|
|
167
|
+
))
|
|
168
|
+
elif ".ssm_state" in layer_name:
|
|
169
|
+
if cache is None or not hasattr(cache, "layers") or idx >= len(cache.layers):
|
|
170
|
+
continue
|
|
171
|
+
ssm = cache.layers[idx].recurrent_states
|
|
172
|
+
if ssm is None:
|
|
173
|
+
continue
|
|
174
|
+
records.append(ActivationRecord(
|
|
175
|
+
layer_name=layer_name,
|
|
176
|
+
activations=ssm.detach(),
|
|
177
|
+
meta={
|
|
178
|
+
"kind": "ssm_state",
|
|
179
|
+
"arch": "mamba",
|
|
180
|
+
"shape_meaning": "(B, intermediate_size, ssm_state_size)",
|
|
181
|
+
"d_inner": ssm.shape[-2],
|
|
182
|
+
"d_state": ssm.shape[-1],
|
|
183
|
+
},
|
|
184
|
+
))
|
|
185
|
+
return records
|
|
186
|
+
|
|
187
|
+
def hidden_dim(self, layer_name: str) -> int:
|
|
188
|
+
if ".ssm_state" in layer_name:
|
|
189
|
+
# SSM state is (intermediate_size × ssm_state_size)
|
|
190
|
+
d_inner = getattr(self.model.config, "intermediate_size", None)
|
|
191
|
+
d_state = getattr(self.model.config, "state_size", None)
|
|
192
|
+
if d_inner and d_state:
|
|
193
|
+
return d_inner * d_state
|
|
194
|
+
# Fallback: introspect from a block
|
|
195
|
+
mixer = self.model.backbone.layers[0].mixer
|
|
196
|
+
return mixer.intermediate_size * mixer.ssm_state_size
|
|
197
|
+
return getattr(self.model.config, "hidden_size", None) or getattr(self.model.config, "d_model", None)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@Backend.register("recurrent")
|
|
201
|
+
class RecurrentBackend(Backend):
|
|
202
|
+
"""Generic recurrent backend — for custom RNN-family models (e.g., kazdov MoBE-BCN).
|
|
203
|
+
|
|
204
|
+
Expects model to expose:
|
|
205
|
+
- .get_hidden_states(inputs) → dict[str, tensor]
|
|
206
|
+
OR registered forward hooks (user injects).
|
|
207
|
+
|
|
208
|
+
Custom models should subclass and override `extract`.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def layer_names(self) -> list[str]:
|
|
212
|
+
# Default: try common RNN attribute names
|
|
213
|
+
if hasattr(self.model, "n_layer"):
|
|
214
|
+
return [f"layer_{i}.hidden" for i in range(self.model.n_layer)]
|
|
215
|
+
if hasattr(self.model, "num_layers"):
|
|
216
|
+
return [f"layer_{i}.hidden" for i in range(self.model.num_layers)]
|
|
217
|
+
return ["layer_0.hidden"]
|
|
218
|
+
|
|
219
|
+
def extract(self, inputs, layers=None):
|
|
220
|
+
# Generic — subclass should override
|
|
221
|
+
if hasattr(self.model, "get_hidden_states"):
|
|
222
|
+
hs = self.model.get_hidden_states(inputs)
|
|
223
|
+
return [
|
|
224
|
+
ActivationRecord(layer_name=k, activations=v, meta={"kind": "hidden", "arch": "recurrent"})
|
|
225
|
+
for k, v in hs.items()
|
|
226
|
+
]
|
|
227
|
+
raise NotImplementedError(
|
|
228
|
+
f"RecurrentBackend default extract() not implemented for {type(self.model).__name__}. "
|
|
229
|
+
"Subclass and override extract(), or call model.get_hidden_states() yourself."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def hidden_dim(self, layer_name: str) -> int:
|
|
233
|
+
for attr in ("d_model", "hidden_size", "d_hidden", "n_embd"):
|
|
234
|
+
if hasattr(self.model, attr):
|
|
235
|
+
return getattr(self.model, attr)
|
|
236
|
+
raise ValueError("Cannot infer hidden_dim — override hidden_dim() in subclass.")
|