slide2vec 4.5.3__tar.gz → 4.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {slide2vec-4.5.3 → slide2vec-4.6.0}/PKG-INFO +1 -1
- {slide2vec-4.5.3 → slide2vec-4.6.0}/pyproject.toml +2 -2
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/__init__.py +1 -1
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/__init__.py +4 -0
- slide2vec-4.6.0/slide2vec/encoders/base.py +335 -0
- slide2vec-4.6.0/slide2vec/encoders/models/conch.py +181 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/gigapath.py +28 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/hibou.py +36 -1
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/hoptimus.py +35 -5
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/lunit.py +1 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/midnight.py +36 -1
- slide2vec-4.6.0/slide2vec/encoders/models/musk.py +124 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/phikon.py +53 -1
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/prost40m.py +1 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/inference.py +15 -1
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/PKG-INFO +1 -1
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/SOURCES.txt +4 -1
- slide2vec-4.6.0/tests/test_dense_extraction.py +487 -0
- slide2vec-4.6.0/tests/test_dense_locality_gated.py +162 -0
- slide2vec-4.6.0/tests/test_tiling_pipeline.py +25 -0
- slide2vec-4.5.3/slide2vec/encoders/base.py +0 -161
- slide2vec-4.5.3/slide2vec/encoders/models/conch.py +0 -93
- slide2vec-4.5.3/slide2vec/encoders/models/musk.py +0 -69
- {slide2vec-4.5.3 → slide2vec-4.6.0}/LICENSE +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/README.md +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/setup.cfg +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/__main__.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/api.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/artifacts.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/cli.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/configs/__init__.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/configs/default.yaml +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/configs/resources.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/data/__init__.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/data/dataset.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/data/tile_reader.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/data/tile_store.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/distributed/__init__.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/distributed/direct_embed_worker.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/distributed/pipeline_worker.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/__init__.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/__init__.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/blocks.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/case.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/loading.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/slide.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/types.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/prism.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/titan.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/uni.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/virchow.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/registry.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/validation.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/progress.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/__init__.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/artifacts_collect.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/batching.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/cpu_budget.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/distributed.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/distributed_stage.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/embedding.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/embedding_persist.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/embedding_pipeline.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/hierarchical.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/manifest.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/model_settings.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/patient_pipeline.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/persist_callbacks.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/persistence.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/process_list.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/progress_bridge.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/registry.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/serialization.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/slide_encode.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/tiling.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/tiling_pipeline.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/types.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/worker_io.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/__init__.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/config.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/coordinates.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/log_utils.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/tiling_io.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/utils.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/dependency_links.txt +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/entry_points.txt +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/not-zip-safe +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/requires.txt +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/top_level.txt +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_architecture_runtime_split.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_encoder_registry.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_hs2p_package_cutover.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_output_consistency.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_progress.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_regression_core.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_regression_inference.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_regression_models.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_runtime_batching.py +0 -0
- {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_tile_store.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "slide2vec"
|
|
7
|
-
version = "4.
|
|
7
|
+
version = "4.6.0"
|
|
8
8
|
description = "Embedding of whole slide images with Foundation Models"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
@@ -164,7 +164,7 @@ no_implicit_reexport = true
|
|
|
164
164
|
max-line-length = 160
|
|
165
165
|
|
|
166
166
|
[tool.bumpver]
|
|
167
|
-
current_version = "4.
|
|
167
|
+
current_version = "4.6.0"
|
|
168
168
|
version_pattern = "MAJOR.MINOR.PATCH"
|
|
169
169
|
commit = false # We do version bumping in CI, not as a commit
|
|
170
170
|
tag = false # Git tag already exists — we don't auto-tag
|
|
@@ -10,6 +10,8 @@ from slide2vec.encoders.base import (
|
|
|
10
10
|
SlideEncoder,
|
|
11
11
|
TileEncoder,
|
|
12
12
|
TimmTileEncoder,
|
|
13
|
+
reshape_tokens_to_grid,
|
|
14
|
+
resolve_recommended_dynamic_img_size,
|
|
13
15
|
resolve_requested_output_variant,
|
|
14
16
|
)
|
|
15
17
|
from slide2vec.encoders.registry import (
|
|
@@ -29,6 +31,8 @@ __all__ = [
|
|
|
29
31
|
"TileEncoder",
|
|
30
32
|
"SlideEncoder",
|
|
31
33
|
"TimmTileEncoder",
|
|
34
|
+
"reshape_tokens_to_grid",
|
|
35
|
+
"resolve_recommended_dynamic_img_size",
|
|
32
36
|
"resolve_requested_output_variant",
|
|
33
37
|
"encoder_registry",
|
|
34
38
|
"register_encoder",
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""Encoder abstractions for tile-level and slide-level feature extraction."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
import timm
|
|
7
|
+
import torch
|
|
8
|
+
from timm.data import create_transform, resolve_data_config
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
from torchvision.transforms import v2
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def preferred_default_device() -> torch.device:
|
|
14
|
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def resolve_requested_output_variant(
|
|
18
|
+
output_variant: str | None,
|
|
19
|
+
*,
|
|
20
|
+
default: str = "default",
|
|
21
|
+
allowed: tuple[str, ...] = ("default",),
|
|
22
|
+
) -> str:
|
|
23
|
+
"""Normalize and validate a requested encoder output variant."""
|
|
24
|
+
resolved = output_variant or default
|
|
25
|
+
if resolved not in allowed:
|
|
26
|
+
available = ", ".join(allowed)
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Unsupported output_variant '{resolved}'. Available: {available}"
|
|
29
|
+
)
|
|
30
|
+
return resolved
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def resolve_recommended_dynamic_img_size(
|
|
34
|
+
*,
|
|
35
|
+
requested: bool | None,
|
|
36
|
+
recommended: bool,
|
|
37
|
+
allow_non_recommended: bool,
|
|
38
|
+
encoder_name: str,
|
|
39
|
+
) -> bool:
|
|
40
|
+
"""Resolve ``dynamic_img_size`` against an encoder's card-recommended value.
|
|
41
|
+
|
|
42
|
+
``None`` uses the recommended value. A value that differs from the
|
|
43
|
+
recommendation requires ``allow_non_recommended_settings=True`` (e.g. dense
|
|
44
|
+
feature extraction deliberately enabling variable input size, justified by
|
|
45
|
+
the registration / native-size-no-op tests); otherwise it raises, so a
|
|
46
|
+
pipeline never silently runs an encoder outside its documented config.
|
|
47
|
+
"""
|
|
48
|
+
if requested is None:
|
|
49
|
+
return recommended
|
|
50
|
+
if requested != recommended and not allow_non_recommended:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Encoder '{encoder_name}' recommends dynamic_img_size={recommended} "
|
|
53
|
+
f"(per its model card); got dynamic_img_size={requested}, which deviates "
|
|
54
|
+
"from the recommended setting. Pass allow_non_recommended_settings=True "
|
|
55
|
+
"to override it deliberately (e.g. dense extraction needs variable input "
|
|
56
|
+
"size; this is a native-size no-op, verified in the encoder tests)."
|
|
57
|
+
)
|
|
58
|
+
return requested
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def reshape_tokens_to_grid(
|
|
62
|
+
tokens: Tensor,
|
|
63
|
+
*,
|
|
64
|
+
grid_h: int,
|
|
65
|
+
grid_w: int,
|
|
66
|
+
num_prefix_tokens: int,
|
|
67
|
+
encoder_name: str,
|
|
68
|
+
) -> Tensor:
|
|
69
|
+
"""Fold a ViT ``(B, T, d)`` token sequence into a dense ``(B, d, h, w)`` grid.
|
|
70
|
+
|
|
71
|
+
Strips the leading ``num_prefix_tokens`` (CLS + register tokens) and reshapes
|
|
72
|
+
the remaining patch tokens back into their row-major spatial grid. ViT patch
|
|
73
|
+
tokens are emitted in row-major order ``[(0,0), (0,1), ..., (h-1, w-1)]`` after
|
|
74
|
+
the prefix tokens, so ``transpose(1, 2).reshape(B, d, h, w)`` recovers the
|
|
75
|
+
spatial layout. Verified bit-for-bit against timm's
|
|
76
|
+
``get_intermediate_layers(..., reshape=True)`` in the encoder tests.
|
|
77
|
+
|
|
78
|
+
Fails loudly if the post-strip token count does not match ``grid_h * grid_w``:
|
|
79
|
+
a silent reshape would train a decoder on spatially corrupted features, which
|
|
80
|
+
is worse than a hard failure.
|
|
81
|
+
"""
|
|
82
|
+
if tokens.ndim != 3:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Dense extraction for '{encoder_name}' expected a (B, T, d) token "
|
|
85
|
+
f"sequence from the backbone, got shape {tuple(tokens.shape)}. This "
|
|
86
|
+
"encoder may not expose a recoverable patch-token grid."
|
|
87
|
+
)
|
|
88
|
+
patch_tokens = tokens[:, num_prefix_tokens:, :]
|
|
89
|
+
batch_size, num_tokens, dim = patch_tokens.shape
|
|
90
|
+
expected = grid_h * grid_w
|
|
91
|
+
if num_tokens != expected:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Dense token accounting mismatch for '{encoder_name}': backbone "
|
|
94
|
+
f"returned {tokens.shape[1]} tokens; after stripping "
|
|
95
|
+
f"{num_prefix_tokens} prefix token(s), {num_tokens} remain, but the "
|
|
96
|
+
f"{grid_h}x{grid_w} grid expects {expected}. Check the prefix-token "
|
|
97
|
+
"count and the input-size / patch-size / grid geometry."
|
|
98
|
+
)
|
|
99
|
+
return patch_tokens.transpose(1, 2).reshape(batch_size, dim, grid_h, grid_w)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Encoder(ABC):
|
|
103
|
+
"""Shared lifecycle contract for all encoders."""
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def encode_dim(self) -> int:
|
|
108
|
+
"""Dimensionality of the output feature vector."""
|
|
109
|
+
...
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def device(self) -> torch.device:
|
|
114
|
+
"""Current device of the encoder."""
|
|
115
|
+
...
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def to(self, device: torch.device | str) -> "Encoder":
|
|
119
|
+
"""Move encoder to the given device. Returns self."""
|
|
120
|
+
...
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class TileEncoder(Encoder):
|
|
124
|
+
"""Base class for encoders that operate directly on image tiles."""
|
|
125
|
+
|
|
126
|
+
@abstractmethod
|
|
127
|
+
def get_transform(self) -> Callable:
|
|
128
|
+
"""Image transform pipeline (PIL Image or ndarray -> Tensor)."""
|
|
129
|
+
...
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
133
|
+
"""Encode a batch of tiles. (B, C, H, W) -> (B, D)."""
|
|
134
|
+
...
|
|
135
|
+
|
|
136
|
+
def encode_tiles_dense(self, batch: Tensor) -> Tensor:
|
|
137
|
+
"""Encode tiles into a dense spatial feature grid. (B, C, H, W) -> (B, d, h, w).
|
|
138
|
+
|
|
139
|
+
Default: unsupported. ViT tile encoders with a recoverable patch grid
|
|
140
|
+
override this; vision-language / slide-native encoders (no usable patch
|
|
141
|
+
grid) do not. ``d`` is the per-token feature dim and ``h, w`` the token
|
|
142
|
+
grid (``H / patch``, ``W / patch``).
|
|
143
|
+
"""
|
|
144
|
+
raise NotImplementedError(
|
|
145
|
+
f"{type(self).__name__} does not support dense (spatial-grid) feature "
|
|
146
|
+
"extraction. Dense extraction requires a ViT tile encoder whose patch "
|
|
147
|
+
"tokens can be reshaped into a spatial grid."
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def patch_size(self) -> tuple[int, int]:
|
|
152
|
+
"""Backbone patch size ``(patch_h, patch_w)`` — only for dense encoders.
|
|
153
|
+
|
|
154
|
+
Encoder-authoritative (a property of the frozen model), used to resolve the
|
|
155
|
+
dense token grid. Default: unsupported, mirroring ``encode_tiles_dense``.
|
|
156
|
+
"""
|
|
157
|
+
raise NotImplementedError(
|
|
158
|
+
f"{type(self).__name__} does not expose a patch size (only dense-capable "
|
|
159
|
+
"ViT tile encoders do)."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def get_dense_transform(self) -> Callable:
|
|
163
|
+
"""Photometric (normalization-only) transform for dense extraction.
|
|
164
|
+
|
|
165
|
+
Returns a transform that applies ONLY this encoder's normalization
|
|
166
|
+
(per-channel mean/std) — **no Resize, no CenterCrop** — so the dense
|
|
167
|
+
feature grid covers the *full* source tile and stays spatially registered
|
|
168
|
+
to it. This deliberately differs from ``get_transform`` (the pooled recipe):
|
|
169
|
+
some encoders resize-then-center-crop there (GigaPath ``Resize(256) ->
|
|
170
|
+
CenterCrop(224)``; Lunit ``crop_pct=0.9 -> Resize(248) -> CenterCrop(224)``),
|
|
171
|
+
which drops the tile margins and would misregister the grid against a dense
|
|
172
|
+
target mask. Geometry (padding to a patch multiple, optional resize,
|
|
173
|
+
cropping logits back) is the dense pipeline's responsibility, not the
|
|
174
|
+
encoder's. Default: unsupported, mirroring ``encode_tiles_dense``.
|
|
175
|
+
"""
|
|
176
|
+
raise NotImplementedError(
|
|
177
|
+
f"{type(self).__name__} does not provide a dense transform. Only "
|
|
178
|
+
"encoders that support dense (spatial-grid) extraction define one."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class SlideEncoder(Encoder):
|
|
183
|
+
"""Base class for encoders that pool tile features into slide features."""
|
|
184
|
+
|
|
185
|
+
tile_encoder: TileEncoder | None = None
|
|
186
|
+
|
|
187
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
188
|
+
if self.tile_encoder is None:
|
|
189
|
+
raise AttributeError("slide encoders must attach a tile_encoder before encoding tiles")
|
|
190
|
+
return self.tile_encoder.encode_tiles(batch)
|
|
191
|
+
|
|
192
|
+
@abstractmethod
|
|
193
|
+
def encode_slide(
|
|
194
|
+
self,
|
|
195
|
+
tile_features: Tensor,
|
|
196
|
+
coordinates: Tensor | None = None,
|
|
197
|
+
*,
|
|
198
|
+
tile_size_lv0: int | None = None,
|
|
199
|
+
) -> Tensor:
|
|
200
|
+
"""Pool tile-level features into a single slide-level embedding."""
|
|
201
|
+
...
|
|
202
|
+
|
|
203
|
+
def prepare_coordinates(
|
|
204
|
+
self,
|
|
205
|
+
coordinates: Tensor,
|
|
206
|
+
*,
|
|
207
|
+
base_spacing_um: float,
|
|
208
|
+
requested_spacing_um: float,
|
|
209
|
+
) -> Tensor:
|
|
210
|
+
"""Hook for model-specific coordinate normalization."""
|
|
211
|
+
return coordinates
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class PatientEncoder(Encoder):
|
|
215
|
+
"""Base class for encoders that aggregate slide embeddings into patient embeddings."""
|
|
216
|
+
|
|
217
|
+
tile_encoder: TileEncoder | None = None
|
|
218
|
+
|
|
219
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
220
|
+
if self.tile_encoder is None:
|
|
221
|
+
raise AttributeError("patient encoders must attach a tile_encoder before encoding tiles")
|
|
222
|
+
return self.tile_encoder.encode_tiles(batch)
|
|
223
|
+
|
|
224
|
+
@abstractmethod
|
|
225
|
+
def encode_slide(
|
|
226
|
+
self,
|
|
227
|
+
tile_features: Tensor,
|
|
228
|
+
coordinates: Tensor | None = None,
|
|
229
|
+
*,
|
|
230
|
+
tile_size_lv0: int | None = None,
|
|
231
|
+
) -> Tensor:
|
|
232
|
+
"""Pool tile-level features into a single slide-level embedding."""
|
|
233
|
+
...
|
|
234
|
+
|
|
235
|
+
@abstractmethod
|
|
236
|
+
def encode_patient(self, slide_embeddings: Tensor) -> Tensor:
|
|
237
|
+
"""Aggregate slide embeddings [S, D] into a single patient-level embedding [D]."""
|
|
238
|
+
...
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class TimmTileEncoder(TileEncoder):
|
|
242
|
+
"""Convenience base for timm-backed tile encoders."""
|
|
243
|
+
|
|
244
|
+
def __init__(
|
|
245
|
+
self,
|
|
246
|
+
model_name: str,
|
|
247
|
+
*,
|
|
248
|
+
output_variant: str | None = None,
|
|
249
|
+
**timm_kwargs,
|
|
250
|
+
):
|
|
251
|
+
defaults = {"pretrained": True, "num_classes": 0}
|
|
252
|
+
defaults.update(timm_kwargs)
|
|
253
|
+
self._model = timm.create_model(model_name, **defaults).eval()
|
|
254
|
+
self._device = preferred_default_device()
|
|
255
|
+
if not hasattr(self, "_output_variant"):
|
|
256
|
+
self._output_variant = resolve_requested_output_variant(output_variant)
|
|
257
|
+
|
|
258
|
+
def get_transform(self) -> Callable:
|
|
259
|
+
data_config = resolve_data_config(self._model.pretrained_cfg, model=self._model)
|
|
260
|
+
return create_transform(**data_config)
|
|
261
|
+
|
|
262
|
+
def get_dense_transform(self) -> Callable:
|
|
263
|
+
# Normalization only — no Resize/CenterCrop (see TileEncoder.get_dense_transform).
|
|
264
|
+
# mean/std come from the same resolved data config get_transform uses, so the
|
|
265
|
+
# photometric pipeline matches pooled extraction even for encoders with custom
|
|
266
|
+
# normalization (e.g. H-optimus 0.7072.../0.2119...); verified per-encoder.
|
|
267
|
+
cfg = resolve_data_config(self._model.pretrained_cfg, model=self._model)
|
|
268
|
+
return v2.Compose([
|
|
269
|
+
v2.ToImage(),
|
|
270
|
+
v2.ToDtype(torch.float32, scale=True),
|
|
271
|
+
v2.Normalize(mean=cfg["mean"], std=cfg["std"]),
|
|
272
|
+
])
|
|
273
|
+
|
|
274
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
275
|
+
return self._model(batch)
|
|
276
|
+
|
|
277
|
+
def _dense_patch_size(self) -> tuple[int, int]:
|
|
278
|
+
"""Backbone patch size as ``(patch_h, patch_w)``."""
|
|
279
|
+
patch = self._model.patch_embed.patch_size
|
|
280
|
+
if isinstance(patch, int):
|
|
281
|
+
return patch, patch
|
|
282
|
+
patch_h, patch_w = patch
|
|
283
|
+
return int(patch_h), int(patch_w)
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def patch_size(self) -> tuple[int, int]:
|
|
287
|
+
return self._dense_patch_size()
|
|
288
|
+
|
|
289
|
+
def _dense_num_prefix_tokens(self) -> int:
|
|
290
|
+
"""Number of leading non-patch tokens (CLS + register tokens)."""
|
|
291
|
+
return int(self._model.num_prefix_tokens)
|
|
292
|
+
|
|
293
|
+
def encode_tiles_dense(self, batch: Tensor) -> Tensor:
|
|
294
|
+
"""Encode tiles into a dense spatial grid. (B, C, H, W) -> (B, d, h, w).
|
|
295
|
+
|
|
296
|
+
Runs the frozen backbone's ``forward_features`` and folds the patch-token
|
|
297
|
+
sequence back into its spatial grid (CLS/register tokens discarded). The
|
|
298
|
+
backbone must accept ``batch`` at its current spatial size (timm ViTs need
|
|
299
|
+
``dynamic_img_size=True`` for sizes other than their native input), and
|
|
300
|
+
``H, W`` must be divisible by the patch size.
|
|
301
|
+
"""
|
|
302
|
+
if batch.ndim != 4:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
"encode_tiles_dense expects a (B, C, H, W) batch, got shape "
|
|
305
|
+
f"{tuple(batch.shape)}."
|
|
306
|
+
)
|
|
307
|
+
_, _, height, width = batch.shape
|
|
308
|
+
patch_h, patch_w = self._dense_patch_size()
|
|
309
|
+
if height % patch_h != 0 or width % patch_w != 0:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
f"Dense extraction for '{type(self).__name__}' requires input "
|
|
312
|
+
f"divisible by the patch size: got {height}x{width}, patch "
|
|
313
|
+
f"{patch_h}x{patch_w}. Pad the tile up to a patch multiple first."
|
|
314
|
+
)
|
|
315
|
+
tokens = self._model.forward_features(batch)
|
|
316
|
+
return reshape_tokens_to_grid(
|
|
317
|
+
tokens,
|
|
318
|
+
grid_h=height // patch_h,
|
|
319
|
+
grid_w=width // patch_w,
|
|
320
|
+
num_prefix_tokens=self._dense_num_prefix_tokens(),
|
|
321
|
+
encoder_name=type(self).__name__,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def encode_dim(self) -> int:
|
|
326
|
+
return self._model.num_features
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def device(self) -> torch.device:
|
|
330
|
+
return self._device
|
|
331
|
+
|
|
332
|
+
def to(self, device: torch.device | str) -> "TimmTileEncoder":
|
|
333
|
+
self._device = torch.device(device)
|
|
334
|
+
self._model = self._model.to(self._device)
|
|
335
|
+
return self
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""CONCH and CONCH v1.5 encoder implementations.
|
|
2
|
+
|
|
3
|
+
CONCH requires the ``conch`` package (pip install conch).
|
|
4
|
+
CONCH v1.5 requires ``transformers`` and uses the TITAN model to extract
|
|
5
|
+
the CONCH v1.5 backbone.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from typing import Callable
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
from torchvision.transforms import v2
|
|
14
|
+
from transformers import AutoModel
|
|
15
|
+
|
|
16
|
+
from slide2vec.encoders.base import (
|
|
17
|
+
TileEncoder,
|
|
18
|
+
preferred_default_device,
|
|
19
|
+
reshape_tokens_to_grid,
|
|
20
|
+
resolve_requested_output_variant,
|
|
21
|
+
)
|
|
22
|
+
from slide2vec.encoders.registry import register_encoder
|
|
23
|
+
|
|
24
|
+
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
25
|
+
_IMAGENET_STD = (0.229, 0.224, 0.225)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _normalize_only_transform(
|
|
29
|
+
*,
|
|
30
|
+
mean: tuple[float, float, float],
|
|
31
|
+
std: tuple[float, float, float],
|
|
32
|
+
) -> Callable:
|
|
33
|
+
return v2.Compose([
|
|
34
|
+
v2.ToImage(),
|
|
35
|
+
v2.ToDtype(torch.float32, scale=True),
|
|
36
|
+
v2.Normalize(mean=mean, std=std),
|
|
37
|
+
])
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _patch_size_from_trunk(trunk) -> tuple[int, int]:
|
|
41
|
+
patch = trunk.patch_embed.patch_size
|
|
42
|
+
if isinstance(patch, int):
|
|
43
|
+
return patch, patch
|
|
44
|
+
patch_h, patch_w = patch
|
|
45
|
+
return int(patch_h), int(patch_w)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _encode_trunk_dense(*, trunk, batch: Tensor, encoder_name: str) -> Tensor:
|
|
49
|
+
if batch.ndim != 4:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"encode_tiles_dense expects a (B, C, H, W) batch, got shape "
|
|
52
|
+
f"{tuple(batch.shape)}."
|
|
53
|
+
)
|
|
54
|
+
_, _, height, width = batch.shape
|
|
55
|
+
patch_h, patch_w = _patch_size_from_trunk(trunk)
|
|
56
|
+
if height % patch_h != 0 or width % patch_w != 0:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Dense extraction for '{encoder_name}' requires input divisible by "
|
|
59
|
+
f"the patch size: got {height}x{width}, patch {patch_h}x{patch_w}. "
|
|
60
|
+
"Pad the tile up to a patch multiple first."
|
|
61
|
+
)
|
|
62
|
+
if hasattr(trunk, "forward_features"):
|
|
63
|
+
tokens = trunk.forward_features(batch)
|
|
64
|
+
else:
|
|
65
|
+
tokens = trunk(batch)
|
|
66
|
+
return reshape_tokens_to_grid(
|
|
67
|
+
tokens,
|
|
68
|
+
grid_h=height // patch_h,
|
|
69
|
+
grid_w=width // patch_w,
|
|
70
|
+
num_prefix_tokens=int(getattr(trunk, "num_prefix_tokens", 1)),
|
|
71
|
+
encoder_name=encoder_name,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@register_encoder(
|
|
76
|
+
"conch",
|
|
77
|
+
output_variants={"default": {"encode_dim": 512}},
|
|
78
|
+
default_output_variant="default",
|
|
79
|
+
input_size=448,
|
|
80
|
+
supported_spacing_um=0.5,
|
|
81
|
+
precision="fp32",
|
|
82
|
+
source="MahmoodLab/conch",
|
|
83
|
+
)
|
|
84
|
+
class CONCH(TileEncoder):
|
|
85
|
+
def __init__(self, *, output_variant: str | None = None):
|
|
86
|
+
from conch.open_clip_custom import create_model_from_pretrained
|
|
87
|
+
|
|
88
|
+
self._model, self._transform = create_model_from_pretrained(
|
|
89
|
+
"conch_ViT-B-16", "hf_hub:MahmoodLab/conch"
|
|
90
|
+
)
|
|
91
|
+
self._model.eval()
|
|
92
|
+
self._device = preferred_default_device()
|
|
93
|
+
self._output_variant = resolve_requested_output_variant(output_variant)
|
|
94
|
+
|
|
95
|
+
def get_transform(self) -> Callable:
|
|
96
|
+
return self._transform
|
|
97
|
+
|
|
98
|
+
def get_dense_transform(self) -> Callable:
|
|
99
|
+
try:
|
|
100
|
+
from conch.open_clip_custom.constants import (
|
|
101
|
+
OPENAI_DATASET_MEAN,
|
|
102
|
+
OPENAI_DATASET_STD,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
mean = tuple(float(v) for v in OPENAI_DATASET_MEAN)
|
|
106
|
+
std = tuple(float(v) for v in OPENAI_DATASET_STD)
|
|
107
|
+
except Exception:
|
|
108
|
+
mean, std = _IMAGENET_MEAN, _IMAGENET_STD
|
|
109
|
+
return _normalize_only_transform(mean=mean, std=std)
|
|
110
|
+
|
|
111
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
112
|
+
return self._model.encode_image(batch, proj_contrast=False, normalize=False)
|
|
113
|
+
|
|
114
|
+
def encode_tiles_dense(self, batch: Tensor) -> Tensor:
|
|
115
|
+
# Use the ViT trunk tokens directly. self._model.visual(...) returns
|
|
116
|
+
# attentional-pool tokens for captioning/contrast, not a spatial patch grid.
|
|
117
|
+
return _encode_trunk_dense(
|
|
118
|
+
trunk=self._model.visual.trunk,
|
|
119
|
+
batch=batch,
|
|
120
|
+
encoder_name=type(self).__name__,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def encode_dim(self) -> int:
|
|
125
|
+
return 512
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def device(self) -> torch.device:
|
|
129
|
+
return self._device
|
|
130
|
+
|
|
131
|
+
def to(self, device: torch.device | str) -> "CONCH":
|
|
132
|
+
self._device = torch.device(device)
|
|
133
|
+
self._model = self._model.to(self._device)
|
|
134
|
+
return self
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@register_encoder(
|
|
138
|
+
"conchv15",
|
|
139
|
+
output_variants={"default": {"encode_dim": 768}},
|
|
140
|
+
default_output_variant="default",
|
|
141
|
+
input_size=448,
|
|
142
|
+
supported_spacing_um=0.5,
|
|
143
|
+
precision="fp16",
|
|
144
|
+
source="MahmoodLab/TITAN",
|
|
145
|
+
)
|
|
146
|
+
class CONCHv15(TileEncoder):
|
|
147
|
+
def __init__(self, *, output_variant: str | None = None):
|
|
148
|
+
titan = AutoModel.from_pretrained("MahmoodLab/TITAN", trust_remote_code=True)
|
|
149
|
+
self._model, self._transform = titan.return_conch()
|
|
150
|
+
self._model.eval()
|
|
151
|
+
self._device = preferred_default_device()
|
|
152
|
+
self._output_variant = resolve_requested_output_variant(output_variant)
|
|
153
|
+
|
|
154
|
+
def get_transform(self) -> Callable:
|
|
155
|
+
return self._transform
|
|
156
|
+
|
|
157
|
+
def get_dense_transform(self) -> Callable:
|
|
158
|
+
return _normalize_only_transform(mean=_IMAGENET_MEAN, std=_IMAGENET_STD)
|
|
159
|
+
|
|
160
|
+
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
161
|
+
return self._model(batch)
|
|
162
|
+
|
|
163
|
+
def encode_tiles_dense(self, batch: Tensor) -> Tensor:
|
|
164
|
+
return _encode_trunk_dense(
|
|
165
|
+
trunk=self._model.trunk,
|
|
166
|
+
batch=batch,
|
|
167
|
+
encoder_name=type(self).__name__,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def encode_dim(self) -> int:
|
|
172
|
+
return 768
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def device(self) -> torch.device:
|
|
176
|
+
return self._device
|
|
177
|
+
|
|
178
|
+
def to(self, device: torch.device | str) -> "CONCHv15":
|
|
179
|
+
self._device = torch.device(device)
|
|
180
|
+
self._model = self._model.to(self._device)
|
|
181
|
+
return self
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
"""Prov-GigaPath encoder implementation."""
|
|
2
2
|
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
3
5
|
import torch
|
|
6
|
+
from torchvision.transforms import v2
|
|
4
7
|
|
|
5
8
|
from slide2vec.encoders.base import (
|
|
6
9
|
SlideEncoder,
|
|
@@ -10,6 +13,14 @@ from slide2vec.encoders.base import (
|
|
|
10
13
|
)
|
|
11
14
|
from slide2vec.encoders.registry import register_encoder
|
|
12
15
|
|
|
16
|
+
# Prov-GigaPath model card transform: resize the 256px tile to 256 (no-op),
|
|
17
|
+
# center-crop to the model's native 224, ImageNet normalization. timm's packaged
|
|
18
|
+
# pretrained_cfg reports crop_pct=1.0 -> get_transform would instead Resize(224),
|
|
19
|
+
# downscaling the whole tile to ~0.57 mpp; the paper feeds the center 224 at the
|
|
20
|
+
# native 0.5 mpp. https://www.nature.com/articles/s41586-024-07441-w
|
|
21
|
+
_GIGAPATH_MEAN = (0.485, 0.456, 0.406)
|
|
22
|
+
_GIGAPATH_STD = (0.229, 0.224, 0.225)
|
|
23
|
+
|
|
13
24
|
|
|
14
25
|
@register_encoder(
|
|
15
26
|
"gigapath",
|
|
@@ -25,8 +36,25 @@ class GigaPath(TimmTileEncoder):
|
|
|
25
36
|
super().__init__(
|
|
26
37
|
"hf_hub:prov-gigapath/prov-gigapath",
|
|
27
38
|
output_variant=output_variant,
|
|
39
|
+
dynamic_img_size=True,
|
|
28
40
|
)
|
|
29
41
|
|
|
42
|
+
def get_transform(self) -> Callable:
|
|
43
|
+
# POOLED transform only: center-crops the 256px tile to the model's 224px
|
|
44
|
+
# native input (paper recipe, center 224 @ native 0.5 mpp). Dense extraction
|
|
45
|
+
# must NOT route through this — it needs the full uncropped tile so the grid
|
|
46
|
+
# covers the whole source tile. The dense path supplies its own no-crop
|
|
47
|
+
# transform (Resize(256), no CenterCrop) → a 16x16 grid over the full tile;
|
|
48
|
+
# encode_tiles_dense itself is transform-agnostic (inherited from
|
|
49
|
+
# TimmTileEncoder) and operates on whatever batch the dense pipeline feeds.
|
|
50
|
+
return v2.Compose([
|
|
51
|
+
v2.ToImage(),
|
|
52
|
+
v2.Resize(256, interpolation=v2.InterpolationMode.BICUBIC, antialias=True),
|
|
53
|
+
v2.CenterCrop(224),
|
|
54
|
+
v2.ToDtype(torch.float32, scale=True),
|
|
55
|
+
v2.Normalize(mean=_GIGAPATH_MEAN, std=_GIGAPATH_STD),
|
|
56
|
+
])
|
|
57
|
+
|
|
30
58
|
|
|
31
59
|
@register_encoder(
|
|
32
60
|
"gigapath-slide",
|
|
@@ -10,7 +10,12 @@ from torch import Tensor
|
|
|
10
10
|
from torchvision.transforms import v2
|
|
11
11
|
from transformers import AutoModel
|
|
12
12
|
|
|
13
|
-
from slide2vec.encoders.base import
|
|
13
|
+
from slide2vec.encoders.base import (
|
|
14
|
+
TileEncoder,
|
|
15
|
+
preferred_default_device,
|
|
16
|
+
reshape_tokens_to_grid,
|
|
17
|
+
resolve_requested_output_variant,
|
|
18
|
+
)
|
|
14
19
|
from slide2vec.encoders.registry import register_encoder
|
|
15
20
|
|
|
16
21
|
_HIBOU_MEAN = (0.7068, 0.5755, 0.722)
|
|
@@ -40,10 +45,40 @@ class _HibouBase(TileEncoder):
|
|
|
40
45
|
def get_transform(self) -> Callable:
|
|
41
46
|
return _hibou_transform()
|
|
42
47
|
|
|
48
|
+
def get_dense_transform(self) -> Callable:
|
|
49
|
+
return v2.Compose([
|
|
50
|
+
v2.ToImage(),
|
|
51
|
+
v2.ToDtype(torch.float32, scale=True),
|
|
52
|
+
v2.Normalize(mean=_HIBOU_MEAN, std=_HIBOU_STD),
|
|
53
|
+
])
|
|
54
|
+
|
|
43
55
|
def encode_tiles(self, batch: Tensor) -> Tensor:
|
|
44
56
|
output = self._model(pixel_values=batch)
|
|
45
57
|
return output.pooler_output
|
|
46
58
|
|
|
59
|
+
def encode_tiles_dense(self, batch: Tensor) -> Tensor:
|
|
60
|
+
if batch.ndim != 4:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"encode_tiles_dense expects a (B, C, H, W) batch, got shape "
|
|
63
|
+
f"{tuple(batch.shape)}."
|
|
64
|
+
)
|
|
65
|
+
_, _, height, width = batch.shape
|
|
66
|
+
patch = int(self._model.config.patch_size)
|
|
67
|
+
if height % patch != 0 or width % patch != 0:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Dense extraction for '{type(self).__name__}' requires input "
|
|
70
|
+
f"divisible by the patch size: got {height}x{width}, patch "
|
|
71
|
+
f"{patch}. Pad the tile up to a patch multiple first."
|
|
72
|
+
)
|
|
73
|
+
output = self._model(pixel_values=batch)
|
|
74
|
+
return reshape_tokens_to_grid(
|
|
75
|
+
output.last_hidden_state,
|
|
76
|
+
grid_h=height // patch,
|
|
77
|
+
grid_w=width // patch,
|
|
78
|
+
num_prefix_tokens=1 + int(getattr(self._model.config, "num_register_tokens", 0)),
|
|
79
|
+
encoder_name=type(self).__name__,
|
|
80
|
+
)
|
|
81
|
+
|
|
47
82
|
@property
|
|
48
83
|
def encode_dim(self) -> int:
|
|
49
84
|
return self._encode_dim
|