ffca 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ffca/__init__.py ADDED
@@ -0,0 +1,41 @@
1
+ """FFCA — Feature-Function Curvature Analysis for any PyTorch model.
2
+
3
+ Quick start (Python):
4
+
5
+ from ffca import FFCAReport
6
+ from ffca.adapters import TabularAdapter
7
+
8
+ adapter = TabularAdapter(model, feature_names=cols)
9
+ report = FFCAReport(adapter, val_loader).run()
10
+ report.save("out/")
11
+
12
+ Or via the CLI: ``ffca-report --help``.
13
+ """
14
+ from .adapters import (
15
+ ChannelAdapter,
16
+ PixelAdapter,
17
+ TabularAdapter,
18
+ TransformerEmbeddingAdapter,
19
+ TransformerHeadAdapter,
20
+ )
21
+ from .checkpoint import CheckpointLoader
22
+ from .core import FFCAModelAdapter, FFCASignature
23
+ from .improvements_pkg import CauchyHVP, CoSensitivityGroups, TrustScore
24
+ from .report import FFCAReport
25
+
26
+ __version__ = "0.1.0a1"
27
+
28
+ __all__ = [
29
+ "FFCAReport",
30
+ "FFCAModelAdapter",
31
+ "FFCASignature",
32
+ "TabularAdapter",
33
+ "PixelAdapter",
34
+ "ChannelAdapter",
35
+ "TransformerEmbeddingAdapter",
36
+ "TransformerHeadAdapter",
37
+ "CheckpointLoader",
38
+ "CauchyHVP",
39
+ "TrustScore",
40
+ "CoSensitivityGroups",
41
+ ]
@@ -0,0 +1,16 @@
1
+ """Built-in adapters covering the four v0.1.0 model families."""
2
+ from .channel import ChannelAdapter
3
+ from .pixel import PixelAdapter
4
+ from .tabular import TabularAdapter
5
+ from .transformer import (
6
+ TransformerEmbeddingAdapter,
7
+ TransformerHeadAdapter,
8
+ )
9
+
10
+ __all__ = [
11
+ "TabularAdapter",
12
+ "PixelAdapter",
13
+ "ChannelAdapter",
14
+ "TransformerEmbeddingAdapter",
15
+ "TransformerHeadAdapter",
16
+ ]
@@ -0,0 +1,145 @@
1
+ """ChannelAdapter — channel-level FFCA at any intermediate layer.
2
+
3
+ Splice mechanism:
4
+ 1. Resolve the named layer fresh on every forward (smoothing may have
5
+ swapped its identity).
6
+ 2. Capture: register a one-shot forward_hook that records the activation
7
+ and unhooks itself.
8
+ 3. Replace: register a hook that returns a pre-set leaf tensor as the
9
+ layer's output, so subsequent layers run from that leaf.
10
+
11
+ Both modes resolve the layer by *name* every time, so even if `smooth()`
12
+ replaces ReLU→Softplus mid-pipeline, the hook still attaches to the
13
+ right (current) instance.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ..core.adapter import FFCAModelAdapter, find_layer
22
+ from ..core.scalars import ScalarFn, predicted_class
23
+
24
+
25
+ class ChannelAdapter(FFCAModelAdapter):
26
+ def __init__(
27
+ self,
28
+ model: nn.Module,
29
+ layer_name: str,
30
+ *,
31
+ scalar: ScalarFn | None = None,
32
+ reduction: str = "spatial_mean",
33
+ ):
34
+ super().__init__(model, scalar=scalar or predicted_class())
35
+ self.layer_name = layer_name
36
+ # Verify the layer exists; we re-resolve each call.
37
+ _ = find_layer(model, layer_name)
38
+ self.reduction = reduction
39
+ self._activation_shape: tuple[int, ...] | None = None
40
+ self.feature_names: list[str] | None = None
41
+ self.n_features = -1
42
+ self.feature_shape = ()
43
+ self._raw_input: torch.Tensor | None = None
44
+
45
+ # ---------------------------------------------------------- splice utilities
46
+ def _capture_activation(self, x: torch.Tensor) -> torch.Tensor:
47
+ """Run a forward pass, capture the named layer's activation, return it."""
48
+ captured: list[torch.Tensor] = []
49
+ layer = find_layer(self.model, self.layer_name)
50
+
51
+ def hook(module, inputs, output):
52
+ captured.append(output.detach())
53
+
54
+ handle = layer.register_forward_hook(hook)
55
+ try:
56
+ with torch.no_grad():
57
+ self.model(x)
58
+ finally:
59
+ handle.remove()
60
+ if not captured:
61
+ raise RuntimeError(
62
+ f"Layer {self.layer_name!r} was not invoked during forward; "
63
+ f"check the layer name is correct for this model's forward path"
64
+ )
65
+ return captured[0]
66
+
67
+ def _forward_with_replacement(self, x: torch.Tensor,
68
+ replacement: torch.Tensor) -> torch.Tensor:
69
+ """Run a forward pass with the named layer's output replaced."""
70
+ layer = find_layer(self.model, self.layer_name)
71
+
72
+ def hook(module, inputs, output):
73
+ return replacement
74
+
75
+ handle = layer.register_forward_hook(hook)
76
+ try:
77
+ return self.model(x)
78
+ finally:
79
+ handle.remove()
80
+
81
+ # ---------------------------------------------------------- shape probe
82
+ def _probe(self, batch):
83
+ x = batch[0] if isinstance(batch, (list, tuple)) else batch
84
+ x = x.to(device=self.device(), dtype=self.dtype())
85
+ act = self._capture_activation(x)
86
+ self._activation_shape = tuple(act.shape[1:])
87
+ if self.reduction == "spatial_mean":
88
+ C = act.shape[1]
89
+ self.n_features = C
90
+ self.feature_shape = (C,)
91
+ self.feature_names = [f"ch_{i}" for i in range(C)]
92
+ elif self.reduction == "none":
93
+ d = int(np.prod(act.shape[1:]))
94
+ self.n_features = d
95
+ self.feature_shape = tuple(act.shape[1:])
96
+ self.feature_names = [f"unit_{i}" for i in range(d)]
97
+ else:
98
+ raise ValueError(f"unknown reduction {self.reduction!r}")
99
+
100
+ # ---------------------------------------------------------- adapter API
101
+ def feature_input(self, batch) -> torch.Tensor:
102
+ x = batch[0] if isinstance(batch, (list, tuple)) else batch
103
+ x = x.to(device=self.device(), dtype=self.dtype())
104
+
105
+ if self._activation_shape is None:
106
+ self._probe(batch)
107
+
108
+ act = self._capture_activation(x) # full activation tensor
109
+
110
+ if self.reduction == "spatial_mean":
111
+ if act.dim() == 4: # (B, C, H, W)
112
+ feat = act.mean(dim=(2, 3))
113
+ elif act.dim() == 3: # (B, C, L) or (B, T, C)
114
+ feat = act.mean(dim=-1)
115
+ else:
116
+ feat = act
117
+ else:
118
+ feat = act.reshape(act.size(0), -1)
119
+
120
+ leaf = feat.clone().detach().requires_grad_(True)
121
+ self._raw_input = x
122
+ return leaf
123
+
124
+ def scalar_output(self, leaf: torch.Tensor, batch) -> torch.Tensor:
125
+ # Re-inflate the leaf to the original activation shape and run the
126
+ # rest of the model with the replacement hook.
127
+ if self.reduction == "spatial_mean":
128
+ shape = self._activation_shape
129
+ if len(shape) == 3:
130
+ C, H, W = shape
131
+ injected = leaf.view(leaf.size(0), C, 1, 1).expand(-1, C, H, W).contiguous()
132
+ elif len(shape) == 2:
133
+ C, L = shape
134
+ injected = leaf.view(leaf.size(0), C, 1).expand(-1, C, L).contiguous()
135
+ else:
136
+ injected = leaf
137
+ else:
138
+ injected = leaf.reshape(leaf.size(0), *self._activation_shape)
139
+ out = self._forward_with_replacement(self._raw_input, injected)
140
+ return self._scalar(out, batch)
141
+
142
+ def channel_count(self) -> int:
143
+ if self._activation_shape is None:
144
+ raise RuntimeError("ChannelAdapter.feature_input must be called once first")
145
+ return self._activation_shape[0]
ffca/adapters/pixel.py ADDED
@@ -0,0 +1,66 @@
1
+ """PixelAdapter — input-level FFCA on image/tensor inputs.
2
+
3
+ Use for: any model that consumes a (C, H, W) image and you want pixel-level
4
+ explanations. The feature axis is flattened C·H·W pixels.
5
+
6
+ Adds an `fbr()` helper to compute the Foreground/Background interaction
7
+ Ratio used in shortcut-learning diagnostics (e.g. Waterbirds): proportion
8
+ of total interaction concentrated in a center foreground box vs the
9
+ surrounding background ring. FBR < 0.5 suggests background-shortcut.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from ..core.adapter import FFCAModelAdapter
18
+ from ..core.scalars import ScalarFn, predicted_class
19
+
20
+
21
+ class PixelAdapter(FFCAModelAdapter):
22
+ def __init__(
23
+ self,
24
+ model: nn.Module,
25
+ *,
26
+ input_shape: tuple[int, int, int], # (C, H, W)
27
+ scalar: ScalarFn | None = None,
28
+ ):
29
+ super().__init__(model, scalar=scalar or predicted_class())
30
+ self.feature_shape = tuple(input_shape)
31
+ self.n_features = int(np.prod(input_shape))
32
+ C, H, W = input_shape
33
+ # Don't name every pixel by default (too many); CLI uses indices
34
+ self.feature_names = [f"px_{c}_{h}_{w}" for c in range(C)
35
+ for h in range(H) for w in range(W)]
36
+
37
+ def feature_input(self, batch) -> torch.Tensor:
38
+ x = batch[0] if isinstance(batch, (list, tuple)) else batch
39
+ x = x.to(device=self.device(), dtype=self.dtype())
40
+ return x.clone().detach().requires_grad_(True)
41
+
42
+ def scalar_output(self, x: torch.Tensor, batch) -> torch.Tensor:
43
+ out = self.model(x)
44
+ return self._scalar(out, batch)
45
+
46
+ # --- spatial helpers ---------------------------------------------------
47
+ def reshape_to_image(self, per_pixel_score: np.ndarray) -> np.ndarray:
48
+ """Map a (d,) per-pixel score back to its (C, H, W) layout."""
49
+ return per_pixel_score.reshape(self.feature_shape)
50
+
51
+ def fbr(self, per_pixel_interaction: np.ndarray, fg_frac: float = 0.5) -> float:
52
+ """Foreground/Background interaction ratio.
53
+
54
+ fg_frac: side-length fraction defining the center foreground box.
55
+ Returns mean(fg) / (mean(fg) + mean(bg)). Values < 0.5 hint at a
56
+ background shortcut.
57
+ """
58
+ C, H, W = self.feature_shape
59
+ img = per_pixel_interaction.reshape(self.feature_shape)
60
+ py = int(H * (1 - fg_frac) / 2)
61
+ px = int(W * (1 - fg_frac) / 2)
62
+ fg = img[:, py:H - py, px:W - px].mean()
63
+ bg_sum = img.sum() - img[:, py:H - py, px:W - px].sum()
64
+ bg_count = img.size - img[:, py:H - py, px:W - px].size
65
+ bg = bg_sum / max(bg_count, 1)
66
+ return float(fg / (fg + bg)) if (fg + bg) > 0 else float("nan")
@@ -0,0 +1,40 @@
1
+ """TabularAdapter — input-level FFCA on dense feature vectors.
2
+
3
+ Use for: MLPs, tabular transformers, scikit-style wrappers over an MLP.
4
+ The feature axis is the d input columns.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from ..core.adapter import FFCAModelAdapter
12
+ from ..core.scalars import ScalarFn, predicted_class
13
+
14
+
15
+ class TabularAdapter(FFCAModelAdapter):
16
+ def __init__(
17
+ self,
18
+ model: nn.Module,
19
+ *,
20
+ feature_names: list[str] | None = None,
21
+ n_features: int | None = None,
22
+ scalar: ScalarFn | None = None,
23
+ ):
24
+ super().__init__(model, scalar=scalar or predicted_class())
25
+ if feature_names is None and n_features is None:
26
+ raise ValueError("provide feature_names or n_features")
27
+ self.feature_names = feature_names
28
+ self.n_features = n_features if n_features is not None else len(feature_names)
29
+ self.feature_shape = (self.n_features,)
30
+ if feature_names is None:
31
+ self.feature_names = [f"feature_{i}" for i in range(self.n_features)]
32
+
33
+ def feature_input(self, batch) -> torch.Tensor:
34
+ x = batch[0] if isinstance(batch, (list, tuple)) else batch
35
+ x = x.to(device=self.device(), dtype=self.dtype())
36
+ return x.clone().detach().requires_grad_(True)
37
+
38
+ def scalar_output(self, x: torch.Tensor, batch) -> torch.Tensor:
39
+ out = self.model(x)
40
+ return self._scalar(out, batch)
@@ -0,0 +1,236 @@
1
+ """Transformer adapters for FFCA — embedding-level and attention-head.
2
+
3
+ Both adapters work on any HuggingFace `AutoModel` / `AutoModelForCausalLM`
4
+ without modification. They auto-detect the input-embedding attribute and
5
+ the per-layer attention block by walking common name conventions:
6
+
7
+ GPT-2 / DistilGPT2 : model.transformer.wte (embeddings)
8
+ model.transformer.h[L].attn (attention block)
9
+ BERT / DistilBERT : model.bert.embeddings.word_embeddings
10
+ model.bert.encoder.layer[L].attention.self
11
+ LLaMA / Mistral : model.model.embed_tokens
12
+ model.model.layers[L].self_attn
13
+ Generic fallback : the first nn.Embedding found in the model
14
+
15
+ If auto-detection fails, pass `embedding_module=` / `attention_layer=` by
16
+ hand. Both adapters use the splice trick (forward_hook re-resolved by name
17
+ each call) so the package's automatic ReLU→Softplus / MaxPool→AvgPool /
18
+ flash-SDP→math-SDPA smoothing does not break them.
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from ..core.adapter import FFCAModelAdapter
27
+ from ..core.scalars import ScalarFn, predicted_class
28
+
29
+
30
+ # ---------------------------------------------------------------- locator
31
+ def _find_embedding(model: nn.Module) -> nn.Embedding:
32
+ """Return the input-token embedding module of an HF transformer."""
33
+ for path in (
34
+ "transformer.wte", # GPT-2 family
35
+ "model.embed_tokens", # LLaMA, Mistral
36
+ "model.decoder.embed_tokens", # OPT, BART
37
+ "bert.embeddings.word_embeddings",
38
+ "embeddings.word_embeddings",
39
+ "shared", # T5
40
+ ):
41
+ obj = model
42
+ ok = True
43
+ for part in path.split("."):
44
+ if not hasattr(obj, part):
45
+ ok = False; break
46
+ obj = getattr(obj, part)
47
+ if ok and isinstance(obj, nn.Embedding):
48
+ return obj
49
+ # Fallback: first nn.Embedding in the module tree
50
+ for m in model.modules():
51
+ if isinstance(m, nn.Embedding):
52
+ return m
53
+ raise AttributeError("No nn.Embedding found in the model")
54
+
55
+
56
+ def _find_attention_layer(model: nn.Module, layer_idx: int = -1) -> nn.Module:
57
+ """Return the attention module at the requested encoder layer index."""
58
+ for path in (
59
+ ("transformer.h", "attn"), # GPT-2
60
+ ("model.layers", "self_attn"), # LLaMA
61
+ ("bert.encoder.layer", "attention.self"),
62
+ ("encoder.layer", "attention.self"),
63
+ ("model.encoder.layers", "self_attn"),
64
+ ("model.decoder.layers", "self_attn"),
65
+ ):
66
+ block_path, attr_path = path
67
+ obj = model
68
+ ok = True
69
+ for part in block_path.split("."):
70
+ if not hasattr(obj, part):
71
+ ok = False; break
72
+ obj = getattr(obj, part)
73
+ if not ok or not hasattr(obj, "__getitem__"):
74
+ continue
75
+ try:
76
+ block = obj[layer_idx]
77
+ except (IndexError, TypeError):
78
+ continue
79
+ attr_obj = block
80
+ for p in attr_path.split("."):
81
+ if not hasattr(attr_obj, p):
82
+ attr_obj = None; break
83
+ attr_obj = getattr(attr_obj, p)
84
+ if attr_obj is not None and isinstance(attr_obj, nn.Module):
85
+ return attr_obj
86
+ raise AttributeError(
87
+ f"Could not auto-locate attention block at layer index {layer_idx}; "
88
+ f"pass attention_layer= explicitly"
89
+ )
90
+
91
+
92
+ def _last_token_logit_scalar(out: torch.Tensor, batch=None) -> torch.Tensor:
93
+ """Default LLM scalar: max logit at the last sequence position."""
94
+ if hasattr(out, "logits"):
95
+ last = out.logits[:, -1, :]
96
+ elif hasattr(out, "last_hidden_state"):
97
+ last = out.last_hidden_state[:, -1, :]
98
+ else:
99
+ last = out
100
+ return last.max(dim=-1).values.sum()
101
+
102
+
103
+ # ---------------------------------------------------------------- embedding adapter
104
+ class TransformerEmbeddingAdapter(FFCAModelAdapter):
105
+ """Treat token × hidden-dim input embeddings as the feature axis.
106
+
107
+ For a sequence of length T and hidden size H, this yields T·H features.
108
+ Suitable for understanding which token position / which hidden dim drives
109
+ the model's prediction.
110
+
111
+ Args:
112
+ model : HF model (AutoModel / AutoModelForCausalLM)
113
+ seq_len, hidden : sequence length and hidden size of the input
114
+ scalar : optional ScalarFn; defaults to last-token max-logit
115
+ embedding_module: optional override for the nn.Embedding to differentiate
116
+ """
117
+ def __init__(
118
+ self,
119
+ model: nn.Module,
120
+ seq_len: int,
121
+ hidden: int,
122
+ *,
123
+ scalar: ScalarFn | None = None,
124
+ embedding_module: nn.Embedding | None = None,
125
+ ):
126
+ super().__init__(model, scalar=scalar or _last_token_logit_scalar)
127
+ self.seq_len = seq_len
128
+ self.hidden = hidden
129
+ self.n_features = seq_len * hidden
130
+ self.feature_shape = (seq_len, hidden)
131
+ self.feature_names = [f"t{t}_h{h}" for t in range(seq_len)
132
+ for h in range(hidden)]
133
+ self._emb = embedding_module or _find_embedding(model)
134
+
135
+ def feature_input(self, batch) -> torch.Tensor:
136
+ ids = batch["input_ids"] if isinstance(batch, dict) else batch[0]
137
+ ids = ids.to(self.device())
138
+ with torch.no_grad():
139
+ embs = self._emb(ids)
140
+ return embs.clone().detach().requires_grad_(True)
141
+
142
+ def scalar_output(self, embs: torch.Tensor, batch) -> torch.Tensor:
143
+ kw = {}
144
+ if isinstance(batch, dict):
145
+ if "attention_mask" in batch:
146
+ kw["attention_mask"] = batch["attention_mask"].to(self.device())
147
+ out = self.model(inputs_embeds=embs, **kw)
148
+ return self._scalar(out, batch)
149
+
150
+
151
+ # ---------------------------------------------------------------- head adapter
152
+ class TransformerHeadAdapter(FFCAModelAdapter):
153
+ """Per-attention-head pooled activations from a chosen encoder layer.
154
+
155
+ For an attention block with H heads of dim D, the feature axis is H·D.
156
+ The mean-pool over tokens turns the (B, T, H·D) attention output into a
157
+ (B, H, D) feature tensor that FFCA can differentiate.
158
+
159
+ Args:
160
+ model : HF model
161
+ n_heads, head_dim: from the model's config
162
+ layer_idx : which layer (negative indexes from the end; -1 = last)
163
+ attention_layer : optional override for the attention module
164
+ scalar : optional ScalarFn
165
+ """
166
+ def __init__(
167
+ self,
168
+ model: nn.Module,
169
+ n_heads: int,
170
+ head_dim: int,
171
+ *,
172
+ layer_idx: int = -1,
173
+ attention_layer: nn.Module | None = None,
174
+ scalar: ScalarFn | None = None,
175
+ ):
176
+ super().__init__(model, scalar=scalar or _last_token_logit_scalar)
177
+ self.n_heads = n_heads
178
+ self.head_dim = head_dim
179
+ self.layer_idx = layer_idx
180
+ self.n_features = n_heads * head_dim
181
+ self.feature_shape = (n_heads, head_dim)
182
+ self.feature_names = [f"h{h}_d{d}" for h in range(n_heads)
183
+ for d in range(head_dim)]
184
+ self._attn = attention_layer or _find_attention_layer(model, layer_idx)
185
+ self._batch_ids: torch.Tensor | None = None
186
+ self._batch_attn: torch.Tensor | None = None
187
+
188
+ def feature_input(self, batch) -> torch.Tensor:
189
+ ids = batch["input_ids"] if isinstance(batch, dict) else batch[0]
190
+ attn_mask = (batch.get("attention_mask")
191
+ if isinstance(batch, dict) else None)
192
+ ids = ids.to(self.device())
193
+ if attn_mask is not None:
194
+ attn_mask = attn_mask.to(self.device())
195
+
196
+ captured: list[torch.Tensor] = []
197
+ def hook(module, inputs, output):
198
+ t = output[0] if isinstance(output, tuple) else output
199
+ captured.append(t.detach())
200
+ h = self._attn.register_forward_hook(hook)
201
+ try:
202
+ with torch.no_grad():
203
+ kw = {"input_ids": ids}
204
+ if attn_mask is not None:
205
+ kw["attention_mask"] = attn_mask
206
+ self.model(**kw)
207
+ finally:
208
+ h.remove()
209
+ if not captured:
210
+ raise RuntimeError("attention hook never fired; "
211
+ "check `attention_layer` is on the forward path")
212
+ c = captured[0]
213
+ # Pool over tokens, shape → (B, n_heads, head_dim)
214
+ flat = c.mean(dim=1).reshape(c.size(0), self.n_heads, self.head_dim)
215
+ leaf = flat.clone().detach().requires_grad_(True)
216
+ self._batch_ids = ids
217
+ self._batch_attn = attn_mask
218
+ return leaf
219
+
220
+ def scalar_output(self, leaf: torch.Tensor, batch) -> torch.Tensor:
221
+ T = self._batch_ids.size(1)
222
+ injected = leaf.reshape(leaf.size(0), 1, self.n_heads * self.head_dim) \
223
+ .expand(-1, T, -1).contiguous()
224
+ def hook(module, inputs, output):
225
+ if isinstance(output, tuple):
226
+ return (injected,) + output[1:]
227
+ return injected
228
+ handle = self._attn.register_forward_hook(hook)
229
+ try:
230
+ kw = {"input_ids": self._batch_ids}
231
+ if self._batch_attn is not None:
232
+ kw["attention_mask"] = self._batch_attn
233
+ out = self.model(**kw)
234
+ finally:
235
+ handle.remove()
236
+ return self._scalar(out, batch)
ffca/checkpoint.py ADDED
@@ -0,0 +1,58 @@
1
+ """CheckpointLoader — yield a fresh adapter for each saved state_dict.
2
+
3
+ For v0.1.0 we support plain `torch.save(model.state_dict(), path)` files.
4
+ Lightning / HF / SafeTensors formats can be added with detect-by-extension
5
+ in v0.2.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from pathlib import Path
10
+ from typing import Callable, Iterator, Sequence
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+
16
+ class CheckpointLoader:
17
+ """Iterate (epoch_label, model) over a list of saved state_dicts.
18
+
19
+ Args:
20
+ model_factory: () -> fresh nn.Module instance (same architecture as
21
+ what was saved).
22
+ checkpoints: list of paths (or (label, path) pairs).
23
+ device: target device for the loaded model.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ model_factory: Callable[[], nn.Module],
29
+ checkpoints: Sequence[str | tuple[str, str]],
30
+ device: str | torch.device = "cpu",
31
+ ):
32
+ self.model_factory = model_factory
33
+ self.checkpoints = [
34
+ (Path(c).stem, Path(c)) if isinstance(c, str) else (str(c[0]), Path(c[1]))
35
+ for c in checkpoints
36
+ ]
37
+ self.device = torch.device(device)
38
+
39
+ def __len__(self) -> int:
40
+ return len(self.checkpoints)
41
+
42
+ def __iter__(self) -> Iterator[tuple[str, nn.Module]]:
43
+ for label, path in self.checkpoints:
44
+ model = self.model_factory()
45
+ state = torch.load(path, map_location=self.device, weights_only=False)
46
+ # Accept either {state_dict} or just a state_dict
47
+ if isinstance(state, dict) and "state_dict" in state:
48
+ state = state["state_dict"]
49
+ try:
50
+ model.load_state_dict(state, strict=False)
51
+ except Exception as e:
52
+ raise RuntimeError(
53
+ f"Failed to load checkpoint {path}: {e}. "
54
+ f"Ensure model_factory() returns the same architecture."
55
+ )
56
+ model.to(self.device)
57
+ model.eval()
58
+ yield label, model