slide2vec 4.6.1__tar.gz → 4.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. {slide2vec-4.6.1 → slide2vec-4.6.3}/PKG-INFO +1 -1
  2. {slide2vec-4.6.1 → slide2vec-4.6.3}/pyproject.toml +2 -2
  3. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/__init__.py +1 -1
  4. slide2vec-4.6.3/slide2vec/encoders/base.py +653 -0
  5. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/conch.py +33 -0
  6. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/hibou.py +40 -0
  7. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/midnight.py +40 -0
  8. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/phikon.py +42 -0
  9. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/config.py +25 -7
  10. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/PKG-INFO +1 -1
  11. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/SOURCES.txt +1 -0
  12. slide2vec-4.6.3/tests/test_attention_extraction.py +427 -0
  13. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_regression_core.py +28 -0
  14. slide2vec-4.6.1/slide2vec/encoders/base.py +0 -335
  15. {slide2vec-4.6.1 → slide2vec-4.6.3}/LICENSE +0 -0
  16. {slide2vec-4.6.1 → slide2vec-4.6.3}/README.md +0 -0
  17. {slide2vec-4.6.1 → slide2vec-4.6.3}/setup.cfg +0 -0
  18. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/__main__.py +0 -0
  19. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/api.py +0 -0
  20. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/artifacts.py +0 -0
  21. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/cli.py +0 -0
  22. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/configs/__init__.py +0 -0
  23. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/configs/default.yaml +0 -0
  24. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/configs/resources.py +0 -0
  25. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/data/__init__.py +0 -0
  26. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/data/dataset.py +0 -0
  27. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/data/tile_reader.py +0 -0
  28. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/data/tile_store.py +0 -0
  29. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/distributed/__init__.py +0 -0
  30. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/distributed/direct_embed_worker.py +0 -0
  31. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/distributed/pipeline_worker.py +0 -0
  32. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/__init__.py +0 -0
  33. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/__init__.py +0 -0
  34. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/gigapath.py +0 -0
  35. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/hoptimus.py +0 -0
  36. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/lunit.py +0 -0
  37. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/__init__.py +0 -0
  38. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/blocks.py +0 -0
  39. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/case.py +0 -0
  40. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/loading.py +0 -0
  41. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/slide.py +0 -0
  42. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/types.py +0 -0
  43. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/musk.py +0 -0
  44. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/prism.py +0 -0
  45. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/prost40m.py +0 -0
  46. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/titan.py +0 -0
  47. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/uni.py +0 -0
  48. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/virchow.py +0 -0
  49. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/registry.py +0 -0
  50. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/validation.py +0 -0
  51. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/inference.py +0 -0
  52. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/progress.py +0 -0
  53. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/__init__.py +0 -0
  54. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/artifacts_collect.py +0 -0
  55. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/batching.py +0 -0
  56. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/cpu_budget.py +0 -0
  57. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/distributed.py +0 -0
  58. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/distributed_stage.py +0 -0
  59. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/embedding.py +0 -0
  60. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/embedding_persist.py +0 -0
  61. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/embedding_pipeline.py +0 -0
  62. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/hierarchical.py +0 -0
  63. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/manifest.py +0 -0
  64. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/model_settings.py +0 -0
  65. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/patient_pipeline.py +0 -0
  66. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/persist_callbacks.py +0 -0
  67. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/persistence.py +0 -0
  68. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/process_list.py +0 -0
  69. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/progress_bridge.py +0 -0
  70. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/registry.py +0 -0
  71. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/serialization.py +0 -0
  72. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/slide_encode.py +0 -0
  73. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/tiling.py +0 -0
  74. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/tiling_pipeline.py +0 -0
  75. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/types.py +0 -0
  76. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/worker_io.py +0 -0
  77. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/__init__.py +0 -0
  78. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/coordinates.py +0 -0
  79. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/log_utils.py +0 -0
  80. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/tiling_io.py +0 -0
  81. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/utils.py +0 -0
  82. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/dependency_links.txt +0 -0
  83. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/entry_points.txt +0 -0
  84. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/not-zip-safe +0 -0
  85. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/requires.txt +0 -0
  86. {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/top_level.txt +0 -0
  87. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_architecture_runtime_split.py +0 -0
  88. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_dense_extraction.py +0 -0
  89. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_dense_locality_gated.py +0 -0
  90. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_encoder_registry.py +0 -0
  91. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_hs2p_package_cutover.py +0 -0
  92. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_output_consistency.py +0 -0
  93. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_progress.py +0 -0
  94. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_regression_inference.py +0 -0
  95. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_regression_models.py +0 -0
  96. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_runtime_batching.py +0 -0
  97. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_tile_store.py +0 -0
  98. {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_tiling_pipeline.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: slide2vec
3
- Version: 4.6.1
3
+ Version: 4.6.3
4
4
  Summary: Embedding of whole slide images with Foundation Models
5
5
  Author-email: Clément Grisi <clement.grisi@radboudumc.nl>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "slide2vec"
7
- version = "4.6.1"
7
+ version = "4.6.3"
8
8
  description = "Embedding of whole slide images with Foundation Models"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -164,7 +164,7 @@ no_implicit_reexport = true
164
164
  max-line-length = 160
165
165
 
166
166
  [tool.bumpver]
167
- current_version = "4.6.1"
167
+ current_version = "4.6.3"
168
168
  version_pattern = "MAJOR.MINOR.PATCH"
169
169
  commit = false # We do version bumping in CI, not as a commit
170
170
  tag = false # Git tag already exists — we don't auto-tag
@@ -11,7 +11,7 @@ from slide2vec.api import (
11
11
  from slide2vec.artifacts import HierarchicalEmbeddingArtifact, SlideEmbeddingArtifact, TileEmbeddingArtifact
12
12
 
13
13
 
14
- __version__ = "4.6.1"
14
+ __version__ = "4.6.3"
15
15
 
16
16
  __all__ = [
17
17
  "Model",
@@ -0,0 +1,653 @@
1
+ """Encoder abstractions for tile-level and slide-level feature extraction."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from contextlib import contextmanager
5
+ from typing import Callable
6
+
7
+ import timm
8
+ import torch
9
+ from timm.data import create_transform, resolve_data_config
10
+ from torch import Tensor
11
+ from torchvision.transforms import v2
12
+
13
+
14
+ def preferred_default_device() -> torch.device:
15
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+
18
+ @contextmanager
19
+ def hf_eager_attention(model):
20
+ """Force an HF model onto eager attention for the duration of the block.
21
+
22
+ Recent ``transformers`` default to an SDPA attention implementation that
23
+ *silently ignores* ``output_attentions=True`` (it warns and returns
24
+ ``attentions=None``), so attention extraction must temporarily switch the
25
+ model to ``eager``, which materializes the weights. The previous
26
+ implementation is restored on exit. A no-op for models that lack
27
+ ``set_attn_implementation`` or are already eager (e.g. some vendored
28
+ ``trust_remote_code`` backbones that always compute the weights)."""
29
+ setter = getattr(model, "set_attn_implementation", None)
30
+ prev = getattr(getattr(model, "config", None), "_attn_implementation", None)
31
+ changed = False
32
+ if callable(setter) and prev not in (None, "eager"):
33
+ try:
34
+ setter("eager")
35
+ changed = True
36
+ except Exception:
37
+ changed = False
38
+ try:
39
+ yield
40
+ finally:
41
+ if changed:
42
+ try:
43
+ setter(prev)
44
+ except Exception:
45
+ pass
46
+
47
+
48
+ def resolve_requested_output_variant(
49
+ output_variant: str | None,
50
+ *,
51
+ default: str = "default",
52
+ allowed: tuple[str, ...] = ("default",),
53
+ ) -> str:
54
+ """Normalize and validate a requested encoder output variant."""
55
+ resolved = output_variant or default
56
+ if resolved not in allowed:
57
+ available = ", ".join(allowed)
58
+ raise ValueError(
59
+ f"Unsupported output_variant '{resolved}'. Available: {available}"
60
+ )
61
+ return resolved
62
+
63
+
64
+ def resolve_recommended_dynamic_img_size(
65
+ *,
66
+ requested: bool | None,
67
+ recommended: bool,
68
+ allow_non_recommended: bool,
69
+ encoder_name: str,
70
+ ) -> bool:
71
+ """Resolve ``dynamic_img_size`` against an encoder's card-recommended value.
72
+
73
+ ``None`` uses the recommended value. A value that differs from the
74
+ recommendation requires ``allow_non_recommended_settings=True`` (e.g. dense
75
+ feature extraction deliberately enabling variable input size, justified by
76
+ the registration / native-size-no-op tests); otherwise it raises, so a
77
+ pipeline never silently runs an encoder outside its documented config.
78
+ """
79
+ if requested is None:
80
+ return recommended
81
+ if requested != recommended and not allow_non_recommended:
82
+ raise ValueError(
83
+ f"Encoder '{encoder_name}' recommends dynamic_img_size={recommended} "
84
+ f"(per its model card); got dynamic_img_size={requested}, which deviates "
85
+ "from the recommended setting. Pass allow_non_recommended_settings=True "
86
+ "to override it deliberately (e.g. dense extraction needs variable input "
87
+ "size; this is a native-size no-op, verified in the encoder tests)."
88
+ )
89
+ return requested
90
+
91
+
92
+ def reshape_tokens_to_grid(
93
+ tokens: Tensor,
94
+ *,
95
+ grid_h: int,
96
+ grid_w: int,
97
+ num_prefix_tokens: int,
98
+ encoder_name: str,
99
+ ) -> Tensor:
100
+ """Fold a ViT ``(B, T, d)`` token sequence into a dense ``(B, d, h, w)`` grid.
101
+
102
+ Strips the leading ``num_prefix_tokens`` (CLS + register tokens) and reshapes
103
+ the remaining patch tokens back into their row-major spatial grid. ViT patch
104
+ tokens are emitted in row-major order ``[(0,0), (0,1), ..., (h-1, w-1)]`` after
105
+ the prefix tokens, so ``transpose(1, 2).reshape(B, d, h, w)`` recovers the
106
+ spatial layout. Verified bit-for-bit against timm's
107
+ ``get_intermediate_layers(..., reshape=True)`` in the encoder tests.
108
+
109
+ Fails loudly if the post-strip token count does not match ``grid_h * grid_w``:
110
+ a silent reshape would train a decoder on spatially corrupted features, which
111
+ is worse than a hard failure.
112
+ """
113
+ if tokens.ndim != 3:
114
+ raise ValueError(
115
+ f"Dense extraction for '{encoder_name}' expected a (B, T, d) token "
116
+ f"sequence from the backbone, got shape {tuple(tokens.shape)}. This "
117
+ "encoder may not expose a recoverable patch-token grid."
118
+ )
119
+ patch_tokens = tokens[:, num_prefix_tokens:, :]
120
+ batch_size, num_tokens, dim = patch_tokens.shape
121
+ expected = grid_h * grid_w
122
+ if num_tokens != expected:
123
+ raise ValueError(
124
+ f"Dense token accounting mismatch for '{encoder_name}': backbone "
125
+ f"returned {tokens.shape[1]} tokens; after stripping "
126
+ f"{num_prefix_tokens} prefix token(s), {num_tokens} remain, but the "
127
+ f"{grid_h}x{grid_w} grid expects {expected}. Check the prefix-token "
128
+ "count and the input-size / patch-size / grid geometry."
129
+ )
130
+ return patch_tokens.transpose(1, 2).reshape(batch_size, dim, grid_h, grid_w)
131
+
132
+
133
+ def prefix_attention_to_grid(
134
+ attn_weights: Tensor,
135
+ *,
136
+ num_prefix_tokens: int,
137
+ include_registers: bool,
138
+ grid_h: int,
139
+ grid_w: int,
140
+ encoder_name: str,
141
+ ) -> Tensor:
142
+ """Fold one block's self-attention into per-prefix-token spatial maps.
143
+
144
+ Given the full attention weights ``(B, nh, N, N)`` of one transformer block
145
+ (rows = query tokens, columns = key tokens, each row a softmax over keys),
146
+ select the **prefix-token query rows** — the CLS token always, the ``M``
147
+ register tokens too when ``include_registers`` — slice the **patch key columns**,
148
+ and reshape each selected row back into its ``(grid_h, grid_w)`` spatial layout.
149
+
150
+ The output is ``(B, K, grid_h, grid_w)`` with ``K = num_query * nh`` channels in
151
+ the deterministic order ``[cls, reg…][head]`` (query-token outer, head inner):
152
+ channel ``q * nh + head`` is prefix-query ``q``'s attention from head ``head``.
153
+ ``num_query = num_prefix_tokens`` when ``include_registers`` else ``1`` (CLS only).
154
+
155
+ This is the attention analog of :func:`reshape_tokens_to_grid` (one folds patch
156
+ *tokens* into a grid, the other folds prefix-token *attention* into grids); it
157
+ reuses the same ``num_prefix_tokens`` split. Per-head is preserved on purpose —
158
+ head specialization is the signal the downstream pixel-classifier exploits, and
159
+ reducing it would be lossy and irreversible in the cache. Fails loud on a token
160
+ accounting mismatch rather than silently mis-reshaping.
161
+ """
162
+ if attn_weights.ndim != 4:
163
+ raise ValueError(
164
+ f"Attention extraction for '{encoder_name}' expected (B, nh, N, N) "
165
+ f"attention weights, got shape {tuple(attn_weights.shape)}."
166
+ )
167
+ batch_size, num_heads, num_query_tokens, num_key_tokens = attn_weights.shape
168
+ if num_query_tokens != num_key_tokens:
169
+ raise ValueError(
170
+ f"Attention extraction for '{encoder_name}' expected square attention "
171
+ f"(query==key tokens), got {num_query_tokens}x{num_key_tokens}."
172
+ )
173
+ if num_prefix_tokens < 1:
174
+ raise ValueError(
175
+ f"Attention extraction for '{encoder_name}' needs at least one prefix "
176
+ f"token (the CLS query row), got num_prefix_tokens={num_prefix_tokens}."
177
+ )
178
+ num_patches = grid_h * grid_w
179
+ expected = num_prefix_tokens + num_patches
180
+ if num_key_tokens != expected:
181
+ raise ValueError(
182
+ f"Attention token accounting mismatch for '{encoder_name}': block "
183
+ f"returned {num_key_tokens} tokens; with {num_prefix_tokens} prefix "
184
+ f"token(s) the {grid_h}x{grid_w} grid expects {expected}. Check the "
185
+ "prefix-token count and the input-size / patch-size / grid geometry."
186
+ )
187
+ num_query = num_prefix_tokens if include_registers else 1
188
+ # rows: prefix query tokens [0:num_query]; columns: patch keys [num_prefix:].
189
+ patch_rows = attn_weights[:, :, :num_query, num_prefix_tokens:] # (B, nh, q, P)
190
+ # (B, nh, q, P) -> (B, q, nh, P) so reshape yields the [query][head] channel order.
191
+ maps = patch_rows.permute(0, 2, 1, 3).reshape(batch_size, num_query * num_heads, num_patches)
192
+ return maps.reshape(batch_size, num_query * num_heads, grid_h, grid_w)
193
+
194
+
195
+ def timm_self_attention_weights(attn_module, x: Tensor) -> Tensor:
196
+ """Recompute a timm ``Attention`` block's softmax weights ``(B, nh, N, N)``.
197
+
198
+ timm's attention runs a *fused* SDPA kernel by default, which never
199
+ materializes the attention matrix. To recover it we re-run the projection from
200
+ the module's own input ``x`` (the post-``norm1`` residual-branch input, captured
201
+ via a forward-pre-hook) using the module's own ``qkv`` / ``q_norm`` / ``k_norm``
202
+ / ``num_heads`` / ``head_dim`` / ``scale`` — i.e. exactly the non-fused branch of
203
+ ``Attention.forward``, so the result is bit-equivalent to the weights the fused
204
+ kernel applies internally. Dropout is omitted (extraction runs under ``eval``).
205
+ """
206
+ if not hasattr(attn_module, "qkv"):
207
+ raise NotImplementedError(
208
+ f"{type(attn_module).__name__} has no fused 'qkv' projection; attention "
209
+ "extraction currently supports timm ViT Attention blocks only."
210
+ )
211
+ batch_size, num_tokens, _ = x.shape
212
+ num_heads = int(attn_module.num_heads)
213
+ head_dim = int(getattr(attn_module, "head_dim", x.shape[-1] // num_heads))
214
+ qkv = (
215
+ attn_module.qkv(x)
216
+ .reshape(batch_size, num_tokens, 3, num_heads, head_dim)
217
+ .permute(2, 0, 3, 1, 4)
218
+ )
219
+ q, k, _ = qkv.unbind(0)
220
+ # q_norm / k_norm are Identity unless the model uses QK-norm; apply them either way.
221
+ q = attn_module.q_norm(q)
222
+ k = attn_module.k_norm(k)
223
+ q = q * attn_module.scale
224
+ attn = q @ k.transpose(-2, -1)
225
+ return attn.softmax(dim=-1)
226
+
227
+
228
+ def resolve_block_indices(blocks, num_blocks: int, *, encoder_name: str) -> list[int]:
229
+ """Normalize a (possibly negative) block selection against ``num_blocks``.
230
+
231
+ Preserves caller order (so the recorded ``[block]`` channel order is the order
232
+ requested) and validates each index, failing loud on an out-of-range block.
233
+ """
234
+ resolved: list[int] = []
235
+ for raw in blocks:
236
+ idx = int(raw)
237
+ if idx < 0:
238
+ idx += num_blocks
239
+ if not (0 <= idx < num_blocks):
240
+ raise ValueError(
241
+ f"Attention extraction for '{encoder_name}': block index {raw} is out "
242
+ f"of range for a backbone with {num_blocks} transformer blocks."
243
+ )
244
+ resolved.append(idx)
245
+ return resolved
246
+
247
+
248
+ def timm_trunk_attention(
249
+ trunk,
250
+ batch: Tensor,
251
+ *,
252
+ blocks: tuple[int, ...] = (-1,),
253
+ include_registers: bool = False,
254
+ encoder_name: str,
255
+ ) -> Tensor:
256
+ """Extract per-head prefix-token attention maps from a timm ViT trunk.
257
+
258
+ The reusable core of :meth:`TimmTileEncoder.encode_tiles_attention`, factored
259
+ out so wrapper encoders that embed a timm ``VisionTransformer`` (CONCH's
260
+ ``visual.trunk``, CONCH v1.5's ``trunk``) reuse the exact same path on their
261
+ inner trunk — the attention analog of how ``_encode_trunk_dense`` is shared.
262
+
263
+ Captures each selected block's attention input via a forward-pre-hook on
264
+ ``trunk.blocks[i].attn`` (the fused SDPA kernel never materializes the matrix),
265
+ recomputes the softmax weights (:func:`timm_self_attention_weights`), and folds
266
+ the prefix-token query rows into spatial grids (:func:`prefix_attention_to_grid`).
267
+ Patch size and prefix-token count are read from the trunk
268
+ (``patch_embed.patch_size`` / ``num_prefix_tokens``). Output ``(B, K, h, w)`` in
269
+ ``[block][cls, reg…][head]`` order.
270
+ """
271
+ if batch.ndim != 4:
272
+ raise ValueError(
273
+ "encode_tiles_attention expects a (B, C, H, W) batch, got shape "
274
+ f"{tuple(batch.shape)}."
275
+ )
276
+ _, _, height, width = batch.shape
277
+ patch = trunk.patch_embed.patch_size
278
+ patch_h, patch_w = (patch, patch) if isinstance(patch, int) else (int(patch[0]), int(patch[1]))
279
+ if height % patch_h != 0 or width % patch_w != 0:
280
+ raise ValueError(
281
+ f"Attention extraction for '{encoder_name}' requires input divisible by "
282
+ f"the patch size: got {height}x{width}, patch {patch_h}x{patch_w}. Pad "
283
+ "the tile up to a patch multiple first."
284
+ )
285
+ if not hasattr(trunk, "blocks"):
286
+ raise NotImplementedError(
287
+ f"{encoder_name} has no '.blocks' transformer stack; attention extraction "
288
+ "supports timm ViT-style backbones only."
289
+ )
290
+ block_list = trunk.blocks
291
+ resolved = resolve_block_indices(blocks, len(block_list), encoder_name=encoder_name)
292
+
293
+ captured: dict[int, Tensor] = {}
294
+
295
+ def _make_hook(index: int):
296
+ def _hook(_module, inputs):
297
+ captured[index] = inputs[0]
298
+
299
+ return _hook
300
+
301
+ handles = []
302
+ for index in sorted(set(resolved)):
303
+ handles.append(block_list[index].attn.register_forward_pre_hook(_make_hook(index)))
304
+ try:
305
+ trunk.forward_features(batch)
306
+ finally:
307
+ for handle in handles:
308
+ handle.remove()
309
+
310
+ grid_h, grid_w = height // patch_h, width // patch_w
311
+ num_prefix = int(getattr(trunk, "num_prefix_tokens", 1))
312
+ grids = []
313
+ for index in resolved:
314
+ attn_weights = timm_self_attention_weights(block_list[index].attn, captured[index])
315
+ grids.append(
316
+ prefix_attention_to_grid(
317
+ attn_weights,
318
+ num_prefix_tokens=num_prefix,
319
+ include_registers=include_registers,
320
+ grid_h=grid_h,
321
+ grid_w=grid_w,
322
+ encoder_name=encoder_name,
323
+ )
324
+ )
325
+ return torch.cat(grids, dim=1) # [block] outer (caller order), [cls, reg…][head] inner
326
+
327
+
328
+ def attentions_tuple_to_grids(
329
+ attentions,
330
+ *,
331
+ num_prefix_tokens: int,
332
+ blocks: tuple[int, ...],
333
+ include_registers: bool,
334
+ grid_h: int,
335
+ grid_w: int,
336
+ encoder_name: str,
337
+ ) -> Tensor:
338
+ """Fold an HF ``output_attentions`` tuple into stacked prefix-token grids.
339
+
340
+ HF transformer ViTs expose every block's softmax attention directly (no
341
+ fused-kernel recompute), as a per-layer tuple of ``(B, nh, N, N)`` tensors.
342
+ This selects the requested blocks (:func:`resolve_block_indices`), folds each
343
+ to spatial grids (:func:`prefix_attention_to_grid`), and concatenates them in
344
+ ``[block][cls, reg…][head]`` order — the shared core of the HF-path encoders
345
+ (Phikon, Hibou, Midnight), which differ only in ``num_prefix_tokens``.
346
+ """
347
+ if not attentions:
348
+ raise NotImplementedError(
349
+ f"{encoder_name} returned no attentions; the model must support "
350
+ "output_attentions=True (an eager/recompute attention implementation, "
351
+ "not a fused SDPA path that discards the weights)."
352
+ )
353
+ resolved = resolve_block_indices(blocks, len(attentions), encoder_name=encoder_name)
354
+ grids = [
355
+ prefix_attention_to_grid(
356
+ attentions[index],
357
+ num_prefix_tokens=num_prefix_tokens,
358
+ include_registers=include_registers,
359
+ grid_h=grid_h,
360
+ grid_w=grid_w,
361
+ encoder_name=encoder_name,
362
+ )
363
+ for index in resolved
364
+ ]
365
+ return torch.cat(grids, dim=1)
366
+
367
+
368
+ class Encoder(ABC):
369
+ """Shared lifecycle contract for all encoders."""
370
+
371
+ @property
372
+ @abstractmethod
373
+ def encode_dim(self) -> int:
374
+ """Dimensionality of the output feature vector."""
375
+ ...
376
+
377
+ @property
378
+ @abstractmethod
379
+ def device(self) -> torch.device:
380
+ """Current device of the encoder."""
381
+ ...
382
+
383
+ @abstractmethod
384
+ def to(self, device: torch.device | str) -> "Encoder":
385
+ """Move encoder to the given device. Returns self."""
386
+ ...
387
+
388
+
389
+ class TileEncoder(Encoder):
390
+ """Base class for encoders that operate directly on image tiles."""
391
+
392
+ @abstractmethod
393
+ def get_transform(self) -> Callable:
394
+ """Image transform pipeline (PIL Image or ndarray -> Tensor)."""
395
+ ...
396
+
397
+ @abstractmethod
398
+ def encode_tiles(self, batch: Tensor) -> Tensor:
399
+ """Encode a batch of tiles. (B, C, H, W) -> (B, D)."""
400
+ ...
401
+
402
+ def encode_tiles_dense(self, batch: Tensor) -> Tensor:
403
+ """Encode tiles into a dense spatial feature grid. (B, C, H, W) -> (B, d, h, w).
404
+
405
+ Default: unsupported. ViT tile encoders with a recoverable patch grid
406
+ override this; vision-language / slide-native encoders (no usable patch
407
+ grid) do not. ``d`` is the per-token feature dim and ``h, w`` the token
408
+ grid (``H / patch``, ``W / patch``).
409
+ """
410
+ raise NotImplementedError(
411
+ f"{type(self).__name__} does not support dense (spatial-grid) feature "
412
+ "extraction. Dense extraction requires a ViT tile encoder whose patch "
413
+ "tokens can be reshaped into a spatial grid."
414
+ )
415
+
416
+ def encode_tiles_attention(
417
+ self,
418
+ batch: Tensor,
419
+ *,
420
+ blocks: tuple[int, ...] = (-1,),
421
+ include_registers: bool = False,
422
+ ) -> Tensor:
423
+ """Encode tiles into per-head CLS/register self-attention maps.
424
+
425
+ ``(B, C, H, W) -> (B, K, h, w)`` where each channel is one prefix-token
426
+ query row's self-attention over the patch grid for one head, stacked in the
427
+ deterministic order ``[block][cls, reg…][head]``. ``K = len(blocks) *
428
+ (1 + M·include_registers) * nh`` with ``M`` the model's register-token count
429
+ (``0`` for models without them) and ``nh`` the head count.
430
+
431
+ The per-head CLS attention of a frozen ViT doubles as a dense per-pixel
432
+ feature (Ramchandani et al., arXiv:2602.18747); ``include_registers`` adds
433
+ the register-token query rows (Darcet et al.) as extra, optional channels.
434
+ Default: unsupported — overridden by ViT tile encoders whose attention
435
+ blocks can be hooked (timm ViTs and HF transformers ViTs).
436
+ """
437
+ raise NotImplementedError(
438
+ f"{type(self).__name__} does not support attention-map extraction. It "
439
+ "requires a ViT tile encoder whose self-attention can be recovered "
440
+ "(timm Attention blocks, or an HF transformer with output_attentions)."
441
+ )
442
+
443
+ @property
444
+ def patch_size(self) -> tuple[int, int]:
445
+ """Backbone patch size ``(patch_h, patch_w)`` — only for dense encoders.
446
+
447
+ Encoder-authoritative (a property of the frozen model), used to resolve the
448
+ dense token grid. Default: unsupported, mirroring ``encode_tiles_dense``.
449
+ """
450
+ raise NotImplementedError(
451
+ f"{type(self).__name__} does not expose a patch size (only dense-capable "
452
+ "ViT tile encoders do)."
453
+ )
454
+
455
+ def get_dense_transform(self) -> Callable:
456
+ """Photometric (normalization-only) transform for dense extraction.
457
+
458
+ Returns a transform that applies ONLY this encoder's normalization
459
+ (per-channel mean/std) — **no Resize, no CenterCrop** — so the dense
460
+ feature grid covers the *full* source tile and stays spatially registered
461
+ to it. This deliberately differs from ``get_transform`` (the pooled recipe):
462
+ some encoders resize-then-center-crop there (GigaPath ``Resize(256) ->
463
+ CenterCrop(224)``; Lunit ``crop_pct=0.9 -> Resize(248) -> CenterCrop(224)``),
464
+ which drops the tile margins and would misregister the grid against a dense
465
+ target mask. Geometry (padding to a patch multiple, optional resize,
466
+ cropping logits back) is the dense pipeline's responsibility, not the
467
+ encoder's. Default: unsupported, mirroring ``encode_tiles_dense``.
468
+ """
469
+ raise NotImplementedError(
470
+ f"{type(self).__name__} does not provide a dense transform. Only "
471
+ "encoders that support dense (spatial-grid) extraction define one."
472
+ )
473
+
474
+
475
+ class SlideEncoder(Encoder):
476
+ """Base class for encoders that pool tile features into slide features."""
477
+
478
+ tile_encoder: TileEncoder | None = None
479
+
480
+ def encode_tiles(self, batch: Tensor) -> Tensor:
481
+ if self.tile_encoder is None:
482
+ raise AttributeError("slide encoders must attach a tile_encoder before encoding tiles")
483
+ return self.tile_encoder.encode_tiles(batch)
484
+
485
+ @abstractmethod
486
+ def encode_slide(
487
+ self,
488
+ tile_features: Tensor,
489
+ coordinates: Tensor | None = None,
490
+ *,
491
+ tile_size_lv0: int | None = None,
492
+ ) -> Tensor:
493
+ """Pool tile-level features into a single slide-level embedding."""
494
+ ...
495
+
496
+ def prepare_coordinates(
497
+ self,
498
+ coordinates: Tensor,
499
+ *,
500
+ base_spacing_um: float,
501
+ requested_spacing_um: float,
502
+ ) -> Tensor:
503
+ """Hook for model-specific coordinate normalization."""
504
+ return coordinates
505
+
506
+
507
+ class PatientEncoder(Encoder):
508
+ """Base class for encoders that aggregate slide embeddings into patient embeddings."""
509
+
510
+ tile_encoder: TileEncoder | None = None
511
+
512
+ def encode_tiles(self, batch: Tensor) -> Tensor:
513
+ if self.tile_encoder is None:
514
+ raise AttributeError("patient encoders must attach a tile_encoder before encoding tiles")
515
+ return self.tile_encoder.encode_tiles(batch)
516
+
517
+ @abstractmethod
518
+ def encode_slide(
519
+ self,
520
+ tile_features: Tensor,
521
+ coordinates: Tensor | None = None,
522
+ *,
523
+ tile_size_lv0: int | None = None,
524
+ ) -> Tensor:
525
+ """Pool tile-level features into a single slide-level embedding."""
526
+ ...
527
+
528
+ @abstractmethod
529
+ def encode_patient(self, slide_embeddings: Tensor) -> Tensor:
530
+ """Aggregate slide embeddings [S, D] into a single patient-level embedding [D]."""
531
+ ...
532
+
533
+
534
+ class TimmTileEncoder(TileEncoder):
535
+ """Convenience base for timm-backed tile encoders."""
536
+
537
+ def __init__(
538
+ self,
539
+ model_name: str,
540
+ *,
541
+ output_variant: str | None = None,
542
+ **timm_kwargs,
543
+ ):
544
+ defaults = {"pretrained": True, "num_classes": 0}
545
+ defaults.update(timm_kwargs)
546
+ self._model = timm.create_model(model_name, **defaults).eval()
547
+ self._device = preferred_default_device()
548
+ if not hasattr(self, "_output_variant"):
549
+ self._output_variant = resolve_requested_output_variant(output_variant)
550
+
551
+ def get_transform(self) -> Callable:
552
+ data_config = resolve_data_config(self._model.pretrained_cfg, model=self._model)
553
+ return create_transform(**data_config)
554
+
555
+ def get_dense_transform(self) -> Callable:
556
+ # Normalization only — no Resize/CenterCrop (see TileEncoder.get_dense_transform).
557
+ # mean/std come from the same resolved data config get_transform uses, so the
558
+ # photometric pipeline matches pooled extraction even for encoders with custom
559
+ # normalization (e.g. H-optimus 0.7072.../0.2119...); verified per-encoder.
560
+ cfg = resolve_data_config(self._model.pretrained_cfg, model=self._model)
561
+ return v2.Compose([
562
+ v2.ToImage(),
563
+ v2.ToDtype(torch.float32, scale=True),
564
+ v2.Normalize(mean=cfg["mean"], std=cfg["std"]),
565
+ ])
566
+
567
+ def encode_tiles(self, batch: Tensor) -> Tensor:
568
+ return self._model(batch)
569
+
570
+ def _dense_patch_size(self) -> tuple[int, int]:
571
+ """Backbone patch size as ``(patch_h, patch_w)``."""
572
+ patch = self._model.patch_embed.patch_size
573
+ if isinstance(patch, int):
574
+ return patch, patch
575
+ patch_h, patch_w = patch
576
+ return int(patch_h), int(patch_w)
577
+
578
+ @property
579
+ def patch_size(self) -> tuple[int, int]:
580
+ return self._dense_patch_size()
581
+
582
+ def _dense_num_prefix_tokens(self) -> int:
583
+ """Number of leading non-patch tokens (CLS + register tokens)."""
584
+ return int(self._model.num_prefix_tokens)
585
+
586
+ def encode_tiles_dense(self, batch: Tensor) -> Tensor:
587
+ """Encode tiles into a dense spatial grid. (B, C, H, W) -> (B, d, h, w).
588
+
589
+ Runs the frozen backbone's ``forward_features`` and folds the patch-token
590
+ sequence back into its spatial grid (CLS/register tokens discarded). The
591
+ backbone must accept ``batch`` at its current spatial size (timm ViTs need
592
+ ``dynamic_img_size=True`` for sizes other than their native input), and
593
+ ``H, W`` must be divisible by the patch size.
594
+ """
595
+ if batch.ndim != 4:
596
+ raise ValueError(
597
+ "encode_tiles_dense expects a (B, C, H, W) batch, got shape "
598
+ f"{tuple(batch.shape)}."
599
+ )
600
+ _, _, height, width = batch.shape
601
+ patch_h, patch_w = self._dense_patch_size()
602
+ if height % patch_h != 0 or width % patch_w != 0:
603
+ raise ValueError(
604
+ f"Dense extraction for '{type(self).__name__}' requires input "
605
+ f"divisible by the patch size: got {height}x{width}, patch "
606
+ f"{patch_h}x{patch_w}. Pad the tile up to a patch multiple first."
607
+ )
608
+ tokens = self._model.forward_features(batch)
609
+ return reshape_tokens_to_grid(
610
+ tokens,
611
+ grid_h=height // patch_h,
612
+ grid_w=width // patch_w,
613
+ num_prefix_tokens=self._dense_num_prefix_tokens(),
614
+ encoder_name=type(self).__name__,
615
+ )
616
+
617
+ def encode_tiles_attention(
618
+ self,
619
+ batch: Tensor,
620
+ *,
621
+ blocks: tuple[int, ...] = (-1,),
622
+ include_registers: bool = False,
623
+ ) -> Tensor:
624
+ """Encode tiles into per-head prefix-token attention maps (timm ViT family).
625
+
626
+ Captures each selected block's attention input via a forward-pre-hook on
627
+ ``blocks[i].attn`` (the fused SDPA kernel never materializes the attention
628
+ matrix), then recomputes the softmax weights from the module's own
629
+ projection (:func:`timm_self_attention_weights`) and folds the prefix-token
630
+ query rows into spatial grids (:func:`prefix_attention_to_grid`). Output is
631
+ ``(B, K, h, w)`` in ``[block][cls, reg…][head]`` order — see
632
+ :meth:`TileEncoder.encode_tiles_attention`.
633
+ """
634
+ return timm_trunk_attention(
635
+ self._model,
636
+ batch,
637
+ blocks=blocks,
638
+ include_registers=include_registers,
639
+ encoder_name=type(self).__name__,
640
+ )
641
+
642
+ @property
643
+ def encode_dim(self) -> int:
644
+ return self._model.num_features
645
+
646
+ @property
647
+ def device(self) -> torch.device:
648
+ return self._device
649
+
650
+ def to(self, device: torch.device | str) -> "TimmTileEncoder":
651
+ self._device = torch.device(device)
652
+ self._model = self._model.to(self._device)
653
+ return self