kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +2 -2
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +2 -2
- eva/core/data/datasets/classification/__init__.py +8 -0
- eva/core/data/datasets/classification/embeddings.py +34 -0
- eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
- eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/loggers/experimental_loggers.py +2 -2
- eva/core/loggers/log/__init__.py +3 -2
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +10 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +10 -4
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/functional.py +1 -0
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +30 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +12 -1
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +16 -17
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +56 -13
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +33 -12
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/.DS_Store +0 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/data/datasets/embeddings/__init__.py +0 -13
- eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
- eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/models/.DS_Store +0 -0
- eva/vision/models/networks/.DS_Store +0 -0
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.2.dist-info/METADATA +0 -431
- kaiko_eva-0.0.2.dist-info/RECORD +0 -127
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,24 +3,30 @@
|
|
|
3
3
|
import functools
|
|
4
4
|
import os
|
|
5
5
|
from glob import glob
|
|
6
|
-
from typing import Callable, Dict, List, Literal, Tuple
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
10
|
+
import torch
|
|
11
|
+
import tqdm
|
|
9
12
|
from torchvision import tv_tensors
|
|
10
13
|
from torchvision.datasets import utils
|
|
11
14
|
from typing_extensions import override
|
|
12
15
|
|
|
13
|
-
from eva.vision.data.datasets import
|
|
16
|
+
from eva.vision.data.datasets import _validators, structs
|
|
14
17
|
from eva.vision.data.datasets.segmentation import base
|
|
15
|
-
from eva.vision.utils import
|
|
18
|
+
from eva.vision.utils import io
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
class TotalSegmentator2D(base.ImageSegmentation):
|
|
19
22
|
"""TotalSegmentator 2D segmentation dataset."""
|
|
20
23
|
|
|
21
24
|
_expected_dataset_lengths: Dict[str, int] = {
|
|
22
|
-
"train_small":
|
|
23
|
-
"val_small":
|
|
25
|
+
"train_small": 35089,
|
|
26
|
+
"val_small": 1283,
|
|
27
|
+
"train_full": 278190,
|
|
28
|
+
"val_full": 14095,
|
|
29
|
+
"test_full": 25578,
|
|
24
30
|
}
|
|
25
31
|
"""Dataset version and split to the expected size."""
|
|
26
32
|
|
|
@@ -45,13 +51,20 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
45
51
|
]
|
|
46
52
|
"""Resources for the small dataset version."""
|
|
47
53
|
|
|
54
|
+
_license: str = (
|
|
55
|
+
"Creative Commons Attribution 4.0 International "
|
|
56
|
+
"(https://creativecommons.org/licenses/by/4.0/deed.en)"
|
|
57
|
+
)
|
|
58
|
+
"""Dataset license."""
|
|
59
|
+
|
|
48
60
|
def __init__(
|
|
49
61
|
self,
|
|
50
62
|
root: str,
|
|
51
|
-
split: Literal["train", "val"] | None,
|
|
52
|
-
version: Literal["small", "full"] | None = "
|
|
63
|
+
split: Literal["train", "val", "test"] | None,
|
|
64
|
+
version: Literal["small", "full"] | None = "full",
|
|
53
65
|
download: bool = False,
|
|
54
|
-
|
|
66
|
+
classes: List[str] | None = None,
|
|
67
|
+
optimize_mask_loading: bool = True,
|
|
55
68
|
transforms: Callable | None = None,
|
|
56
69
|
) -> None:
|
|
57
70
|
"""Initialize dataset.
|
|
@@ -66,7 +79,12 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
66
79
|
Note that the download will be executed only by additionally
|
|
67
80
|
calling the :meth:`prepare_data` method and if the data does not
|
|
68
81
|
exist yet on disk.
|
|
69
|
-
|
|
82
|
+
classes: Whether to configure the dataset with a subset of classes.
|
|
83
|
+
If `None`, it will use all of them.
|
|
84
|
+
optimize_mask_loading: Whether to pre-process the segmentation masks
|
|
85
|
+
in order to optimize the loading time. In the `setup` method, it
|
|
86
|
+
will reformat the binary one-hot masks to a semantic mask and store
|
|
87
|
+
it on disk.
|
|
70
88
|
transforms: A function/transforms that takes in an image and a target
|
|
71
89
|
mask and returns the transformed versions of both.
|
|
72
90
|
"""
|
|
@@ -76,7 +94,13 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
76
94
|
self._split = split
|
|
77
95
|
self._version = version
|
|
78
96
|
self._download = download
|
|
79
|
-
self.
|
|
97
|
+
self._classes = classes
|
|
98
|
+
self._optimize_mask_loading = optimize_mask_loading
|
|
99
|
+
|
|
100
|
+
if self._optimize_mask_loading and self._classes is not None:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"To use customize classes please set the optimize_mask_loading to `False`."
|
|
103
|
+
)
|
|
80
104
|
|
|
81
105
|
self._samples_dirs: List[str] = []
|
|
82
106
|
self._indices: List[Tuple[int, int]] = []
|
|
@@ -91,7 +115,13 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
91
115
|
first_sample_labels = os.path.join(
|
|
92
116
|
self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
|
|
93
117
|
)
|
|
94
|
-
|
|
118
|
+
all_classes = sorted(map(get_filename, glob(first_sample_labels)))
|
|
119
|
+
if self._classes:
|
|
120
|
+
is_subset = all(name in all_classes for name in self._classes)
|
|
121
|
+
if not is_subset:
|
|
122
|
+
raise ValueError("Provided class names are not subset of the dataset onces.")
|
|
123
|
+
|
|
124
|
+
return all_classes if self._classes is None else self._classes
|
|
95
125
|
|
|
96
126
|
@property
|
|
97
127
|
@override
|
|
@@ -99,7 +129,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
99
129
|
return {label: index for index, label in enumerate(self.classes)}
|
|
100
130
|
|
|
101
131
|
@override
|
|
102
|
-
def filename(self, index: int) -> str:
|
|
132
|
+
def filename(self, index: int, segmented: bool = True) -> str:
|
|
103
133
|
sample_idx, _ = self._indices[index]
|
|
104
134
|
sample_dir = self._samples_dirs[sample_idx]
|
|
105
135
|
return os.path.join(sample_dir, "ct.nii.gz")
|
|
@@ -113,17 +143,23 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
113
143
|
def configure(self) -> None:
|
|
114
144
|
self._samples_dirs = self._fetch_samples_dirs()
|
|
115
145
|
self._indices = self._create_indices()
|
|
146
|
+
if self._optimize_mask_loading:
|
|
147
|
+
self._export_semantic_label_masks()
|
|
116
148
|
|
|
117
149
|
@override
|
|
118
150
|
def validate(self) -> None:
|
|
119
|
-
if self._version is None:
|
|
151
|
+
if self._version is None or self._sample_every_n_slices is not None:
|
|
120
152
|
return
|
|
121
153
|
|
|
122
154
|
_validators.check_dataset_integrity(
|
|
123
155
|
self,
|
|
124
156
|
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
|
|
125
|
-
n_classes=117,
|
|
126
|
-
first_and_last_labels=(
|
|
157
|
+
n_classes=len(self._classes) if self._classes else 117,
|
|
158
|
+
first_and_last_labels=(
|
|
159
|
+
(self._classes[0], self._classes[-1])
|
|
160
|
+
if self._classes
|
|
161
|
+
else ("adrenal_gland_left", "vertebrae_T9")
|
|
162
|
+
),
|
|
127
163
|
)
|
|
128
164
|
|
|
129
165
|
@override
|
|
@@ -134,25 +170,68 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
134
170
|
def load_image(self, index: int) -> tv_tensors.Image:
|
|
135
171
|
sample_index, slice_index = self._indices[index]
|
|
136
172
|
image_path = self._get_image_path(sample_index)
|
|
137
|
-
image_array = io.
|
|
138
|
-
if self._as_uint8:
|
|
139
|
-
image_array = convert.to_8bit(image_array)
|
|
173
|
+
image_array = io.read_nifti(image_path, slice_index)
|
|
140
174
|
image_rgb_array = image_array.repeat(3, axis=2)
|
|
141
175
|
return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
|
|
142
176
|
|
|
143
177
|
@override
|
|
144
178
|
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
179
|
+
if self._optimize_mask_loading:
|
|
180
|
+
return self._load_semantic_label_mask(index)
|
|
181
|
+
return self._load_mask(index)
|
|
182
|
+
|
|
183
|
+
@override
|
|
184
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
185
|
+
_, slice_index = self._indices[index]
|
|
186
|
+
return {"slice_index": slice_index}
|
|
187
|
+
|
|
188
|
+
def _load_mask(self, index: int) -> tv_tensors.Mask:
|
|
189
|
+
"""Loads and builds the segmentation mask from NifTi files."""
|
|
190
|
+
sample_index, slice_index = self._indices[index]
|
|
191
|
+
semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
|
|
192
|
+
return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
193
|
+
|
|
194
|
+
def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
|
|
195
|
+
"""Loads the segmentation mask from a semantic label NifTi file."""
|
|
145
196
|
sample_index, slice_index = self._indices[index]
|
|
146
197
|
masks_dir = self._get_masks_dir(sample_index)
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
198
|
+
filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
|
|
199
|
+
semantic_labels = io.read_nifti(filename, slice_index)
|
|
200
|
+
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
201
|
+
|
|
202
|
+
def _load_masks_as_semantic_label(
|
|
203
|
+
self, sample_index: int, slice_index: int | None = None
|
|
204
|
+
) -> npt.NDArray[Any]:
|
|
205
|
+
"""Loads binary masks as a semantic label mask.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
sample_index: The data sample index.
|
|
209
|
+
slice_index: Whether to return only a specific slice.
|
|
210
|
+
"""
|
|
211
|
+
masks_dir = self._get_masks_dir(sample_index)
|
|
212
|
+
mask_paths = [os.path.join(masks_dir, label + ".nii.gz") for label in self.classes]
|
|
213
|
+
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
|
|
214
|
+
background_mask = np.zeros_like(binary_masks[0])
|
|
215
|
+
return np.argmax([background_mask] + binary_masks, axis=0)
|
|
216
|
+
|
|
217
|
+
def _export_semantic_label_masks(self) -> None:
|
|
218
|
+
"""Exports the segmentation binary masks (one-hot) to semantic labels."""
|
|
219
|
+
total_samples = len(self._samples_dirs)
|
|
220
|
+
masks_dirs = map(self._get_masks_dir, range(total_samples))
|
|
221
|
+
semantic_labels = [
|
|
222
|
+
(index, os.path.join(directory, "semantic_labels", "masks.nii.gz"))
|
|
223
|
+
for index, directory in enumerate(masks_dirs)
|
|
224
|
+
]
|
|
225
|
+
to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
|
|
226
|
+
|
|
227
|
+
for sample_index, filename in tqdm.tqdm(
|
|
228
|
+
list(to_export),
|
|
229
|
+
desc=">> Exporting optimized semantic masks",
|
|
230
|
+
leave=False,
|
|
231
|
+
):
|
|
232
|
+
semantic_labels = self._load_masks_as_semantic_label(sample_index)
|
|
233
|
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
234
|
+
io.save_array_as_nifti(semantic_labels, filename)
|
|
156
235
|
|
|
157
236
|
def _get_image_path(self, sample_index: int) -> str:
|
|
158
237
|
"""Returns the corresponding image path."""
|
|
@@ -164,10 +243,16 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
164
243
|
sample_dir = self._samples_dirs[sample_index]
|
|
165
244
|
return os.path.join(self._root, sample_dir, "segmentations")
|
|
166
245
|
|
|
246
|
+
def _get_semantic_labels_filename(self, sample_index: int) -> str:
|
|
247
|
+
"""Returns the semantic label filename."""
|
|
248
|
+
masks_dir = self._get_masks_dir(sample_index)
|
|
249
|
+
return os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
|
|
250
|
+
|
|
167
251
|
def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
|
|
168
252
|
"""Returns the total amount of slices of a sample."""
|
|
169
253
|
image_path = self._get_image_path(sample_index)
|
|
170
|
-
|
|
254
|
+
image_shape = io.fetch_nifti_shape(image_path)
|
|
255
|
+
return image_shape[-1]
|
|
171
256
|
|
|
172
257
|
def _fetch_samples_dirs(self) -> List[str]:
|
|
173
258
|
"""Returns the name of all the samples of all the splits of the dataset."""
|
|
@@ -180,16 +265,20 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
180
265
|
|
|
181
266
|
def _get_split_indices(self) -> List[int]:
|
|
182
267
|
"""Returns the samples indices that corresponding the dataset split and version."""
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
case "
|
|
188
|
-
|
|
268
|
+
metadata_file = os.path.join(self._root, "meta.csv")
|
|
269
|
+
metadata = io.read_csv(metadata_file, delimiter=";", encoding="utf-8-sig")
|
|
270
|
+
|
|
271
|
+
match self._split:
|
|
272
|
+
case "train":
|
|
273
|
+
image_ids = [item["image_id"] for item in metadata if item["split"] == "train"]
|
|
274
|
+
case "val":
|
|
275
|
+
image_ids = [item["image_id"] for item in metadata if item["split"] == "val"]
|
|
276
|
+
case "test":
|
|
277
|
+
image_ids = [item["image_id"] for item in metadata if item["split"] == "test"]
|
|
189
278
|
case _:
|
|
190
|
-
|
|
279
|
+
image_ids = self._samples_dirs
|
|
191
280
|
|
|
192
|
-
return
|
|
281
|
+
return sorted(map(self._samples_dirs.index, image_ids))
|
|
193
282
|
|
|
194
283
|
def _create_indices(self) -> List[Tuple[int, int]]:
|
|
195
284
|
"""Builds the dataset indices for the specified split.
|
|
@@ -219,6 +308,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
219
308
|
f"Can't download data version '{self._version}'. Use 'small' or 'full'."
|
|
220
309
|
)
|
|
221
310
|
|
|
311
|
+
self._print_license()
|
|
222
312
|
for resource in resources:
|
|
223
313
|
if os.path.isdir(self._root):
|
|
224
314
|
continue
|
|
@@ -229,3 +319,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
229
319
|
filename=resource.filename,
|
|
230
320
|
remove_finished=True,
|
|
231
321
|
)
|
|
322
|
+
|
|
323
|
+
def _print_license(self) -> None:
|
|
324
|
+
"""Prints the dataset license."""
|
|
325
|
+
print(f"Dataset license: {self._license}")
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Dataset classes for whole-slide images."""
|
|
2
|
+
|
|
3
|
+
import bisect
|
|
4
|
+
import os
|
|
5
|
+
from typing import Callable, List
|
|
6
|
+
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from torch.utils.data import dataset as torch_datasets
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.transforms.v2 import functional
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data import wsi
|
|
14
|
+
from eva.vision.data.datasets import vision
|
|
15
|
+
from eva.vision.data.wsi.patching import samplers
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class WsiDataset(vision.VisionDataset):
|
|
19
|
+
"""Dataset class for reading patches from whole-slide images."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
file_path: str,
|
|
24
|
+
width: int,
|
|
25
|
+
height: int,
|
|
26
|
+
sampler: samplers.Sampler,
|
|
27
|
+
target_mpp: float,
|
|
28
|
+
overwrite_mpp: float | None = None,
|
|
29
|
+
backend: str = "openslide",
|
|
30
|
+
image_transforms: Callable | None = None,
|
|
31
|
+
):
|
|
32
|
+
"""Initializes a new dataset instance.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
file_path: Path to the whole-slide image file.
|
|
36
|
+
width: Width of the patches to be extracted, in pixels.
|
|
37
|
+
height: Height of the patches to be extracted, in pixels.
|
|
38
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
39
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
40
|
+
overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
|
|
41
|
+
backend: The backend to use for reading the whole-slide images.
|
|
42
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__()
|
|
45
|
+
|
|
46
|
+
self._file_path = file_path
|
|
47
|
+
self._width = width
|
|
48
|
+
self._height = height
|
|
49
|
+
self._sampler = sampler
|
|
50
|
+
self._target_mpp = target_mpp
|
|
51
|
+
self._overwrite_mpp = overwrite_mpp
|
|
52
|
+
self._backend = backend
|
|
53
|
+
self._image_transforms = image_transforms
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def __len__(self):
|
|
57
|
+
return len(self._coords.x_y)
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def filename(self, index: int) -> str:
|
|
61
|
+
return f"{self._file_path}_{index}"
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def _wsi(self) -> wsi.Wsi:
|
|
65
|
+
return wsi.get_cached_wsi(self._file_path, self._backend, self._overwrite_mpp)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def _coords(self) -> wsi.PatchCoordinates:
|
|
69
|
+
return wsi.get_cached_coords(
|
|
70
|
+
file_path=self._file_path,
|
|
71
|
+
width=self._width,
|
|
72
|
+
height=self._height,
|
|
73
|
+
target_mpp=self._target_mpp,
|
|
74
|
+
overwrite_mpp=self._overwrite_mpp,
|
|
75
|
+
sampler=self._sampler,
|
|
76
|
+
backend=self._backend,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def __getitem__(self, index: int) -> tv_tensors.Image:
|
|
81
|
+
x, y = self._coords.x_y[index]
|
|
82
|
+
width, height, level_idx = self._coords.width, self._coords.height, self._coords.level_idx
|
|
83
|
+
patch = self._wsi.read_region((x, y), level_idx, (width, height))
|
|
84
|
+
patch = functional.to_image(patch)
|
|
85
|
+
patch = self._apply_transforms(patch)
|
|
86
|
+
return patch
|
|
87
|
+
|
|
88
|
+
def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
|
|
89
|
+
if self._image_transforms is not None:
|
|
90
|
+
image = self._image_transforms(image)
|
|
91
|
+
return image
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class MultiWsiDataset(vision.VisionDataset):
|
|
95
|
+
"""Dataset class for reading patches from multiple whole-slide images."""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
root: str,
|
|
100
|
+
file_paths: List[str],
|
|
101
|
+
width: int,
|
|
102
|
+
height: int,
|
|
103
|
+
sampler: samplers.Sampler,
|
|
104
|
+
target_mpp: float,
|
|
105
|
+
overwrite_mpp: float | None = None,
|
|
106
|
+
backend: str = "openslide",
|
|
107
|
+
image_transforms: Callable | None = None,
|
|
108
|
+
):
|
|
109
|
+
"""Initializes a new dataset instance.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
root: Root directory of the dataset.
|
|
113
|
+
file_paths: List of paths to the whole-slide image files, relative to the root.
|
|
114
|
+
width: Width of the patches to be extracted, in pixels.
|
|
115
|
+
height: Height of the patches to be extracted, in pixels.
|
|
116
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
117
|
+
overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
|
|
118
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
119
|
+
backend: The backend to use for reading the whole-slide images.
|
|
120
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
121
|
+
"""
|
|
122
|
+
super().__init__()
|
|
123
|
+
|
|
124
|
+
self._root = root
|
|
125
|
+
self._file_paths = file_paths
|
|
126
|
+
self._width = width
|
|
127
|
+
self._height = height
|
|
128
|
+
self._target_mpp = target_mpp
|
|
129
|
+
self._overwrite_mpp = overwrite_mpp
|
|
130
|
+
self._sampler = sampler
|
|
131
|
+
self._backend = backend
|
|
132
|
+
self._image_transforms = image_transforms
|
|
133
|
+
|
|
134
|
+
self._concat_dataset: torch_datasets.ConcatDataset
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def datasets(self) -> List[WsiDataset]:
|
|
138
|
+
"""Returns the list of WSI datasets."""
|
|
139
|
+
return self._concat_dataset.datasets # type: ignore
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def cumulative_sizes(self) -> List[int]:
|
|
143
|
+
"""Returns the cumulative sizes of the WSI datasets."""
|
|
144
|
+
return self._concat_dataset.cumulative_sizes
|
|
145
|
+
|
|
146
|
+
@override
|
|
147
|
+
def configure(self) -> None:
|
|
148
|
+
self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
|
|
149
|
+
|
|
150
|
+
@override
|
|
151
|
+
def __len__(self) -> int:
|
|
152
|
+
return len(self._concat_dataset)
|
|
153
|
+
|
|
154
|
+
@override
|
|
155
|
+
def __getitem__(self, index: int) -> tv_tensors.Image:
|
|
156
|
+
return self._concat_dataset[index]
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def filename(self, index: int) -> str:
|
|
160
|
+
return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
|
|
161
|
+
|
|
162
|
+
def _load_datasets(self) -> list[WsiDataset]:
|
|
163
|
+
logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
|
|
164
|
+
wsi_datasets = []
|
|
165
|
+
for file_path in self._file_paths:
|
|
166
|
+
file_path = (
|
|
167
|
+
os.path.join(self._root, file_path) if self._root not in file_path else file_path
|
|
168
|
+
)
|
|
169
|
+
if not os.path.exists(file_path):
|
|
170
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
171
|
+
|
|
172
|
+
wsi_datasets.append(
|
|
173
|
+
WsiDataset(
|
|
174
|
+
file_path=file_path,
|
|
175
|
+
width=self._width,
|
|
176
|
+
height=self._height,
|
|
177
|
+
sampler=self._sampler,
|
|
178
|
+
target_mpp=self._target_mpp,
|
|
179
|
+
overwrite_mpp=self._overwrite_mpp,
|
|
180
|
+
backend=self._backend,
|
|
181
|
+
image_transforms=self._image_transforms,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
return wsi_datasets
|
|
185
|
+
|
|
186
|
+
def _get_dataset_idx(self, index: int) -> int:
|
|
187
|
+
return bisect.bisect_right(self.cumulative_sizes, index)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Vision data transforms."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.transforms.common import ResizeAndCrop
|
|
3
|
+
from eva.vision.data.transforms.common import ResizeAndClamp, ResizeAndCrop
|
|
4
|
+
from eva.vision.data.transforms.normalization import Clamp, RescaleIntensity
|
|
4
5
|
|
|
5
|
-
__all__ = ["ResizeAndCrop"]
|
|
6
|
+
__all__ = ["ResizeAndCrop", "ResizeAndClamp", "Clamp", "RescaleIntensity"]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Common vision transforms."""
|
|
2
2
|
|
|
3
|
+
from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp
|
|
3
4
|
from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
|
|
4
5
|
|
|
5
|
-
__all__ = ["ResizeAndCrop"]
|
|
6
|
+
__all__ = ["ResizeAndClamp", "ResizeAndCrop"]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Specialized transforms for resizing, clamping and range normalizing."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
from torchvision.transforms import v2
|
|
6
|
+
|
|
7
|
+
from eva.vision.data.transforms import normalization
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ResizeAndClamp(v2.Compose):
|
|
11
|
+
"""Resizes, crops, clamps and normalizes an input image."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
size: int | Sequence[int] = 224,
|
|
16
|
+
clamp_range: Tuple[int, int] = (-1024, 1024),
|
|
17
|
+
mean: Sequence[float] = (0.0, 0.0, 0.0),
|
|
18
|
+
std: Sequence[float] = (1.0, 1.0, 1.0),
|
|
19
|
+
) -> None:
|
|
20
|
+
"""Initializes the transform object.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
size: Desired output size of the crop. If size is an `int` instead
|
|
24
|
+
of sequence like (h, w), a square crop (size, size) is made.
|
|
25
|
+
clamp_range: The lower and upper bound to clamp the pixel values.
|
|
26
|
+
mean: Sequence of means for each image channel.
|
|
27
|
+
std: Sequence of standard deviations for each image channel.
|
|
28
|
+
"""
|
|
29
|
+
self._size = size
|
|
30
|
+
self._clamp_range = clamp_range
|
|
31
|
+
self._mean = mean
|
|
32
|
+
self._std = std
|
|
33
|
+
|
|
34
|
+
super().__init__(transforms=self._build_transforms())
|
|
35
|
+
|
|
36
|
+
def _build_transforms(self) -> Sequence[Callable]:
|
|
37
|
+
"""Builds and returns the list of transforms."""
|
|
38
|
+
transforms = [
|
|
39
|
+
v2.Resize(size=self._size),
|
|
40
|
+
v2.CenterCrop(size=self._size),
|
|
41
|
+
normalization.Clamp(out_range=self._clamp_range),
|
|
42
|
+
normalization.RescaleIntensity(
|
|
43
|
+
in_range=self._clamp_range,
|
|
44
|
+
out_range=(0.0, 1.0),
|
|
45
|
+
),
|
|
46
|
+
v2.Normalize(
|
|
47
|
+
mean=self._mean,
|
|
48
|
+
std=self._std,
|
|
49
|
+
),
|
|
50
|
+
]
|
|
51
|
+
return transforms
|
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
from typing import Callable, Sequence
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
|
|
6
|
+
from torchvision.transforms import v2
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class ResizeAndCrop(
|
|
9
|
+
class ResizeAndCrop(v2.Compose):
|
|
10
10
|
"""Resizes, crops and normalizes an input image while preserving its aspect ratio."""
|
|
11
11
|
|
|
12
12
|
def __init__(
|
|
@@ -32,11 +32,10 @@ class ResizeAndCrop(torch_transforms.Compose):
|
|
|
32
32
|
def _build_transforms(self) -> Sequence[Callable]:
|
|
33
33
|
"""Builds and returns the list of transforms."""
|
|
34
34
|
transforms = [
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
torch_transforms.Normalize(
|
|
35
|
+
v2.Resize(size=self._size),
|
|
36
|
+
v2.CenterCrop(size=self._size),
|
|
37
|
+
v2.ToDtype(torch.float32, scale=True),
|
|
38
|
+
v2.Normalize(
|
|
40
39
|
mean=self._mean,
|
|
41
40
|
std=self._std,
|
|
42
41
|
),
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Image clamp transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torchvision.transforms.v2 as torch_transforms
|
|
8
|
+
from torchvision import tv_tensors
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Clamp(torch_transforms.Transform):
|
|
13
|
+
"""Clamps all elements in input into a specific range."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, out_range: Tuple[int, int]) -> None:
|
|
16
|
+
"""Initializes the transform.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
out_range: The lower and upper bound of the range to
|
|
20
|
+
be clamped to.
|
|
21
|
+
"""
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
self._out_range = out_range
|
|
25
|
+
|
|
26
|
+
@functools.singledispatchmethod
|
|
27
|
+
@override
|
|
28
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
29
|
+
return inpt
|
|
30
|
+
|
|
31
|
+
@_transform.register(torch.Tensor)
|
|
32
|
+
def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
|
|
33
|
+
return torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
|
|
34
|
+
|
|
35
|
+
@_transform.register(tv_tensors.Image)
|
|
36
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
37
|
+
inpt_clamp = torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
|
|
38
|
+
return tv_tensors.wrap(inpt_clamp, like=inpt)
|
|
39
|
+
|
|
40
|
+
@_transform.register(tv_tensors.BoundingBoxes)
|
|
41
|
+
@_transform.register(tv_tensors.Mask)
|
|
42
|
+
def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
|
|
43
|
+
return inpt
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Intensity level functions."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def rescale_intensity(
|
|
10
|
+
image: torch.Tensor,
|
|
11
|
+
in_range: Tuple[float, float] | None = None,
|
|
12
|
+
out_range: Tuple[float, float] = (0.0, 1.0),
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
"""Stretches or shrinks the image intensity levels.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
image: The image tensor as float-type.
|
|
18
|
+
in_range: The input data range. If `None`, it will
|
|
19
|
+
fetch the min and max of the input image.
|
|
20
|
+
out_range: The desired intensity range of the output.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
The image tensor after stretching or shrinking its intensity levels.
|
|
24
|
+
"""
|
|
25
|
+
imin, imax = in_range or (image.min(), image.max())
|
|
26
|
+
omin, omax = out_range
|
|
27
|
+
image_scaled = (image - imin) / (imax - imin + sys.float_info.epsilon)
|
|
28
|
+
return image_scaled * (omax - omin) + omin
|