archscope 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {archscope-0.2.4/src/archscope.egg-info → archscope-0.2.6}/PKG-INFO +32 -11
  2. {archscope-0.2.4 → archscope-0.2.6}/README.md +31 -10
  3. {archscope-0.2.4 → archscope-0.2.6}/pyproject.toml +1 -1
  4. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/__init__.py +3 -6
  5. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/attribute.py +35 -2
  6. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/backends.py +53 -13
  7. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/circuits.py +23 -5
  8. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/cli.py +11 -0
  9. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/diff.py +4 -0
  10. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/lens.py +26 -4
  11. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/neurons.py +9 -2
  12. {archscope-0.2.4 → archscope-0.2.6/src/archscope.egg-info}/PKG-INFO +32 -11
  13. {archscope-0.2.4 → archscope-0.2.6}/tests/test_circuits_3arch.py +1 -1
  14. {archscope-0.2.4 → archscope-0.2.6}/tests/test_unit.py +95 -11
  15. {archscope-0.2.4 → archscope-0.2.6}/LICENSE +0 -0
  16. {archscope-0.2.4 → archscope-0.2.6}/setup.cfg +0 -0
  17. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/_utils.py +0 -0
  18. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/bench.py +0 -0
  19. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/kazdov_backend.py +0 -0
  20. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/loader.py +0 -0
  21. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/probes.py +0 -0
  22. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/py.typed +0 -0
  23. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/sae.py +0 -0
  24. {archscope-0.2.4 → archscope-0.2.6}/src/archscope/transfer.py +0 -0
  25. {archscope-0.2.4 → archscope-0.2.6}/src/archscope.egg-info/SOURCES.txt +0 -0
  26. {archscope-0.2.4 → archscope-0.2.6}/src/archscope.egg-info/dependency_links.txt +0 -0
  27. {archscope-0.2.4 → archscope-0.2.6}/src/archscope.egg-info/entry_points.txt +0 -0
  28. {archscope-0.2.4 → archscope-0.2.6}/src/archscope.egg-info/requires.txt +0 -0
  29. {archscope-0.2.4 → archscope-0.2.6}/src/archscope.egg-info/top_level.txt +0 -0
  30. {archscope-0.2.4 → archscope-0.2.6}/tests/test_diff.py +0 -0
  31. {archscope-0.2.4 → archscope-0.2.6}/tests/test_kazdov_integration.py +0 -0
  32. {archscope-0.2.4 → archscope-0.2.6}/tests/test_lens.py +0 -0
  33. {archscope-0.2.4 → archscope-0.2.6}/tests/test_mamba_integration.py +0 -0
  34. {archscope-0.2.4 → archscope-0.2.6}/tests/test_mamba_ssm_state.py +0 -0
  35. {archscope-0.2.4 → archscope-0.2.6}/tests/test_probe_transfer.py +0 -0
  36. {archscope-0.2.4 → archscope-0.2.6}/tests/test_pythia_end_to_end.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: archscope
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: Lightweight workbench for cross-architecture mechanistic interpretability experiments on small models
5
5
  Author: Juan Cruz Dovzak
6
6
  License: Apache-2.0
