archscope 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
archscope/__init__.py ADDED
@@ -0,0 +1,30 @@
1
+ """archscope: unified mech interp toolkit across small + RNN + transformer.
2
+
3
+ Four core methods unified under a single API:
4
+ - probes: linear/MLP probes over hidden states (Drop the Act inspired)
5
+ - sae: sparse autoencoders for residual + recurrent state (WriteSAE)
6
+ - neurons: targeted neuron modulation via contrastive search (Nous Research)
7
+ - attribute: activation patching + DIM decomposition (Multi-Agent Sycophancy)
8
+
9
+ Each method exposes the same architecture-agnostic API:
10
+ - .extract(model, inputs) -> hidden states / activations
11
+ - .fit(activations, labels) -> learned tool
12
+ - .apply(model, inputs) -> modified outputs / scores / explanations
13
+
14
+ Designed for cross-architecture comparison: transformer, Mamba/SSM, custom RNN.
15
+ """
16
+
17
+ __version__ = "0.2.2"
18
+
19
+ from . import probes, sae, neurons, attribute, backends, circuits, transfer, bench, lens, diff
20
+
21
+ # Kazdov backend registers itself on import — optional, only if kazdov repo present
22
+ try:
23
+ from . import kazdov_backend # noqa: F401
24
+ except ImportError:
25
+ pass
26
+
27
+ __all__ = [
28
+ "probes", "sae", "neurons", "attribute", "backends",
29
+ "circuits", "transfer", "bench", "lens", "diff", "__version__",
30
+ ]
archscope/_utils.py ADDED
@@ -0,0 +1,113 @@
1
+ """Internal utilities — not part of the public API.
2
+
3
+ Shared layer-resolution logic used by neurons.py and attribute.py to find
4
+ the actual nn.Module corresponding to a layer_name string like
5
+ "layer_5.residual" across different HuggingFace model architectures.
6
+ """
7
+ from __future__ import annotations
8
+ from typing import Any
9
+
10
+
11
+ # Common HF architecture paths to layer ModuleList.
12
+ # Order matters: first match wins. Add new architectures here.
13
+ _LAYER_PATHS: list[tuple[str, str | None]] = [
14
+ ("model", "layers"), # Llama / Mistral / Qwen / GPT-NeoX-style
15
+ ("transformer", "h"), # GPT-2 / Falcon
16
+ ("transformer", "blocks"), # MPT
17
+ ("gpt_neox", "layers"), # Pythia
18
+ ("backbone", "layers"), # Mamba / Mamba-2
19
+ ("layers", None), # Direct .layers (some custom models, e.g. kazdov)
20
+ ("h", None), # Direct .h
21
+ ("blocks", None), # Direct .blocks (kazdov)
22
+ ]
23
+
24
+
25
+ def _parse_layer_index(layer_name: str) -> int | None:
26
+ """Extract the integer index from a name like 'layer_5.residual'."""
27
+ try:
28
+ idx_part = layer_name.split("_")[1].split(".")[0]
29
+ return int(idx_part)
30
+ except (IndexError, ValueError):
31
+ return None
32
+
33
+
34
+ def resolve_layer_module(model: Any, layer_name: str):
35
+ """Return the nn.Module corresponding to a layer_name across HF naming conventions.
36
+
37
+ Handles: Llama, Mistral, Qwen, GPT-2, Falcon, MPT, Pythia, Mamba, custom .blocks.
38
+
39
+ Returns None if the layer name cannot be parsed or no path matches.
40
+ """
41
+ idx = _parse_layer_index(layer_name)
42
+ if idx is None:
43
+ return None
44
+ for parent_attr, child_attr in _LAYER_PATHS:
45
+ parent_obj = getattr(model, parent_attr, None)
46
+ if parent_obj is None:
47
+ continue
48
+ layers = parent_obj if child_attr is None else getattr(parent_obj, child_attr, None)
49
+ if layers is None:
50
+ continue
51
+ try:
52
+ return layers[idx]
53
+ except (IndexError, TypeError):
54
+ continue
55
+ return None
56
+
57
+
58
+ _UNEMBED_PATHS = [
59
+ "lm_head", # Llama, Pythia, Mistral, Mamba, kazdov, most HF CausalLMs
60
+ "embed_out", # some HF models
61
+ "output_layer", # some custom models
62
+ ]
63
+
64
+ _FINAL_NORM_PATHS: list[tuple[str, str | None]] = [
65
+ ("model", "norm"), # Llama / Mistral
66
+ ("gpt_neox", "final_layer_norm"), # Pythia
67
+ ("transformer", "ln_f"), # GPT-2 / Falcon
68
+ ("backbone", "norm_f"), # Mamba
69
+ ("ln_f", None), # kazdov (top-level)
70
+ ]
71
+
72
+
73
+ def resolve_unembedding(model: Any):
74
+ """Find the model's unembedding / lm_head module. Returns nn.Module or None."""
75
+ for path in _UNEMBED_PATHS:
76
+ m = getattr(model, path, None)
77
+ if m is not None:
78
+ return m
79
+ return None
80
+
81
+
82
+ def resolve_final_norm(model: Any):
83
+ """Find the model's final pre-unembedding layer norm. Returns module or None."""
84
+ for parent_attr, child_attr in _FINAL_NORM_PATHS:
85
+ parent_obj = getattr(model, parent_attr, None)
86
+ if parent_obj is None:
87
+ continue
88
+ norm = parent_obj if child_attr is None else getattr(parent_obj, child_attr, None)
89
+ if norm is not None:
90
+ return norm
91
+ return None
92
+
93
+
94
+ def resolve_subcomponent_module(model: Any, idx: int, component: str):
95
+ """Find an attention or MLP submodule inside a layer at index `idx`.
96
+
97
+ component: "attention" or "mlp".
98
+ Returns None if the component isn't found.
99
+ """
100
+ layer = resolve_layer_module(model, f"layer_{idx}.residual")
101
+ if layer is None:
102
+ return None
103
+ if component == "attention":
104
+ for attr in ("self_attn", "attn", "attention"):
105
+ sub = getattr(layer, attr, None)
106
+ if sub is not None:
107
+ return sub
108
+ elif component == "mlp":
109
+ for attr in ("mlp", "feed_forward", "ffn"):
110
+ sub = getattr(layer, attr, None)
111
+ if sub is not None:
112
+ return sub
113
+ return None
archscope/attribute.py ADDED
@@ -0,0 +1,201 @@
1
+ """Activation patching + DIM decomposition (Multi-Agent Sycophancy, 2605.12991).
2
+
3
+ Two methods:
4
+ - `activation_patch`: replace activations from one prompt with another at
5
+ specified layer range. Measures how much of the behavioral gap is "restored"
6
+ by the patch.
7
+ - `dim_decompose`: difference-in-means decomposition of attribution per
8
+ component (MLP vs attention).
9
+
10
+ Use cases:
11
+ - Localize behavior to specific layers (e.g., "L14-L18 restores 96.8% of gap")
12
+ - Separate attention vs MLP contribution
13
+ """
14
+ from __future__ import annotations
15
+ from dataclasses import dataclass
16
+
17
+ import torch
18
+
19
+ from .backends import Backend
20
+ from ._utils import resolve_layer_module, resolve_subcomponent_module
21
+
22
+
23
+ @dataclass
24
+ class PatchResult:
25
+ """Outcome of a single activation_patch experiment."""
26
+ layer_range: tuple[int, int]
27
+ gap_restored: float # fraction of behavioral gap closed by patching
28
+ target_metric: str # what we measured (e.g., "logit_diff")
29
+ baseline_metric: float
30
+ patched_metric: float
31
+ clean_metric: float
32
+
33
+
34
+ @dataclass
35
+ class DIMResult:
36
+ """Difference-in-means attribution per component (attention vs MLP)."""
37
+ components: dict[str, float] # e.g., {"attention": 0.45, "mlp": 0.02}
38
+ total: float
39
+ layer_range: tuple[int, int]
40
+
41
+
42
+ def _remove_hooks(hooks):
43
+ """Cleanly detach a list of forward-hook handles."""
44
+ for h in hooks:
45
+ h.remove()
46
+
47
+
48
+ def activation_patch(
49
+ model,
50
+ prompt_source,
51
+ prompt_target,
52
+ layer_indices: list[int],
53
+ metric_fn,
54
+ backend_hint: str | None = None,
55
+ ) -> PatchResult:
56
+ """Replace activations at chosen layers with those from a source prompt.
57
+
58
+ Args:
59
+ model: any model exposing `model(**inputs, return_dict=True).logits`
60
+ prompt_source: tokenized inputs to extract clean activations from
61
+ prompt_target: tokenized inputs where activations will be replaced
62
+ layer_indices: which layer indices to patch
63
+ metric_fn: function mapping `model_outputs → scalar` (e.g., logit diff)
64
+ backend_hint: backend name for extraction
65
+
66
+ Returns:
67
+ PatchResult with the fraction of behavioral gap closed by patching.
68
+ """
69
+ backend = Backend.for_model(model, hint=backend_hint)
70
+ layer_names = [f"layer_{i}.residual" for i in layer_indices]
71
+
72
+ # 1. Clean source: extract activations to patch in.
73
+ src_acts = backend.extract(prompt_source, layers=layer_names)
74
+
75
+ # 2. Clean target: baseline metric.
76
+ with torch.no_grad():
77
+ target_clean_out = model(**prompt_target, output_hidden_states=False, return_dict=True)
78
+ clean_metric = metric_fn(target_clean_out)
79
+
80
+ # 3. Patched target: hook in source activations.
81
+ hooks = []
82
+ for layer_name, src_rec in zip(layer_names, src_acts):
83
+ idx = int(layer_name.split("_")[1].split(".")[0])
84
+ module = resolve_layer_module(model, f"layer_{idx}.residual")
85
+ if module is None:
86
+ continue
87
+ src_h = src_rec.activations
88
+
89
+ def hook(mod, inp, out, replacement=src_h):
90
+ if isinstance(out, tuple):
91
+ return (replacement,) + out[1:]
92
+ return replacement
93
+ hooks.append(module.register_forward_hook(hook))
94
+
95
+ try:
96
+ with torch.no_grad():
97
+ patched_out = model(**prompt_target, output_hidden_states=False, return_dict=True)
98
+ patched_metric = metric_fn(patched_out)
99
+ finally:
100
+ _remove_hooks(hooks)
101
+
102
+ # Source baseline (no hooks).
103
+ with torch.no_grad():
104
+ src_out = model(**prompt_source, output_hidden_states=False, return_dict=True)
105
+ source_metric = metric_fn(src_out)
106
+
107
+ gap = source_metric - clean_metric
108
+ gap_restored = 0.0 if abs(gap) < 1e-9 else (patched_metric - clean_metric) / gap
109
+
110
+ return PatchResult(
111
+ layer_range=(min(layer_indices), max(layer_indices)),
112
+ gap_restored=float(gap_restored),
113
+ target_metric="custom",
114
+ baseline_metric=float(source_metric),
115
+ patched_metric=float(patched_metric),
116
+ clean_metric=float(clean_metric),
117
+ )
118
+
119
+
120
+ def dim_decompose(
121
+ model,
122
+ prompt_a,
123
+ prompt_b,
124
+ layer_indices: list[int],
125
+ metric_fn,
126
+ components: tuple[str, ...] = ("attention", "mlp"),
127
+ backend_hint: str | None = None, # kept for symmetry with activation_patch
128
+ ) -> DIMResult:
129
+ """Decompose a behavioral difference into per-component contributions.
130
+
131
+ For each component (default: attention, mlp), captures its output during
132
+ `prompt_a`, then patches that output into the forward pass on `prompt_b`,
133
+ and measures the fraction of the metric gap that the patch closes.
134
+
135
+ `backend_hint` is accepted but unused (this function uses module hooks
136
+ directly via `resolve_subcomponent_module`).
137
+ """
138
+ del backend_hint # unused; kept for API symmetry
139
+
140
+ with torch.no_grad():
141
+ out_a = model(**prompt_a, return_dict=True)
142
+ out_b = model(**prompt_b, return_dict=True)
143
+ metric_a = metric_fn(out_a)
144
+ metric_b = metric_fn(out_b)
145
+ total_gap = metric_a - metric_b
146
+
147
+ contributions: dict[str, float] = {}
148
+ for comp in components:
149
+ # 1) Capture component outputs during prompt_a.
150
+ capture_hooks = []
151
+ src_acts_by_layer: dict[int, list] = {}
152
+ for idx in layer_indices:
153
+ module = resolve_subcomponent_module(model, idx, comp)
154
+ if module is None:
155
+ continue
156
+ captured: list = []
157
+
158
+ def capture(mod, inp, out, store=captured):
159
+ store.append(out[0] if isinstance(out, tuple) else out)
160
+ capture_hooks.append(module.register_forward_hook(capture))
161
+ src_acts_by_layer[idx] = captured
162
+
163
+ try:
164
+ with torch.no_grad():
165
+ model(**prompt_a, return_dict=True)
166
+ finally:
167
+ _remove_hooks(capture_hooks)
168
+
169
+ # 2) Patch captured outputs into prompt_b's forward pass.
170
+ patch_hooks = []
171
+ for idx in layer_indices:
172
+ if idx not in src_acts_by_layer:
173
+ continue
174
+ module = resolve_subcomponent_module(model, idx, comp)
175
+ if module is None:
176
+ continue
177
+ stored = src_acts_by_layer[idx]
178
+ if not stored:
179
+ continue
180
+ captured_out = stored[0]
181
+
182
+ def patch(mod, inp, out, repl=captured_out):
183
+ if isinstance(out, tuple):
184
+ return (repl,) + out[1:]
185
+ return repl
186
+ patch_hooks.append(module.register_forward_hook(patch))
187
+
188
+ try:
189
+ with torch.no_grad():
190
+ patched_out = model(**prompt_b, return_dict=True)
191
+ patched_metric = metric_fn(patched_out)
192
+ finally:
193
+ _remove_hooks(patch_hooks)
194
+
195
+ contributions[comp] = float((patched_metric - metric_b) / (total_gap + 1e-9))
196
+
197
+ return DIMResult(
198
+ components=contributions,
199
+ total=float(total_gap),
200
+ layer_range=(min(layer_indices), max(layer_indices)),
201
+ )
archscope/backends.py ADDED
@@ -0,0 +1,236 @@
1
+ """Architecture-agnostic activation extraction.
2
+
3
+ The core abstraction: a `Backend` knows how to hook into a model and pull out
4
+ hidden states at named layers, regardless of underlying framework
5
+ (PyTorch/JAX/custom).
6
+
7
+ Three backends implemented:
8
+ - TransformerBackend: HuggingFace transformers (residual stream per layer)
9
+ - MambaBackend: state-space models (hidden state + ssm state per layer)
10
+ - RecurrentBackend: generic RNN-like (extracts hidden state per timestep)
11
+
12
+ Custom architectures (e.g., kazdov MoBE-BCN) register via Backend.register().
13
+ """
14
+ from __future__ import annotations
15
+ import abc
16
+ import torch
17
+ from dataclasses import dataclass
18
+ from typing import Any
19
+
20
+
21
+ @dataclass
22
+ class ActivationRecord:
23
+ """Captured activations from a single forward pass.
24
+
25
+ Attributes:
26
+ layer_name: identifier of the layer (e.g., "blocks.5.residual")
27
+ activations: tensor of shape (batch, seq_len, hidden_dim) typically
28
+ meta: arch-specific metadata (e.g., {'kind': 'residual'} or {'kind': 'ssm_state'})
29
+ """
30
+ layer_name: str
31
+ activations: Any # torch.Tensor or jax.Array
32
+ meta: dict
33
+
34
+
35
+ class Backend(abc.ABC):
36
+ """Abstract interface — extract activations from any model architecture."""
37
+
38
+ _registry: dict[str, type["Backend"]] = {}
39
+
40
+ @classmethod
41
+ def register(cls, name: str):
42
+ def deco(klass):
43
+ cls._registry[name] = klass
44
+ return klass
45
+ return deco
46
+
47
+ @classmethod
48
+ def for_model(cls, model: Any, hint: str | None = None) -> "Backend":
49
+ """Auto-detect or use hint to select backend."""
50
+ if hint and hint in cls._registry:
51
+ return cls._registry[hint](model)
52
+ # Auto-detect via attribute introspection
53
+ if hasattr(model, "config") and getattr(model.config, "model_type", None) in ("llama", "gpt2", "qwen2", "qwen3"):
54
+ return cls._registry["transformer"](model)
55
+ if hasattr(model, "config") and getattr(model.config, "model_type", "") in ("mamba", "mamba2"):
56
+ return cls._registry["mamba"](model)
57
+ # Default fallback
58
+ if "recurrent" in cls._registry:
59
+ return cls._registry["recurrent"](model)
60
+ raise ValueError(f"No backend matches model {type(model).__name__}. Register via Backend.register('name').")
61
+
62
+ def __init__(self, model: Any):
63
+ self.model = model
64
+
65
+ @abc.abstractmethod
66
+ def layer_names(self) -> list[str]:
67
+ """Return list of layer identifiers we can hook."""
68
+ ...
69
+
70
+ @abc.abstractmethod
71
+ def extract(self, inputs: Any, layers: list[str] | None = None) -> list[ActivationRecord]:
72
+ """Run forward pass, return activations at requested layers (all if None)."""
73
+ ...
74
+
75
+ @abc.abstractmethod
76
+ def hidden_dim(self, layer_name: str) -> int:
77
+ """Dimensionality of activations at a given layer."""
78
+ ...
79
+
80
+
81
+ @Backend.register("transformer")
82
+ class TransformerBackend(Backend):
83
+ """HuggingFace transformers backend — extracts residual stream per layer."""
84
+
85
+ def layer_names(self) -> list[str]:
86
+ # Standard HF: model.model.layers[i] for decoder transformers
87
+ n_layers = getattr(self.model.config, "num_hidden_layers", 0)
88
+ return [f"layer_{i}.residual" for i in range(n_layers)]
89
+
90
+ def extract(self, inputs, layers=None):
91
+ layers = layers or self.layer_names()
92
+ # Use HF's output_hidden_states=True for clean extraction.
93
+ # Wrap in no_grad: extraction shouldn't build a backward graph.
94
+ with torch.no_grad():
95
+ outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
96
+ records = []
97
+ hidden_states = outputs.hidden_states # tuple of (n_layers+1) tensors
98
+ for layer_name in layers:
99
+ # Parse "layer_N.residual" → N+1 (since [0] is embedding output)
100
+ idx = int(layer_name.split("_")[1].split(".")[0]) + 1
101
+ records.append(ActivationRecord(
102
+ layer_name=layer_name,
103
+ activations=hidden_states[idx],
104
+ meta={"kind": "residual", "arch": "transformer"},
105
+ ))
106
+ return records
107
+
108
+ def hidden_dim(self, layer_name: str) -> int:
109
+ return self.model.config.hidden_size
110
+
111
+
112
+ @Backend.register("mamba")
113
+ class MambaBackend(Backend):
114
+ """Mamba/Mamba-2 backend — extracts residual stream AND SSM recurrent state.
115
+
116
+ Works with HuggingFace MambaForCausalLM (and Mamba2ForCausalLM).
117
+
118
+ Two flavors of activations exposed:
119
+ - `layer_N.residual` → residual stream after block N (B, T, hidden_size)
120
+ - `layer_N.ssm_state` → final SSM recurrent state after processing the
121
+ sequence at block N: shape (B, intermediate_size, ssm_state_size).
122
+ This exposes the recurrent state used by Mamba-style models —
123
+ useful when experiments need access to memory-like state rather
124
+ than residual activations alone.
125
+ """
126
+
127
+ def layer_names(self) -> list[str]:
128
+ n_layers = getattr(self.model.config, "n_layer", 0) or getattr(self.model.config, "num_hidden_layers", 0)
129
+ out = []
130
+ for i in range(n_layers):
131
+ out.append(f"layer_{i}.residual")
132
+ out.append(f"layer_{i}.ssm_state")
133
+ return out
134
+
135
+ def extract(self, inputs, layers=None):
136
+ layers = layers or self.layer_names()
137
+
138
+ need_residual = any(".residual" in ln for ln in layers)
139
+ need_ssm = any(".ssm_state" in ln for ln in layers)
140
+
141
+ with torch.no_grad():
142
+ if need_ssm:
143
+ # Pass a DynamicCache so Mamba writes final SSM states to it
144
+ from transformers.models.mamba.modeling_mamba import DynamicCache
145
+ cache = DynamicCache(config=self.model.config)
146
+ outputs = self.model(
147
+ **inputs,
148
+ cache_params=cache,
149
+ use_cache=True,
150
+ output_hidden_states=need_residual,
151
+ return_dict=True,
152
+ )
153
+ else:
154
+ outputs = self.model(
155
+ **inputs, output_hidden_states=True, return_dict=True
156
+ )
157
+ cache = None
158
+
159
+ records = []
160
+ for layer_name in layers:
161
+ idx = int(layer_name.split("_")[1].split(".")[0])
162
+ if ".residual" in layer_name:
163
+ records.append(ActivationRecord(
164
+ layer_name=layer_name,
165
+ activations=outputs.hidden_states[idx + 1].detach(), # +1 because [0] is embedding output
166
+ meta={"kind": "residual", "arch": "mamba", "shape_meaning": "(B, T, hidden_size)"},
167
+ ))
168
+ elif ".ssm_state" in layer_name:
169
+ if cache is None or not hasattr(cache, "layers") or idx >= len(cache.layers):
170
+ continue
171
+ ssm = cache.layers[idx].recurrent_states
172
+ if ssm is None:
173
+ continue
174
+ records.append(ActivationRecord(
175
+ layer_name=layer_name,
176
+ activations=ssm.detach(),
177
+ meta={
178
+ "kind": "ssm_state",
179
+ "arch": "mamba",
180
+ "shape_meaning": "(B, intermediate_size, ssm_state_size)",
181
+ "d_inner": ssm.shape[-2],
182
+ "d_state": ssm.shape[-1],
183
+ },
184
+ ))
185
+ return records
186
+
187
+ def hidden_dim(self, layer_name: str) -> int:
188
+ if ".ssm_state" in layer_name:
189
+ # SSM state is (intermediate_size × ssm_state_size)
190
+ d_inner = getattr(self.model.config, "intermediate_size", None)
191
+ d_state = getattr(self.model.config, "state_size", None)
192
+ if d_inner and d_state:
193
+ return d_inner * d_state
194
+ # Fallback: introspect from a block
195
+ mixer = self.model.backbone.layers[0].mixer
196
+ return mixer.intermediate_size * mixer.ssm_state_size
197
+ return getattr(self.model.config, "hidden_size", None) or getattr(self.model.config, "d_model", None)
198
+
199
+
200
+ @Backend.register("recurrent")
201
+ class RecurrentBackend(Backend):
202
+ """Generic recurrent backend — for custom RNN-family models (e.g., kazdov MoBE-BCN).
203
+
204
+ Expects model to expose:
205
+ - .get_hidden_states(inputs) → dict[str, tensor]
206
+ OR registered forward hooks (user injects).
207
+
208
+ Custom models should subclass and override `extract`.
209
+ """
210
+
211
+ def layer_names(self) -> list[str]:
212
+ # Default: try common RNN attribute names
213
+ if hasattr(self.model, "n_layer"):
214
+ return [f"layer_{i}.hidden" for i in range(self.model.n_layer)]
215
+ if hasattr(self.model, "num_layers"):
216
+ return [f"layer_{i}.hidden" for i in range(self.model.num_layers)]
217
+ return ["layer_0.hidden"]
218
+
219
+ def extract(self, inputs, layers=None):
220
+ # Generic — subclass should override
221
+ if hasattr(self.model, "get_hidden_states"):
222
+ hs = self.model.get_hidden_states(inputs)
223
+ return [
224
+ ActivationRecord(layer_name=k, activations=v, meta={"kind": "hidden", "arch": "recurrent"})
225
+ for k, v in hs.items()
226
+ ]
227
+ raise NotImplementedError(
228
+ f"RecurrentBackend default extract() not implemented for {type(self.model).__name__}. "
229
+ "Subclass and override extract(), or call model.get_hidden_states() yourself."
230
+ )
231
+
232
+ def hidden_dim(self, layer_name: str) -> int:
233
+ for attr in ("d_model", "hidden_size", "d_hidden", "n_embd"):
234
+ if hasattr(self.model, attr):
235
+ return getattr(self.model, attr)
236
+ raise ValueError("Cannot infer hidden_dim — override hidden_dim() in subclass.")