slide2vec 4.5.3__tar.gz → 4.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. {slide2vec-4.5.3 → slide2vec-4.6.0}/PKG-INFO +1 -1
  2. {slide2vec-4.5.3 → slide2vec-4.6.0}/pyproject.toml +2 -2
  3. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/__init__.py +1 -1
  4. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/__init__.py +4 -0
  5. slide2vec-4.6.0/slide2vec/encoders/base.py +335 -0
  6. slide2vec-4.6.0/slide2vec/encoders/models/conch.py +181 -0
  7. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/gigapath.py +28 -0
  8. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/hibou.py +36 -1
  9. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/hoptimus.py +35 -5
  10. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/lunit.py +1 -0
  11. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/midnight.py +36 -1
  12. slide2vec-4.6.0/slide2vec/encoders/models/musk.py +124 -0
  13. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/phikon.py +53 -1
  14. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/prost40m.py +1 -0
  15. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/inference.py +15 -1
  16. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/PKG-INFO +1 -1
  17. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/SOURCES.txt +4 -1
  18. slide2vec-4.6.0/tests/test_dense_extraction.py +487 -0
  19. slide2vec-4.6.0/tests/test_dense_locality_gated.py +162 -0
  20. slide2vec-4.6.0/tests/test_tiling_pipeline.py +25 -0
  21. slide2vec-4.5.3/slide2vec/encoders/base.py +0 -161
  22. slide2vec-4.5.3/slide2vec/encoders/models/conch.py +0 -93
  23. slide2vec-4.5.3/slide2vec/encoders/models/musk.py +0 -69
  24. {slide2vec-4.5.3 → slide2vec-4.6.0}/LICENSE +0 -0
  25. {slide2vec-4.5.3 → slide2vec-4.6.0}/README.md +0 -0
  26. {slide2vec-4.5.3 → slide2vec-4.6.0}/setup.cfg +0 -0
  27. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/__main__.py +0 -0
  28. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/api.py +0 -0
  29. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/artifacts.py +0 -0
  30. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/cli.py +0 -0
  31. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/configs/__init__.py +0 -0
  32. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/configs/default.yaml +0 -0
  33. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/configs/resources.py +0 -0
  34. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/data/__init__.py +0 -0
  35. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/data/dataset.py +0 -0
  36. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/data/tile_reader.py +0 -0
  37. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/data/tile_store.py +0 -0
  38. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/distributed/__init__.py +0 -0
  39. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/distributed/direct_embed_worker.py +0 -0
  40. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/distributed/pipeline_worker.py +0 -0
  41. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/__init__.py +0 -0
  42. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/__init__.py +0 -0
  43. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/blocks.py +0 -0
  44. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/case.py +0 -0
  45. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/loading.py +0 -0
  46. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/slide.py +0 -0
  47. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/moozy/types.py +0 -0
  48. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/prism.py +0 -0
  49. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/titan.py +0 -0
  50. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/uni.py +0 -0
  51. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/models/virchow.py +0 -0
  52. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/registry.py +0 -0
  53. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/encoders/validation.py +0 -0
  54. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/progress.py +0 -0
  55. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/__init__.py +0 -0
  56. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/artifacts_collect.py +0 -0
  57. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/batching.py +0 -0
  58. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/cpu_budget.py +0 -0
  59. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/distributed.py +0 -0
  60. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/distributed_stage.py +0 -0
  61. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/embedding.py +0 -0
  62. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/embedding_persist.py +0 -0
  63. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/embedding_pipeline.py +0 -0
  64. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/hierarchical.py +0 -0
  65. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/manifest.py +0 -0
  66. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/model_settings.py +0 -0
  67. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/patient_pipeline.py +0 -0
  68. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/persist_callbacks.py +0 -0
  69. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/persistence.py +0 -0
  70. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/process_list.py +0 -0
  71. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/progress_bridge.py +0 -0
  72. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/registry.py +0 -0
  73. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/serialization.py +0 -0
  74. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/slide_encode.py +0 -0
  75. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/tiling.py +0 -0
  76. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/tiling_pipeline.py +0 -0
  77. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/types.py +0 -0
  78. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/runtime/worker_io.py +0 -0
  79. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/__init__.py +0 -0
  80. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/config.py +0 -0
  81. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/coordinates.py +0 -0
  82. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/log_utils.py +0 -0
  83. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/tiling_io.py +0 -0
  84. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec/utils/utils.py +0 -0
  85. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/dependency_links.txt +0 -0
  86. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/entry_points.txt +0 -0
  87. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/not-zip-safe +0 -0
  88. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/requires.txt +0 -0
  89. {slide2vec-4.5.3 → slide2vec-4.6.0}/slide2vec.egg-info/top_level.txt +0 -0
  90. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_architecture_runtime_split.py +0 -0
  91. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_encoder_registry.py +0 -0
  92. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_hs2p_package_cutover.py +0 -0
  93. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_output_consistency.py +0 -0
  94. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_progress.py +0 -0
  95. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_regression_core.py +0 -0
  96. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_regression_inference.py +0 -0
  97. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_regression_models.py +0 -0
  98. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_runtime_batching.py +0 -0
  99. {slide2vec-4.5.3 → slide2vec-4.6.0}/tests/test_tile_store.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: slide2vec
