ffca 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ffca/__init__.py +41 -0
- ffca/adapters/__init__.py +16 -0
- ffca/adapters/channel.py +145 -0
- ffca/adapters/pixel.py +66 -0
- ffca/adapters/tabular.py +40 -0
- ffca/adapters/transformer.py +236 -0
- ffca/checkpoint.py +58 -0
- ffca/cli.py +662 -0
- ffca/core/__init__.py +45 -0
- ffca/core/adapter.py +150 -0
- ffca/core/archetypes.py +142 -0
- ffca/core/derivatives.py +152 -0
- ffca/core/scalars.py +110 -0
- ffca/core/signature.py +72 -0
- ffca/core/smoothing.py +162 -0
- ffca/diagnostics.py +495 -0
- ffca/improvements_pkg/__init__.py +6 -0
- ffca/improvements_pkg/cauchy_hvp.py +205 -0
- ffca/improvements_pkg/co_sensitivity.py +193 -0
- ffca/improvements_pkg/trust_score.py +121 -0
- ffca/report.py +437 -0
- ffca/viz/__init__.py +93 -0
- ffca/viz/diagnostics.py +126 -0
- ffca/viz/dynamic.py +148 -0
- ffca/viz/spatial.py +181 -0
- ffca/viz/static.py +147 -0
- ffca-0.1.0a1.dist-info/METADATA +324 -0
- ffca-0.1.0a1.dist-info/RECORD +31 -0
- ffca-0.1.0a1.dist-info/WHEEL +4 -0
- ffca-0.1.0a1.dist-info/entry_points.txt +2 -0
- ffca-0.1.0a1.dist-info/licenses/LICENSE +21 -0
ffca/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""FFCA — Feature-Function Curvature Analysis for any PyTorch model.
|
|
2
|
+
|
|
3
|
+
Quick start (Python):
|
|
4
|
+
|
|
5
|
+
from ffca import FFCAReport
|
|
6
|
+
from ffca.adapters import TabularAdapter
|
|
7
|
+
|
|
8
|
+
adapter = TabularAdapter(model, feature_names=cols)
|
|
9
|
+
report = FFCAReport(adapter, val_loader).run()
|
|
10
|
+
report.save("out/")
|
|
11
|
+
|
|
12
|
+
Or via the CLI: ``ffca-report --help``.
|
|
13
|
+
"""
|
|
14
|
+
from .adapters import (
|
|
15
|
+
ChannelAdapter,
|
|
16
|
+
PixelAdapter,
|
|
17
|
+
TabularAdapter,
|
|
18
|
+
TransformerEmbeddingAdapter,
|
|
19
|
+
TransformerHeadAdapter,
|
|
20
|
+
)
|
|
21
|
+
from .checkpoint import CheckpointLoader
|
|
22
|
+
from .core import FFCAModelAdapter, FFCASignature
|
|
23
|
+
from .improvements_pkg import CauchyHVP, CoSensitivityGroups, TrustScore
|
|
24
|
+
from .report import FFCAReport
|
|
25
|
+
|
|
26
|
+
__version__ = "0.1.0a1"
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"FFCAReport",
|
|
30
|
+
"FFCAModelAdapter",
|
|
31
|
+
"FFCASignature",
|
|
32
|
+
"TabularAdapter",
|
|
33
|
+
"PixelAdapter",
|
|
34
|
+
"ChannelAdapter",
|
|
35
|
+
"TransformerEmbeddingAdapter",
|
|
36
|
+
"TransformerHeadAdapter",
|
|
37
|
+
"CheckpointLoader",
|
|
38
|
+
"CauchyHVP",
|
|
39
|
+
"TrustScore",
|
|
40
|
+
"CoSensitivityGroups",
|
|
41
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Built-in adapters covering the four v0.1.0 model families."""
|
|
2
|
+
from .channel import ChannelAdapter
|
|
3
|
+
from .pixel import PixelAdapter
|
|
4
|
+
from .tabular import TabularAdapter
|
|
5
|
+
from .transformer import (
|
|
6
|
+
TransformerEmbeddingAdapter,
|
|
7
|
+
TransformerHeadAdapter,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"TabularAdapter",
|
|
12
|
+
"PixelAdapter",
|
|
13
|
+
"ChannelAdapter",
|
|
14
|
+
"TransformerEmbeddingAdapter",
|
|
15
|
+
"TransformerHeadAdapter",
|
|
16
|
+
]
|
ffca/adapters/channel.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""ChannelAdapter — channel-level FFCA at any intermediate layer.
|
|
2
|
+
|
|
3
|
+
Splice mechanism:
|
|
4
|
+
1. Resolve the named layer fresh on every forward (smoothing may have
|
|
5
|
+
swapped its identity).
|
|
6
|
+
2. Capture: register a one-shot forward_hook that records the activation
|
|
7
|
+
and unhooks itself.
|
|
8
|
+
3. Replace: register a hook that returns a pre-set leaf tensor as the
|
|
9
|
+
layer's output, so subsequent layers run from that leaf.
|
|
10
|
+
|
|
11
|
+
Both modes resolve the layer by *name* every time, so even if `smooth()`
|
|
12
|
+
replaces ReLU→Softplus mid-pipeline, the hook still attaches to the
|
|
13
|
+
right (current) instance.
|
|
14
|
+
"""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
|
|
21
|
+
from ..core.adapter import FFCAModelAdapter, find_layer
|
|
22
|
+
from ..core.scalars import ScalarFn, predicted_class
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ChannelAdapter(FFCAModelAdapter):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model: nn.Module,
|
|
29
|
+
layer_name: str,
|
|
30
|
+
*,
|
|
31
|
+
scalar: ScalarFn | None = None,
|
|
32
|
+
reduction: str = "spatial_mean",
|
|
33
|
+
):
|
|
34
|
+
super().__init__(model, scalar=scalar or predicted_class())
|
|
35
|
+
self.layer_name = layer_name
|
|
36
|
+
# Verify the layer exists; we re-resolve each call.
|
|
37
|
+
_ = find_layer(model, layer_name)
|
|
38
|
+
self.reduction = reduction
|
|
39
|
+
self._activation_shape: tuple[int, ...] | None = None
|
|
40
|
+
self.feature_names: list[str] | None = None
|
|
41
|
+
self.n_features = -1
|
|
42
|
+
self.feature_shape = ()
|
|
43
|
+
self._raw_input: torch.Tensor | None = None
|
|
44
|
+
|
|
45
|
+
# ---------------------------------------------------------- splice utilities
|
|
46
|
+
def _capture_activation(self, x: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
"""Run a forward pass, capture the named layer's activation, return it."""
|
|
48
|
+
captured: list[torch.Tensor] = []
|
|
49
|
+
layer = find_layer(self.model, self.layer_name)
|
|
50
|
+
|
|
51
|
+
def hook(module, inputs, output):
|
|
52
|
+
captured.append(output.detach())
|
|
53
|
+
|
|
54
|
+
handle = layer.register_forward_hook(hook)
|
|
55
|
+
try:
|
|
56
|
+
with torch.no_grad():
|
|
57
|
+
self.model(x)
|
|
58
|
+
finally:
|
|
59
|
+
handle.remove()
|
|
60
|
+
if not captured:
|
|
61
|
+
raise RuntimeError(
|
|
62
|
+
f"Layer {self.layer_name!r} was not invoked during forward; "
|
|
63
|
+
f"check the layer name is correct for this model's forward path"
|
|
64
|
+
)
|
|
65
|
+
return captured[0]
|
|
66
|
+
|
|
67
|
+
def _forward_with_replacement(self, x: torch.Tensor,
|
|
68
|
+
replacement: torch.Tensor) -> torch.Tensor:
|
|
69
|
+
"""Run a forward pass with the named layer's output replaced."""
|
|
70
|
+
layer = find_layer(self.model, self.layer_name)
|
|
71
|
+
|
|
72
|
+
def hook(module, inputs, output):
|
|
73
|
+
return replacement
|
|
74
|
+
|
|
75
|
+
handle = layer.register_forward_hook(hook)
|
|
76
|
+
try:
|
|
77
|
+
return self.model(x)
|
|
78
|
+
finally:
|
|
79
|
+
handle.remove()
|
|
80
|
+
|
|
81
|
+
# ---------------------------------------------------------- shape probe
|
|
82
|
+
def _probe(self, batch):
|
|
83
|
+
x = batch[0] if isinstance(batch, (list, tuple)) else batch
|
|
84
|
+
x = x.to(device=self.device(), dtype=self.dtype())
|
|
85
|
+
act = self._capture_activation(x)
|
|
86
|
+
self._activation_shape = tuple(act.shape[1:])
|
|
87
|
+
if self.reduction == "spatial_mean":
|
|
88
|
+
C = act.shape[1]
|
|
89
|
+
self.n_features = C
|
|
90
|
+
self.feature_shape = (C,)
|
|
91
|
+
self.feature_names = [f"ch_{i}" for i in range(C)]
|
|
92
|
+
elif self.reduction == "none":
|
|
93
|
+
d = int(np.prod(act.shape[1:]))
|
|
94
|
+
self.n_features = d
|
|
95
|
+
self.feature_shape = tuple(act.shape[1:])
|
|
96
|
+
self.feature_names = [f"unit_{i}" for i in range(d)]
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"unknown reduction {self.reduction!r}")
|
|
99
|
+
|
|
100
|
+
# ---------------------------------------------------------- adapter API
|
|
101
|
+
def feature_input(self, batch) -> torch.Tensor:
|
|
102
|
+
x = batch[0] if isinstance(batch, (list, tuple)) else batch
|
|
103
|
+
x = x.to(device=self.device(), dtype=self.dtype())
|
|
104
|
+
|
|
105
|
+
if self._activation_shape is None:
|
|
106
|
+
self._probe(batch)
|
|
107
|
+
|
|
108
|
+
act = self._capture_activation(x) # full activation tensor
|
|
109
|
+
|
|
110
|
+
if self.reduction == "spatial_mean":
|
|
111
|
+
if act.dim() == 4: # (B, C, H, W)
|
|
112
|
+
feat = act.mean(dim=(2, 3))
|
|
113
|
+
elif act.dim() == 3: # (B, C, L) or (B, T, C)
|
|
114
|
+
feat = act.mean(dim=-1)
|
|
115
|
+
else:
|
|
116
|
+
feat = act
|
|
117
|
+
else:
|
|
118
|
+
feat = act.reshape(act.size(0), -1)
|
|
119
|
+
|
|
120
|
+
leaf = feat.clone().detach().requires_grad_(True)
|
|
121
|
+
self._raw_input = x
|
|
122
|
+
return leaf
|
|
123
|
+
|
|
124
|
+
def scalar_output(self, leaf: torch.Tensor, batch) -> torch.Tensor:
|
|
125
|
+
# Re-inflate the leaf to the original activation shape and run the
|
|
126
|
+
# rest of the model with the replacement hook.
|
|
127
|
+
if self.reduction == "spatial_mean":
|
|
128
|
+
shape = self._activation_shape
|
|
129
|
+
if len(shape) == 3:
|
|
130
|
+
C, H, W = shape
|
|
131
|
+
injected = leaf.view(leaf.size(0), C, 1, 1).expand(-1, C, H, W).contiguous()
|
|
132
|
+
elif len(shape) == 2:
|
|
133
|
+
C, L = shape
|
|
134
|
+
injected = leaf.view(leaf.size(0), C, 1).expand(-1, C, L).contiguous()
|
|
135
|
+
else:
|
|
136
|
+
injected = leaf
|
|
137
|
+
else:
|
|
138
|
+
injected = leaf.reshape(leaf.size(0), *self._activation_shape)
|
|
139
|
+
out = self._forward_with_replacement(self._raw_input, injected)
|
|
140
|
+
return self._scalar(out, batch)
|
|
141
|
+
|
|
142
|
+
def channel_count(self) -> int:
|
|
143
|
+
if self._activation_shape is None:
|
|
144
|
+
raise RuntimeError("ChannelAdapter.feature_input must be called once first")
|
|
145
|
+
return self._activation_shape[0]
|
ffca/adapters/pixel.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""PixelAdapter — input-level FFCA on image/tensor inputs.
|
|
2
|
+
|
|
3
|
+
Use for: any model that consumes a (C, H, W) image and you want pixel-level
|
|
4
|
+
explanations. The feature axis is flattened C·H·W pixels.
|
|
5
|
+
|
|
6
|
+
Adds an `fbr()` helper to compute the Foreground/Background interaction
|
|
7
|
+
Ratio used in shortcut-learning diagnostics (e.g. Waterbirds): proportion
|
|
8
|
+
of total interaction concentrated in a center foreground box vs the
|
|
9
|
+
surrounding background ring. FBR < 0.5 suggests background-shortcut.
|
|
10
|
+
"""
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
|
|
17
|
+
from ..core.adapter import FFCAModelAdapter
|
|
18
|
+
from ..core.scalars import ScalarFn, predicted_class
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PixelAdapter(FFCAModelAdapter):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model: nn.Module,
|
|
25
|
+
*,
|
|
26
|
+
input_shape: tuple[int, int, int], # (C, H, W)
|
|
27
|
+
scalar: ScalarFn | None = None,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(model, scalar=scalar or predicted_class())
|
|
30
|
+
self.feature_shape = tuple(input_shape)
|
|
31
|
+
self.n_features = int(np.prod(input_shape))
|
|
32
|
+
C, H, W = input_shape
|
|
33
|
+
# Don't name every pixel by default (too many); CLI uses indices
|
|
34
|
+
self.feature_names = [f"px_{c}_{h}_{w}" for c in range(C)
|
|
35
|
+
for h in range(H) for w in range(W)]
|
|
36
|
+
|
|
37
|
+
def feature_input(self, batch) -> torch.Tensor:
|
|
38
|
+
x = batch[0] if isinstance(batch, (list, tuple)) else batch
|
|
39
|
+
x = x.to(device=self.device(), dtype=self.dtype())
|
|
40
|
+
return x.clone().detach().requires_grad_(True)
|
|
41
|
+
|
|
42
|
+
def scalar_output(self, x: torch.Tensor, batch) -> torch.Tensor:
|
|
43
|
+
out = self.model(x)
|
|
44
|
+
return self._scalar(out, batch)
|
|
45
|
+
|
|
46
|
+
# --- spatial helpers ---------------------------------------------------
|
|
47
|
+
def reshape_to_image(self, per_pixel_score: np.ndarray) -> np.ndarray:
|
|
48
|
+
"""Map a (d,) per-pixel score back to its (C, H, W) layout."""
|
|
49
|
+
return per_pixel_score.reshape(self.feature_shape)
|
|
50
|
+
|
|
51
|
+
def fbr(self, per_pixel_interaction: np.ndarray, fg_frac: float = 0.5) -> float:
|
|
52
|
+
"""Foreground/Background interaction ratio.
|
|
53
|
+
|
|
54
|
+
fg_frac: side-length fraction defining the center foreground box.
|
|
55
|
+
Returns mean(fg) / (mean(fg) + mean(bg)). Values < 0.5 hint at a
|
|
56
|
+
background shortcut.
|
|
57
|
+
"""
|
|
58
|
+
C, H, W = self.feature_shape
|
|
59
|
+
img = per_pixel_interaction.reshape(self.feature_shape)
|
|
60
|
+
py = int(H * (1 - fg_frac) / 2)
|
|
61
|
+
px = int(W * (1 - fg_frac) / 2)
|
|
62
|
+
fg = img[:, py:H - py, px:W - px].mean()
|
|
63
|
+
bg_sum = img.sum() - img[:, py:H - py, px:W - px].sum()
|
|
64
|
+
bg_count = img.size - img[:, py:H - py, px:W - px].size
|
|
65
|
+
bg = bg_sum / max(bg_count, 1)
|
|
66
|
+
return float(fg / (fg + bg)) if (fg + bg) > 0 else float("nan")
|
ffca/adapters/tabular.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""TabularAdapter — input-level FFCA on dense feature vectors.
|
|
2
|
+
|
|
3
|
+
Use for: MLPs, tabular transformers, scikit-style wrappers over an MLP.
|
|
4
|
+
The feature axis is the d input columns.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from ..core.adapter import FFCAModelAdapter
|
|
12
|
+
from ..core.scalars import ScalarFn, predicted_class
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TabularAdapter(FFCAModelAdapter):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model: nn.Module,
|
|
19
|
+
*,
|
|
20
|
+
feature_names: list[str] | None = None,
|
|
21
|
+
n_features: int | None = None,
|
|
22
|
+
scalar: ScalarFn | None = None,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(model, scalar=scalar or predicted_class())
|
|
25
|
+
if feature_names is None and n_features is None:
|
|
26
|
+
raise ValueError("provide feature_names or n_features")
|
|
27
|
+
self.feature_names = feature_names
|
|
28
|
+
self.n_features = n_features if n_features is not None else len(feature_names)
|
|
29
|
+
self.feature_shape = (self.n_features,)
|
|
30
|
+
if feature_names is None:
|
|
31
|
+
self.feature_names = [f"feature_{i}" for i in range(self.n_features)]
|
|
32
|
+
|
|
33
|
+
def feature_input(self, batch) -> torch.Tensor:
|
|
34
|
+
x = batch[0] if isinstance(batch, (list, tuple)) else batch
|
|
35
|
+
x = x.to(device=self.device(), dtype=self.dtype())
|
|
36
|
+
return x.clone().detach().requires_grad_(True)
|
|
37
|
+
|
|
38
|
+
def scalar_output(self, x: torch.Tensor, batch) -> torch.Tensor:
|
|
39
|
+
out = self.model(x)
|
|
40
|
+
return self._scalar(out, batch)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Transformer adapters for FFCA — embedding-level and attention-head.
|
|
2
|
+
|
|
3
|
+
Both adapters work on any HuggingFace `AutoModel` / `AutoModelForCausalLM`
|
|
4
|
+
without modification. They auto-detect the input-embedding attribute and
|
|
5
|
+
the per-layer attention block by walking common name conventions:
|
|
6
|
+
|
|
7
|
+
GPT-2 / DistilGPT2 : model.transformer.wte (embeddings)
|
|
8
|
+
model.transformer.h[L].attn (attention block)
|
|
9
|
+
BERT / DistilBERT : model.bert.embeddings.word_embeddings
|
|
10
|
+
model.bert.encoder.layer[L].attention.self
|
|
11
|
+
LLaMA / Mistral : model.model.embed_tokens
|
|
12
|
+
model.model.layers[L].self_attn
|
|
13
|
+
Generic fallback : the first nn.Embedding found in the model
|
|
14
|
+
|
|
15
|
+
If auto-detection fails, pass `embedding_module=` / `attention_layer=` by
|
|
16
|
+
hand. Both adapters use the splice trick (forward_hook re-resolved by name
|
|
17
|
+
each call) so the package's automatic ReLU→Softplus / MaxPool→AvgPool /
|
|
18
|
+
flash-SDP→math-SDPA smoothing does not break them.
|
|
19
|
+
"""
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
|
|
26
|
+
from ..core.adapter import FFCAModelAdapter
|
|
27
|
+
from ..core.scalars import ScalarFn, predicted_class
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ---------------------------------------------------------------- locator
|
|
31
|
+
def _find_embedding(model: nn.Module) -> nn.Embedding:
|
|
32
|
+
"""Return the input-token embedding module of an HF transformer."""
|
|
33
|
+
for path in (
|
|
34
|
+
"transformer.wte", # GPT-2 family
|
|
35
|
+
"model.embed_tokens", # LLaMA, Mistral
|
|
36
|
+
"model.decoder.embed_tokens", # OPT, BART
|
|
37
|
+
"bert.embeddings.word_embeddings",
|
|
38
|
+
"embeddings.word_embeddings",
|
|
39
|
+
"shared", # T5
|
|
40
|
+
):
|
|
41
|
+
obj = model
|
|
42
|
+
ok = True
|
|
43
|
+
for part in path.split("."):
|
|
44
|
+
if not hasattr(obj, part):
|
|
45
|
+
ok = False; break
|
|
46
|
+
obj = getattr(obj, part)
|
|
47
|
+
if ok and isinstance(obj, nn.Embedding):
|
|
48
|
+
return obj
|
|
49
|
+
# Fallback: first nn.Embedding in the module tree
|
|
50
|
+
for m in model.modules():
|
|
51
|
+
if isinstance(m, nn.Embedding):
|
|
52
|
+
return m
|
|
53
|
+
raise AttributeError("No nn.Embedding found in the model")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _find_attention_layer(model: nn.Module, layer_idx: int = -1) -> nn.Module:
|
|
57
|
+
"""Return the attention module at the requested encoder layer index."""
|
|
58
|
+
for path in (
|
|
59
|
+
("transformer.h", "attn"), # GPT-2
|
|
60
|
+
("model.layers", "self_attn"), # LLaMA
|
|
61
|
+
("bert.encoder.layer", "attention.self"),
|
|
62
|
+
("encoder.layer", "attention.self"),
|
|
63
|
+
("model.encoder.layers", "self_attn"),
|
|
64
|
+
("model.decoder.layers", "self_attn"),
|
|
65
|
+
):
|
|
66
|
+
block_path, attr_path = path
|
|
67
|
+
obj = model
|
|
68
|
+
ok = True
|
|
69
|
+
for part in block_path.split("."):
|
|
70
|
+
if not hasattr(obj, part):
|
|
71
|
+
ok = False; break
|
|
72
|
+
obj = getattr(obj, part)
|
|
73
|
+
if not ok or not hasattr(obj, "__getitem__"):
|
|
74
|
+
continue
|
|
75
|
+
try:
|
|
76
|
+
block = obj[layer_idx]
|
|
77
|
+
except (IndexError, TypeError):
|
|
78
|
+
continue
|
|
79
|
+
attr_obj = block
|
|
80
|
+
for p in attr_path.split("."):
|
|
81
|
+
if not hasattr(attr_obj, p):
|
|
82
|
+
attr_obj = None; break
|
|
83
|
+
attr_obj = getattr(attr_obj, p)
|
|
84
|
+
if attr_obj is not None and isinstance(attr_obj, nn.Module):
|
|
85
|
+
return attr_obj
|
|
86
|
+
raise AttributeError(
|
|
87
|
+
f"Could not auto-locate attention block at layer index {layer_idx}; "
|
|
88
|
+
f"pass attention_layer= explicitly"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _last_token_logit_scalar(out: torch.Tensor, batch=None) -> torch.Tensor:
|
|
93
|
+
"""Default LLM scalar: max logit at the last sequence position."""
|
|
94
|
+
if hasattr(out, "logits"):
|
|
95
|
+
last = out.logits[:, -1, :]
|
|
96
|
+
elif hasattr(out, "last_hidden_state"):
|
|
97
|
+
last = out.last_hidden_state[:, -1, :]
|
|
98
|
+
else:
|
|
99
|
+
last = out
|
|
100
|
+
return last.max(dim=-1).values.sum()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# ---------------------------------------------------------------- embedding adapter
|
|
104
|
+
class TransformerEmbeddingAdapter(FFCAModelAdapter):
|
|
105
|
+
"""Treat token × hidden-dim input embeddings as the feature axis.
|
|
106
|
+
|
|
107
|
+
For a sequence of length T and hidden size H, this yields T·H features.
|
|
108
|
+
Suitable for understanding which token position / which hidden dim drives
|
|
109
|
+
the model's prediction.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
model : HF model (AutoModel / AutoModelForCausalLM)
|
|
113
|
+
seq_len, hidden : sequence length and hidden size of the input
|
|
114
|
+
scalar : optional ScalarFn; defaults to last-token max-logit
|
|
115
|
+
embedding_module: optional override for the nn.Embedding to differentiate
|
|
116
|
+
"""
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
model: nn.Module,
|
|
120
|
+
seq_len: int,
|
|
121
|
+
hidden: int,
|
|
122
|
+
*,
|
|
123
|
+
scalar: ScalarFn | None = None,
|
|
124
|
+
embedding_module: nn.Embedding | None = None,
|
|
125
|
+
):
|
|
126
|
+
super().__init__(model, scalar=scalar or _last_token_logit_scalar)
|
|
127
|
+
self.seq_len = seq_len
|
|
128
|
+
self.hidden = hidden
|
|
129
|
+
self.n_features = seq_len * hidden
|
|
130
|
+
self.feature_shape = (seq_len, hidden)
|
|
131
|
+
self.feature_names = [f"t{t}_h{h}" for t in range(seq_len)
|
|
132
|
+
for h in range(hidden)]
|
|
133
|
+
self._emb = embedding_module or _find_embedding(model)
|
|
134
|
+
|
|
135
|
+
def feature_input(self, batch) -> torch.Tensor:
|
|
136
|
+
ids = batch["input_ids"] if isinstance(batch, dict) else batch[0]
|
|
137
|
+
ids = ids.to(self.device())
|
|
138
|
+
with torch.no_grad():
|
|
139
|
+
embs = self._emb(ids)
|
|
140
|
+
return embs.clone().detach().requires_grad_(True)
|
|
141
|
+
|
|
142
|
+
def scalar_output(self, embs: torch.Tensor, batch) -> torch.Tensor:
|
|
143
|
+
kw = {}
|
|
144
|
+
if isinstance(batch, dict):
|
|
145
|
+
if "attention_mask" in batch:
|
|
146
|
+
kw["attention_mask"] = batch["attention_mask"].to(self.device())
|
|
147
|
+
out = self.model(inputs_embeds=embs, **kw)
|
|
148
|
+
return self._scalar(out, batch)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ---------------------------------------------------------------- head adapter
|
|
152
|
+
class TransformerHeadAdapter(FFCAModelAdapter):
|
|
153
|
+
"""Per-attention-head pooled activations from a chosen encoder layer.
|
|
154
|
+
|
|
155
|
+
For an attention block with H heads of dim D, the feature axis is H·D.
|
|
156
|
+
The mean-pool over tokens turns the (B, T, H·D) attention output into a
|
|
157
|
+
(B, H, D) feature tensor that FFCA can differentiate.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
model : HF model
|
|
161
|
+
n_heads, head_dim: from the model's config
|
|
162
|
+
layer_idx : which layer (negative indexes from the end; -1 = last)
|
|
163
|
+
attention_layer : optional override for the attention module
|
|
164
|
+
scalar : optional ScalarFn
|
|
165
|
+
"""
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
model: nn.Module,
|
|
169
|
+
n_heads: int,
|
|
170
|
+
head_dim: int,
|
|
171
|
+
*,
|
|
172
|
+
layer_idx: int = -1,
|
|
173
|
+
attention_layer: nn.Module | None = None,
|
|
174
|
+
scalar: ScalarFn | None = None,
|
|
175
|
+
):
|
|
176
|
+
super().__init__(model, scalar=scalar or _last_token_logit_scalar)
|
|
177
|
+
self.n_heads = n_heads
|
|
178
|
+
self.head_dim = head_dim
|
|
179
|
+
self.layer_idx = layer_idx
|
|
180
|
+
self.n_features = n_heads * head_dim
|
|
181
|
+
self.feature_shape = (n_heads, head_dim)
|
|
182
|
+
self.feature_names = [f"h{h}_d{d}" for h in range(n_heads)
|
|
183
|
+
for d in range(head_dim)]
|
|
184
|
+
self._attn = attention_layer or _find_attention_layer(model, layer_idx)
|
|
185
|
+
self._batch_ids: torch.Tensor | None = None
|
|
186
|
+
self._batch_attn: torch.Tensor | None = None
|
|
187
|
+
|
|
188
|
+
def feature_input(self, batch) -> torch.Tensor:
|
|
189
|
+
ids = batch["input_ids"] if isinstance(batch, dict) else batch[0]
|
|
190
|
+
attn_mask = (batch.get("attention_mask")
|
|
191
|
+
if isinstance(batch, dict) else None)
|
|
192
|
+
ids = ids.to(self.device())
|
|
193
|
+
if attn_mask is not None:
|
|
194
|
+
attn_mask = attn_mask.to(self.device())
|
|
195
|
+
|
|
196
|
+
captured: list[torch.Tensor] = []
|
|
197
|
+
def hook(module, inputs, output):
|
|
198
|
+
t = output[0] if isinstance(output, tuple) else output
|
|
199
|
+
captured.append(t.detach())
|
|
200
|
+
h = self._attn.register_forward_hook(hook)
|
|
201
|
+
try:
|
|
202
|
+
with torch.no_grad():
|
|
203
|
+
kw = {"input_ids": ids}
|
|
204
|
+
if attn_mask is not None:
|
|
205
|
+
kw["attention_mask"] = attn_mask
|
|
206
|
+
self.model(**kw)
|
|
207
|
+
finally:
|
|
208
|
+
h.remove()
|
|
209
|
+
if not captured:
|
|
210
|
+
raise RuntimeError("attention hook never fired; "
|
|
211
|
+
"check `attention_layer` is on the forward path")
|
|
212
|
+
c = captured[0]
|
|
213
|
+
# Pool over tokens, shape → (B, n_heads, head_dim)
|
|
214
|
+
flat = c.mean(dim=1).reshape(c.size(0), self.n_heads, self.head_dim)
|
|
215
|
+
leaf = flat.clone().detach().requires_grad_(True)
|
|
216
|
+
self._batch_ids = ids
|
|
217
|
+
self._batch_attn = attn_mask
|
|
218
|
+
return leaf
|
|
219
|
+
|
|
220
|
+
def scalar_output(self, leaf: torch.Tensor, batch) -> torch.Tensor:
|
|
221
|
+
T = self._batch_ids.size(1)
|
|
222
|
+
injected = leaf.reshape(leaf.size(0), 1, self.n_heads * self.head_dim) \
|
|
223
|
+
.expand(-1, T, -1).contiguous()
|
|
224
|
+
def hook(module, inputs, output):
|
|
225
|
+
if isinstance(output, tuple):
|
|
226
|
+
return (injected,) + output[1:]
|
|
227
|
+
return injected
|
|
228
|
+
handle = self._attn.register_forward_hook(hook)
|
|
229
|
+
try:
|
|
230
|
+
kw = {"input_ids": self._batch_ids}
|
|
231
|
+
if self._batch_attn is not None:
|
|
232
|
+
kw["attention_mask"] = self._batch_attn
|
|
233
|
+
out = self.model(**kw)
|
|
234
|
+
finally:
|
|
235
|
+
handle.remove()
|
|
236
|
+
return self._scalar(out, batch)
|
ffca/checkpoint.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""CheckpointLoader — yield a fresh adapter for each saved state_dict.
|
|
2
|
+
|
|
3
|
+
For v0.1.0 we support plain `torch.save(model.state_dict(), path)` files.
|
|
4
|
+
Lightning / HF / SafeTensors formats can be added with detect-by-extension
|
|
5
|
+
in v0.2.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Callable, Iterator, Sequence
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CheckpointLoader:
|
|
17
|
+
"""Iterate (epoch_label, model) over a list of saved state_dicts.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model_factory: () -> fresh nn.Module instance (same architecture as
|
|
21
|
+
what was saved).
|
|
22
|
+
checkpoints: list of paths (or (label, path) pairs).
|
|
23
|
+
device: target device for the loaded model.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model_factory: Callable[[], nn.Module],
|
|
29
|
+
checkpoints: Sequence[str | tuple[str, str]],
|
|
30
|
+
device: str | torch.device = "cpu",
|
|
31
|
+
):
|
|
32
|
+
self.model_factory = model_factory
|
|
33
|
+
self.checkpoints = [
|
|
34
|
+
(Path(c).stem, Path(c)) if isinstance(c, str) else (str(c[0]), Path(c[1]))
|
|
35
|
+
for c in checkpoints
|
|
36
|
+
]
|
|
37
|
+
self.device = torch.device(device)
|
|
38
|
+
|
|
39
|
+
def __len__(self) -> int:
|
|
40
|
+
return len(self.checkpoints)
|
|
41
|
+
|
|
42
|
+
def __iter__(self) -> Iterator[tuple[str, nn.Module]]:
|
|
43
|
+
for label, path in self.checkpoints:
|
|
44
|
+
model = self.model_factory()
|
|
45
|
+
state = torch.load(path, map_location=self.device, weights_only=False)
|
|
46
|
+
# Accept either {state_dict} or just a state_dict
|
|
47
|
+
if isinstance(state, dict) and "state_dict" in state:
|
|
48
|
+
state = state["state_dict"]
|
|
49
|
+
try:
|
|
50
|
+
model.load_state_dict(state, strict=False)
|
|
51
|
+
except Exception as e:
|
|
52
|
+
raise RuntimeError(
|
|
53
|
+
f"Failed to load checkpoint {path}: {e}. "
|
|
54
|
+
f"Ensure model_factory() returns the same architecture."
|
|
55
|
+
)
|
|
56
|
+
model.to(self.device)
|
|
57
|
+
model.eval()
|
|
58
|
+
yield label, model
|