interpkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,90 @@
1
+ """ablate — zero or mean ablate a module and measure the effect on output."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import torch
8
+
9
+ from interpkit.ops.patch import _get_module
10
+
11
+ if TYPE_CHECKING:
12
+ from interpkit.core.model import Model
13
+
14
+
15
+ def run_ablate(
16
+ model: "Model",
17
+ input_data: Any,
18
+ *,
19
+ at: str,
20
+ method: str = "zero",
21
+ ) -> dict[str, Any]:
22
+ """Ablate module *at* and measure the effect on output logits.
23
+
24
+ Parameters
25
+ ----------
26
+ method:
27
+ ``"zero"`` replaces the module output with zeros.
28
+ ``"mean"`` replaces it with the mean activation across the sequence dimension.
29
+ """
30
+ from interpkit.core.render import render_ablate
31
+
32
+ model_input = model._prepare(input_data)
33
+ target_mod = _get_module(model._model, at)
34
+
35
+ # 1. Clean forward — get baseline logits
36
+ clean_logits = model._forward(model_input)
37
+
38
+ # 2. Ablated forward
39
+ def _ablate_hook(_mod: torch.nn.Module, _inp: Any, output: Any) -> Any:
40
+ t = output if isinstance(output, torch.Tensor) else (
41
+ output[0] if isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor) else None
42
+ )
43
+ if t is None:
44
+ return output
45
+
46
+ if method == "zero":
47
+ replacement = torch.zeros_like(t)
48
+ elif method == "mean":
49
+ if t.dim() >= 3:
50
+ replacement = t.mean(dim=-2, keepdim=True).expand_as(t)
51
+ else:
52
+ replacement = t.mean(dim=-1, keepdim=True).expand_as(t)
53
+ else:
54
+ raise ValueError(f"Unknown ablation method: {method!r}. Use 'zero' or 'mean'.")
55
+
56
+ if isinstance(output, torch.Tensor):
57
+ return replacement
58
+ return (replacement,) + tuple(output[1:])
59
+
60
+ handle = target_mod.register_forward_hook(_ablate_hook)
61
+ ablated_logits = model._forward(model_input)
62
+ handle.remove()
63
+
64
+ effect = _compute_ablation_effect(clean_logits, ablated_logits)
65
+
66
+ result = {
67
+ "module": at,
68
+ "method": method,
69
+ "effect": effect,
70
+ "clean_logits": clean_logits,
71
+ "ablated_logits": ablated_logits,
72
+ }
73
+ render_ablate(result)
74
+ return result
75
+
76
+
77
+ def _compute_ablation_effect(clean: torch.Tensor, ablated: torch.Tensor) -> float:
78
+ """Measure how much ablation changed the output (0 = no change, 1 = max change)."""
79
+ clean_flat = clean.view(-1, clean.shape[-1]).float()
80
+ ablated_flat = ablated.view(-1, ablated.shape[-1]).float()
81
+
82
+ if clean_flat.shape[0] > 1:
83
+ clean_flat = clean_flat[-1:]
84
+ ablated_flat = ablated_flat[-1:]
85
+
86
+ clean_probs = torch.softmax(clean_flat, dim=-1)
87
+ ablated_probs = torch.softmax(ablated_flat, dim=-1)
88
+
89
+ cosine_sim = torch.nn.functional.cosine_similarity(clean_probs, ablated_probs, dim=-1)
90
+ return (1.0 - cosine_sim.item())
@@ -0,0 +1,67 @@
1
+ """activations — extract raw activation tensors at any named module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import torch
8
+
9
+ from interpkit.ops.patch import _get_module
10
+
11
+ if TYPE_CHECKING:
12
+ from interpkit.core.model import Model
13
+
14
+
15
+ def run_activations(
16
+ model: "Model",
17
+ input_data: Any,
18
+ *,
19
+ at: str | list[str],
20
+ print_stats: bool = True,
21
+ ) -> dict[str, torch.Tensor] | torch.Tensor:
22
+ """Extract activations at one or more named modules.
23
+
24
+ Returns a single tensor if *at* is a string, or a dict if *at* is a list.
25
+ """
26
+ model_input = model._prepare(input_data)
27
+ single = isinstance(at, str)
28
+ module_names = [at] if single else list(at)
29
+
30
+ # Check activation cache first (pass prepared input to avoid re-tokenizing)
31
+ cached = model._get_cached(input_data, module_names, _prepared_input=model_input)
32
+ if cached is not None:
33
+ cache = cached
34
+ else:
35
+ cache = {}
36
+
37
+ def _make_hook(name: str):
38
+ def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
39
+ t = output if isinstance(output, torch.Tensor) else (
40
+ output[0] if isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor) else None
41
+ )
42
+ if t is not None:
43
+ cache[name] = t.detach().clone()
44
+ return hook_fn
45
+
46
+ hooks = []
47
+ for name in module_names:
48
+ mod = _get_module(model._model, name)
49
+ hooks.append(mod.register_forward_hook(_make_hook(name)))
50
+
51
+ with torch.no_grad():
52
+ model._forward(model_input)
53
+
54
+ for h in hooks:
55
+ h.remove()
56
+
57
+ if print_stats:
58
+ from interpkit.core.render import render_activations
59
+
60
+ render_activations(cache)
61
+
62
+ if single:
63
+ if at not in cache:
64
+ raise RuntimeError(f"Module '{at}' produced no tensor output.")
65
+ return cache[at]
66
+
67
+ return cache
@@ -0,0 +1,234 @@
1
+ """attention — capture and display attention patterns for transformer models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from rich.console import Console
11
+
12
+ from interpkit.ops.patch import _get_module
13
+
14
+ if TYPE_CHECKING:
15
+ from interpkit.core.model import Model
16
+
17
+ console = Console()
18
+
19
+
20
+ def run_attention(
21
+ model: "Model",
22
+ input_data: Any,
23
+ *,
24
+ layer: int | None = None,
25
+ head: int | None = None,
26
+ save: str | None = None,
27
+ html: str | None = None,
28
+ ) -> list[dict[str, Any]] | None:
29
+ """Capture attention weights and display a summary.
30
+
31
+ Computes attention weights manually from Q/K projections via hooks,
32
+ since modern transformer implementations (SDPA, FlashAttention) don't
33
+ return attention weights.
34
+ """
35
+ from interpkit.core.render import render_attention
36
+
37
+ arch = model.arch_info
38
+ attn_modules = [m for m in arch.modules if m.role == "attention"]
39
+
40
+ if not attn_modules:
41
+ console.print(
42
+ "\n [yellow]attention not available:[/yellow] no attention modules detected"
43
+ f" for {arch.arch_family or 'this model'}.\n"
44
+ )
45
+ return None
46
+
47
+ model_input = model._prepare(input_data)
48
+
49
+ tokens = None
50
+ if model._tokenizer is not None and isinstance(input_data, str):
51
+ encoded = model._tokenizer(input_data, return_tensors="pt")
52
+ token_ids = encoded["input_ids"][0].tolist()
53
+ tokens = model._tokenizer.convert_ids_to_tokens(token_ids)
54
+
55
+ # Capture Q and K projections to compute attention weights manually
56
+ qk_cache: dict[str, dict[str, torch.Tensor]] = {}
57
+ hooks = []
58
+
59
+ for mod_info in attn_modules:
60
+ if layer is not None:
61
+ layer_match = re.search(r"\.(\d+)\.", mod_info.name)
62
+ if layer_match and int(layer_match.group(1)) != layer:
63
+ continue
64
+
65
+ attn_mod = _get_module(model._model, mod_info.name)
66
+
67
+ # Find Q/K projection submodules
68
+ for child_name, child_mod in attn_mod.named_modules():
69
+ full_name = f"{mod_info.name}.{child_name}" if child_name else mod_info.name
70
+ is_qkv = any(p in child_name.lower() for p in ("c_attn", "qkv", "q_proj", "query"))
71
+ is_k = any(p in child_name.lower() for p in ("k_proj", "key"))
72
+
73
+ if is_qkv or is_k:
74
+ def _make_qk_hook(name: str, attn_name: str):
75
+ def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
76
+ t = output if isinstance(output, torch.Tensor) else (
77
+ output[0] if isinstance(output, (tuple, list)) else None
78
+ )
79
+ if t is not None:
80
+ qk_cache.setdefault(attn_name, {})[name] = t.detach()
81
+ return hook_fn
82
+
83
+ hooks.append(child_mod.register_forward_hook(_make_qk_hook(child_name, mod_info.name)))
84
+
85
+ with torch.no_grad():
86
+ model._forward(model_input)
87
+
88
+ for h in hooks:
89
+ h.remove()
90
+
91
+ # Compute attention weights from cached Q/K
92
+ results: list[dict[str, Any]] = []
93
+
94
+ for attn_name, projections in qk_cache.items():
95
+ layer_match = re.search(r"\.(\d+)\.", attn_name)
96
+ layer_idx = int(layer_match.group(1)) if layer_match else 0
97
+
98
+ attn_weights = _compute_attention_from_projections(
99
+ projections, arch.num_attention_heads or 12
100
+ )
101
+
102
+ if attn_weights is None:
103
+ continue
104
+
105
+ num_heads = attn_weights.shape[0]
106
+ for head_idx in range(num_heads):
107
+ if head is not None and head_idx != head:
108
+ continue
109
+
110
+ head_attn = attn_weights[head_idx]
111
+ top_pairs = _get_top_pairs(head_attn, k=5)
112
+ entropy = _attention_entropy(head_attn)
113
+
114
+ results.append({
115
+ "layer": layer_idx,
116
+ "head": head_idx,
117
+ "top_pairs": top_pairs,
118
+ "entropy": entropy,
119
+ "weights": head_attn,
120
+ })
121
+
122
+ if not results:
123
+ console.print(
124
+ "\n [yellow]attention:[/yellow] could not compute attention weights.\n"
125
+ )
126
+ return None
127
+
128
+ model_name = arch.arch_family or "model"
129
+ render_attention(results, tokens, model_name)
130
+
131
+ if save is not None:
132
+ from interpkit.core.plot import plot_attention, plot_attention_multi
133
+
134
+ if layer is not None and head is not None and len(results) == 1:
135
+ plot_attention(results[0]["weights"], tokens, layer=results[0]["layer"],
136
+ head=results[0]["head"], save_path=save)
137
+ else:
138
+ plot_attention_multi(results, tokens, save_path=save)
139
+
140
+ if html is not None:
141
+ from interpkit.core.html import html_attention as gen_html_attention
142
+ from interpkit.core.html import save_html
143
+
144
+ serializable = []
145
+ for r in results:
146
+ entry = {**r}
147
+ w = r.get("weights")
148
+ if isinstance(w, torch.Tensor):
149
+ entry["weights"] = w.tolist()
150
+ serializable.append(entry)
151
+ save_html(gen_html_attention(serializable, tokens), html)
152
+
153
+ return results
154
+
155
+
156
+ def _compute_attention_from_projections(
157
+ projections: dict[str, torch.Tensor],
158
+ num_heads: int,
159
+ ) -> torch.Tensor | None:
160
+ """Compute attention weights from captured QKV or Q/K projections."""
161
+ # GPT-2 style: c_attn produces [Q, K, V] concatenated
162
+ for key, tensor in projections.items():
163
+ if "c_attn" in key or "qkv" in key:
164
+ # tensor shape: (batch, seq, 3 * hidden) or (seq, 3 * hidden)
165
+ if tensor.dim() == 3:
166
+ tensor = tensor[0] # drop batch
167
+ hidden = tensor.shape[-1] // 3
168
+ q, k, _v = tensor.split(hidden, dim=-1)
169
+ return _qk_to_attention(q, k, num_heads)
170
+
171
+ # Separate Q and K projections
172
+ q_tensor = None
173
+ k_tensor = None
174
+ for key, tensor in projections.items():
175
+ if "q_proj" in key or "query" in key:
176
+ q_tensor = tensor
177
+ elif "k_proj" in key or "key" in key:
178
+ k_tensor = tensor
179
+
180
+ if q_tensor is not None and k_tensor is not None:
181
+ if q_tensor.dim() == 3:
182
+ q_tensor = q_tensor[0]
183
+ if k_tensor.dim() == 3:
184
+ k_tensor = k_tensor[0]
185
+ return _qk_to_attention(q_tensor, k_tensor, num_heads)
186
+
187
+ return None
188
+
189
+
190
+ def _qk_to_attention(
191
+ q: torch.Tensor, k: torch.Tensor, num_heads: int,
192
+ ) -> torch.Tensor:
193
+ """Compute attention weights from Q and K tensors.
194
+
195
+ q, k: (seq_len, hidden_size)
196
+ Returns: (num_heads, seq_len, seq_len)
197
+ """
198
+ seq_len, hidden = q.shape
199
+ head_dim = hidden // num_heads
200
+
201
+ q = q.view(seq_len, num_heads, head_dim).transpose(0, 1) # (heads, seq, head_dim)
202
+ k = k.view(seq_len, num_heads, head_dim).transpose(0, 1)
203
+
204
+ scale = head_dim ** 0.5
205
+ scores = torch.matmul(q, k.transpose(-2, -1)) / scale # (heads, seq, seq)
206
+
207
+ # Apply causal mask
208
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=q.device), diagonal=1)
209
+ scores.masked_fill_(causal_mask.unsqueeze(0), float("-inf"))
210
+
211
+ return F.softmax(scores, dim=-1)
212
+
213
+
214
+ def _get_top_pairs(
215
+ attn: torch.Tensor, k: int = 5,
216
+ ) -> list[tuple[int, int, float]]:
217
+ """Find top-k (source_pos, target_pos, score) pairs in an attention matrix."""
218
+ flat = attn.view(-1)
219
+ topk_vals, topk_idxs = flat.topk(min(k, flat.numel()))
220
+ seq_len = attn.shape[-1]
221
+ pairs = []
222
+ for val, idx in zip(topk_vals.tolist(), topk_idxs.tolist()):
223
+ src = idx // seq_len
224
+ tgt = idx % seq_len
225
+ pairs.append((src, tgt, val))
226
+ return pairs
227
+
228
+
229
+ def _attention_entropy(attn: torch.Tensor) -> float:
230
+ """Mean entropy of attention distributions across query positions."""
231
+ eps = 1e-10
232
+ log_attn = torch.log(attn + eps)
233
+ entropy_per_query = -(attn * log_attn).sum(dim=-1)
234
+ return entropy_per_query.mean().item()
@@ -0,0 +1,206 @@
1
+ """attribute — gradient saliency over input tokens or pixels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import torch
8
+
9
+ if TYPE_CHECKING:
10
+ from interpkit.core.model import Model
11
+
12
+
13
+ def run_attribute(
14
+ model: "Model",
15
+ input_data: Any,
16
+ *,
17
+ target: int | None = None,
18
+ save: str | None = None,
19
+ html: str | None = None,
20
+ ) -> None:
21
+ """Compute gradient-based saliency and render results.
22
+
23
+ For text inputs: shows coloured tokens by importance.
24
+ For image/tensor inputs: saves a heatmap to disk.
25
+ """
26
+ is_text = isinstance(input_data, str) and not _is_image_path(input_data)
27
+ is_image = isinstance(input_data, str) and _is_image_path(input_data)
28
+
29
+ if is_text:
30
+ _attribute_text(model, input_data, target=target, save=save, html=html)
31
+ elif is_image:
32
+ _attribute_image(model, input_data, target=target, save=save)
33
+ else:
34
+ _attribute_tensor(model, input_data, target=target)
35
+
36
+
37
+ def _attribute_text(model: "Model", text: str, *, target: int | None, save: str | None = None, html: str | None = None) -> None:
38
+ from interpkit.core.render import render_attribution_tokens
39
+
40
+ if model._tokenizer is None:
41
+ raise ValueError("No tokenizer available for text attribution.")
42
+
43
+ encoded = model._tokenizer(text, return_tensors="pt")
44
+ input_ids = encoded["input_ids"].to(model._device)
45
+
46
+ # Get embedding layer
47
+ embed_layer = _find_embedding(model._model)
48
+ if embed_layer is None:
49
+ raise RuntimeError("Could not find embedding layer for gradient attribution.")
50
+
51
+ embeddings = embed_layer(input_ids)
52
+ embeddings = embeddings.detach().requires_grad_(True)
53
+
54
+ # Replace embedding output with our gradient-tracked version
55
+ original_forward = embed_layer.forward
56
+
57
+ def _patched_forward(*args: Any, **kwargs: Any) -> torch.Tensor:
58
+ return embeddings
59
+
60
+ embed_layer.forward = _patched_forward # type: ignore[assignment]
61
+
62
+ try:
63
+ model_kwargs = {k: v.to(model._device) for k, v in encoded.items() if k != "input_ids"}
64
+ out = model._model(input_ids, **model_kwargs)
65
+
66
+ logits = out.logits if hasattr(out, "logits") else (out[0] if isinstance(out, (tuple, list)) else out)
67
+
68
+ # Pick target: last-position argmax if not specified
69
+ if logits.dim() == 3:
70
+ logits_last = logits[0, -1, :]
71
+ else:
72
+ logits_last = logits[0]
73
+
74
+ if target is None:
75
+ target = logits_last.argmax().item()
76
+
77
+ score = logits_last[target]
78
+ score.backward()
79
+ finally:
80
+ embed_layer.forward = original_forward # type: ignore[assignment]
81
+
82
+ if embeddings.grad is None:
83
+ raise RuntimeError("Gradient computation failed — no gradients on embeddings.")
84
+
85
+ # Per-token importance: L2 norm of gradient over the embedding dimension
86
+ token_grads = embeddings.grad[0] # (seq_len, hidden)
87
+ token_scores = token_grads.norm(dim=-1).tolist() # (seq_len,)
88
+
89
+ tokens = model._tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
90
+ render_attribution_tokens(tokens, token_scores)
91
+
92
+ if save is not None:
93
+ from interpkit.core.plot import plot_attribution
94
+
95
+ plot_attribution(tokens, token_scores, save_path=save)
96
+
97
+ if html is not None:
98
+ from interpkit.core.html import html_attribution as gen_html_attribution
99
+ from interpkit.core.html import save_html
100
+
101
+ save_html(gen_html_attribution(tokens, token_scores), html)
102
+
103
+
104
+ def _attribute_image(model: "Model", image_path: str, *, target: int | None, save: str | None = None) -> None:
105
+ from interpkit.core.inputs import _load_image
106
+ from interpkit.core.render import render_attribution_heatmap
107
+
108
+ processed = _load_image(
109
+ image_path,
110
+ image_processor=model._image_processor,
111
+ device=model._device,
112
+ )
113
+
114
+ if isinstance(processed, dict):
115
+ pixel_key = "pixel_values" if "pixel_values" in processed else list(processed.keys())[0]
116
+ pixel_values = processed[pixel_key].requires_grad_(True)
117
+ model_input = {**processed, pixel_key: pixel_values}
118
+ out = model._model(**model_input)
119
+ else:
120
+ pixel_values = processed.requires_grad_(True)
121
+ out = model._model(pixel_values)
122
+
123
+ logits = out.logits if hasattr(out, "logits") else (out[0] if isinstance(out, (tuple, list)) else out)
124
+
125
+ if logits.dim() > 1:
126
+ logits_flat = logits[0]
127
+ else:
128
+ logits_flat = logits
129
+
130
+ if target is None:
131
+ target = logits_flat.argmax().item()
132
+
133
+ score = logits_flat[target]
134
+ score.backward()
135
+
136
+ if pixel_values.grad is None:
137
+ raise RuntimeError("Gradient computation failed — no gradients on pixel values.")
138
+
139
+ out_path = save or "attribution_heatmap.png"
140
+ render_attribution_heatmap(pixel_values.grad[0], output_path=out_path)
141
+
142
+
143
+ def _attribute_tensor(model: "Model", tensor_input: Any, *, target: int | None) -> None:
144
+ from interpkit.core.render import render_attribution_tokens
145
+
146
+ inp = model._prepare(tensor_input)
147
+
148
+ if isinstance(inp, dict):
149
+ # Pick the first tensor-valued entry
150
+ for k, v in inp.items():
151
+ if isinstance(v, torch.Tensor) and v.is_floating_point():
152
+ inp[k] = v.requires_grad_(True)
153
+ grad_tensor = v
154
+ break
155
+ else:
156
+ raise ValueError("No floating-point tensor found in input dict.")
157
+ out = model._model(**inp)
158
+ else:
159
+ inp = inp.requires_grad_(True)
160
+ grad_tensor = inp
161
+ out = model._model(inp)
162
+
163
+ logits = out.logits if hasattr(out, "logits") else (out[0] if isinstance(out, (tuple, list)) else out)
164
+ logits_flat = logits.view(-1)
165
+
166
+ if target is None:
167
+ target = logits_flat.argmax().item()
168
+
169
+ score = logits_flat[target]
170
+ score.backward()
171
+
172
+ if grad_tensor.grad is None:
173
+ raise RuntimeError("Gradient computation failed.")
174
+
175
+ # Flatten to feature importance
176
+ grad = grad_tensor.grad.detach().float()
177
+ if grad.dim() > 1:
178
+ feature_scores = grad.view(grad.shape[0], -1).norm(dim=0).tolist()
179
+ else:
180
+ feature_scores = grad.abs().tolist()
181
+
182
+ labels = [f"feat_{i}" for i in range(len(feature_scores))]
183
+ render_attribution_tokens(labels, feature_scores)
184
+
185
+
186
+ def _find_embedding(model: torch.nn.Module) -> torch.nn.Module | None:
187
+ """Find the token embedding layer."""
188
+ for name, mod in model.named_modules():
189
+ if isinstance(mod, torch.nn.Embedding) and "token" not in name.lower().replace("token", ""):
190
+ pass
191
+ if isinstance(mod, torch.nn.Embedding):
192
+ # Pick the largest embedding (usually token embeddings, not position)
193
+ if mod.num_embeddings > 1000:
194
+ return mod
195
+
196
+ # Fallback: first embedding
197
+ for _name, mod in model.named_modules():
198
+ if isinstance(mod, torch.nn.Embedding):
199
+ return mod
200
+
201
+ return None
202
+
203
+
204
+ def _is_image_path(s: str) -> bool:
205
+ import os
206
+ return os.path.splitext(s)[1].lower() in {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
interpkit/ops/diff.py ADDED
@@ -0,0 +1,79 @@
1
+ """diff — compare activations between two models on the same input."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import torch
8
+
9
+ if TYPE_CHECKING:
10
+ from interpkit.core.model import Model
11
+
12
+
13
+ def run_diff(
14
+ model_a: "Model",
15
+ model_b: "Model",
16
+ input_data: Any,
17
+ *,
18
+ save: str | None = None,
19
+ ) -> list[dict[str, Any]]:
20
+ """Compare activations between two models at all discovered layers.
21
+
22
+ Returns a list of dicts sorted by cosine distance (highest change first).
23
+ """
24
+ from interpkit.core.render import render_diff
25
+ from interpkit.ops.activations import run_activations
26
+
27
+ # Find shared layer-like modules
28
+ layers_a = set(model_a.arch_info.layer_names or [])
29
+ layers_b = set(model_b.arch_info.layer_names or [])
30
+
31
+ # If no layers detected, fall back to all named modules
32
+ if not layers_a:
33
+ layers_a = {m.name for m in model_a.arch_info.modules if m.param_count > 0}
34
+ if not layers_b:
35
+ layers_b = {m.name for m in model_b.arch_info.modules if m.param_count > 0}
36
+
37
+ shared_layers = sorted(layers_a & layers_b)
38
+
39
+ if not shared_layers:
40
+ from rich.console import Console
41
+ Console().print("\n [yellow]diff:[/yellow] no shared modules found between the two models.\n")
42
+ return []
43
+
44
+ acts_a = run_activations(model_a, input_data, at=shared_layers, print_stats=False)
45
+ acts_b = run_activations(model_b, input_data, at=shared_layers, print_stats=False)
46
+
47
+ results: list[dict[str, Any]] = []
48
+ for name in shared_layers:
49
+ if name not in acts_a or name not in acts_b:
50
+ continue
51
+
52
+ a = acts_a[name].float().view(-1)
53
+ b = acts_b[name].float().view(-1)
54
+
55
+ if a.shape != b.shape:
56
+ min_size = min(a.numel(), b.numel())
57
+ a = a[:min_size]
58
+ b = b[:min_size]
59
+
60
+ cosine_sim = torch.nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0), dim=-1)
61
+ distance = (1.0 - cosine_sim.item())
62
+
63
+ results.append({
64
+ "module": name,
65
+ "distance": distance,
66
+ })
67
+
68
+ results.sort(key=lambda r: r["distance"], reverse=True)
69
+
70
+ model_a_name = model_a.arch_info.arch_family or "model_a"
71
+ model_b_name = model_b.arch_info.arch_family or "model_b"
72
+ render_diff(results, model_a_name, model_b_name)
73
+
74
+ if save is not None:
75
+ from interpkit.core.plot import plot_diff
76
+
77
+ plot_diff(results, model_a_name=model_a_name, model_b_name=model_b_name, save_path=save)
78
+
79
+ return results
@@ -0,0 +1,14 @@
1
+ """inspect — module tree with types, param counts, output shapes, and detected roles."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from interpkit.core.model import Model
9
+
10
+
11
+ def run_inspect(model: "Model") -> None:
12
+ from interpkit.core.render import render_inspect
13
+
14
+ render_inspect(model.arch_info)