kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""MoNuSAC dataset."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal
|
|
7
|
+
from xml.etree import ElementTree # nosec
|
|
8
|
+
|
|
9
|
+
import imagesize
|
|
10
|
+
import numpy as np
|
|
11
|
+
import numpy.typing as npt
|
|
12
|
+
import torch
|
|
13
|
+
import tqdm
|
|
14
|
+
from skimage import draw
|
|
15
|
+
from torchvision import tv_tensors
|
|
16
|
+
from torchvision.datasets import utils
|
|
17
|
+
from typing_extensions import override
|
|
18
|
+
|
|
19
|
+
from eva.vision.data.datasets import _validators, structs
|
|
20
|
+
from eva.vision.data.datasets.segmentation import base
|
|
21
|
+
from eva.vision.utils import io
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MoNuSAC(base.ImageSegmentation):
|
|
25
|
+
"""MoNuSAC2020: A Multi-organ Nuclei Segmentation and Classification Challenge.
|
|
26
|
+
|
|
27
|
+
Webpage: https://monusac-2020.grand-challenge.org/
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
_expected_dataset_lengths: Dict[str, int] = {
|
|
31
|
+
"train": 209,
|
|
32
|
+
"test": 85,
|
|
33
|
+
}
|
|
34
|
+
"""Dataset version and split to the expected size."""
|
|
35
|
+
|
|
36
|
+
_resources: List[structs.DownloadResource] = [
|
|
37
|
+
structs.DownloadResource(
|
|
38
|
+
filename="MoNuSAC_images_and_annotations.zip",
|
|
39
|
+
url="https://drive.google.com/file/d/1lxMZaAPSpEHLSxGA9KKMt_r-4S8dwLhq/view?usp=sharing",
|
|
40
|
+
),
|
|
41
|
+
structs.DownloadResource(
|
|
42
|
+
filename="MoNuSAC Testing Data and Annotations.zip",
|
|
43
|
+
url="https://drive.google.com/file/d/1G54vsOdxWY1hG7dzmkeK3r0xz9s-heyQ/view?usp=sharing",
|
|
44
|
+
),
|
|
45
|
+
]
|
|
46
|
+
"""Resources for the full dataset version."""
|
|
47
|
+
|
|
48
|
+
_license: str = (
|
|
49
|
+
"Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International "
|
|
50
|
+
"(https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode)"
|
|
51
|
+
)
|
|
52
|
+
"""Dataset license."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
root: str,
|
|
57
|
+
split: Literal["train", "test"],
|
|
58
|
+
export_masks: bool = True,
|
|
59
|
+
download: bool = False,
|
|
60
|
+
transforms: Callable | None = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""Initialize dataset.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
root: Path to the root directory of the dataset. The dataset will
|
|
66
|
+
be downloaded and extracted here, if it does not already exist.
|
|
67
|
+
split: Dataset split to use.
|
|
68
|
+
export_masks: Whether to export, save and use the semantic label masks
|
|
69
|
+
from disk.
|
|
70
|
+
download: Whether to download the data for the specified split.
|
|
71
|
+
Note that the download will be executed only by additionally
|
|
72
|
+
calling the :meth:`prepare_data` method and if the data does not
|
|
73
|
+
exist yet on disk.
|
|
74
|
+
transforms: A function/transforms that takes in an image and a target
|
|
75
|
+
mask and returns the transformed versions of both.
|
|
76
|
+
"""
|
|
77
|
+
super().__init__(transforms=transforms)
|
|
78
|
+
|
|
79
|
+
self._root = root
|
|
80
|
+
self._split = split
|
|
81
|
+
self._export_masks = export_masks
|
|
82
|
+
self._download = download
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
@override
|
|
86
|
+
def classes(self) -> List[str]:
|
|
87
|
+
return ["Epithelial", "Lymphocyte", "Neutrophil", "Macrophage"]
|
|
88
|
+
|
|
89
|
+
@functools.cached_property
|
|
90
|
+
@override
|
|
91
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
92
|
+
return {label: index for index, label in enumerate(self.classes)}
|
|
93
|
+
|
|
94
|
+
@override
|
|
95
|
+
def filename(self, index: int) -> str:
|
|
96
|
+
return os.path.relpath(self._image_files[index], self._root)
|
|
97
|
+
|
|
98
|
+
@override
|
|
99
|
+
def prepare_data(self) -> None:
|
|
100
|
+
if self._download:
|
|
101
|
+
self._download_dataset()
|
|
102
|
+
if self._export_masks:
|
|
103
|
+
self._export_semantic_label_masks()
|
|
104
|
+
|
|
105
|
+
@override
|
|
106
|
+
def validate(self) -> None:
|
|
107
|
+
_validators.check_dataset_integrity(
|
|
108
|
+
self,
|
|
109
|
+
length=self._expected_dataset_lengths.get(self._split, 0),
|
|
110
|
+
n_classes=4,
|
|
111
|
+
first_and_last_labels=("Epithelial", "Macrophage"),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
@override
|
|
115
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
116
|
+
image_path = self._image_files[index]
|
|
117
|
+
image_rgb_array = io.read_image(image_path)
|
|
118
|
+
return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
122
|
+
semantic_labels = (
|
|
123
|
+
self._load_semantic_mask_file(index)
|
|
124
|
+
if self._export_masks
|
|
125
|
+
else self._get_semantic_mask(index)
|
|
126
|
+
)
|
|
127
|
+
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
128
|
+
|
|
129
|
+
@override
|
|
130
|
+
def __len__(self) -> int:
|
|
131
|
+
return len(self._image_files)
|
|
132
|
+
|
|
133
|
+
@functools.cached_property
|
|
134
|
+
def _image_files(self) -> List[str]:
|
|
135
|
+
"""Return the list of image files in the dataset.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
List of image file paths.
|
|
139
|
+
"""
|
|
140
|
+
files_pattern = os.path.join(self._data_directory, "**", "*.tif")
|
|
141
|
+
image_files = glob.glob(files_pattern, recursive=True)
|
|
142
|
+
return sorted(image_files)
|
|
143
|
+
|
|
144
|
+
@functools.cached_property
|
|
145
|
+
def _data_directory(self) -> str:
|
|
146
|
+
"""Returns the data directory of the dataset."""
|
|
147
|
+
match self._split:
|
|
148
|
+
case "train":
|
|
149
|
+
directory = "MoNuSAC_images_and_annotations"
|
|
150
|
+
case "test":
|
|
151
|
+
directory = "MoNuSAC Testing Data and Annotations"
|
|
152
|
+
case _:
|
|
153
|
+
raise ValueError(f"Invalid 'split' value '{self._split}'.")
|
|
154
|
+
|
|
155
|
+
return os.path.join(self._root, directory)
|
|
156
|
+
|
|
157
|
+
def _export_semantic_label_masks(self) -> None:
|
|
158
|
+
"""Export semantic label masks to disk."""
|
|
159
|
+
mask_files = [
|
|
160
|
+
(index, filename.replace(".tif", ".npy"))
|
|
161
|
+
for index, filename in enumerate(self._image_files)
|
|
162
|
+
]
|
|
163
|
+
to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
|
|
164
|
+
for sample_index, filename in tqdm.tqdm(
|
|
165
|
+
list(to_export),
|
|
166
|
+
desc=">> Exporting semantic masks",
|
|
167
|
+
leave=False,
|
|
168
|
+
):
|
|
169
|
+
semantic_labels = self._get_semantic_mask(sample_index)
|
|
170
|
+
np.save(filename, semantic_labels)
|
|
171
|
+
|
|
172
|
+
def _load_semantic_mask_file(self, index: int) -> npt.NDArray[Any]:
|
|
173
|
+
"""Load a semantic mask file from disk.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
index: Index of the mask file to load.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Loaded mask as a numpy array.
|
|
180
|
+
"""
|
|
181
|
+
mask_filename = self._image_files[index].replace(".tif", ".npy")
|
|
182
|
+
return np.load(mask_filename)
|
|
183
|
+
|
|
184
|
+
def _get_semantic_mask(self, index: int) -> npt.NDArray[Any]:
|
|
185
|
+
"""Builds and loads the semantic label mask from the XML annotations.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
index: Index of the image file.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Semantic label mask as a numpy array.
|
|
192
|
+
"""
|
|
193
|
+
image_path = self._image_files[index]
|
|
194
|
+
width, height = imagesize.get(image_path)
|
|
195
|
+
annotation_path = image_path.replace(".tif", ".xml")
|
|
196
|
+
element_tree = ElementTree.parse(annotation_path) # nosec
|
|
197
|
+
root = element_tree.getroot()
|
|
198
|
+
|
|
199
|
+
semantic_labels = np.zeros((height, width), "uint8") # type: ignore[reportCallIssue]
|
|
200
|
+
for level in range(len(root)):
|
|
201
|
+
label = [item.attrib["Name"] for item in root[level][0]][0]
|
|
202
|
+
class_id = self.class_to_idx.get(label, 254) + 1
|
|
203
|
+
# for the test dataset an additional class 'Ambiguous' was added for
|
|
204
|
+
# difficult regions with fuzzy boundaries - we return it as 255
|
|
205
|
+
regions = [item for child in root[level] for item in child if item.tag == "Region"]
|
|
206
|
+
for region in regions:
|
|
207
|
+
vertices = np.array(
|
|
208
|
+
[(vertex.attrib["X"], vertex.attrib["Y"]) for vertex in region[1]],
|
|
209
|
+
dtype=np.dtype(float),
|
|
210
|
+
)
|
|
211
|
+
fill_row_coords, fill_col_coords = draw.polygon(
|
|
212
|
+
vertices[:, 0],
|
|
213
|
+
vertices[:, 1],
|
|
214
|
+
(width, height),
|
|
215
|
+
)
|
|
216
|
+
semantic_labels[fill_col_coords, fill_row_coords] = class_id
|
|
217
|
+
|
|
218
|
+
return semantic_labels
|
|
219
|
+
|
|
220
|
+
def _download_dataset(self) -> None:
|
|
221
|
+
"""Downloads the dataset."""
|
|
222
|
+
self._print_license()
|
|
223
|
+
for resource in self._resources:
|
|
224
|
+
if os.path.isdir(self._data_directory):
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
utils.download_and_extract_archive(
|
|
228
|
+
resource.url,
|
|
229
|
+
download_root=self._root,
|
|
230
|
+
filename=resource.filename,
|
|
231
|
+
remove_finished=True,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
def _print_license(self) -> None:
|
|
235
|
+
"""Prints the dataset license."""
|
|
236
|
+
print(f"Dataset license: {self._license}")
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
"""TotalSegmentator 2D segmentation dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import os
|
|
5
|
+
from glob import glob
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
10
|
+
import torch
|
|
11
|
+
import tqdm
|
|
12
|
+
from torchvision import tv_tensors
|
|
13
|
+
from torchvision.datasets import utils
|
|
14
|
+
from typing_extensions import override
|
|
15
|
+
|
|
16
|
+
from eva.vision.data.datasets import _validators, structs
|
|
17
|
+
from eva.vision.data.datasets.segmentation import base
|
|
18
|
+
from eva.vision.utils import io
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TotalSegmentator2D(base.ImageSegmentation):
|
|
22
|
+
"""TotalSegmentator 2D segmentation dataset."""
|
|
23
|
+
|
|
24
|
+
_expected_dataset_lengths: Dict[str, int] = {
|
|
25
|
+
"train_small": 35089,
|
|
26
|
+
"val_small": 1283,
|
|
27
|
+
"train_full": 278190,
|
|
28
|
+
"val_full": 14095,
|
|
29
|
+
"test_full": 25578,
|
|
30
|
+
}
|
|
31
|
+
"""Dataset version and split to the expected size."""
|
|
32
|
+
|
|
33
|
+
_sample_every_n_slices: int | None = None
|
|
34
|
+
"""The amount of slices to sub-sample per 3D CT scan image."""
|
|
35
|
+
|
|
36
|
+
_resources_full: List[structs.DownloadResource] = [
|
|
37
|
+
structs.DownloadResource(
|
|
38
|
+
filename="Totalsegmentator_dataset_v201.zip",
|
|
39
|
+
url="https://zenodo.org/records/10047292/files/Totalsegmentator_dataset_v201.zip",
|
|
40
|
+
md5="fe250e5718e0a3b5df4c4ea9d58a62fe",
|
|
41
|
+
),
|
|
42
|
+
]
|
|
43
|
+
"""Resources for the full dataset version."""
|
|
44
|
+
|
|
45
|
+
_resources_small: List[structs.DownloadResource] = [
|
|
46
|
+
structs.DownloadResource(
|
|
47
|
+
filename="Totalsegmentator_dataset_small_v201.zip",
|
|
48
|
+
url="https://zenodo.org/records/10047263/files/Totalsegmentator_dataset_small_v201.zip",
|
|
49
|
+
md5="6b5524af4b15e6ba06ef2d700c0c73e0",
|
|
50
|
+
),
|
|
51
|
+
]
|
|
52
|
+
"""Resources for the small dataset version."""
|
|
53
|
+
|
|
54
|
+
_license: str = (
|
|
55
|
+
"Creative Commons Attribution 4.0 International "
|
|
56
|
+
"(https://creativecommons.org/licenses/by/4.0/deed.en)"
|
|
57
|
+
)
|
|
58
|
+
"""Dataset license."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
root: str,
|
|
63
|
+
split: Literal["train", "val", "test"] | None,
|
|
64
|
+
version: Literal["small", "full"] | None = "full",
|
|
65
|
+
download: bool = False,
|
|
66
|
+
classes: List[str] | None = None,
|
|
67
|
+
optimize_mask_loading: bool = True,
|
|
68
|
+
transforms: Callable | None = None,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Initialize dataset.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
root: Path to the root directory of the dataset. The dataset will
|
|
74
|
+
be downloaded and extracted here, if it does not already exist.
|
|
75
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
76
|
+
version: The version of the dataset to initialize. If `None`, it will
|
|
77
|
+
use the files located at root as is and wont perform any checks.
|
|
78
|
+
download: Whether to download the data for the specified split.
|
|
79
|
+
Note that the download will be executed only by additionally
|
|
80
|
+
calling the :meth:`prepare_data` method and if the data does not
|
|
81
|
+
exist yet on disk.
|
|
82
|
+
classes: Whether to configure the dataset with a subset of classes.
|
|
83
|
+
If `None`, it will use all of them.
|
|
84
|
+
optimize_mask_loading: Whether to pre-process the segmentation masks
|
|
85
|
+
in order to optimize the loading time. In the `setup` method, it
|
|
86
|
+
will reformat the binary one-hot masks to a semantic mask and store
|
|
87
|
+
it on disk.
|
|
88
|
+
transforms: A function/transforms that takes in an image and a target
|
|
89
|
+
mask and returns the transformed versions of both.
|
|
90
|
+
"""
|
|
91
|
+
super().__init__(transforms=transforms)
|
|
92
|
+
|
|
93
|
+
self._root = root
|
|
94
|
+
self._split = split
|
|
95
|
+
self._version = version
|
|
96
|
+
self._download = download
|
|
97
|
+
self._classes = classes
|
|
98
|
+
self._optimize_mask_loading = optimize_mask_loading
|
|
99
|
+
|
|
100
|
+
if self._optimize_mask_loading and self._classes is not None:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"To use customize classes please set the optimize_mask_loading to `False`."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self._samples_dirs: List[str] = []
|
|
106
|
+
self._indices: List[Tuple[int, int]] = []
|
|
107
|
+
|
|
108
|
+
@functools.cached_property
|
|
109
|
+
@override
|
|
110
|
+
def classes(self) -> List[str]:
|
|
111
|
+
def get_filename(path: str) -> str:
|
|
112
|
+
"""Returns the filename from the full path."""
|
|
113
|
+
return os.path.basename(path).split(".")[0]
|
|
114
|
+
|
|
115
|
+
first_sample_labels = os.path.join(
|
|
116
|
+
self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
|
|
117
|
+
)
|
|
118
|
+
all_classes = sorted(map(get_filename, glob(first_sample_labels)))
|
|
119
|
+
if self._classes:
|
|
120
|
+
is_subset = all(name in all_classes for name in self._classes)
|
|
121
|
+
if not is_subset:
|
|
122
|
+
raise ValueError("Provided class names are not subset of the dataset onces.")
|
|
123
|
+
|
|
124
|
+
return all_classes if self._classes is None else self._classes
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
@override
|
|
128
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
129
|
+
return {label: index for index, label in enumerate(self.classes)}
|
|
130
|
+
|
|
131
|
+
@override
|
|
132
|
+
def filename(self, index: int, segmented: bool = True) -> str:
|
|
133
|
+
sample_idx, _ = self._indices[index]
|
|
134
|
+
sample_dir = self._samples_dirs[sample_idx]
|
|
135
|
+
return os.path.join(sample_dir, "ct.nii.gz")
|
|
136
|
+
|
|
137
|
+
@override
|
|
138
|
+
def prepare_data(self) -> None:
|
|
139
|
+
if self._download:
|
|
140
|
+
self._download_dataset()
|
|
141
|
+
|
|
142
|
+
@override
|
|
143
|
+
def configure(self) -> None:
|
|
144
|
+
self._samples_dirs = self._fetch_samples_dirs()
|
|
145
|
+
self._indices = self._create_indices()
|
|
146
|
+
if self._optimize_mask_loading:
|
|
147
|
+
self._export_semantic_label_masks()
|
|
148
|
+
|
|
149
|
+
@override
|
|
150
|
+
def validate(self) -> None:
|
|
151
|
+
if self._version is None or self._sample_every_n_slices is not None:
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
_validators.check_dataset_integrity(
|
|
155
|
+
self,
|
|
156
|
+
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
|
|
157
|
+
n_classes=len(self._classes) if self._classes else 117,
|
|
158
|
+
first_and_last_labels=(
|
|
159
|
+
(self._classes[0], self._classes[-1])
|
|
160
|
+
if self._classes
|
|
161
|
+
else ("adrenal_gland_left", "vertebrae_T9")
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
@override
|
|
166
|
+
def __len__(self) -> int:
|
|
167
|
+
return len(self._indices)
|
|
168
|
+
|
|
169
|
+
@override
|
|
170
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
171
|
+
sample_index, slice_index = self._indices[index]
|
|
172
|
+
image_path = self._get_image_path(sample_index)
|
|
173
|
+
image_array = io.read_nifti(image_path, slice_index)
|
|
174
|
+
image_rgb_array = image_array.repeat(3, axis=2)
|
|
175
|
+
return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
|
|
176
|
+
|
|
177
|
+
@override
|
|
178
|
+
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
179
|
+
if self._optimize_mask_loading:
|
|
180
|
+
return self._load_semantic_label_mask(index)
|
|
181
|
+
return self._load_mask(index)
|
|
182
|
+
|
|
183
|
+
@override
|
|
184
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
185
|
+
_, slice_index = self._indices[index]
|
|
186
|
+
return {"slice_index": slice_index}
|
|
187
|
+
|
|
188
|
+
def _load_mask(self, index: int) -> tv_tensors.Mask:
|
|
189
|
+
"""Loads and builds the segmentation mask from NifTi files."""
|
|
190
|
+
sample_index, slice_index = self._indices[index]
|
|
191
|
+
semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
|
|
192
|
+
return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
193
|
+
|
|
194
|
+
def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
|
|
195
|
+
"""Loads the segmentation mask from a semantic label NifTi file."""
|
|
196
|
+
sample_index, slice_index = self._indices[index]
|
|
197
|
+
masks_dir = self._get_masks_dir(sample_index)
|
|
198
|
+
filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
|
|
199
|
+
semantic_labels = io.read_nifti(filename, slice_index)
|
|
200
|
+
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
201
|
+
|
|
202
|
+
def _load_masks_as_semantic_label(
|
|
203
|
+
self, sample_index: int, slice_index: int | None = None
|
|
204
|
+
) -> npt.NDArray[Any]:
|
|
205
|
+
"""Loads binary masks as a semantic label mask.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
sample_index: The data sample index.
|
|
209
|
+
slice_index: Whether to return only a specific slice.
|
|
210
|
+
"""
|
|
211
|
+
masks_dir = self._get_masks_dir(sample_index)
|
|
212
|
+
mask_paths = [os.path.join(masks_dir, label + ".nii.gz") for label in self.classes]
|
|
213
|
+
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
|
|
214
|
+
background_mask = np.zeros_like(binary_masks[0])
|
|
215
|
+
return np.argmax([background_mask] + binary_masks, axis=0)
|
|
216
|
+
|
|
217
|
+
def _export_semantic_label_masks(self) -> None:
|
|
218
|
+
"""Exports the segmentation binary masks (one-hot) to semantic labels."""
|
|
219
|
+
total_samples = len(self._samples_dirs)
|
|
220
|
+
masks_dirs = map(self._get_masks_dir, range(total_samples))
|
|
221
|
+
semantic_labels = [
|
|
222
|
+
(index, os.path.join(directory, "semantic_labels", "masks.nii.gz"))
|
|
223
|
+
for index, directory in enumerate(masks_dirs)
|
|
224
|
+
]
|
|
225
|
+
to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
|
|
226
|
+
|
|
227
|
+
for sample_index, filename in tqdm.tqdm(
|
|
228
|
+
list(to_export),
|
|
229
|
+
desc=">> Exporting optimized semantic masks",
|
|
230
|
+
leave=False,
|
|
231
|
+
):
|
|
232
|
+
semantic_labels = self._load_masks_as_semantic_label(sample_index)
|
|
233
|
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
234
|
+
io.save_array_as_nifti(semantic_labels, filename)
|
|
235
|
+
|
|
236
|
+
def _get_image_path(self, sample_index: int) -> str:
|
|
237
|
+
"""Returns the corresponding image path."""
|
|
238
|
+
sample_dir = self._samples_dirs[sample_index]
|
|
239
|
+
return os.path.join(self._root, sample_dir, "ct.nii.gz")
|
|
240
|
+
|
|
241
|
+
def _get_masks_dir(self, sample_index: int) -> str:
|
|
242
|
+
"""Returns the directory of the corresponding masks."""
|
|
243
|
+
sample_dir = self._samples_dirs[sample_index]
|
|
244
|
+
return os.path.join(self._root, sample_dir, "segmentations")
|
|
245
|
+
|
|
246
|
+
def _get_semantic_labels_filename(self, sample_index: int) -> str:
|
|
247
|
+
"""Returns the semantic label filename."""
|
|
248
|
+
masks_dir = self._get_masks_dir(sample_index)
|
|
249
|
+
return os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
|
|
250
|
+
|
|
251
|
+
def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
|
|
252
|
+
"""Returns the total amount of slices of a sample."""
|
|
253
|
+
image_path = self._get_image_path(sample_index)
|
|
254
|
+
image_shape = io.fetch_nifti_shape(image_path)
|
|
255
|
+
return image_shape[-1]
|
|
256
|
+
|
|
257
|
+
def _fetch_samples_dirs(self) -> List[str]:
|
|
258
|
+
"""Returns the name of all the samples of all the splits of the dataset."""
|
|
259
|
+
sample_filenames = [
|
|
260
|
+
filename
|
|
261
|
+
for filename in os.listdir(self._root)
|
|
262
|
+
if os.path.isdir(os.path.join(self._root, filename))
|
|
263
|
+
]
|
|
264
|
+
return sorted(sample_filenames)
|
|
265
|
+
|
|
266
|
+
def _get_split_indices(self) -> List[int]:
|
|
267
|
+
"""Returns the samples indices that corresponding the dataset split and version."""
|
|
268
|
+
metadata_file = os.path.join(self._root, "meta.csv")
|
|
269
|
+
metadata = io.read_csv(metadata_file, delimiter=";", encoding="utf-8-sig")
|
|
270
|
+
|
|
271
|
+
match self._split:
|
|
272
|
+
case "train":
|
|
273
|
+
image_ids = [item["image_id"] for item in metadata if item["split"] == "train"]
|
|
274
|
+
case "val":
|
|
275
|
+
image_ids = [item["image_id"] for item in metadata if item["split"] == "val"]
|
|
276
|
+
case "test":
|
|
277
|
+
image_ids = [item["image_id"] for item in metadata if item["split"] == "test"]
|
|
278
|
+
case _:
|
|
279
|
+
image_ids = self._samples_dirs
|
|
280
|
+
|
|
281
|
+
return sorted(map(self._samples_dirs.index, image_ids))
|
|
282
|
+
|
|
283
|
+
def _create_indices(self) -> List[Tuple[int, int]]:
|
|
284
|
+
"""Builds the dataset indices for the specified split.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
A list of tuples, where the first value indicates the
|
|
288
|
+
sample index which the second its corresponding slice
|
|
289
|
+
index.
|
|
290
|
+
"""
|
|
291
|
+
indices = [
|
|
292
|
+
(sample_idx, slide_idx)
|
|
293
|
+
for sample_idx in self._get_split_indices()
|
|
294
|
+
for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx))
|
|
295
|
+
if slide_idx % (self._sample_every_n_slices or 1) == 0
|
|
296
|
+
]
|
|
297
|
+
return indices
|
|
298
|
+
|
|
299
|
+
def _download_dataset(self) -> None:
|
|
300
|
+
"""Downloads the dataset."""
|
|
301
|
+
dataset_resources = {
|
|
302
|
+
"small": self._resources_small,
|
|
303
|
+
"full": self._resources_full,
|
|
304
|
+
}
|
|
305
|
+
resources = dataset_resources.get(self._version or "")
|
|
306
|
+
if resources is None:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Can't download data version '{self._version}'. Use 'small' or 'full'."
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
self._print_license()
|
|
312
|
+
for resource in resources:
|
|
313
|
+
if os.path.isdir(self._root):
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
utils.download_and_extract_archive(
|
|
317
|
+
resource.url,
|
|
318
|
+
download_root=self._root,
|
|
319
|
+
filename=resource.filename,
|
|
320
|
+
remove_finished=True,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
def _print_license(self) -> None:
|
|
324
|
+
"""Prints the dataset license."""
|
|
325
|
+
print(f"Dataset license: {self._license}")
|