slide2vec 4.6.1__tar.gz → 4.6.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {slide2vec-4.6.1 → slide2vec-4.6.3}/PKG-INFO +1 -1
- {slide2vec-4.6.1 → slide2vec-4.6.3}/pyproject.toml +2 -2
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/__init__.py +1 -1
- slide2vec-4.6.3/slide2vec/encoders/base.py +653 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/conch.py +33 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/hibou.py +40 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/midnight.py +40 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/phikon.py +42 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/config.py +25 -7
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/PKG-INFO +1 -1
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/SOURCES.txt +1 -0
- slide2vec-4.6.3/tests/test_attention_extraction.py +427 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_regression_core.py +28 -0
- slide2vec-4.6.1/slide2vec/encoders/base.py +0 -335
- {slide2vec-4.6.1 → slide2vec-4.6.3}/LICENSE +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/README.md +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/setup.cfg +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/__main__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/api.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/artifacts.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/cli.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/configs/__init__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/configs/default.yaml +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/configs/resources.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/data/__init__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/data/dataset.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/data/tile_reader.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/data/tile_store.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/distributed/__init__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/distributed/direct_embed_worker.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/distributed/pipeline_worker.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/__init__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/__init__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/gigapath.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/hoptimus.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/lunit.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/__init__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/blocks.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/case.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/loading.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/slide.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/moozy/types.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/musk.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/prism.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/prost40m.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/titan.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/uni.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/models/virchow.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/registry.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/encoders/validation.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/inference.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/progress.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/__init__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/artifacts_collect.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/batching.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/cpu_budget.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/distributed.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/distributed_stage.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/embedding.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/embedding_persist.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/embedding_pipeline.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/hierarchical.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/manifest.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/model_settings.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/patient_pipeline.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/persist_callbacks.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/persistence.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/process_list.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/progress_bridge.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/registry.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/serialization.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/slide_encode.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/tiling.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/tiling_pipeline.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/types.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/runtime/worker_io.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/__init__.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/coordinates.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/log_utils.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/tiling_io.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec/utils/utils.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/dependency_links.txt +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/entry_points.txt +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/not-zip-safe +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/requires.txt +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/slide2vec.egg-info/top_level.txt +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_architecture_runtime_split.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_dense_extraction.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_dense_locality_gated.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_encoder_registry.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_hs2p_package_cutover.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_output_consistency.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_progress.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_regression_inference.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_regression_models.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_runtime_batching.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_tile_store.py +0 -0
- {slide2vec-4.6.1 → slide2vec-4.6.3}/tests/test_tiling_pipeline.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "slide2vec"
|
|
7
|
-
version = "4.6.
|
|
7
|
+
version = "4.6.3"
|
|
8
8
|
description = "Embedding of whole slide images with Foundation Models"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
@@ -164,7 +164,7 @@ no_implicit_reexport = true
|
|
|
164
164
|
max-line-length = 160
|
|
165
165
|
|
|
166
166
|
[tool.bumpver]
|
|
167
|
-
current_version = "4.6.
|
|
167
|
+
current_version = "4.6.3"
|
|
168
168
|
version_pattern = "MAJOR.MINOR.PATCH"
|
|
169
169
|
commit = false # We do version bumping in CI, not as a commit
|
|
170
170
|
tag = false # Git tag already exists — we don't auto-tag
|
|
@@ -0,0 +1,653 @@
|
|
|
1
|
+
"""Encoder abstractions for tile-level and slide-level feature extraction."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
import timm
|
|
8
|
+
import torch
|
|
9
|
+
from timm.data import create_transform, resolve_data_config
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
from torchvision.transforms import v2
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def preferred_default_device() -> torch.device:
|
|
15
|
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@contextmanager
|
|
19
|
+
def hf_eager_attention(model):
|
|
20
|
+
"""Force an HF model onto eager attention for the duration of the block.
|
|
21
|
+
|
|
22
|
+
Recent ``transformers`` default to an SDPA attention implementation that
|
|
23
|
+
*silently ignores* ``output_attentions=True`` (it warns and returns
|
|
24
|
+
``attentions=None``), so attention extraction must temporarily switch the
|
|
25
|
+
model to ``eager``, which materializes the weights. The previous
|
|
26
|
+
implementation is restored on exit. A no-op for models that lack
|
|
27
|
+
``set_attn_implementation`` or are already eager (e.g. some vendored
|
|
28
|
+
``trust_remote_code`` backbones that always compute the weights)."""
|
|
29
|
+
setter = getattr(model, "set_attn_implementation", None)
|
|
30
|
+
prev = getattr(getattr(model, "config", None), "_attn_implementation", None)
|
|
31
|
+
changed = False
|
|
32
|
+
if callable(setter) and prev not in (None, "eager"):
|
|
33
|
+
try:
|
|
34
|
+
setter("eager")
|
|
35
|
+
changed = True
|
|
36
|
+
except Exception:
|
|
37
|
+
changed = False
|
|
38
|
+
try:
|
|
39
|
+
yield
|
|
40
|
+
finally:
|
|
41
|
+
if changed:
|
|
42
|
+
try:
|
|
43
|
+
setter(prev)
|
|
44
|
+
except Exception:
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def resolve_requested_output_variant(
|
|
49
|
+
output_variant: str | None,
|
|
50
|
+
*,
|
|
51
|
+
default: str = "default",
|
|
52
|
+
allowed: tuple[str, ...] = ("default",),
|
|
53
|
+
) -> str:
|
|
54
|
+
"""Normalize and validate a requested encoder output variant."""
|
|
55
|
+
resolved = output_variant or default
|
|
56
|
+
if resolved not in allowed:
|
|
57
|
+
available = ", ".join(allowed)
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Unsupported output_variant '{resolved}'. Available: {available}"
|
|
60
|
+
)
|
|
61
|
+
return resolved
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def resolve_recommended_dynamic_img_size(
|
|
65
|
+
*,
|
|
66
|
+
requested: bool | None,
|
|
67
|
+
recommended: bool,
|
|
68
|
+
allow_non_recommended: bool,
|
|
69
|
+
encoder_name: str,
|
|
70
|
+
) -> bool:
|
|
71
|
+
"""Resolve ``dynamic_img_size`` against an encoder's card-recommended value.
|
|
72
|
+
|
|
73
|
+
``None`` uses the recommended value. A value that differs from the
|
|
74
|
+
recommendation requires ``allow_non_recommended_settings=True`` (e.g. dense
|
|
75
|
+
feature extraction deliberately enabling variable input size, justified by
|
|
76
|
+
the registration / native-size-no-op tests); otherwise it raises, so a
|
|
77
|
+
pipeline never silently runs an encoder outside its documented config.
|
|
78
|
+
"""
|
|
79
|
+
if requested is None:
|
|
80
|
+
return recommended
|
|
81
|
+
if requested != recommended and not allow_non_recommended:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Encoder '{encoder_name}' recommends dynamic_img_size={recommended} "
|
|
84
|
+
f"(per its model card); got dynamic_img_size={requested}, which deviates "
|
|
85
|
+
"from the recommended setting. Pass allow_non_recommended_settings=True "
|
|
86
|
+
"to override it deliberately (e.g. dense extraction needs variable input "
|
|
87
|
+
"size; this is a native-size no-op, verified in the encoder tests)."
|
|
88
|
+
)
|
|
89
|
+
return requested
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def reshape_tokens_to_grid(
|
|
93
|
+
tokens: Tensor,
|
|
94
|
+
*,
|
|
95
|
+
grid_h: int,
|
|
96
|
+
grid_w: int,
|
|
97
|
+
num_prefix_tokens: int,
|
|
98
|
+
encoder_name: str,
|
|
99
|
+
) -> Tensor:
|
|
100
|
+
"""Fold a ViT ``(B, T, d)`` token sequence into a dense ``(B, d, h, w)`` grid.
|
|
101
|
+
|
|
102
|
+
Strips the leading ``num_prefix_tokens`` (CLS + register tokens) and reshapes
|
|
103
|
+
the remaining patch tokens back into their row-major spatial grid. ViT patch
|
|
104
|
+
tokens are emitted in row-major order ``[(0,0), (0,1), ..., (h-1, w-1)]`` after
|
|
105
|
+
the prefix tokens, so ``transpose(1, 2).reshape(B, d, h, w)`` recovers the
|
|
106
|
+
spatial layout. Verified bit-for-bit against timm's
|
|
107
|
+
``get_intermediate_layers(..., reshape=True)`` in the encoder tests.
|
|
108
|
+
|
|
109
|
+
Fails loudly if the post-strip token count does not match ``grid_h * grid_w``:
|
|
110
|
+
a silent reshape would train a decoder on spatially corrupted features, which
|
|
111
|
+
is worse than a hard failure.
|
|
112
|
+
"""
|
|
113
|
+
if tokens.ndim != 3:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"Dense extraction for '{encoder_name}' expected a (B, T, d) token "
|
|
116
|
+
f"sequence from the backbone, got shape {tuple(tokens.shape)}. This "
|
|
117
|
+
"encoder may not expose a recoverable patch-token grid."
|
|
118
|
+
)
|
|
119
|
+
patch_tokens = tokens[:, num_prefix_tokens:, :]
|
|
120
|
+
batch_size, num_tokens, dim = patch_tokens.shape
|
|
121
|
+
expected = grid_h * grid_w
|
|
122
|
+
if num_tokens != expected:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Dense token accounting mismatch for '{encoder_name}': backbone "
|
|
125
|
+
f"returned {tokens.shape[1]} tokens; after stripping "
|
|
126
|
+
f"{num_prefix_tokens} prefix token(s), {num_tokens} remain, but the "
|
|
127
|
+
f"{grid_h}x{grid_w} grid expects {expected}. Check the prefix-token "
|
|
128
|
+
"count and the input-size / patch-size / grid geometry."
|
|
129
|
+
)
|
|
130
|
+
return patch_tokens.transpose(1, 2).reshape(batch_size, dim, grid_h, grid_w)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def prefix_attention_to_grid(
|
|
134
|
+
attn_weights: Tensor,
|
|
135
|
+
*,
|
|
136
|
+
num_prefix_tokens: int,
|
|
137
|
+
include_registers: bool,
|
|
138
|
+
grid_h: int,
|
|
139
|
+
grid_w: int,
|
|
140
|
+
encoder_name: str,
|
|
141
|
+
) -> Tensor:
|
|
142
|
+
"""Fold one block's self-attention into per-prefix-token spatial maps.
|
|
143
|
+
|
|
144
|
+
Given the full attention weights ``(B, nh, N, N)`` of one transformer block
|
|
145
|
+
(rows = query tokens, columns = key tokens, each row a softmax over keys),
|
|
146
|
+
select the **prefix-token query rows** — the CLS token always, the ``M``
|
|
147
|
+
register tokens too when ``include_registers`` — slice the **patch key columns**,
|
|
148
|
+
and reshape each selected row back into its ``(grid_h, grid_w)`` spatial layout.
|
|
149
|
+
|
|
150
|
+
The output is ``(B, K, grid_h, grid_w)`` with ``K = num_query * nh`` channels in
|
|
151
|
+
the deterministic order ``[cls, reg…][head]`` (query-token outer, head inner):
|
|
152
|
+
channel ``q * nh + head`` is prefix-query ``q``'s attention from head ``head``.
|
|
153
|
+
``num_query = num_prefix_tokens`` when ``include_registers`` else ``1`` (CLS only).
|
|
154
|
+
|
|
155
|
+
This is the attention analog of :func:`reshape_tokens_to_grid` (one folds patch
|
|
156
|
+
*tokens* into a grid, the other folds prefix-token *attention* into grids); it
|
|
157
|
+
reuses the same ``num_prefix_tokens`` split. Per-head is preserved on purpose —
|
|
158
|
+
head specialization is the signal the downstream pixel-classifier exploits, and
|
|
159
|
+
reducing it would be lossy and irreversible in the cache. Fails loud on a token
|
|
160
|
+
accounting mismatch rather than silently mis-reshaping.
|
|
161
|
+
"""
|
|
162
|
+
if attn_weights.ndim != 4:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Attention extraction for '{encoder_name}' expected (B, nh, N, N) "
|
|
165
|
+
f"attention weights, got shape {tuple(attn_weights.shape)}."
|
|
166
|
+
)
|
|
167
|
+
batch_size, num_heads, num_query_tokens, num_key_tokens = attn_weights.shape
|
|
168
|
+
if num_query_tokens != num_key_tokens:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Attention extraction for '{encoder_name}' expected square attention "
|
|
171
|
+
f"(query==key tokens), got {num_query_tokens}x{num_key_tokens}."
|
|
172
|
+
)
|
|
173
|
+
if num_prefix_tokens < 1:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"Attention extraction for '{encoder_name}' needs at least one prefix "
|
|
176
|
+
f"token (the CLS query row), got num_prefix_tokens={num_prefix_tokens}."
|
|
177
|
+
)
|
|
178
|
+
num_patches = grid_h * grid_w
|
|
179
|
+
expected = num_prefix_tokens + num_patches
|
|
180
|
+
if num_key_tokens != expected:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"Attention token accounting mismatch for '{encoder_name}': block "
|
|
183
|
+
f"returned {num_key_tokens} tokens; with {num_prefix_tokens} prefix "
|
|
184
|
+
f"token(s) the {grid_h}x{grid_w} grid expects {expected}. Check the "
|
|
185
|
+
"prefix-token count and the input-size / patch-size / grid geometry."
|
|
186
|
+
)
|
|
187
|
+
num_query = num_prefix_tokens if include_registers else 1
|
|
188
|
+
# rows: prefix query tokens [0:num_query]; columns: patch keys [num_prefix:].
|
|
189
|
+
patch_rows = attn_weights[:, :, :num_query, num_prefix_tokens:] # (B, nh, q, P)
|
|
190
|
+
# (B, nh, q, P) -> (B, q, nh, P) so reshape yields the [query][head] channel order.
|
|
191
|
+
maps = patch_rows.permute(0, 2, 1, 3).reshape(batch_size, num_query * num_heads, num_patches)
|
|
192
|
+
return maps.reshape(batch_size, num_query * num_heads, grid_h, grid_w)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def timm_self_attention_weights(attn_module, x: Tensor) -> Tensor:
|
|
196
|
+
"""Recompute a timm ``Attention`` block's softmax weights ``(B, nh, N, N)``.
|
|
197
|
+
|
|
198
|
+
timm's attention runs a *fused* SDPA kernel by default, which never
|
|
199
|
+
materializes the attention matrix. To recover it we re-run the projection from
|
|
200
|
+
the module's own input ``x`` (the post-``norm1`` residual-branch input, captured
|
|
201
|
+
via a forward-pre-hook) using the module's own ``qkv`` / ``q_norm`` / ``k_norm``
|
|
202
|
+
/ ``num_heads`` / ``head_dim`` / ``scale`` — i.e. exactly the non-fused branch of
|
|
203
|
+
``Attention.forward``, so the result is bit-equivalent to the weights the fused
|
|
204
|
+
kernel applies internally. Dropout is omitted (extraction runs under ``eval``).
|
|
205
|
+
"""
|
|
206
|
+
if not hasattr(attn_module, "qkv"):
|
|
207
|
+
raise NotImplementedError(
|
|
208
|
+
f"{type(attn_module).__name__} has no fused 'qkv' projection; attention "
|
|
209
|
+
"extraction currently supports timm ViT Attention blocks only."
|
|
210
|
+
)
|
|
211
|
+
batch_size, num_tokens, _ = x.shape
|
|
212
|
+
num_heads = int(attn_module.num_heads)
|
|
213
|
+
head_dim = int(getattr(attn_module, "head_dim", x.shape[-1] // num_heads))
|
|
214
|
+
qkv = (
|
|
215
|
+
attn_module.qkv(x)
|
|
216
|
+
.reshape(batch_size, num_tokens, 3, num_heads, head_dim)
|
|
217
|
+
.permute(2, 0, 3, 1, 4)
|
|
218
|
+
)
|
|
219
|
+
q, k, _ = qkv.unbind(0)
|
|
220
|
+
# q_norm / k_norm are Identity unless the model uses QK-norm; apply them either way.
|
|
221
|
+
q = attn_module.q_norm(q)
|
|
222
|
+
k = attn_module.k_norm(k)
|
|
223
|
+
q = q * attn_module.scale
|
|
224
|
+
attn = q @ k.transpose(-2, -1)
|
|
225
|
+
return attn.softmax(dim=-1)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def resolve_block_indices(blocks, num_blocks: int, *, encoder_name: str) -> list[int]:
|
|
229
|
+
"""Normalize a (possibly negative) block selection against ``num_blocks``.
|
|
230
|
+
|
|
231
|
+
Preserves caller order (so the recorded ``[block]`` channel order is the order
|
|
232
|
+
requested) and validates each index, failing loud on an out-of-range block.
|
|
233
|
+
"""
|
|
234
|
+
resolved: list[int] = []
|
|
235
|
+
for raw in blocks:
|
|
236
|
+
idx = int(raw)
|
|
237
|
+
if idx < 0:
|
|
238
|
+
idx += num_blocks
|
|
239
|
+
if not (0 <= idx < num_blocks):
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Attention extraction for '{encoder_name}': block index {raw} is out "
|
|
242
|
+
f"of range for a backbone with {num_blocks} transformer blocks."
|
|
243
|
+
)
|
|
244
|
+
resolved.append(idx)
|
|
245
|
+
return resolved
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def timm_trunk_attention(
|
|
249
|
+
trunk,
|
|
250
|
+
batch: Tensor,
|
|
251
|
+
*,
|
|
252
|
+
blocks: tuple[int, ...] = (-1,),
|
|
253
|
+
include_registers: bool = False,
|
|
254
|
+
encoder_name: str,
|
|
255
|
+
) -> Tensor:
|
|
256
|
+
"""Extract per-head prefix-token attention maps from a timm ViT trunk.
|
|
257
|
+
|
|
258
|
+
The reusable core of :meth:`TimmTileEncoder.encode_tiles_attention`, factored
|
|
259
|
+
out so wrapper encoders that embed a timm ``VisionTransformer`` (CONCH's
|
|
260
|
+
``visual.trunk``, CONCH v1.5's ``trunk``) reuse the exact same path on their
|
|
261
|
+
inner trunk — the attention analog of how ``_encode_trunk_dense`` is shared.
|
|
262
|
+
|
|
263
|
+
Captures each selected block's attention input via a forward-pre-hook on
|
|
264
|
+
``trunk.blocks[i].attn`` (the fused SDPA kernel never materializes the matrix),
|
|
265
|
+
recomputes the softmax weights (:func:`timm_self_attention_weights`), and folds
|
|
266
|
+
the prefix-token query rows into spatial grids (:func:`prefix_attention_to_grid`).
|
|
267
|
+
Patch size and prefix-token count are read from the trunk
|
|
268
|
+
(``patch_embed.patch_size`` / ``num_prefix_tokens``). Output ``(B, K, h, w)`` in
|
|
269
|
+
``[block][cls, reg…][head]`` order.
|
|
270
|
+
"""
|
|
271
|
+
if batch.ndim != 4:
|
|
272
|
+
raise ValueError(
|
|
273
|
+
"encode_tiles_attention expects a (B, C, H, W) batch, got shape "
|
|
274
|
+
f"{tuple(batch.shape)}."
|
|
275
|
+
)
|
|
276
|
+
_, _, height, width = batch.shape
|
|
277
|
+
patch = trunk.patch_embed.patch_size
|
|
278
|
+
patch_h, patch_w = (patch, patch) if isinstance(patch, int) else (int(patch[0]), int(patch[1]))
|
|
279
|
+
if height % patch_h != 0 or width % patch_w != 0:
|
|
280
|
+
raise ValueError(
|
|
281
|
+
f"Attention extraction for '{encoder_name}' requires input divisible by "
|
|
282
|
+
f"the patch size: got {height}x{width}, patch {patch_h}x{patch_w}. Pad "
|
|
283
|
+
"the tile up to a patch multiple first."
|
|
284
|
+
)
|
|
285
|
+
if not hasattr(trunk, "blocks"):
|
|
286
|
+
raise NotImplementedError(
|
|
287
|
+
f"{encoder_name} has no '.blocks' transformer stack; attention extraction "
|
|
288
|
+
"supports timm ViT-style backbones only."
|
|
289
|
+
)
|
|
290
|
+
block_list = trunk.blocks
|
|
291
|
+
resolved = resolve_block_indices(blocks, len(block_list), encoder_name=encoder_name)
|
|
292
|
+
|
|
293
|
+
captured: dict[int, Tensor] = {}
|
|
294
|
+
|
|
295
|
+
def _make_hook(index: int):
|
|
296
|
+
def _hook(_module, inputs):
|
|
297
|
+
captured[index] = inputs[0]
|
|
298
|
+
|
|
299
|
+
return _hook
|
|
300
|
+
|
|
301
|
+
handles = []
|
|
302
|
+
for index in sorted(set(resolved)):
|
|
303
|
+
handles.append(block_list[index].attn.register_forward_pre_hook(_make_hook(index)))
|
|
304
|
+
try:
|
|
305
|
+
trunk.forward_features(batch)
|
|
306
|
+
finally:
|
|
307
|
+
for handle in handles:
|
|
308
|
+
handle.remove()
|
|
309
|
+
|
|
310
|
+
grid_h, grid_w = height // patch_h, width // patch_w
|
|
311
|
+
num_prefix = int(getattr(trunk, "num_prefix_tokens", 1))
|
|
312
|
+
grids = []
|
|
313
|
+
for index in resolved:
|
|
314
|
+
attn_weights = timm_self_attention_weights(block_list[index].attn, captured[index])
|
|
315
|
+
grids.append(
|
|
316
|
+
prefix_attention_to_grid(
|
|
317
|
+
attn_weights,
|
|
318
|
+
num_prefix_tokens=num_prefix,
|
|
319
|
+
include_registers=include_registers,
|
|
320
|
+
grid_h=grid_h,
|
|
321
|
+
grid_w=grid_w,
|
|
322
|
+
encoder_name=encoder_name,
|
|
323
|
+
)
|
|
324
|
+
)
|
|
325
|
+
return torch.cat(grids, dim=1) # [block] outer (caller order), [cls, reg…][head] inner
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def attentions_tuple_to_grids(
|
|
329
|
+
attentions,
|
|
330
|
+
*,
|
|
331
|
+
num_prefix_tokens: int,
|
|
332
|
+
blocks: tuple[int, ...],
|
|
333
|
+
include_registers: bool,
|
|
334
|
+
grid_h: int,
|
|
335
|
+
grid_w: int,
|
|
336
|
+
encoder_name: str,
|
|
337
|
+
) -> Tensor:
|
|
338
|
+
"""Fold an HF ``output_attentions`` tuple into stacked prefix-token grids.
|
|
339
|
+
|
|
340
|
+
HF transformer ViTs expose every block's softmax attention directly (no
|
|
341
|
+
fused-kernel recompute), as a per-layer tuple of ``(B, nh, N, N)`` tensors.
|
|
342
|
+
This selects the requested blocks (:func:`resolve_block_indices`), folds each
|
|
343
|
+
to spatial grids (:func:`prefix_attention_to_grid`), and concatenates them in
|
|
344
|
+
``[block][cls, reg…][head]`` order — the shared core of the HF-path encoders
|
|
345
|
+
(Phikon, Hibou, Midnight), which differ only in ``num_prefix_tokens``.
|
|
346
|
+
"""
|
|
347
|
+
if not attentions:
|
|
348
|
+
raise NotImplementedError(
|
|
349
|
+
f"{encoder_name} returned no attentions; the model must support "
|
|
350
|
+
"output_attentions=True (an eager/recompute attention implementation, "
|
|
351
|
+
"not a fused SDPA path that discards the weights)."
|
|
352
|
+
)
|
|
353
|
+
resolved = resolve_block_indices(blocks, len(attentions), encoder_name=encoder_name)
|
|
354
|
+
grids = [
|
|
355
|
+
prefix_attention_to_grid(
|
|
356
|
+
attentions[index],
|
|
357
|
+
num_prefix_tokens=num_prefix_tokens,
|
|
358
|
+
include_registers=include_registers,
|
|
359
|
+
grid_h=grid_h,
|
|
360
|
+
grid_w=grid_w,
|
|
361
|
+
encoder_name=encoder_name,
|
|
362
|
+
)
|
|
363
|
+
for index in resolved
|
|
364
|
+
]
|
|
365
|
+
return torch.cat(grids, dim=1)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class Encoder(ABC):
|
|
369
|
+
"""Shared lifecycle contract for all encoders."""
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
@abstractmethod
|
|
373
|
+
def encode_dim(self) -> int:
|
|
374
|
+
"""Dimensionality of the output feature vector."""
|
|
375
|
+
...
|
|
376
|
+
|
|
377
|
+
@property
|
|
378
|
+
@abstractmethod
|
|
379
|
+
def device(self) -> torch.device:
|
|
380
|
+
"""Current device of the encoder."""
|
|
381
|
+
...
|
|
382
|
+
|
|
383
|
+
@abstractmethod
|
|
384
|
+
def to(self, device: torch.device | str) -> "Encoder":
|
|
385
|
+
"""Move encoder to the given device. Returns self."""
|
|
386
|
+
...
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class TileEncoder(Encoder):
|
|
390
|
+
"""Base class for encoders that operate directly on image tiles."""
|
|
391
|
+
|
|
392
|
+
@abstractmethod
|
|
393
|
+
def get_transform(self) -> Callable:
|
|
394
|
+
"""Image transform pipeline (PIL Image or ndarray -> Tensor)."""
|
|
395
|
+
...
|
|
396
|
+
|
|
397
|
+
@abstractmethod
|
|
398
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
399
|
+
"""Encode a batch of tiles. (B, C, H, W) -> (B, D)."""
|
|
400
|
+
...
|
|
401
|
+
|
|
402
|
+
def encode_tiles_dense(self, batch: Tensor) -> Tensor:
|
|
403
|
+
"""Encode tiles into a dense spatial feature grid. (B, C, H, W) -> (B, d, h, w).
|
|
404
|
+
|
|
405
|
+
Default: unsupported. ViT tile encoders with a recoverable patch grid
|
|
406
|
+
override this; vision-language / slide-native encoders (no usable patch
|
|
407
|
+
grid) do not. ``d`` is the per-token feature dim and ``h, w`` the token
|
|
408
|
+
grid (``H / patch``, ``W / patch``).
|
|
409
|
+
"""
|
|
410
|
+
raise NotImplementedError(
|
|
411
|
+
f"{type(self).__name__} does not support dense (spatial-grid) feature "
|
|
412
|
+
"extraction. Dense extraction requires a ViT tile encoder whose patch "
|
|
413
|
+
"tokens can be reshaped into a spatial grid."
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
def encode_tiles_attention(
|
|
417
|
+
self,
|
|
418
|
+
batch: Tensor,
|
|
419
|
+
*,
|
|
420
|
+
blocks: tuple[int, ...] = (-1,),
|
|
421
|
+
include_registers: bool = False,
|
|
422
|
+
) -> Tensor:
|
|
423
|
+
"""Encode tiles into per-head CLS/register self-attention maps.
|
|
424
|
+
|
|
425
|
+
``(B, C, H, W) -> (B, K, h, w)`` where each channel is one prefix-token
|
|
426
|
+
query row's self-attention over the patch grid for one head, stacked in the
|
|
427
|
+
deterministic order ``[block][cls, reg…][head]``. ``K = len(blocks) *
|
|
428
|
+
(1 + M·include_registers) * nh`` with ``M`` the model's register-token count
|
|
429
|
+
(``0`` for models without them) and ``nh`` the head count.
|
|
430
|
+
|
|
431
|
+
The per-head CLS attention of a frozen ViT doubles as a dense per-pixel
|
|
432
|
+
feature (Ramchandani et al., arXiv:2602.18747); ``include_registers`` adds
|
|
433
|
+
the register-token query rows (Darcet et al.) as extra, optional channels.
|
|
434
|
+
Default: unsupported — overridden by ViT tile encoders whose attention
|
|
435
|
+
blocks can be hooked (timm ViTs and HF transformers ViTs).
|
|
436
|
+
"""
|
|
437
|
+
raise NotImplementedError(
|
|
438
|
+
f"{type(self).__name__} does not support attention-map extraction. It "
|
|
439
|
+
"requires a ViT tile encoder whose self-attention can be recovered "
|
|
440
|
+
"(timm Attention blocks, or an HF transformer with output_attentions)."
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
@property
|
|
444
|
+
def patch_size(self) -> tuple[int, int]:
|
|
445
|
+
"""Backbone patch size ``(patch_h, patch_w)`` — only for dense encoders.
|
|
446
|
+
|
|
447
|
+
Encoder-authoritative (a property of the frozen model), used to resolve the
|
|
448
|
+
dense token grid. Default: unsupported, mirroring ``encode_tiles_dense``.
|
|
449
|
+
"""
|
|
450
|
+
raise NotImplementedError(
|
|
451
|
+
f"{type(self).__name__} does not expose a patch size (only dense-capable "
|
|
452
|
+
"ViT tile encoders do)."
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
def get_dense_transform(self) -> Callable:
|
|
456
|
+
"""Photometric (normalization-only) transform for dense extraction.
|
|
457
|
+
|
|
458
|
+
Returns a transform that applies ONLY this encoder's normalization
|
|
459
|
+
(per-channel mean/std) — **no Resize, no CenterCrop** — so the dense
|
|
460
|
+
feature grid covers the *full* source tile and stays spatially registered
|
|
461
|
+
to it. This deliberately differs from ``get_transform`` (the pooled recipe):
|
|
462
|
+
some encoders resize-then-center-crop there (GigaPath ``Resize(256) ->
|
|
463
|
+
CenterCrop(224)``; Lunit ``crop_pct=0.9 -> Resize(248) -> CenterCrop(224)``),
|
|
464
|
+
which drops the tile margins and would misregister the grid against a dense
|
|
465
|
+
target mask. Geometry (padding to a patch multiple, optional resize,
|
|
466
|
+
cropping logits back) is the dense pipeline's responsibility, not the
|
|
467
|
+
encoder's. Default: unsupported, mirroring ``encode_tiles_dense``.
|
|
468
|
+
"""
|
|
469
|
+
raise NotImplementedError(
|
|
470
|
+
f"{type(self).__name__} does not provide a dense transform. Only "
|
|
471
|
+
"encoders that support dense (spatial-grid) extraction define one."
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
class SlideEncoder(Encoder):
|
|
476
|
+
"""Base class for encoders that pool tile features into slide features."""
|
|
477
|
+
|
|
478
|
+
tile_encoder: TileEncoder | None = None
|
|
479
|
+
|
|
480
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
481
|
+
if self.tile_encoder is None:
|
|
482
|
+
raise AttributeError("slide encoders must attach a tile_encoder before encoding tiles")
|
|
483
|
+
return self.tile_encoder.encode_tiles(batch)
|
|
484
|
+
|
|
485
|
+
@abstractmethod
|
|
486
|
+
def encode_slide(
|
|
487
|
+
self,
|
|
488
|
+
tile_features: Tensor,
|
|
489
|
+
coordinates: Tensor | None = None,
|
|
490
|
+
*,
|
|
491
|
+
tile_size_lv0: int | None = None,
|
|
492
|
+
) -> Tensor:
|
|
493
|
+
"""Pool tile-level features into a single slide-level embedding."""
|
|
494
|
+
...
|
|
495
|
+
|
|
496
|
+
def prepare_coordinates(
|
|
497
|
+
self,
|
|
498
|
+
coordinates: Tensor,
|
|
499
|
+
*,
|
|
500
|
+
base_spacing_um: float,
|
|
501
|
+
requested_spacing_um: float,
|
|
502
|
+
) -> Tensor:
|
|
503
|
+
"""Hook for model-specific coordinate normalization."""
|
|
504
|
+
return coordinates
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class PatientEncoder(Encoder):
|
|
508
|
+
"""Base class for encoders that aggregate slide embeddings into patient embeddings."""
|
|
509
|
+
|
|
510
|
+
tile_encoder: TileEncoder | None = None
|
|
511
|
+
|
|
512
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
513
|
+
if self.tile_encoder is None:
|
|
514
|
+
raise AttributeError("patient encoders must attach a tile_encoder before encoding tiles")
|
|
515
|
+
return self.tile_encoder.encode_tiles(batch)
|
|
516
|
+
|
|
517
|
+
@abstractmethod
|
|
518
|
+
def encode_slide(
|
|
519
|
+
self,
|
|
520
|
+
tile_features: Tensor,
|
|
521
|
+
coordinates: Tensor | None = None,
|
|
522
|
+
*,
|
|
523
|
+
tile_size_lv0: int | None = None,
|
|
524
|
+
) -> Tensor:
|
|
525
|
+
"""Pool tile-level features into a single slide-level embedding."""
|
|
526
|
+
...
|
|
527
|
+
|
|
528
|
+
@abstractmethod
|
|
529
|
+
def encode_patient(self, slide_embeddings: Tensor) -> Tensor:
|
|
530
|
+
"""Aggregate slide embeddings [S, D] into a single patient-level embedding [D]."""
|
|
531
|
+
...
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class TimmTileEncoder(TileEncoder):
|
|
535
|
+
"""Convenience base for timm-backed tile encoders."""
|
|
536
|
+
|
|
537
|
+
def __init__(
|
|
538
|
+
self,
|
|
539
|
+
model_name: str,
|
|
540
|
+
*,
|
|
541
|
+
output_variant: str | None = None,
|
|
542
|
+
**timm_kwargs,
|
|
543
|
+
):
|
|
544
|
+
defaults = {"pretrained": True, "num_classes": 0}
|
|
545
|
+
defaults.update(timm_kwargs)
|
|
546
|
+
self._model = timm.create_model(model_name, **defaults).eval()
|
|
547
|
+
self._device = preferred_default_device()
|
|
548
|
+
if not hasattr(self, "_output_variant"):
|
|
549
|
+
self._output_variant = resolve_requested_output_variant(output_variant)
|
|
550
|
+
|
|
551
|
+
def get_transform(self) -> Callable:
|
|
552
|
+
data_config = resolve_data_config(self._model.pretrained_cfg, model=self._model)
|
|
553
|
+
return create_transform(**data_config)
|
|
554
|
+
|
|
555
|
+
def get_dense_transform(self) -> Callable:
|
|
556
|
+
# Normalization only — no Resize/CenterCrop (see TileEncoder.get_dense_transform).
|
|
557
|
+
# mean/std come from the same resolved data config get_transform uses, so the
|
|
558
|
+
# photometric pipeline matches pooled extraction even for encoders with custom
|
|
559
|
+
# normalization (e.g. H-optimus 0.7072.../0.2119...); verified per-encoder.
|
|
560
|
+
cfg = resolve_data_config(self._model.pretrained_cfg, model=self._model)
|
|
561
|
+
return v2.Compose([
|
|
562
|
+
v2.ToImage(),
|
|
563
|
+
v2.ToDtype(torch.float32, scale=True),
|
|
564
|
+
v2.Normalize(mean=cfg["mean"], std=cfg["std"]),
|
|
565
|
+
])
|
|
566
|
+
|
|
567
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
568
|
+
return self._model(batch)
|
|
569
|
+
|
|
570
|
+
def _dense_patch_size(self) -> tuple[int, int]:
|
|
571
|
+
"""Backbone patch size as ``(patch_h, patch_w)``."""
|
|
572
|
+
patch = self._model.patch_embed.patch_size
|
|
573
|
+
if isinstance(patch, int):
|
|
574
|
+
return patch, patch
|
|
575
|
+
patch_h, patch_w = patch
|
|
576
|
+
return int(patch_h), int(patch_w)
|
|
577
|
+
|
|
578
|
+
@property
|
|
579
|
+
def patch_size(self) -> tuple[int, int]:
|
|
580
|
+
return self._dense_patch_size()
|
|
581
|
+
|
|
582
|
+
def _dense_num_prefix_tokens(self) -> int:
|
|
583
|
+
"""Number of leading non-patch tokens (CLS + register tokens)."""
|
|
584
|
+
return int(self._model.num_prefix_tokens)
|
|
585
|
+
|
|
586
|
+
def encode_tiles_dense(self, batch: Tensor) -> Tensor:
|
|
587
|
+
"""Encode tiles into a dense spatial grid. (B, C, H, W) -> (B, d, h, w).
|
|
588
|
+
|
|
589
|
+
Runs the frozen backbone's ``forward_features`` and folds the patch-token
|
|
590
|
+
sequence back into its spatial grid (CLS/register tokens discarded). The
|
|
591
|
+
backbone must accept ``batch`` at its current spatial size (timm ViTs need
|
|
592
|
+
``dynamic_img_size=True`` for sizes other than their native input), and
|
|
593
|
+
``H, W`` must be divisible by the patch size.
|
|
594
|
+
"""
|
|
595
|
+
if batch.ndim != 4:
|
|
596
|
+
raise ValueError(
|
|
597
|
+
"encode_tiles_dense expects a (B, C, H, W) batch, got shape "
|
|
598
|
+
f"{tuple(batch.shape)}."
|
|
599
|
+
)
|
|
600
|
+
_, _, height, width = batch.shape
|
|
601
|
+
patch_h, patch_w = self._dense_patch_size()
|
|
602
|
+
if height % patch_h != 0 or width % patch_w != 0:
|
|
603
|
+
raise ValueError(
|
|
604
|
+
f"Dense extraction for '{type(self).__name__}' requires input "
|
|
605
|
+
f"divisible by the patch size: got {height}x{width}, patch "
|
|
606
|
+
f"{patch_h}x{patch_w}. Pad the tile up to a patch multiple first."
|
|
607
|
+
)
|
|
608
|
+
tokens = self._model.forward_features(batch)
|
|
609
|
+
return reshape_tokens_to_grid(
|
|
610
|
+
tokens,
|
|
611
|
+
grid_h=height // patch_h,
|
|
612
|
+
grid_w=width // patch_w,
|
|
613
|
+
num_prefix_tokens=self._dense_num_prefix_tokens(),
|
|
614
|
+
encoder_name=type(self).__name__,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
def encode_tiles_attention(
|
|
618
|
+
self,
|
|
619
|
+
batch: Tensor,
|
|
620
|
+
*,
|
|
621
|
+
blocks: tuple[int, ...] = (-1,),
|
|
622
|
+
include_registers: bool = False,
|
|
623
|
+
) -> Tensor:
|
|
624
|
+
"""Encode tiles into per-head prefix-token attention maps (timm ViT family).
|
|
625
|
+
|
|
626
|
+
Captures each selected block's attention input via a forward-pre-hook on
|
|
627
|
+
``blocks[i].attn`` (the fused SDPA kernel never materializes the attention
|
|
628
|
+
matrix), then recomputes the softmax weights from the module's own
|
|
629
|
+
projection (:func:`timm_self_attention_weights`) and folds the prefix-token
|
|
630
|
+
query rows into spatial grids (:func:`prefix_attention_to_grid`). Output is
|
|
631
|
+
``(B, K, h, w)`` in ``[block][cls, reg…][head]`` order — see
|
|
632
|
+
:meth:`TileEncoder.encode_tiles_attention`.
|
|
633
|
+
"""
|
|
634
|
+
return timm_trunk_attention(
|
|
635
|
+
self._model,
|
|
636
|
+
batch,
|
|
637
|
+
blocks=blocks,
|
|
638
|
+
include_registers=include_registers,
|
|
639
|
+
encoder_name=type(self).__name__,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
@property
|
|
643
|
+
def encode_dim(self) -> int:
|
|
644
|
+
return self._model.num_features
|
|
645
|
+
|
|
646
|
+
@property
|
|
647
|
+
def device(self) -> torch.device:
|
|
648
|
+
return self._device
|
|
649
|
+
|
|
650
|
+
def to(self, device: torch.device | str) -> "TimmTileEncoder":
|
|
651
|
+
self._device = torch.device(device)
|
|
652
|
+
self._model = self._model.to(self._device)
|
|
653
|
+
return self
|