@@ -58,18 +58,17 @@ It is **not**: a competitor to `transformer_lens` or `nnsight` (both are broader
58
58
 
59
59
  ```python
60
60
  import archscope as mi
61
- from transformers import AutoModelForCausalLM, AutoTokenizer
62
-
63
- tok = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
64
- model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
65
61
 
66
- backend = mi.backends.Backend.for_model(model, hint="mamba")
62
+ # One call → HuggingFace model + tokenizer + the right backend
63
+ model, tok, backend = mi.load_model("state-spaces/mamba-130m-hf", arch="mamba")
67
64
 
68
65
  # Extract Mamba's recurrent SSM state h_t (in addition to residual stream)
69
66
  ssm = backend.extract(tok("text", return_tensors="pt"), layers=["layer_12.ssm_state"])[0]
70
67
  # Shape: (B, intermediate_size, ssm_state_size) = (B, 1536, 16) for mamba-130m
71
68
  ```
72
69
 
70
+ `load_model` handles `pad_token` setup, `model.eval()`, and backend auto-detection. If you'd rather drive `transformers` yourself, every method also accepts `backend_hint=...`.
71
+
73
72
  ---
74
73
 
75
74
  ## What's inside
@@ -96,12 +95,34 @@ ssm = backend.extract(tok("text", return_tensors="pt"), layers=["layer_12.ssm_st
96
95
 
97
96
  ### Backends
98
97
 
99
- | Backend | Models | Specific |
98
+ | Backend | Auto-detected `model_type` | What you get |
100
99
  |---|---|---|
101
- | `transformer` | Pythia, GPT-2, Llama, Mistral, Qwen, MPT, Falcon, GPT-Neo | residual stream |
102
- | `mamba` | Mamba, Mamba-2 | residual + explicit `.ssm_state` (recurrent h_t) |
103
- | `kazdov` | Kazdov-α hybrid MoBE-BCN+MHA | residual per custom block |
104
- | `recurrent` | Generic RNN (user subclass) | hidden state per layer |
100
+ | `transformer` | `llama`, `mistral`, `qwen2`, `qwen3`, `gpt2`, `gpt_neox` (Pythia), `gpt_neo`, `gptj`, `falcon`, `mpt`, `bloom`, `opt`, `phi`, `phi3`, `gemma`, `gemma2`, `starcoder2` | residual stream per layer |
101
+ | `mamba` | `mamba`, `mamba2` | residual + explicit `.ssm_state` (recurrent h_t) |
102
+ | `kazdov` | (pass `hint="kazdov"`) | residual per custom block |
103
+ | `recurrent` | (pass `hint="recurrent"`, subclass for full extract) | hidden state per layer |
104
+
105
+ If `Backend.for_model(model)` is called on a model whose `config.model_type` isn't in the autodetect list, it raises a clear `ValueError` rather than silently picking a backend. Pass `hint="..."` explicitly for anything outside the list, or register a new backend via `Backend.register("name")`.
106
+
107
+ ### Method × backend support
108
+
109
+ Not every method works on every architecture. The cross-product:
110
+
111
+ | Method | transformer | mamba | kazdov | recurrent |
112
+ |---|:---:|:---:|:---:|:---:|
113
+ | `probes.fit_probe` | ✅ | ✅ | ✅ | ✅ |
114
+ | `sae.fit_sae` (Dense / Rank-1) | ✅ | ✅ | ✅ | ✅ |
115
+ | `neurons.find_neurons` | ✅ | ✅ | ✅ | ✅ |
116
+ | `attribute.activation_patch` | ✅ | ✅ residual only | ✅ | ⚠️ subclass needed |
117
+ | `attribute.dim_decompose` | ✅ | ❌ no attention/MLP submods | ✅ | ❌ |
118
+ | `circuits.*` (behavioural) | ✅ | ✅ | ✅ | ✅ |
119
+ | `lens.logit_lens` | ✅ | ⚠️ degrades with depth — use `TunedLens` | ✅ | ⚠️ |
120
+ | `lens.TunedLens.fit` | ✅ | ✅ | ✅ | ⚠️ |
121
+ | `diff.compare` | ✅ | ✅ | ✅ | ✅ |
122
+ | `transfer.evaluate_transfer` | ✅ ↔ any | ✅ ↔ any | ✅ ↔ any | ✅ ↔ any |
123
+ | `bench.benchmark` | ✅ | ✅ | ✅ | partial |
124
+
125
+ ❌ entries raise a clear `ValueError` rather than silently degrading.
105
126
 
106
127
  ---
107
128
 
@@ -21,18 +21,17 @@ It is **not**: a competitor to `transformer_lens` or `nnsight` (both are broader
21
21
 
22
22
  ```python
23
23
  import archscope as mi
24
- from transformers import AutoModelForCausalLM, AutoTokenizer
25
-
26
- tok = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
27
- model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
28
24
 
29
- backend = mi.backends.Backend.for_model(model, hint="mamba")
25
+ # One call → HuggingFace model + tokenizer + the right backend
26
+ model, tok, backend = mi.load_model("state-spaces/mamba-130m-hf", arch="mamba")
30
27
 
31
28
  # Extract Mamba's recurrent SSM state h_t (in addition to residual stream)
32
29
  ssm = backend.extract(tok("text", return_tensors="pt"), layers=["layer_12.ssm_state"])[0]
33
30
  # Shape: (B, intermediate_size, ssm_state_size) = (B, 1536, 16) for mamba-130m
34
31
  ```
35
32
 
33
+ `load_model` handles `pad_token` setup, `model.eval()`, and backend auto-detection. If you'd rather drive `transformers` yourself, every method also accepts `backend_hint=...`.
34
+
36
35
  ---
37
36
 
38
37
  ## What's inside
@@ -59,12 +58,34 @@ ssm = backend.extract(tok("text", return_tensors="pt"), layers=["layer_12.ssm_st
59
58
 
60
59
  ### Backends
61
60
 
62
- | Backend | Models | Specific |
61
+ | Backend | Auto-detected `model_type` | What you get |
63
62
  |---|---|---|
64
- | `transformer` | Pythia, GPT-2, Llama, Mistral, Qwen, MPT, Falcon, GPT-Neo | residual stream |
65
- | `mamba` | Mamba, Mamba-2 | residual + explicit `.ssm_state` (recurrent h_t) |
66
- | `kazdov` | Kazdov-α hybrid MoBE-BCN+MHA | residual per custom block |
67
- | `recurrent` | Generic RNN (user subclass) | hidden state per layer |
63
+ | `transformer` | `llama`, `mistral`, `qwen2`, `qwen3`, `gpt2`, `gpt_neox` (Pythia), `gpt_neo`, `gptj`, `falcon`, `mpt`, `bloom`, `opt`, `phi`, `phi3`, `gemma`, `gemma2`, `starcoder2` | residual stream per layer |
64
+ | `mamba` | `mamba`, `mamba2` | residual + explicit `.ssm_state` (recurrent h_t) |
65
+ | `kazdov` | (pass `hint="kazdov"`) | residual per custom block |
66
+ | `recurrent` | (pass `hint="recurrent"`, subclass for full extract) | hidden state per layer |
67
+
68
+ If `Backend.for_model(model)` is called on a model whose `config.model_type` isn't in the autodetect list, it raises a clear `ValueError` rather than silently picking a backend. Pass `hint="..."` explicitly for anything outside the list, or register a new backend via `Backend.register("name")`.
69
+
70
+ ### Method × backend support
71
+
72
+ Not every method works on every architecture. The cross-product:
73
+
74
+ | Method | transformer | mamba | kazdov | recurrent |
75
+ |---|:---:|:---:|:---:|:---:|
76
+ | `probes.fit_probe` | ✅ | ✅ | ✅ | ✅ |
77
+ | `sae.fit_sae` (Dense / Rank-1) | ✅ | ✅ | ✅ | ✅ |
78
+ | `neurons.find_neurons` | ✅ | ✅ | ✅ | ✅ |
79
+ | `attribute.activation_patch` | ✅ | ✅ residual only | ✅ | ⚠️ subclass needed |
80
+ | `attribute.dim_decompose` | ✅ | ❌ no attention/MLP submods | ✅ | ❌ |
81
+ | `circuits.*` (behavioural) | ✅ | ✅ | ✅ | ✅ |
82
+ | `lens.logit_lens` | ✅ | ⚠️ degrades with depth — use `TunedLens` | ✅ | ⚠️ |
83
+ | `lens.TunedLens.fit` | ✅ | ✅ | ✅ | ⚠️ |
84
+ | `diff.compare` | ✅ | ✅ | ✅ | ✅ |
85
+ | `transfer.evaluate_transfer` | ✅ ↔ any | ✅ ↔ any | ✅ ↔ any | ✅ ↔ any |
86
+ | `bench.benchmark` | ✅ | ✅ | ✅ | partial |
87
+
88
+ ❌ entries raise a clear `ValueError` rather than silently degrading.
68
89
 
69
90
  ---
70
91
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "archscope"
3
- version = "0.2.4"
3
+ version = "0.2.6"
4
4
  description = "Lightweight workbench for cross-architecture mechanistic interpretability experiments on small models"
5
5
  readme = "README.md"
6
6
  authors = [{name = "Juan Cruz Dovzak"}]
@@ -25,16 +25,13 @@ Quick start::
25
25
  print(result.to_markdown())
26
26
  """
27
27
 
28
- __version__ = "0.2.4"
28
+ __version__ = "0.2.6"
29
29
 
30
30
  from . import probes, sae, neurons, attribute, backends, circuits, transfer, bench, lens, diff
31
31
  from .loader import load_model, make_tokenize_fn
32
32
 
33
- # Kazdov backend registers itself on import — optional, only if kazdov repo present
34
- try:
35
- from . import kazdov_backend # noqa: F401
36
- except ImportError:
37
- pass
33
+ # Custom-architecture backend ("kazdov" generic blocks-based, see kazdov_backend.py)
34
+ from . import kazdov_backend # noqa: F401
38
35
 
39
36
  __all__ = [
40
37
  "probes", "sae", "neurons", "attribute", "backends",
@@ -66,6 +66,17 @@ def activation_patch(
66
66
  Returns:
67
67
  PatchResult with the fraction of behavioral gap closed by patching.
68
68
  """
69
+ # Source and target must have matching shape — the patched-in activation
70
+ # is installed via a forward hook that expects the target's (B, T, H).
71
+ src_ids = prompt_source.get("input_ids") if isinstance(prompt_source, dict) else None
72
+ tgt_ids = prompt_target.get("input_ids") if isinstance(prompt_target, dict) else None
73
+ if src_ids is not None and tgt_ids is not None and src_ids.shape != tgt_ids.shape:
74
+ raise ValueError(
75
+ f"activation_patch: prompt_source and prompt_target must have "
76
+ f"matching input_ids shape; got source={tuple(src_ids.shape)} "
77
+ f"vs target={tuple(tgt_ids.shape)}. Pad/truncate to the same length."
78
+ )
79
+
69
80
  backend = Backend.for_model(model, hint=backend_hint)
70
81
  layer_names = [f"layer_{i}.residual" for i in layer_indices]
71
82
 
@@ -84,7 +95,9 @@ def activation_patch(
84
95
  module = resolve_layer_module(model, f"layer_{idx}.residual")
85
96
  if module is None:
86
97
  continue
87
- src_h = src_rec.activations
98
+ # detach+clone for the same reason dim_decompose does: avoid aliasing
99
+ # a tensor that could be overwritten when the patched forward runs.
100
+ src_h = src_rec.activations.detach().clone()
88
101
 
89
102
  def hook(mod, inp, out, replacement=src_h):
90
103
  if isinstance(out, tuple):
@@ -144,6 +157,23 @@ def dim_decompose(
144
157
  metric_b = metric_fn(out_b)
145
158
  total_gap = metric_a - metric_b
146
159
 
160
+ # Sanity check: at least one component must be resolvable for at least one
161
+ # requested layer. Architectures without attention/MLP submodules (Mamba,
162
+ # pure SSMs, custom recurrent blocks) would otherwise silently return an
163
+ # empty DIMResult.
164
+ resolvable = any(
165
+ resolve_subcomponent_module(model, idx, comp) is not None
166
+ for idx in layer_indices for comp in components
167
+ )
168
+ if not resolvable:
169
+ raise ValueError(
170
+ f"dim_decompose: none of components={components} were found on this "
171
+ f"model (type {type(model).__name__}). This method expects "
172
+ "attention/MLP submodules — it's transformer-style only. For "
173
+ "SSM/recurrent architectures, use activation_patch on the residual "
174
+ "stream instead."
175
+ )
176
+
147
177
  contributions: dict[str, float] = {}
148
178
  for comp in components:
149
179
  # 1) Capture component outputs during prompt_a.
@@ -156,7 +186,10 @@ def dim_decompose(
156
186
  captured: list = []
157
187
 
158
188
  def capture(mod, inp, out, store=captured):
159
- store.append(out[0] if isinstance(out, tuple) else out)
189
+ # CRITICAL: detach + clone so the captured tensor isn't
190
+ # overwritten by a later forward pass that reuses module buffers.
191
+ tensor = out[0] if isinstance(out, tuple) else out
192
+ store.append(tensor.detach().clone())
160
193
  capture_hooks.append(module.register_forward_hook(capture))
161
194
  src_acts_by_layer[idx] = captured
162
195
 
@@ -44,20 +44,57 @@ class Backend(abc.ABC):
44
44
  return klass
45
45
  return deco
46
46
 
47
+ # HF model_type → backend name. Transformer family covers most HF decoder LMs;
48
+ # add new families here as they ship. Auto-detect intentionally raises when
49
+ # nothing matches (silent fallback caused real bugs in v0.2.4).
50
+ _AUTODETECT = {
51
+ # transformer family
52
+ "llama": "transformer",
53
+ "mistral": "transformer",
54
+ "qwen2": "transformer",
55
+ "qwen3": "transformer",
56
+ "gpt2": "transformer",
57
+ "gpt_neox": "transformer", # Pythia uses gpt_neox
58
+ "gpt_neo": "transformer",
59
+ "gptj": "transformer",
60
+ "falcon": "transformer",
61
+ "mpt": "transformer",
62
+ "bloom": "transformer",
63
+ "opt": "transformer",
64
+ "phi": "transformer",
65
+ "phi3": "transformer",
66
+ "gemma": "transformer",
67
+ "gemma2": "transformer",
68
+ "starcoder2": "transformer",
69
+ # SSM family
70
+ "mamba": "mamba",
71
+ "mamba2": "mamba",
72
+ }
73
+
47
74
  @classmethod
48
75
  def for_model(cls, model: Any, hint: str | None = None) -> "Backend":
49
- """Auto-detect or use hint to select backend."""
50
- if hint and hint in cls._registry:
51
- return cls._registry[hint](model)
52
- # Auto-detect via attribute introspection
53
- if hasattr(model, "config") and getattr(model.config, "model_type", None) in ("llama", "gpt2", "qwen2", "qwen3"):
54
- return cls._registry["transformer"](model)
55
- if hasattr(model, "config") and getattr(model.config, "model_type", "") in ("mamba", "mamba2"):
56
- return cls._registry["mamba"](model)
57
- # Default fallback
58
- if "recurrent" in cls._registry:
59
- return cls._registry["recurrent"](model)
60
- raise ValueError(f"No backend matches model {type(model).__name__}. Register via Backend.register('name').")
76
+ """Auto-detect (or use hint) to select a backend.
77
+
78
+ Raises ValueError if no hint is provided and the model's ``config.model_type``
79
+ is not in the autodetect table. Pass ``hint=...`` explicitly for any model
80
+ that's not auto-detected, or register a custom backend via
81
+ ``Backend.register('name')``.
82
+ """
83
+ if hint:
84
+ if hint in cls._registry:
85
+ return cls._registry[hint](model)
86
+ raise ValueError(
87
+ f"Unknown backend hint '{hint}'. Registered: {sorted(cls._registry)}"
88
+ )
89
+ model_type = getattr(getattr(model, "config", None), "model_type", None)
90
+ if model_type in cls._AUTODETECT:
91
+ return cls._registry[cls._AUTODETECT[model_type]](model)
92
+ raise ValueError(
93
+ f"No backend matches model with config.model_type={model_type!r} "
94
+ f"(type {type(model).__name__}). Pass hint=... explicitly, or "
95
+ f"register a custom backend via Backend.register('name'). "
96
+ f"Auto-detected types: {sorted(cls._AUTODETECT)}"
97
+ )
61
98
 
62
99
  def __init__(self, model: Any):
63
100
  self.model = model
@@ -98,7 +135,10 @@ class TransformerBackend(Backend):
98
135
  """HuggingFace transformers backend — extracts residual stream per layer."""
99
136
 
100
137
  def layer_names(self) -> list[str]:
101
- # Standard HF: model.model.layers[i] for decoder transformers
138
+ # Layer names are virtual handles consumed by .extract(), which uses
139
+ # HF's `output_hidden_states=True` to retrieve the residual stream
140
+ # (no direct attribute walk into model.model.layers[i] needed —
141
+ # so this works across HF decoder LM families).
102
142
  n_layers = getattr(self.model.config, "num_hidden_layers", 0)
103
143
  return [f"layer_{i}.residual" for i in range(n_layers)]
104
144
 
@@ -74,12 +74,22 @@ def induction_head_score(
74
74
  else:
75
75
  vocab_size = 50257 # GPT-2 default
76
76
 
77
+ # Adaptive vocab window — defaults to [100, 40000) for full-size LMs but
78
+ # tightens for small-vocab toy models so we don't sample outside the range.
79
+ lo = min(100, max(1, vocab_size // 4))
80
+ hi = min(vocab_size, 40000)
81
+ if hi - lo < 2 * n_pairs:
82
+ raise ValueError(
83
+ f"induction_head_score: vocab window [{lo}, {hi}) has only "
84
+ f"{hi - lo} tokens but n_pairs={n_pairs} requires {2 * n_pairs} distinct ids. "
85
+ f"Lower n_pairs or pass a model with vocab_size >= {2 * n_pairs + 100}."
86
+ )
87
+
77
88
  successes = 0
78
89
  rank_sum = 0.0
79
90
  prob_target_sum = 0.0
80
91
  for trial in range(n_trials):
81
- # Pick n_pairs random token pairs
82
- tokens = rng.sample(range(100, min(vocab_size, 40000)), 2 * n_pairs)
92
+ tokens = rng.sample(range(lo, hi), 2 * n_pairs)
83
93
  seq = []
84
94
  pairs = []
85
95
  for i in range(n_pairs):
@@ -154,12 +164,20 @@ def copy_score(
154
164
  words = rng.sample(word_pool, n_words)
155
165
  prompt = f"list: {' '.join(words)}. list: "
156
166
  ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
157
- ids.shape[1]
158
167
 
159
- # Token IDs for target words (first token of each)
168
+ # Different tokenizers handle whitespace differently:
169
+ # - BPE (GPT-2 / NeoX / Pythia / Llama-2): " word" → leading-space token
170
+ # - SentencePiece (Llama-3, Qwen, T5): "▁word" → leading-underscore token
171
+ # Try " word" first; fall back to bare word for tokenizers that don't
172
+ # use a space prefix.
160
173
  target_tokens = []
161
174
  for w in words:
162
- target_tokens.append(tokenizer(" " + w, add_special_tokens=False).input_ids[0])
175
+ ids_w = tokenizer(" " + w, add_special_tokens=False).input_ids
176
+ if not ids_w:
177
+ ids_w = tokenizer(w, add_special_tokens=False).input_ids
178
+ if not ids_w:
179
+ continue # pathological; skip
180
+ target_tokens.append(ids_w[0])
163
181
 
164
182
  # Autoregressively predict n_words tokens, chaining the model's own
165
183
  # predictions (not teacher-forcing) — measures cumulative copy ability.
@@ -89,6 +89,16 @@ def bench(model_name: str, arch: str, out: str | None) -> None:
89
89
  return tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=32)
90
90
 
91
91
  arch_family = {"transformer": "transformer", "mamba": "ssm", "kazdov": "hybrid"}[arch]
92
+
93
+ # For Mamba, pick a representative SSM-state layer at mid-depth so the
94
+ # ssm_state_variance_ratio metric is populated (otherwise bench returns NaN).
95
+ extra: dict = {}
96
+ if arch == "mamba":
97
+ from .backends import Backend
98
+ backend = Backend.for_model(model, hint="mamba")
99
+ n_residual = sum(1 for ln in backend.layer_names() if ".residual" in ln)
100
+ extra["ssm_layer"] = max(0, n_residual // 2)
101
+
92
102
  profile = bench_mod.benchmark(
93
103
  model_name=model_name,
94
104
  model=model,
@@ -96,6 +106,7 @@ def bench(model_name: str, arch: str, out: str | None) -> None:
96
106
  backend_hint=arch,
97
107
  arch_family=arch_family,
98
108
  tokenize_fn=tokenize_fn,
109
+ **extra,
99
110
  )
100
111
 
101
112
  markdown = bench_mod.profile_to_markdown(profile)
@@ -155,6 +155,10 @@ def compare(
155
155
  raise ValueError("base and fine_tuned have different layer structure — "
156
156
  "they must share architecture")
157
157
 
158
+ # Ensure tokenizer has a pad token (GPT-2 family ships without one).
159
+ if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None) is not None:
160
+ tokenizer.pad_token = tokenizer.eos_token
161
+
158
162
  # Tokenize calibration
159
163
  enc = tokenizer(calibration_texts, return_tensors="pt", padding=True,
160
164
  truncation=True, max_length=max_length)
@@ -218,14 +218,36 @@ class TunedLens(nn.Module):
218
218
 
219
219
  opt = torch.optim.AdamW(tl.translators.parameters(), lr=lr)
220
220
 
221
- # Pre-extract all activations + target logits once
221
+ # Pre-extract all activations + target logits once.
222
+ # Ensure tokenizer has a pad token (GPT-2 family ships without one).
223
+ if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None) is not None:
224
+ tokenizer.pad_token = tokenizer.eos_token
225
+
222
226
  enc = tokenizer(calibration_texts, return_tensors="pt", padding=True,
223
227
  truncation=True, max_length=max_len)
224
228
  inputs = {"input_ids": enc["input_ids"].to(device)}
229
+ if "attention_mask" in enc:
230
+ inputs["attention_mask"] = enc["attention_mask"].to(device)
231
+
232
+ # Per-row index of the last REAL (non-pad) token. If no attention_mask
233
+ # (single, unpadded sequence), the conventional last-position is fine.
234
+ if "attention_mask" in enc:
235
+ real_lengths = enc["attention_mask"].sum(dim=1).to(device) # (B,)
236
+ last_idx = (real_lengths - 1).clamp(min=0)
237
+ else:
238
+ B = inputs["input_ids"].shape[0]
239
+ last_idx = torch.full((B,), inputs["input_ids"].shape[1] - 1,
240
+ dtype=torch.long, device=device)
241
+
242
+ def gather_last(acts: torch.Tensor) -> torch.Tensor:
243
+ # acts: (B, T, H) → (B, H) at each row's real last position.
244
+ B = acts.shape[0]
245
+ return acts[torch.arange(B, device=acts.device), last_idx]
246
+
225
247
  with torch.no_grad():
226
248
  records = backend.extract(inputs, layers=layer_names)
227
- # Target: model's actual final logits at last position
228
- final_residual = records[-1].activations[:, -1, :]
249
+ # Target: model's actual final logits at last REAL position per row.
250
+ final_residual = gather_last(records[-1].activations)
229
251
  if norm is not None:
230
252
  final_residual = norm(final_residual)
231
253
  target_logits = unembed(final_residual).detach() # (B, vocab)
@@ -235,7 +257,7 @@ class TunedLens(nn.Module):
235
257
  opt.zero_grad()
236
258
  total_loss = 0.0
237
259
  for i, rec in enumerate(records):
238
- last = rec.activations[:, -1, :].detach()
260
+ last = gather_last(rec.activations).detach()
239
261
  translated = tl.translators[i](last)
240
262
  if norm is not None:
241
263
  translated = norm(translated)
@@ -17,7 +17,7 @@ from ._utils import resolve_layer_module
17
17
  @dataclass
18
18
  class NeuronEditConfig:
19
19
  top_frac: float = 0.001 # top 0.1% by default
20
- layer_filter: str | None = None # e.g., "mlp" to restrict to MLP neurons
20
+ layer_filter: str | None = None # substring filter on layer_names() (e.g. "residual")
21
21
  mode: str = "scalar" # "scalar" (multiply by m) or "ablate" (m=0)
22
22
 
23
23
 
@@ -87,8 +87,15 @@ def find_neurons(
87
87
  config = config or NeuronEditConfig()
88
88
  backend = Backend.for_model(model, hint=backend_hint)
89
89
 
90
- # Get all layers (will filter to MLP later if requested)
91
90
  all_layers = backend.layer_names()
91
+ if config.layer_filter is not None:
92
+ all_layers = [ln for ln in all_layers if config.layer_filter in ln]
93
+ if not all_layers:
94
+ raise ValueError(
95
+ f"layer_filter={config.layer_filter!r} matched no layers. "
96
+ f"Available substrings include: "
97
+ f"{sorted({ln.split('.', 1)[-1] for ln in backend.layer_names()})}"
98
+ )
92
99
 
93
100
  # Forward both classes, collect final-token activations
94
101
  harm_acts = backend.extract(inputs_harmful, layers=all_layers)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: archscope
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: Lightweight workbench for cross-architecture mechanistic interpretability experiments on small models
5
5
  Author: Juan Cruz Dovzak
6
6
  License: Apache-2.0
@@ -58,18 +58,17 @@ It is **not**: a competitor to `transformer_lens` or `nnsight` (both are broader
58
58
 
59
59
  ```python
60
60
  import archscope as mi
61
- from transformers import AutoModelForCausalLM, AutoTokenizer
62
-
63
- tok = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
64
- model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
65
61
 
66
- backend = mi.backends.Backend.for_model(model, hint="mamba")
62
+ # One call → HuggingFace model + tokenizer + the right backend
63
+ model, tok, backend = mi.load_model("state-spaces/mamba-130m-hf", arch="mamba")
67
64
 
68
65
  # Extract Mamba's recurrent SSM state h_t (in addition to residual stream)
69
66
  ssm = backend.extract(tok("text", return_tensors="pt"), layers=["layer_12.ssm_state"])[0]
70
67
  # Shape: (B, intermediate_size, ssm_state_size) = (B, 1536, 16) for mamba-130m
71
68
  ```
72
69
 
70
+ `load_model` handles `pad_token` setup, `model.eval()`, and backend auto-detection. If you'd rather drive `transformers` yourself, every method also accepts `backend_hint=...`.
71
+
73
72
  ---
74
73
 
75
74
  ## What's inside
@@ -96,12 +95,34 @@ ssm = backend.extract(tok("text", return_tensors="pt"), layers=["layer_12.ssm_st
96
95
 
97
96
  ### Backends
98
97
 
99
- | Backend | Models | Specific |
98
+ | Backend | Auto-detected `model_type` | What you get |
100
99
  |---|---|---|
101
- | `transformer` | Pythia, GPT-2, Llama, Mistral, Qwen, MPT, Falcon, GPT-Neo | residual stream |
102
- | `mamba` | Mamba, Mamba-2 | residual + explicit `.ssm_state` (recurrent h_t) |
103
- | `kazdov` | Kazdov-α hybrid MoBE-BCN+MHA | residual per custom block |
104
- | `recurrent` | Generic RNN (user subclass) | hidden state per layer |
100
+ | `transformer` | `llama`, `mistral`, `qwen2`, `qwen3`, `gpt2`, `gpt_neox` (Pythia), `gpt_neo`, `gptj`, `falcon`, `mpt`, `bloom`, `opt`, `phi`, `phi3`, `gemma`, `gemma2`, `starcoder2` | residual stream per layer |
101
+ | `mamba` | `mamba`, `mamba2` | residual + explicit `.ssm_state` (recurrent h_t) |
102
+ | `kazdov` | (pass `hint="kazdov"`) | residual per custom block |
103
+ | `recurrent` | (pass `hint="recurrent"`, subclass for full extract) | hidden state per layer |
104
+
105
+ If `Backend.for_model(model)` is called on a model whose `config.model_type` isn't in the autodetect list, it raises a clear `ValueError` rather than silently picking a backend. Pass `hint="..."` explicitly for anything outside the list, or register a new backend via `Backend.register("name")`.
106
+
107
+ ### Method × backend support
108
+
109
+ Not every method works on every architecture. The cross-product:
110
+
111
+ | Method | transformer | mamba | kazdov | recurrent |
112
+ |---|:---:|:---:|:---:|:---:|
113
+ | `probes.fit_probe` | ✅ | ✅ | ✅ | ✅ |
114
+ | `sae.fit_sae` (Dense / Rank-1) | ✅ | ✅ | ✅ | ✅ |
115
+ | `neurons.find_neurons` | ✅ | ✅ | ✅ | ✅ |
116
+ | `attribute.activation_patch` | ✅ | ✅ residual only | ✅ | ⚠️ subclass needed |
117
+ | `attribute.dim_decompose` | ✅ | ❌ no attention/MLP submods | ✅ | ❌ |
118
+ | `circuits.*` (behavioural) | ✅ | ✅ | ✅ | ✅ |
119
+ | `lens.logit_lens` | ✅ | ⚠️ degrades with depth — use `TunedLens` | ✅ | ⚠️ |
120
+ | `lens.TunedLens.fit` | ✅ | ✅ | ✅ | ⚠️ |
121
+ | `diff.compare` | ✅ | ✅ | ✅ | ✅ |
122
+ | `transfer.evaluate_transfer` | ✅ ↔ any | ✅ ↔ any | ✅ ↔ any | ✅ ↔ any |
123
+ | `bench.benchmark` | ✅ | ✅ | ✅ | partial |
124
+
125
+ ❌ entries raise a clear `ValueError` rather than silently degrading.
105
126
 
106
127
  ---
107
128
 
@@ -98,7 +98,7 @@ def main():
98
98
  print(" • concentration relative ≈ 0 → highly confident predictions (concentrated)")
99
99
 
100
100
  # Save
101
- out_path = "str(__import__("pathlib").Path(__file__).parent.parent / "_research")/circuits_3arch.json"
101
+ out_path = str(__import__("pathlib").Path(__file__).parent.parent / "_research" / "circuits_3arch.json")
102
102
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
103
103
  with open(out_path, "w") as f:
104
104
  json.dump(all_results, f, indent=2, default=str)
@@ -20,9 +20,9 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
20
20
  def test_imports():
21
21
  """All modules import without errors."""
22
22
  import archscope
23
- from archscope import (probes, sae, neurons, attribute, backends,
24
- circuits, transfer, bench, lens, diff)
25
- assert archscope.__version__ == "0.2.4"
23
+ from archscope import (probes, sae, neurons, attribute, backends, # noqa: F401
24
+ circuits, transfer, bench, lens, diff) # noqa: F401
25
+ assert archscope.__version__ == "0.2.6"
26
26
 
27
27
 
28
28
  def test_loader_exports():
@@ -36,7 +36,7 @@ def test_loader_exports():
36
36
 
37
37
  def test_layer_name_validation_clear_error():
38
38
  """Backend validates layer names with an informative error."""
39
- from archscope.backends import Backend, ActivationRecord
39
+ from archscope.backends import Backend
40
40
 
41
41
  # Build a minimal mock backend
42
42
  class _MockBackend(Backend):
@@ -133,14 +133,12 @@ def test_backend_registry():
133
133
  assert name in Backend._registry, f"{name} not registered"
134
134
 
135
135
 
136
- def test_kazdov_backend_registers_when_available():
137
- """KazdovBackend optional import succeeds."""
136
+ def test_kazdov_backend_registers():
137
+ """KazdovBackend is always registered (generic blocks-based backend)."""
138
138
  from archscope.backends import Backend
139
- # kazdov_backend imports at __init__ and registers — it's optional
140
- if "kazdov" in Backend._registry:
141
- # If kazdov repo is importable, backend should be there
142
- from archscope.kazdov_backend import KazdovBackend
143
- assert KazdovBackend is Backend._registry["kazdov"]
139
+ from archscope.kazdov_backend import KazdovBackend
140
+ assert "kazdov" in Backend._registry
141
+ assert KazdovBackend is Backend._registry["kazdov"]
144
142
 
145
143
 
146
144
  def test_alignment_math():
@@ -194,6 +192,92 @@ def test_interpprofile_serializes():
194
192
  assert "test" in j
195
193
 
196
194
 
195
+ def test_activation_patch_rejects_shape_mismatch():
196
+ """activation_patch surfaces a clear error when source/target shapes differ."""
197
+ from archscope.attribute import activation_patch
198
+ src = {"input_ids": torch.tensor([[1, 2, 3]])}
199
+ tgt = {"input_ids": torch.tensor([[1, 2, 3, 4, 5]])}
200
+ with pytest.raises(ValueError) as ei:
201
+ activation_patch(model=None, prompt_source=src, prompt_target=tgt,
202
+ layer_indices=[0], metric_fn=lambda o: 0.0,
203
+ backend_hint="transformer")
204
+ assert "matching input_ids shape" in str(ei.value)
205
+
206
+
207
+ def test_backend_for_model_raises_on_unknown_type():
208
+ """Unknown config.model_type → clear ValueError, no silent fallback."""
209
+ from archscope.backends import Backend
210
+
211
+ class _FakeConfig:
212
+ model_type = "not_a_real_arch"
213
+
214
+ class _FakeModel:
215
+ config = _FakeConfig()
216
+
217
+ with pytest.raises(ValueError) as ei:
218
+ Backend.for_model(_FakeModel())
219
+ msg = str(ei.value)
220
+ assert "No backend matches" in msg
221
+ assert "not_a_real_arch" in msg
222
+
223
+
224
+ def test_backend_for_model_autodetect_includes_pythia():
225
+ """gpt_neox (Pythia) auto-detects to transformer backend."""
226
+ from archscope.backends import Backend, TransformerBackend
227
+
228
+ class _FakeConfig:
229
+ model_type = "gpt_neox"
230
+ num_hidden_layers = 2
231
+ hidden_size = 8
232
+
233
+ class _FakeModel:
234
+ config = _FakeConfig()
235
+
236
+ backend = Backend.for_model(_FakeModel())
237
+ assert isinstance(backend, TransformerBackend)
238
+
239
+
240
+ def test_neurons_layer_filter_rejects_nonmatching():
241
+ """layer_filter that matches nothing raises with a helpful message."""
242
+ from archscope.neurons import NeuronEditConfig
243
+ cfg = NeuronEditConfig(layer_filter="not_a_substring")
244
+ assert cfg.layer_filter == "not_a_substring"
245
+
246
+
247
+ def test_induction_head_score_small_vocab_clear_error():
248
+ """induction_head_score raises a clear error when vocab is too small."""
249
+ from archscope.circuits import induction_head_score
250
+
251
+ class _TinyModel:
252
+ class config:
253
+ vocab_size = 40 # << 2*n_pairs + 100
254
+ def __call__(self, ids):
255
+ return torch.zeros(1, ids.shape[1], 40)
256
+
257
+ with pytest.raises(ValueError) as ei:
258
+ induction_head_score(_TinyModel(), n_pairs=20, n_trials=1)
259
+ assert "vocab window" in str(ei.value).lower() or "n_pairs" in str(ei.value)
260
+
261
+
262
+ def test_dim_decompose_rejects_mamba_style_model():
263
+ """dim_decompose raises on models with no attention/MLP submodules."""
264
+ from archscope.attribute import dim_decompose
265
+
266
+ class _NoSubmods(torch.nn.Module):
267
+ def forward(self, **kwargs):
268
+ class Out:
269
+ logits = torch.zeros(1, 3, 8)
270
+ return Out()
271
+
272
+ with pytest.raises(ValueError) as ei:
273
+ dim_decompose(_NoSubmods(),
274
+ prompt_a={"input_ids": torch.tensor([[1, 2, 3]])},
275
+ prompt_b={"input_ids": torch.tensor([[4, 5, 6]])},
276
+ layer_indices=[0, 1],
277
+ metric_fn=lambda o: 0.0)
278
+ assert "attention" in str(ei.value).lower() or "submod" in str(ei.value).lower()
279
+
280
+
197
281
  if __name__ == "__main__":
198
282
  # Allow `python tests/test_unit.py` for quick local check
199
283
  pytest.main([__file__, "-v"])
File without changes
File without changes
File without changes
File without changes