3
- Version: 4.5.3
3
+ Version: 4.6.0
4
4
  Summary: Embedding of whole slide images with Foundation Models
5
5
  Author-email: Clément Grisi <clement.grisi@radboudumc.nl>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "slide2vec"
7
- version = "4.5.3"
7
+ version = "4.6.0"
8
8
  description = "Embedding of whole slide images with Foundation Models"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -164,7 +164,7 @@ no_implicit_reexport = true
164
164
  max-line-length = 160
165
165
 
166
166
  [tool.bumpver]
167
- current_version = "4.5.3"
167
+ current_version = "4.6.0"
168
168
  version_pattern = "MAJOR.MINOR.PATCH"
169
169
  commit = false # We do version bumping in CI, not as a commit
170
170
  tag = false # Git tag already exists — we don't auto-tag
@@ -11,7 +11,7 @@ from slide2vec.api import (
11
11
  from slide2vec.artifacts import HierarchicalEmbeddingArtifact, SlideEmbeddingArtifact, TileEmbeddingArtifact
12
12
 
13
13
 
14
- __version__ = "4.5.3"
14
+ __version__ = "4.6.0"
15
15
 
16
16
  __all__ = [
17
17
  "Model",
@@ -10,6 +10,8 @@ from slide2vec.encoders.base import (
10
10
  SlideEncoder,
11
11
  TileEncoder,
12
12
  TimmTileEncoder,
13
+ reshape_tokens_to_grid,
14
+ resolve_recommended_dynamic_img_size,
13
15
  resolve_requested_output_variant,
14
16
  )
15
17
  from slide2vec.encoders.registry import (
@@ -29,6 +31,8 @@ __all__ = [
29
31
  "TileEncoder",
30
32
  "SlideEncoder",
31
33
  "TimmTileEncoder",
34
+ "reshape_tokens_to_grid",
35
+ "resolve_recommended_dynamic_img_size",
32
36
  "resolve_requested_output_variant",
33
37
  "encoder_registry",
34
38
  "register_encoder",
@@ -0,0 +1,335 @@
1
+ """Encoder abstractions for tile-level and slide-level feature extraction."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Callable
5
+
6
+ import timm
7
+ import torch
8
+ from timm.data import create_transform, resolve_data_config
9
+ from torch import Tensor
10
+ from torchvision.transforms import v2
11
+
12
+
13
+ def preferred_default_device() -> torch.device:
14
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+
17
+ def resolve_requested_output_variant(
18
+ output_variant: str | None,
19
+ *,
20
+ default: str = "default",
21
+ allowed: tuple[str, ...] = ("default",),
22
+ ) -> str:
23
+ """Normalize and validate a requested encoder output variant."""
24
+ resolved = output_variant or default
25
+ if resolved not in allowed:
26
+ available = ", ".join(allowed)
27
+ raise ValueError(
28
+ f"Unsupported output_variant '{resolved}'. Available: {available}"
29
+ )
30
+ return resolved
31
+
32
+
33
+ def resolve_recommended_dynamic_img_size(
34
+ *,
35
+ requested: bool | None,
36
+ recommended: bool,
37
+ allow_non_recommended: bool,
38
+ encoder_name: str,
39
+ ) -> bool:
40
+ """Resolve ``dynamic_img_size`` against an encoder's card-recommended value.
41
+
42
+ ``None`` uses the recommended value. A value that differs from the
43
+ recommendation requires ``allow_non_recommended_settings=True`` (e.g. dense
44
+ feature extraction deliberately enabling variable input size, justified by
45
+ the registration / native-size-no-op tests); otherwise it raises, so a
46
+ pipeline never silently runs an encoder outside its documented config.
47
+ """
48
+ if requested is None:
49
+ return recommended
50
+ if requested != recommended and not allow_non_recommended:
51
+ raise ValueError(
52
+ f"Encoder '{encoder_name}' recommends dynamic_img_size={recommended} "
53
+ f"(per its model card); got dynamic_img_size={requested}, which deviates "
54
+ "from the recommended setting. Pass allow_non_recommended_settings=True "
55
+ "to override it deliberately (e.g. dense extraction needs variable input "
56
+ "size; this is a native-size no-op, verified in the encoder tests)."
57
+ )
58
+ return requested
59
+
60
+
61
+ def reshape_tokens_to_grid(
62
+ tokens: Tensor,
63
+ *,
64
+ grid_h: int,
65
+ grid_w: int,
66
+ num_prefix_tokens: int,
67
+ encoder_name: str,
68
+ ) -> Tensor:
69
+ """Fold a ViT ``(B, T, d)`` token sequence into a dense ``(B, d, h, w)`` grid.
70
+
71
+ Strips the leading ``num_prefix_tokens`` (CLS + register tokens) and reshapes
72
+ the remaining patch tokens back into their row-major spatial grid. ViT patch
73
+ tokens are emitted in row-major order ``[(0,0), (0,1), ..., (h-1, w-1)]`` after
74
+ the prefix tokens, so ``transpose(1, 2).reshape(B, d, h, w)`` recovers the
75
+ spatial layout. Verified bit-for-bit against timm's
76
+ ``get_intermediate_layers(..., reshape=True)`` in the encoder tests.
77
+
78
+ Fails loudly if the post-strip token count does not match ``grid_h * grid_w``:
79
+ a silent reshape would train a decoder on spatially corrupted features, which
80
+ is worse than a hard failure.
81
+ """
82
+ if tokens.ndim != 3:
83
+ raise ValueError(
84
+ f"Dense extraction for '{encoder_name}' expected a (B, T, d) token "
85
+ f"sequence from the backbone, got shape {tuple(tokens.shape)}. This "
86
+ "encoder may not expose a recoverable patch-token grid."
87
+ )
88
+ patch_tokens = tokens[:, num_prefix_tokens:, :]
89
+ batch_size, num_tokens, dim = patch_tokens.shape
90
+ expected = grid_h * grid_w
91
+ if num_tokens != expected:
92
+ raise ValueError(
93
+ f"Dense token accounting mismatch for '{encoder_name}': backbone "
94
+ f"returned {tokens.shape[1]} tokens; after stripping "
95
+ f"{num_prefix_tokens} prefix token(s), {num_tokens} remain, but the "
96
+ f"{grid_h}x{grid_w} grid expects {expected}. Check the prefix-token "
97
+ "count and the input-size / patch-size / grid geometry."
98
+ )
99
+ return patch_tokens.transpose(1, 2).reshape(batch_size, dim, grid_h, grid_w)
100
+
101
+
102
+ class Encoder(ABC):
103
+ """Shared lifecycle contract for all encoders."""
104
+
105
+ @property
106
+ @abstractmethod
107
+ def encode_dim(self) -> int:
108
+ """Dimensionality of the output feature vector."""
109
+ ...
110
+
111
+ @property
112
+ @abstractmethod
113
+ def device(self) -> torch.device:
114
+ """Current device of the encoder."""
115
+ ...
116
+
117
+ @abstractmethod
118
+ def to(self, device: torch.device | str) -> "Encoder":
119
+ """Move encoder to the given device. Returns self."""
120
+ ...
121
+
122
+
123
+ class TileEncoder(Encoder):
124
+ """Base class for encoders that operate directly on image tiles."""
125
+
126
+ @abstractmethod
127
+ def get_transform(self) -> Callable:
128
+ """Image transform pipeline (PIL Image or ndarray -> Tensor)."""
129
+ ...
130
+
131
+ @abstractmethod
132
+ def encode_tiles(self, batch: Tensor) -> Tensor:
133
+ """Encode a batch of tiles. (B, C, H, W) -> (B, D)."""
134
+ ...
135
+
136
+ def encode_tiles_dense(self, batch: Tensor) -> Tensor:
137
+ """Encode tiles into a dense spatial feature grid. (B, C, H, W) -> (B, d, h, w).
138
+
139
+ Default: unsupported. ViT tile encoders with a recoverable patch grid
140
+ override this; vision-language / slide-native encoders (no usable patch
141
+ grid) do not. ``d`` is the per-token feature dim and ``h, w`` the token
142
+ grid (``H / patch``, ``W / patch``).
143
+ """
144
+ raise NotImplementedError(
145
+ f"{type(self).__name__} does not support dense (spatial-grid) feature "
146
+ "extraction. Dense extraction requires a ViT tile encoder whose patch "
147
+ "tokens can be reshaped into a spatial grid."
148
+ )
149
+
150
+ @property
151
+ def patch_size(self) -> tuple[int, int]:
152
+ """Backbone patch size ``(patch_h, patch_w)`` — only for dense encoders.
153
+
154
+ Encoder-authoritative (a property of the frozen model), used to resolve the
155
+ dense token grid. Default: unsupported, mirroring ``encode_tiles_dense``.
156
+ """
157
+ raise NotImplementedError(
158
+ f"{type(self).__name__} does not expose a patch size (only dense-capable "
159
+ "ViT tile encoders do)."
160
+ )
161
+
162
+ def get_dense_transform(self) -> Callable:
163
+ """Photometric (normalization-only) transform for dense extraction.
164
+
165
+ Returns a transform that applies ONLY this encoder's normalization
166
+ (per-channel mean/std) — **no Resize, no CenterCrop** — so the dense
167
+ feature grid covers the *full* source tile and stays spatially registered
168
+ to it. This deliberately differs from ``get_transform`` (the pooled recipe):
169
+ some encoders resize-then-center-crop there (GigaPath ``Resize(256) ->
170
+ CenterCrop(224)``; Lunit ``crop_pct=0.9 -> Resize(248) -> CenterCrop(224)``),
171
+ which drops the tile margins and would misregister the grid against a dense
172
+ target mask. Geometry (padding to a patch multiple, optional resize,
173
+ cropping logits back) is the dense pipeline's responsibility, not the
174
+ encoder's. Default: unsupported, mirroring ``encode_tiles_dense``.
175
+ """
176
+ raise NotImplementedError(
177
+ f"{type(self).__name__} does not provide a dense transform. Only "
178
+ "encoders that support dense (spatial-grid) extraction define one."
179
+ )
180
+
181
+
182
+ class SlideEncoder(Encoder):
183
+ """Base class for encoders that pool tile features into slide features."""
184
+
185
+ tile_encoder: TileEncoder | None = None
186
+
187
+ def encode_tiles(self, batch: Tensor) -> Tensor:
188
+ if self.tile_encoder is None:
189
+ raise AttributeError("slide encoders must attach a tile_encoder before encoding tiles")
190
+ return self.tile_encoder.encode_tiles(batch)
191
+
192
+ @abstractmethod
193
+ def encode_slide(
194
+ self,
195
+ tile_features: Tensor,
196
+ coordinates: Tensor | None = None,
197
+ *,
198
+ tile_size_lv0: int | None = None,
199
+ ) -> Tensor:
200
+ """Pool tile-level features into a single slide-level embedding."""
201
+ ...
202
+
203
+ def prepare_coordinates(
204
+ self,
205
+ coordinates: Tensor,
206
+ *,
207
+ base_spacing_um: float,
208
+ requested_spacing_um: float,
209
+ ) -> Tensor:
210
+ """Hook for model-specific coordinate normalization."""
211
+ return coordinates
212
+
213
+
214
+ class PatientEncoder(Encoder):
215
+ """Base class for encoders that aggregate slide embeddings into patient embeddings."""
216
+
217
+ tile_encoder: TileEncoder | None = None
218
+
219
+ def encode_tiles(self, batch: Tensor) -> Tensor:
220
+ if self.tile_encoder is None:
221
+ raise AttributeError("patient encoders must attach a tile_encoder before encoding tiles")
222
+ return self.tile_encoder.encode_tiles(batch)
223
+
224
+ @abstractmethod
225
+ def encode_slide(
226
+ self,
227
+ tile_features: Tensor,
228
+ coordinates: Tensor | None = None,
229
+ *,
230
+ tile_size_lv0: int | None = None,
231
+ ) -> Tensor:
232
+ """Pool tile-level features into a single slide-level embedding."""
233
+ ...
234
+
235
+ @abstractmethod
236
+ def encode_patient(self, slide_embeddings: Tensor) -> Tensor:
237
+ """Aggregate slide embeddings [S, D] into a single patient-level embedding [D]."""
238
+ ...
239
+
240
+
241
+ class TimmTileEncoder(TileEncoder):
242
+ """Convenience base for timm-backed tile encoders."""
243
+
244
+ def __init__(
245
+ self,
246
+ model_name: str,
247
+ *,
248
+ output_variant: str | None = None,
249
+ **timm_kwargs,
250
+ ):
251
+ defaults = {"pretrained": True, "num_classes": 0}
252
+ defaults.update(timm_kwargs)
253
+ self._model = timm.create_model(model_name, **defaults).eval()
254
+ self._device = preferred_default_device()
255
+ if not hasattr(self, "_output_variant"):
256
+ self._output_variant = resolve_requested_output_variant(output_variant)
257
+
258
+ def get_transform(self) -> Callable:
259
+ data_config = resolve_data_config(self._model.pretrained_cfg, model=self._model)
260
+ return create_transform(**data_config)
261
+
262
+ def get_dense_transform(self) -> Callable:
263
+ # Normalization only — no Resize/CenterCrop (see TileEncoder.get_dense_transform).
264
+ # mean/std come from the same resolved data config get_transform uses, so the
265
+ # photometric pipeline matches pooled extraction even for encoders with custom
266
+ # normalization (e.g. H-optimus 0.7072.../0.2119...); verified per-encoder.
267
+ cfg = resolve_data_config(self._model.pretrained_cfg, model=self._model)
268
+ return v2.Compose([
269
+ v2.ToImage(),
270
+ v2.ToDtype(torch.float32, scale=True),
271
+ v2.Normalize(mean=cfg["mean"], std=cfg["std"]),
272
+ ])
273
+
274
+ def encode_tiles(self, batch: Tensor) -> Tensor:
275
+ return self._model(batch)
276
+
277
+ def _dense_patch_size(self) -> tuple[int, int]:
278
+ """Backbone patch size as ``(patch_h, patch_w)``."""
279
+ patch = self._model.patch_embed.patch_size
280
+ if isinstance(patch, int):
281
+ return patch, patch
282
+ patch_h, patch_w = patch
283
+ return int(patch_h), int(patch_w)
284
+
285
+ @property
286
+ def patch_size(self) -> tuple[int, int]:
287
+ return self._dense_patch_size()
288
+
289
+ def _dense_num_prefix_tokens(self) -> int:
290
+ """Number of leading non-patch tokens (CLS + register tokens)."""
291
+ return int(self._model.num_prefix_tokens)
292
+
293
+ def encode_tiles_dense(self, batch: Tensor) -> Tensor:
294
+ """Encode tiles into a dense spatial grid. (B, C, H, W) -> (B, d, h, w).
295
+
296
+ Runs the frozen backbone's ``forward_features`` and folds the patch-token
297
+ sequence back into its spatial grid (CLS/register tokens discarded). The
298
+ backbone must accept ``batch`` at its current spatial size (timm ViTs need
299
+ ``dynamic_img_size=True`` for sizes other than their native input), and
300
+ ``H, W`` must be divisible by the patch size.
301
+ """
302
+ if batch.ndim != 4:
303
+ raise ValueError(
304
+ "encode_tiles_dense expects a (B, C, H, W) batch, got shape "
305
+ f"{tuple(batch.shape)}."
306
+ )
307
+ _, _, height, width = batch.shape
308
+ patch_h, patch_w = self._dense_patch_size()
309
+ if height % patch_h != 0 or width % patch_w != 0:
310
+ raise ValueError(
311
+ f"Dense extraction for '{type(self).__name__}' requires input "
312
+ f"divisible by the patch size: got {height}x{width}, patch "
313
+ f"{patch_h}x{patch_w}. Pad the tile up to a patch multiple first."
314
+ )
315
+ tokens = self._model.forward_features(batch)
316
+ return reshape_tokens_to_grid(
317
+ tokens,
318
+ grid_h=height // patch_h,
319
+ grid_w=width // patch_w,
320
+ num_prefix_tokens=self._dense_num_prefix_tokens(),
321
+ encoder_name=type(self).__name__,
322
+ )
323
+
324
+ @property
325
+ def encode_dim(self) -> int:
326
+ return self._model.num_features
327
+
328
+ @property
329
+ def device(self) -> torch.device:
330
+ return self._device
331
+
332
+ def to(self, device: torch.device | str) -> "TimmTileEncoder":
333
+ self._device = torch.device(device)
334
+ self._model = self._model.to(self._device)
335
+ return self
@@ -0,0 +1,181 @@
1
+ """CONCH and CONCH v1.5 encoder implementations.
2
+
3
+ CONCH requires the ``conch`` package (pip install conch).
4
+ CONCH v1.5 requires ``transformers`` and uses the TITAN model to extract
5
+ the CONCH v1.5 backbone.
6
+ """
7
+
8
+
9
+ from typing import Callable
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torchvision.transforms import v2
14
+ from transformers import AutoModel
15
+
16
+ from slide2vec.encoders.base import (
17
+ TileEncoder,
18
+ preferred_default_device,
19
+ reshape_tokens_to_grid,
20
+ resolve_requested_output_variant,
21
+ )
22
+ from slide2vec.encoders.registry import register_encoder
23
+
24
+ _IMAGENET_MEAN = (0.485, 0.456, 0.406)
25
+ _IMAGENET_STD = (0.229, 0.224, 0.225)
26
+
27
+
28
+ def _normalize_only_transform(
29
+ *,
30
+ mean: tuple[float, float, float],
31
+ std: tuple[float, float, float],
32
+ ) -> Callable:
33
+ return v2.Compose([
34
+ v2.ToImage(),
35
+ v2.ToDtype(torch.float32, scale=True),
36
+ v2.Normalize(mean=mean, std=std),
37
+ ])
38
+
39
+
40
+ def _patch_size_from_trunk(trunk) -> tuple[int, int]:
41
+ patch = trunk.patch_embed.patch_size
42
+ if isinstance(patch, int):
43
+ return patch, patch
44
+ patch_h, patch_w = patch
45
+ return int(patch_h), int(patch_w)
46
+
47
+
48
+ def _encode_trunk_dense(*, trunk, batch: Tensor, encoder_name: str) -> Tensor:
49
+ if batch.ndim != 4:
50
+ raise ValueError(
51
+ "encode_tiles_dense expects a (B, C, H, W) batch, got shape "
52
+ f"{tuple(batch.shape)}."
53
+ )
54
+ _, _, height, width = batch.shape
55
+ patch_h, patch_w = _patch_size_from_trunk(trunk)
56
+ if height % patch_h != 0 or width % patch_w != 0:
57
+ raise ValueError(
58
+ f"Dense extraction for '{encoder_name}' requires input divisible by "
59
+ f"the patch size: got {height}x{width}, patch {patch_h}x{patch_w}. "
60
+ "Pad the tile up to a patch multiple first."
61
+ )
62
+ if hasattr(trunk, "forward_features"):
63
+ tokens = trunk.forward_features(batch)
64
+ else:
65
+ tokens = trunk(batch)
66
+ return reshape_tokens_to_grid(
67
+ tokens,
68
+ grid_h=height // patch_h,
69
+ grid_w=width // patch_w,
70
+ num_prefix_tokens=int(getattr(trunk, "num_prefix_tokens", 1)),
71
+ encoder_name=encoder_name,
72
+ )
73
+
74
+
75
+ @register_encoder(
76
+ "conch",
77
+ output_variants={"default": {"encode_dim": 512}},
78
+ default_output_variant="default",
79
+ input_size=448,
80
+ supported_spacing_um=0.5,
81
+ precision="fp32",
82
+ source="MahmoodLab/conch",
83
+ )
84
+ class CONCH(TileEncoder):
85
+ def __init__(self, *, output_variant: str | None = None):
86
+ from conch.open_clip_custom import create_model_from_pretrained
87
+
88
+ self._model, self._transform = create_model_from_pretrained(
89
+ "conch_ViT-B-16", "hf_hub:MahmoodLab/conch"
90
+ )
91
+ self._model.eval()
92
+ self._device = preferred_default_device()
93
+ self._output_variant = resolve_requested_output_variant(output_variant)
94
+
95
+ def get_transform(self) -> Callable:
96
+ return self._transform
97
+
98
+ def get_dense_transform(self) -> Callable:
99
+ try:
100
+ from conch.open_clip_custom.constants import (
101
+ OPENAI_DATASET_MEAN,
102
+ OPENAI_DATASET_STD,
103
+ )
104
+
105
+ mean = tuple(float(v) for v in OPENAI_DATASET_MEAN)
106
+ std = tuple(float(v) for v in OPENAI_DATASET_STD)
107
+ except Exception:
108
+ mean, std = _IMAGENET_MEAN, _IMAGENET_STD
109
+ return _normalize_only_transform(mean=mean, std=std)
110
+
111
+ def encode_tiles(self, batch: Tensor) -> Tensor:
112
+ return self._model.encode_image(batch, proj_contrast=False, normalize=False)
113
+
114
+ def encode_tiles_dense(self, batch: Tensor) -> Tensor:
115
+ # Use the ViT trunk tokens directly. self._model.visual(...) returns
116
+ # attentional-pool tokens for captioning/contrast, not a spatial patch grid.
117
+ return _encode_trunk_dense(
118
+ trunk=self._model.visual.trunk,
119
+ batch=batch,
120
+ encoder_name=type(self).__name__,
121
+ )
122
+
123
+ @property
124
+ def encode_dim(self) -> int:
125
+ return 512
126
+
127
+ @property
128
+ def device(self) -> torch.device:
129
+ return self._device
130
+
131
+ def to(self, device: torch.device | str) -> "CONCH":
132
+ self._device = torch.device(device)
133
+ self._model = self._model.to(self._device)
134
+ return self
135
+
136
+
137
+ @register_encoder(
138
+ "conchv15",
139
+ output_variants={"default": {"encode_dim": 768}},
140
+ default_output_variant="default",
141
+ input_size=448,
142
+ supported_spacing_um=0.5,
143
+ precision="fp16",
144
+ source="MahmoodLab/TITAN",
145
+ )
146
+ class CONCHv15(TileEncoder):
147
+ def __init__(self, *, output_variant: str | None = None):
148
+ titan = AutoModel.from_pretrained("MahmoodLab/TITAN", trust_remote_code=True)
149
+ self._model, self._transform = titan.return_conch()
150
+ self._model.eval()
151
+ self._device = preferred_default_device()
152
+ self._output_variant = resolve_requested_output_variant(output_variant)
153
+
154
+ def get_transform(self) -> Callable:
155
+ return self._transform
156
+
157
+ def get_dense_transform(self) -> Callable:
158
+ return _normalize_only_transform(mean=_IMAGENET_MEAN, std=_IMAGENET_STD)
159
+
160
+ def encode_tiles(self, batch: Tensor) -> Tensor:
161
+ return self._model(batch)
162
+
163
+ def encode_tiles_dense(self, batch: Tensor) -> Tensor:
164
+ return _encode_trunk_dense(
165
+ trunk=self._model.trunk,
166
+ batch=batch,
167
+ encoder_name=type(self).__name__,
168
+ )
169
+
170
+ @property
171
+ def encode_dim(self) -> int:
172
+ return 768
173
+
174
+ @property
175
+ def device(self) -> torch.device:
176
+ return self._device
177
+
178
+ def to(self, device: torch.device | str) -> "CONCHv15":
179
+ self._device = torch.device(device)
180
+ self._model = self._model.to(self._device)
181
+ return self
@@ -1,6 +1,9 @@
1
1
  """Prov-GigaPath encoder implementation."""
2
2
 
3
+ from typing import Callable
4
+
3
5
  import torch
6
+ from torchvision.transforms import v2
4
7
 
5
8
  from slide2vec.encoders.base import (
6
9
  SlideEncoder,
@@ -10,6 +13,14 @@ from slide2vec.encoders.base import (
10
13
  )
11
14
  from slide2vec.encoders.registry import register_encoder
12
15
 
16
+ # Prov-GigaPath model card transform: resize the 256px tile to 256 (no-op),
17
+ # center-crop to the model's native 224, ImageNet normalization. timm's packaged
18
+ # pretrained_cfg reports crop_pct=1.0 -> get_transform would instead Resize(224),
19
+ # downscaling the whole tile to ~0.57 mpp; the paper feeds the center 224 at the
20
+ # native 0.5 mpp. https://www.nature.com/articles/s41586-024-07441-w
21
+ _GIGAPATH_MEAN = (0.485, 0.456, 0.406)
22
+ _GIGAPATH_STD = (0.229, 0.224, 0.225)
23
+
13
24
 
14
25
  @register_encoder(
15
26
  "gigapath",
@@ -25,8 +36,25 @@ class GigaPath(TimmTileEncoder):
25
36
  super().__init__(
26
37
  "hf_hub:prov-gigapath/prov-gigapath",
27
38
  output_variant=output_variant,
39
+ dynamic_img_size=True,
28
40
  )
29
41
 
42
+ def get_transform(self) -> Callable:
43
+ # POOLED transform only: center-crops the 256px tile to the model's 224px
44
+ # native input (paper recipe, center 224 @ native 0.5 mpp). Dense extraction
45
+ # must NOT route through this — it needs the full uncropped tile so the grid
46
+ # covers the whole source tile. The dense path supplies its own no-crop
47
+ # transform (Resize(256), no CenterCrop) → a 16x16 grid over the full tile;
48
+ # encode_tiles_dense itself is transform-agnostic (inherited from
49
+ # TimmTileEncoder) and operates on whatever batch the dense pipeline feeds.
50
+ return v2.Compose([
51
+ v2.ToImage(),
52
+ v2.Resize(256, interpolation=v2.InterpolationMode.BICUBIC, antialias=True),
53
+ v2.CenterCrop(224),
54
+ v2.ToDtype(torch.float32, scale=True),
55
+ v2.Normalize(mean=_GIGAPATH_MEAN, std=_GIGAPATH_STD),
56
+ ])
57
+
30
58
 
31
59
  @register_encoder(
32
60
  "gigapath-slide",
@@ -10,7 +10,12 @@ from torch import Tensor
10
10
  from torchvision.transforms import v2
11
11
  from transformers import AutoModel
12
12
 
13
- from slide2vec.encoders.base import TileEncoder, preferred_default_device, resolve_requested_output_variant
13
+ from slide2vec.encoders.base import (
14
+ TileEncoder,
15
+ preferred_default_device,
16
+ reshape_tokens_to_grid,
17
+ resolve_requested_output_variant,
18
+ )
14
19
  from slide2vec.encoders.registry import register_encoder
15
20
 
16
21
  _HIBOU_MEAN = (0.7068, 0.5755, 0.722)
@@ -40,10 +45,40 @@ class _HibouBase(TileEncoder):
40
45
  def get_transform(self) -> Callable:
41
46
  return _hibou_transform()
42
47
 
48
+ def get_dense_transform(self) -> Callable:
49
+ return v2.Compose([
50
+ v2.ToImage(),
51
+ v2.ToDtype(torch.float32, scale=True),
52
+ v2.Normalize(mean=_HIBOU_MEAN, std=_HIBOU_STD),
53
+ ])
54
+
43
55
  def encode_tiles(self, batch: Tensor) -> Tensor:
44
56
  output = self._model(pixel_values=batch)
45
57
  return output.pooler_output
46
58
 
59
+ def encode_tiles_dense(self, batch: Tensor) -> Tensor:
60
+ if batch.ndim != 4:
61
+ raise ValueError(
62
+ "encode_tiles_dense expects a (B, C, H, W) batch, got shape "
63
+ f"{tuple(batch.shape)}."
64
+ )
65
+ _, _, height, width = batch.shape
66
+ patch = int(self._model.config.patch_size)
67
+ if height % patch != 0 or width % patch != 0:
68
+ raise ValueError(
69
+ f"Dense extraction for '{type(self).__name__}' requires input "
70
+ f"divisible by the patch size: got {height}x{width}, patch "
71
+ f"{patch}. Pad the tile up to a patch multiple first."
72
+ )
73
+ output = self._model(pixel_values=batch)
74
+ return reshape_tokens_to_grid(
75
+ output.last_hidden_state,
76
+ grid_h=height // patch,
77
+ grid_w=width // patch,
78
+ num_prefix_tokens=1 + int(getattr(self._model.config, "num_register_tokens", 0)),
79
+ encoder_name=type(self).__name__,
80
+ )
81
+
47
82
  @property
48
83
  def encode_dim(self) -> int:
49
84
  return self._encode_dim