kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Dataset classes for whole-slide images."""
|
|
2
|
+
|
|
3
|
+
import bisect
|
|
4
|
+
import os
|
|
5
|
+
from typing import Callable, List
|
|
6
|
+
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from torch.utils.data import dataset as torch_datasets
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.transforms.v2 import functional
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data import wsi
|
|
14
|
+
from eva.vision.data.datasets import vision
|
|
15
|
+
from eva.vision.data.wsi.patching import samplers
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class WsiDataset(vision.VisionDataset):
|
|
19
|
+
"""Dataset class for reading patches from whole-slide images."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
file_path: str,
|
|
24
|
+
width: int,
|
|
25
|
+
height: int,
|
|
26
|
+
sampler: samplers.Sampler,
|
|
27
|
+
target_mpp: float,
|
|
28
|
+
overwrite_mpp: float | None = None,
|
|
29
|
+
backend: str = "openslide",
|
|
30
|
+
image_transforms: Callable | None = None,
|
|
31
|
+
):
|
|
32
|
+
"""Initializes a new dataset instance.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
file_path: Path to the whole-slide image file.
|
|
36
|
+
width: Width of the patches to be extracted, in pixels.
|
|
37
|
+
height: Height of the patches to be extracted, in pixels.
|
|
38
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
39
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
40
|
+
overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
|
|
41
|
+
backend: The backend to use for reading the whole-slide images.
|
|
42
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__()
|
|
45
|
+
|
|
46
|
+
self._file_path = file_path
|
|
47
|
+
self._width = width
|
|
48
|
+
self._height = height
|
|
49
|
+
self._sampler = sampler
|
|
50
|
+
self._target_mpp = target_mpp
|
|
51
|
+
self._overwrite_mpp = overwrite_mpp
|
|
52
|
+
self._backend = backend
|
|
53
|
+
self._image_transforms = image_transforms
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def __len__(self):
|
|
57
|
+
return len(self._coords.x_y)
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def filename(self, index: int) -> str:
|
|
61
|
+
return f"{self._file_path}_{index}"
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def _wsi(self) -> wsi.Wsi:
|
|
65
|
+
return wsi.get_cached_wsi(self._file_path, self._backend, self._overwrite_mpp)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def _coords(self) -> wsi.PatchCoordinates:
|
|
69
|
+
return wsi.get_cached_coords(
|
|
70
|
+
file_path=self._file_path,
|
|
71
|
+
width=self._width,
|
|
72
|
+
height=self._height,
|
|
73
|
+
target_mpp=self._target_mpp,
|
|
74
|
+
overwrite_mpp=self._overwrite_mpp,
|
|
75
|
+
sampler=self._sampler,
|
|
76
|
+
backend=self._backend,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def __getitem__(self, index: int) -> tv_tensors.Image:
|
|
81
|
+
x, y = self._coords.x_y[index]
|
|
82
|
+
width, height, level_idx = self._coords.width, self._coords.height, self._coords.level_idx
|
|
83
|
+
patch = self._wsi.read_region((x, y), level_idx, (width, height))
|
|
84
|
+
patch = functional.to_image(patch)
|
|
85
|
+
patch = self._apply_transforms(patch)
|
|
86
|
+
return patch
|
|
87
|
+
|
|
88
|
+
def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
|
|
89
|
+
if self._image_transforms is not None:
|
|
90
|
+
image = self._image_transforms(image)
|
|
91
|
+
return image
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class MultiWsiDataset(vision.VisionDataset):
|
|
95
|
+
"""Dataset class for reading patches from multiple whole-slide images."""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
root: str,
|
|
100
|
+
file_paths: List[str],
|
|
101
|
+
width: int,
|
|
102
|
+
height: int,
|
|
103
|
+
sampler: samplers.Sampler,
|
|
104
|
+
target_mpp: float,
|
|
105
|
+
overwrite_mpp: float | None = None,
|
|
106
|
+
backend: str = "openslide",
|
|
107
|
+
image_transforms: Callable | None = None,
|
|
108
|
+
):
|
|
109
|
+
"""Initializes a new dataset instance.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
root: Root directory of the dataset.
|
|
113
|
+
file_paths: List of paths to the whole-slide image files, relative to the root.
|
|
114
|
+
width: Width of the patches to be extracted, in pixels.
|
|
115
|
+
height: Height of the patches to be extracted, in pixels.
|
|
116
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
117
|
+
overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
|
|
118
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
119
|
+
backend: The backend to use for reading the whole-slide images.
|
|
120
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
121
|
+
"""
|
|
122
|
+
super().__init__()
|
|
123
|
+
|
|
124
|
+
self._root = root
|
|
125
|
+
self._file_paths = file_paths
|
|
126
|
+
self._width = width
|
|
127
|
+
self._height = height
|
|
128
|
+
self._target_mpp = target_mpp
|
|
129
|
+
self._overwrite_mpp = overwrite_mpp
|
|
130
|
+
self._sampler = sampler
|
|
131
|
+
self._backend = backend
|
|
132
|
+
self._image_transforms = image_transforms
|
|
133
|
+
|
|
134
|
+
self._concat_dataset: torch_datasets.ConcatDataset
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def datasets(self) -> List[WsiDataset]:
|
|
138
|
+
"""Returns the list of WSI datasets."""
|
|
139
|
+
return self._concat_dataset.datasets # type: ignore
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def cumulative_sizes(self) -> List[int]:
|
|
143
|
+
"""Returns the cumulative sizes of the WSI datasets."""
|
|
144
|
+
return self._concat_dataset.cumulative_sizes
|
|
145
|
+
|
|
146
|
+
@override
|
|
147
|
+
def configure(self) -> None:
|
|
148
|
+
self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
|
|
149
|
+
|
|
150
|
+
@override
|
|
151
|
+
def __len__(self) -> int:
|
|
152
|
+
return len(self._concat_dataset)
|
|
153
|
+
|
|
154
|
+
@override
|
|
155
|
+
def __getitem__(self, index: int) -> tv_tensors.Image:
|
|
156
|
+
return self._concat_dataset[index]
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def filename(self, index: int) -> str:
|
|
160
|
+
return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
|
|
161
|
+
|
|
162
|
+
def _load_datasets(self) -> list[WsiDataset]:
|
|
163
|
+
logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
|
|
164
|
+
wsi_datasets = []
|
|
165
|
+
for file_path in self._file_paths:
|
|
166
|
+
file_path = (
|
|
167
|
+
os.path.join(self._root, file_path) if self._root not in file_path else file_path
|
|
168
|
+
)
|
|
169
|
+
if not os.path.exists(file_path):
|
|
170
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
171
|
+
|
|
172
|
+
wsi_datasets.append(
|
|
173
|
+
WsiDataset(
|
|
174
|
+
file_path=file_path,
|
|
175
|
+
width=self._width,
|
|
176
|
+
height=self._height,
|
|
177
|
+
sampler=self._sampler,
|
|
178
|
+
target_mpp=self._target_mpp,
|
|
179
|
+
overwrite_mpp=self._overwrite_mpp,
|
|
180
|
+
backend=self._backend,
|
|
181
|
+
image_transforms=self._image_transforms,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
return wsi_datasets
|
|
185
|
+
|
|
186
|
+
def _get_dataset_idx(self, index: int) -> int:
|
|
187
|
+
return bisect.bisect_right(self.cumulative_sizes, index)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Vision data transforms."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.transforms.common import ResizeAndCrop
|
|
3
|
+
from eva.vision.data.transforms.common import ResizeAndClamp, ResizeAndCrop
|
|
4
|
+
from eva.vision.data.transforms.normalization import Clamp, RescaleIntensity
|
|
4
5
|
|
|
5
|
-
__all__ = ["ResizeAndCrop"]
|
|
6
|
+
__all__ = ["ResizeAndCrop", "ResizeAndClamp", "Clamp", "RescaleIntensity"]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Common vision transforms."""
|
|
2
2
|
|
|
3
|
+
from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp
|
|
3
4
|
from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
|
|
4
5
|
|
|
5
|
-
__all__ = ["ResizeAndCrop"]
|
|
6
|
+
__all__ = ["ResizeAndClamp", "ResizeAndCrop"]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Specialized transforms for resizing, clamping and range normalizing."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
from torchvision.transforms import v2
|
|
6
|
+
|
|
7
|
+
from eva.vision.data.transforms import normalization
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ResizeAndClamp(v2.Compose):
|
|
11
|
+
"""Resizes, crops, clamps and normalizes an input image."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
size: int | Sequence[int] = 224,
|
|
16
|
+
clamp_range: Tuple[int, int] = (-1024, 1024),
|
|
17
|
+
mean: Sequence[float] = (0.0, 0.0, 0.0),
|
|
18
|
+
std: Sequence[float] = (1.0, 1.0, 1.0),
|
|
19
|
+
) -> None:
|
|
20
|
+
"""Initializes the transform object.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
size: Desired output size of the crop. If size is an `int` instead
|
|
24
|
+
of sequence like (h, w), a square crop (size, size) is made.
|
|
25
|
+
clamp_range: The lower and upper bound to clamp the pixel values.
|
|
26
|
+
mean: Sequence of means for each image channel.
|
|
27
|
+
std: Sequence of standard deviations for each image channel.
|
|
28
|
+
"""
|
|
29
|
+
self._size = size
|
|
30
|
+
self._clamp_range = clamp_range
|
|
31
|
+
self._mean = mean
|
|
32
|
+
self._std = std
|
|
33
|
+
|
|
34
|
+
super().__init__(transforms=self._build_transforms())
|
|
35
|
+
|
|
36
|
+
def _build_transforms(self) -> Sequence[Callable]:
|
|
37
|
+
"""Builds and returns the list of transforms."""
|
|
38
|
+
transforms = [
|
|
39
|
+
v2.Resize(size=self._size),
|
|
40
|
+
v2.CenterCrop(size=self._size),
|
|
41
|
+
normalization.Clamp(out_range=self._clamp_range),
|
|
42
|
+
normalization.RescaleIntensity(
|
|
43
|
+
in_range=self._clamp_range,
|
|
44
|
+
out_range=(0.0, 1.0),
|
|
45
|
+
),
|
|
46
|
+
v2.Normalize(
|
|
47
|
+
mean=self._mean,
|
|
48
|
+
std=self._std,
|
|
49
|
+
),
|
|
50
|
+
]
|
|
51
|
+
return transforms
|
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
from typing import Callable, Sequence
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
|
|
6
|
+
from torchvision.transforms import v2
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class ResizeAndCrop(
|
|
9
|
+
class ResizeAndCrop(v2.Compose):
|
|
10
10
|
"""Resizes, crops and normalizes an input image while preserving its aspect ratio."""
|
|
11
11
|
|
|
12
12
|
def __init__(
|
|
@@ -32,11 +32,10 @@ class ResizeAndCrop(torch_transforms.Compose):
|
|
|
32
32
|
def _build_transforms(self) -> Sequence[Callable]:
|
|
33
33
|
"""Builds and returns the list of transforms."""
|
|
34
34
|
transforms = [
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
torch_transforms.Normalize(
|
|
35
|
+
v2.Resize(size=self._size),
|
|
36
|
+
v2.CenterCrop(size=self._size),
|
|
37
|
+
v2.ToDtype(torch.float32, scale=True),
|
|
38
|
+
v2.Normalize(
|
|
40
39
|
mean=self._mean,
|
|
41
40
|
std=self._std,
|
|
42
41
|
),
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Image clamp transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torchvision.transforms.v2 as torch_transforms
|
|
8
|
+
from torchvision import tv_tensors
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Clamp(torch_transforms.Transform):
|
|
13
|
+
"""Clamps all elements in input into a specific range."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, out_range: Tuple[int, int]) -> None:
|
|
16
|
+
"""Initializes the transform.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
out_range: The lower and upper bound of the range to
|
|
20
|
+
be clamped to.
|
|
21
|
+
"""
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
self._out_range = out_range
|
|
25
|
+
|
|
26
|
+
@functools.singledispatchmethod
|
|
27
|
+
@override
|
|
28
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
29
|
+
return inpt
|
|
30
|
+
|
|
31
|
+
@_transform.register(torch.Tensor)
|
|
32
|
+
def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
|
|
33
|
+
return torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
|
|
34
|
+
|
|
35
|
+
@_transform.register(tv_tensors.Image)
|
|
36
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
37
|
+
inpt_clamp = torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
|
|
38
|
+
return tv_tensors.wrap(inpt_clamp, like=inpt)
|
|
39
|
+
|
|
40
|
+
@_transform.register(tv_tensors.BoundingBoxes)
|
|
41
|
+
@_transform.register(tv_tensors.Mask)
|
|
42
|
+
def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
|
|
43
|
+
return inpt
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Intensity level functions."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def rescale_intensity(
|
|
10
|
+
image: torch.Tensor,
|
|
11
|
+
in_range: Tuple[float, float] | None = None,
|
|
12
|
+
out_range: Tuple[float, float] = (0.0, 1.0),
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
"""Stretches or shrinks the image intensity levels.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
image: The image tensor as float-type.
|
|
18
|
+
in_range: The input data range. If `None`, it will
|
|
19
|
+
fetch the min and max of the input image.
|
|
20
|
+
out_range: The desired intensity range of the output.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
The image tensor after stretching or shrinking its intensity levels.
|
|
24
|
+
"""
|
|
25
|
+
imin, imax = in_range or (image.min(), image.max())
|
|
26
|
+
omin, omax = out_range
|
|
27
|
+
image_scaled = (image - imin) / (imax - imin + sys.float_info.epsilon)
|
|
28
|
+
return image_scaled * (omax - omin) + omin
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Intensity level scaling transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torchvision.transforms.v2 as torch_transforms
|
|
8
|
+
from torchvision import tv_tensors
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data.transforms.normalization import functional
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RescaleIntensity(torch_transforms.Transform):
|
|
15
|
+
"""Stretches or shrinks the image intensity levels."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
in_range: Tuple[float, float] | None = None,
|
|
20
|
+
out_range: Tuple[float, float] = (0.0, 1.0),
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initializes the transform.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
in_range: The input data range. If `None`, it will
|
|
26
|
+
fetch the min and max of the input image.
|
|
27
|
+
out_range: The desired intensity range of the output.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
self._in_range = in_range
|
|
32
|
+
self._out_range = out_range
|
|
33
|
+
|
|
34
|
+
@functools.singledispatchmethod
|
|
35
|
+
@override
|
|
36
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
37
|
+
return inpt
|
|
38
|
+
|
|
39
|
+
@_transform.register(torch.Tensor)
|
|
40
|
+
def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
|
|
41
|
+
return functional.rescale_intensity(
|
|
42
|
+
inpt, in_range=self._in_range, out_range=self._out_range
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
@_transform.register(tv_tensors.Image)
|
|
46
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
47
|
+
scaled_inpt = functional.rescale_intensity(inpt, out_range=self._out_range)
|
|
48
|
+
return tv_tensors.wrap(scaled_inpt, like=inpt)
|
|
49
|
+
|
|
50
|
+
@_transform.register(tv_tensors.BoundingBoxes)
|
|
51
|
+
@_transform.register(tv_tensors.Mask)
|
|
52
|
+
def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
|
|
53
|
+
return inpt
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""WSI API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.wsi.backends import Wsi, get_cached_wsi, wsi_backend
|
|
4
|
+
from eva.vision.data.wsi.patching.coordinates import PatchCoordinates, get_cached_coords
|
|
5
|
+
from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Wsi",
|
|
9
|
+
"PatchCoordinates",
|
|
10
|
+
"Mask",
|
|
11
|
+
"get_cached_coords",
|
|
12
|
+
"wsi_backend",
|
|
13
|
+
"get_cached_wsi",
|
|
14
|
+
"get_mask",
|
|
15
|
+
"get_mask_level",
|
|
16
|
+
]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""WSI Backends API."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import importlib.util
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from eva.vision.data.wsi.backends.base import Wsi
|
|
8
|
+
|
|
9
|
+
LRU_CACHE_SIZE = 32
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _is_openslide_available() -> bool:
|
|
13
|
+
"""Whether the OpenSlide library is available."""
|
|
14
|
+
return importlib.util.find_spec("openslide") is not None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _is_tiffslide_available() -> bool:
|
|
18
|
+
"""Whether the TiffSlide library is available."""
|
|
19
|
+
return importlib.util.find_spec("tiffslide") is not None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def is_backend_available(backend: str) -> bool:
|
|
23
|
+
"""Whether the specified backend is available."""
|
|
24
|
+
match backend:
|
|
25
|
+
case "openslide":
|
|
26
|
+
return _is_openslide_available()
|
|
27
|
+
case "tiffslide":
|
|
28
|
+
return _is_tiffslide_available()
|
|
29
|
+
return False
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def wsi_backend(backend: str = "openslide") -> Callable[..., Wsi]:
|
|
33
|
+
"""Returns the backend to use for reading the whole-slide images."""
|
|
34
|
+
match backend:
|
|
35
|
+
case "openslide":
|
|
36
|
+
if _is_openslide_available():
|
|
37
|
+
from eva.vision.data.wsi.backends.openslide import WsiOpenslide
|
|
38
|
+
|
|
39
|
+
return WsiOpenslide
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"Missing optional dependency: openslide.\n"
|
|
43
|
+
"Please install using `pip install openslide-python`."
|
|
44
|
+
)
|
|
45
|
+
case "tiffslide":
|
|
46
|
+
if _is_tiffslide_available():
|
|
47
|
+
from eva.vision.data.wsi.backends.tiffslide import WsiTiffslide
|
|
48
|
+
|
|
49
|
+
return WsiTiffslide
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Missing optional dependency: tiffslide.\n"
|
|
53
|
+
"Please install using `pip install tiffslide`."
|
|
54
|
+
)
|
|
55
|
+
case "pil":
|
|
56
|
+
from eva.vision.data.wsi.backends.pil import PILImage
|
|
57
|
+
|
|
58
|
+
return PILImage
|
|
59
|
+
case _:
|
|
60
|
+
raise ValueError(f"Unknown WSI backend selected: {backend}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@functools.lru_cache(LRU_CACHE_SIZE)
|
|
64
|
+
def get_cached_wsi(file_path: str, backend: str, overwrite_mpp: float | None = None) -> Wsi:
|
|
65
|
+
"""Returns a cached instance of the whole-slide image backend reader."""
|
|
66
|
+
return wsi_backend(backend)(file_path, overwrite_mpp)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
__all__ = ["Wsi", "wsi_backend", "get_cached_wsi", "_is_openslide_available"]
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Base Module for loading data from WSI files."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Sequence, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Wsi(abc.ABC):
|
|
10
|
+
"""Base class for loading data from Whole Slide Image (WSI) files."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, file_path: str, overwrite_mpp: float | None = None):
|
|
13
|
+
"""Initializes a Wsi object.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
file_path: The path to the WSI file.
|
|
17
|
+
overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
|
|
18
|
+
"""
|
|
19
|
+
self._wsi = self.open_file(file_path)
|
|
20
|
+
self._overwrite_mpp = overwrite_mpp
|
|
21
|
+
|
|
22
|
+
@abc.abstractmethod
|
|
23
|
+
def open_file(self, file_path: str) -> Any:
|
|
24
|
+
"""Opens the WSI file.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
file_path: The path to the WSI file.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
@abc.abstractmethod
|
|
32
|
+
def level_dimensions(self) -> Sequence[Tuple[int, int]]:
|
|
33
|
+
"""A list of (width, height) tuples for each level, from highest to lowest resolution."""
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
@abc.abstractmethod
|
|
37
|
+
def level_downsamples(self) -> Sequence[float]:
|
|
38
|
+
"""A list of downsampling factors for each level, relative to the highest resolution."""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
@abc.abstractmethod
|
|
42
|
+
def mpp(self) -> float:
|
|
43
|
+
"""Microns per pixel at the highest resolution (level 0)."""
|
|
44
|
+
|
|
45
|
+
@abc.abstractmethod
|
|
46
|
+
def _read_region(
|
|
47
|
+
self, location: Tuple[int, int], level: int, size: Tuple[int, int]
|
|
48
|
+
) -> np.ndarray:
|
|
49
|
+
"""Abstract method to read a region at a specified zoom level."""
|
|
50
|
+
|
|
51
|
+
def read_region(
|
|
52
|
+
self, location: Tuple[int, int], level: int, size: Tuple[int, int]
|
|
53
|
+
) -> np.ndarray:
|
|
54
|
+
"""Reads and returns image data for a specified region and zoom level.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
location: Top-left corner (x, y) to start reading at level 0.
|
|
58
|
+
level: WSI level to read from.
|
|
59
|
+
size: Region size as (width, height) in pixels at the selected read level.
|
|
60
|
+
Remember to scale the size correctly.
|
|
61
|
+
"""
|
|
62
|
+
self._verify_location(location, size)
|
|
63
|
+
data = self._read_region(location, level, size)
|
|
64
|
+
return self._read_postprocess(data)
|
|
65
|
+
|
|
66
|
+
def get_closest_level(self, target_mpp: float) -> int:
|
|
67
|
+
"""Calculate the slide level that is closest to the target mpp.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
slide: The whole-slide image object.
|
|
71
|
+
target_mpp: The target microns per pixel (mpp) value.
|
|
72
|
+
"""
|
|
73
|
+
# Calculate the mpp for each level
|
|
74
|
+
level_mpps = self.mpp * np.array(self.level_downsamples)
|
|
75
|
+
|
|
76
|
+
# Ignore levels with higher mpp
|
|
77
|
+
level_mpps_filtered = level_mpps.copy()
|
|
78
|
+
level_mpps_filtered[level_mpps_filtered > target_mpp] = 0
|
|
79
|
+
|
|
80
|
+
if level_mpps_filtered.max() == 0:
|
|
81
|
+
# When all levels have higher mpp than target_mpp return the level with lowest mpp
|
|
82
|
+
level_idx = np.argmin(level_mpps)
|
|
83
|
+
else:
|
|
84
|
+
level_idx = np.argmax(level_mpps_filtered)
|
|
85
|
+
|
|
86
|
+
return int(level_idx)
|
|
87
|
+
|
|
88
|
+
def _verify_location(self, location: Tuple[int, int], size: Tuple[int, int]) -> None:
|
|
89
|
+
"""Verifies that the requested region is within the slide dimensions.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
location: Top-left corner (x, y) to start reading at level 0.
|
|
93
|
+
size: Region size as (width, height) in pixels at the selected read level.
|
|
94
|
+
"""
|
|
95
|
+
x_max, y_max = self.level_dimensions[0]
|
|
96
|
+
x_scale = x_max / self.level_dimensions[0][0]
|
|
97
|
+
y_scale = y_max / self.level_dimensions[0][1]
|
|
98
|
+
|
|
99
|
+
if (
|
|
100
|
+
int(location[0] + x_scale * size[0]) > x_max
|
|
101
|
+
or int(location[1] + y_scale * size[1]) > y_max
|
|
102
|
+
):
|
|
103
|
+
raise ValueError(f"Out of bounds region: {location}, {size}")
|
|
104
|
+
|
|
105
|
+
def _read_postprocess(self, data: np.ndarray) -> np.ndarray:
|
|
106
|
+
"""Post-processes the read region data.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
data: The read region data as a numpy array of shape (height, width, channels).
|
|
110
|
+
"""
|
|
111
|
+
# Change color to white where the alpha channel is 0
|
|
112
|
+
if data.shape[2] == 4:
|
|
113
|
+
data[data[:, :, 3] == 0] = 255
|
|
114
|
+
|
|
115
|
+
return data[:, :, :3]
|