kaiko-eva 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/datasets/base.py +7 -2
- eva/core/models/modules/head.py +4 -2
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/_recorder.py +4 -1
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +2 -2
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +3 -4
- eva/vision/data/datasets/classification/breakhis.py +3 -4
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +3 -4
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +32 -19
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/METADATA +16 -12
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/RECORD +74 -58
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Segmentation datasets API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.datasets.segmentation.base import ImageSegmentation
|
|
4
3
|
from eva.vision.data.datasets.segmentation.bcss import BCSS
|
|
4
|
+
from eva.vision.data.datasets.segmentation.btcv import BTCV
|
|
5
5
|
from eva.vision.data.datasets.segmentation.consep import CoNSeP
|
|
6
6
|
from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
|
|
7
7
|
from eva.vision.data.datasets.segmentation.lits import LiTS
|
|
@@ -10,8 +10,8 @@ from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
|
|
|
10
10
|
from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
|
|
11
11
|
|
|
12
12
|
__all__ = [
|
|
13
|
-
"ImageSegmentation",
|
|
14
13
|
"BCSS",
|
|
14
|
+
"BTCV",
|
|
15
15
|
"CoNSeP",
|
|
16
16
|
"EmbeddingsSegmentationDataset",
|
|
17
17
|
"LiTS",
|
|
@@ -1,8 +1,12 @@
|
|
|
1
1
|
from typing import Any, Tuple
|
|
2
2
|
|
|
3
3
|
import numpy.typing as npt
|
|
4
|
+
import torch
|
|
5
|
+
from torchvision import tv_tensors
|
|
4
6
|
|
|
7
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
5
8
|
from eva.vision.data.datasets import wsi
|
|
9
|
+
from eva.vision.utils import io
|
|
6
10
|
|
|
7
11
|
|
|
8
12
|
def get_coords_at_index(
|
|
@@ -36,3 +40,46 @@ def extract_mask_patch(
|
|
|
36
40
|
"""
|
|
37
41
|
(x, y), width, height = get_coords_at_index(dataset, index)
|
|
38
42
|
return mask[y : y + height, x : x + width]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def load_volume_tensor(file: str, orientation: str = "PLS") -> eva_tv_tensors.Volume:
|
|
46
|
+
"""Load a volume from NIfTI file as :class:`eva.vision.data.tv_tensors.Volume`.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
file: The path to the NIfTI file.
|
|
50
|
+
orientation: The orientation code to reorient the nifti image.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Volume tensor representing of shape `[T, C, H, W]`.
|
|
54
|
+
"""
|
|
55
|
+
nii = io.read_nifti(file, orientation=orientation)
|
|
56
|
+
array = io.nifti_to_array(nii)
|
|
57
|
+
array_reshaped_tchw = array[None, :, :, :].transpose(3, 0, 1, 2)
|
|
58
|
+
|
|
59
|
+
if nii.affine is None:
|
|
60
|
+
raise ValueError(f"Affine matrix is missing for {file}.")
|
|
61
|
+
affine = torch.tensor(nii.affine[:, [2, 0, 1, 3]], dtype=torch.float32)
|
|
62
|
+
|
|
63
|
+
return eva_tv_tensors.Volume(
|
|
64
|
+
array_reshaped_tchw, affine=affine, dtype=torch.float32
|
|
65
|
+
) # type: ignore
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def load_mask_tensor(
|
|
69
|
+
file: str, volume_file: str | None = None, orientation: str = "PLS"
|
|
70
|
+
) -> tv_tensors.Mask:
|
|
71
|
+
"""Load a volume from NIfTI file as :class:`torchvision.tv_tensors.Mask`.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
file: The path to the NIfTI file containing the mask.
|
|
75
|
+
volume_file: The path to the volume file used as orientation reference in case
|
|
76
|
+
the mask file is missing the pixdim array in the NIfTI header.
|
|
77
|
+
orientation: The orientation code to reorient the nifti image.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Mask tensor of shape `[T, C, H, W]`.
|
|
81
|
+
"""
|
|
82
|
+
nii = io.read_nifti(file, orientation="PLS", orientation_reference=volume_file)
|
|
83
|
+
array = io.nifti_to_array(nii)
|
|
84
|
+
array_reshaped_tchw = array[None, :, :, :].transpose(3, 0, 1, 2)
|
|
85
|
+
return tv_tensors.Mask(array_reshaped_tchw, dtype=torch.long) # type: ignore
|
|
@@ -12,13 +12,13 @@ from torchvision import tv_tensors
|
|
|
12
12
|
from torchvision.transforms.v2 import functional
|
|
13
13
|
from typing_extensions import override
|
|
14
14
|
|
|
15
|
-
from eva.vision.data.datasets import _validators, wsi
|
|
16
|
-
from eva.vision.data.datasets.segmentation import _utils
|
|
15
|
+
from eva.vision.data.datasets import _validators, vision, wsi
|
|
16
|
+
from eva.vision.data.datasets.segmentation import _utils
|
|
17
17
|
from eva.vision.data.wsi.patching import samplers
|
|
18
18
|
from eva.vision.utils import io
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class BCSS(wsi.MultiWsiDataset,
|
|
21
|
+
class BCSS(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
22
22
|
"""Dataset class for BCSS semantic segmentation task.
|
|
23
23
|
|
|
24
24
|
Source: https://github.com/PathologyDataScience/BCSS
|
|
@@ -71,7 +71,6 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
71
71
|
width: Width of the patches to be extracted, in pixels.
|
|
72
72
|
height: Height of the patches to be extracted, in pixels.
|
|
73
73
|
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
74
|
-
backend: The backend to use for reading the whole-slide images.
|
|
75
74
|
transforms: Transforms to apply to the extracted image & mask patches.
|
|
76
75
|
"""
|
|
77
76
|
self._split = split
|
|
@@ -90,7 +89,7 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
90
89
|
overwrite_mpp=0.25,
|
|
91
90
|
backend="pil",
|
|
92
91
|
)
|
|
93
|
-
|
|
92
|
+
vision.VisionDataset.__init__(self, transforms=transforms)
|
|
94
93
|
|
|
95
94
|
@property
|
|
96
95
|
@override
|
|
@@ -129,15 +128,15 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
129
128
|
|
|
130
129
|
@override
|
|
131
130
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
132
|
-
return
|
|
131
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
133
132
|
|
|
134
133
|
@override
|
|
135
|
-
def
|
|
134
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
136
135
|
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
137
136
|
return functional.to_image(image_array)
|
|
138
137
|
|
|
139
138
|
@override
|
|
140
|
-
def
|
|
139
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
141
140
|
path = self._get_mask_path(index)
|
|
142
141
|
mask = io.read_image_as_array(path)
|
|
143
142
|
mask_patch = _utils.extract_mask_patch(mask, self, index)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""BTCV dataset."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import huggingface_hub
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.datasets import utils as data_utils
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
from eva.vision.data.datasets import _utils as _data_utils
|
|
15
|
+
from eva.vision.data.datasets.segmentation import _utils
|
|
16
|
+
from eva.vision.data.datasets.vision import VisionDataset
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BTCV(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
|
|
20
|
+
"""Beyond the Cranial Vault (BTCV) Abdomen dataset.
|
|
21
|
+
|
|
22
|
+
The BTCV dataset comprises abdominal CT acquired at the Vanderbilt
|
|
23
|
+
University Medical Center from metastatic liver cancer patients or
|
|
24
|
+
post-operative ventral hernia patients. The dataset contains one
|
|
25
|
+
background class and thirteen organ classes.
|
|
26
|
+
|
|
27
|
+
More info:
|
|
28
|
+
- Multi-organ Abdominal CT Reference Standard Segmentations
|
|
29
|
+
https://zenodo.org/records/1169361
|
|
30
|
+
- Dataset Split
|
|
31
|
+
https://github.com/Luffy03/Large-Scale-Medical/blob/main/Downstream/monai/BTCV/dataset/dataset_0.json
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
_split_index_ranges = {
|
|
35
|
+
"train": [(0, 24)],
|
|
36
|
+
"val": [(24, 30)],
|
|
37
|
+
None: [(0, 30)],
|
|
38
|
+
}
|
|
39
|
+
"""Sample indices for the dataset splits."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
root: str,
|
|
44
|
+
split: Literal["train", "val"] | None = None,
|
|
45
|
+
download: bool = False,
|
|
46
|
+
transforms: Callable | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initializes the dataset.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
root: Path to the dataset root directory.
|
|
52
|
+
split: Dataset split to use ('train' or 'val').
|
|
53
|
+
If None, it uses the full dataset.
|
|
54
|
+
download: Whether to download the dataset.
|
|
55
|
+
transforms: A callable object for applying data transformations.
|
|
56
|
+
If None, no transformations are applied.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__(transforms=transforms)
|
|
59
|
+
|
|
60
|
+
self._root = root
|
|
61
|
+
self._split = split
|
|
62
|
+
self._download = download
|
|
63
|
+
|
|
64
|
+
self._samples: List[Tuple[str, str]]
|
|
65
|
+
self._indices: List[int]
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
@override
|
|
69
|
+
def classes(self) -> List[str]:
|
|
70
|
+
return [
|
|
71
|
+
"background",
|
|
72
|
+
"spleen",
|
|
73
|
+
"right_kidney",
|
|
74
|
+
"left_kidney",
|
|
75
|
+
"gallbladder",
|
|
76
|
+
"esophagus",
|
|
77
|
+
"liver",
|
|
78
|
+
"stomach",
|
|
79
|
+
"aorta",
|
|
80
|
+
"inferior_vena_cava",
|
|
81
|
+
"portal_and_splenic_vein",
|
|
82
|
+
"pancreas",
|
|
83
|
+
"right_adrenal_gland",
|
|
84
|
+
"left_adrenal_gland",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
@override
|
|
89
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
90
|
+
return {label: index for index, label in enumerate(self.classes)}
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def filename(self, index: int) -> str:
|
|
94
|
+
return os.path.basename(self._samples[self._indices[index]][0])
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def prepare_data(self) -> None:
|
|
98
|
+
if self._download:
|
|
99
|
+
self._download_dataset()
|
|
100
|
+
|
|
101
|
+
@override
|
|
102
|
+
def configure(self) -> None:
|
|
103
|
+
self._samples = self._find_samples()
|
|
104
|
+
self._indices = self._make_indices()
|
|
105
|
+
|
|
106
|
+
@override
|
|
107
|
+
def validate(self) -> None:
|
|
108
|
+
def _valid_sample(index: int) -> bool:
|
|
109
|
+
"""Indicates if the sample files exist and are reachable."""
|
|
110
|
+
volume_file, segmentation_file = self._samples[self._indices[index]]
|
|
111
|
+
return os.path.isfile(volume_file) and os.path.isfile(segmentation_file)
|
|
112
|
+
|
|
113
|
+
if len(self._samples) < len(self._indices):
|
|
114
|
+
raise OSError(f"Dataset is missing {len(self._indices) - len(self._samples)} files.")
|
|
115
|
+
|
|
116
|
+
invalid_samples = [self._samples[i] for i in range(len(self)) if not _valid_sample(i)]
|
|
117
|
+
if invalid_samples:
|
|
118
|
+
raise OSError(
|
|
119
|
+
f"Dataset '{self.__class__.__qualname__}' contains missing or "
|
|
120
|
+
f"corrupted samples ({len(invalid_samples)} in total). "
|
|
121
|
+
f"Examples of missing folders: {str(invalid_samples[:10])[:-1]}, ...]. "
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
@override
|
|
125
|
+
def __getitem__(
|
|
126
|
+
self, index: int
|
|
127
|
+
) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask, dict[str, Any]]:
|
|
128
|
+
volume = self.load_data(index)
|
|
129
|
+
mask = self.load_target(index)
|
|
130
|
+
metadata = self.load_metadata(index) or {}
|
|
131
|
+
volume_tensor, mask_tensor = self._apply_transforms(volume, mask)
|
|
132
|
+
return volume_tensor, mask_tensor, metadata
|
|
133
|
+
|
|
134
|
+
@override
|
|
135
|
+
def __len__(self) -> int:
|
|
136
|
+
return len(self._indices)
|
|
137
|
+
|
|
138
|
+
@override
|
|
139
|
+
def load_data(self, index: int) -> eva_tv_tensors.Volume:
|
|
140
|
+
"""Loads the CT volume for a given sample.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
index: The index of the desired sample.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Tensor representing the CT volume of shape `[T, C, H, W]`.
|
|
147
|
+
"""
|
|
148
|
+
ct_scan_file, _ = self._samples[self._indices[index]]
|
|
149
|
+
return _utils.load_volume_tensor(ct_scan_file)
|
|
150
|
+
|
|
151
|
+
@override
|
|
152
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
153
|
+
"""Loads the segmentation mask for a given sample.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
index: The index of the desired sample.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Tensor representing the segmentation mask of shape `[T, C, H, W]`.
|
|
160
|
+
"""
|
|
161
|
+
ct_scan_file, mask_file = self._samples[self._indices[index]]
|
|
162
|
+
return _utils.load_mask_tensor(mask_file, ct_scan_file)
|
|
163
|
+
|
|
164
|
+
def _apply_transforms(
|
|
165
|
+
self, ct_scan: eva_tv_tensors.Volume, mask: tv_tensors.Mask
|
|
166
|
+
) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask]:
|
|
167
|
+
"""Applies transformations to the provided data.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
ct_scan: The CT volume tensor.
|
|
171
|
+
mask: The segmentation mask tensor.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
A tuple containing the transformed CT and mask tensors.
|
|
175
|
+
"""
|
|
176
|
+
return self._transforms(ct_scan, mask) if self._transforms else (ct_scan, mask)
|
|
177
|
+
|
|
178
|
+
def _find_samples(self) -> list[tuple[str, str]]:
|
|
179
|
+
"""Retrieves the file paths for the CT volumes and segmentation.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
The a list of file path to the CT volumes and segmentation.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def filename_id(filename: str) -> int:
|
|
186
|
+
matches = re.match(r".*(?:\D|^)(\d+)", filename)
|
|
187
|
+
if matches is None:
|
|
188
|
+
raise ValueError(f"Filename '{filename}' is not valid.")
|
|
189
|
+
|
|
190
|
+
return int(matches.group(1))
|
|
191
|
+
|
|
192
|
+
subdir = os.path.join(self._root, "BTCV")
|
|
193
|
+
root = subdir if os.path.isdir(subdir) else self._root
|
|
194
|
+
|
|
195
|
+
volume_files_pattern = os.path.join(root, "imagesTr", "*.nii.gz")
|
|
196
|
+
volume_filenames = glob.glob(volume_files_pattern)
|
|
197
|
+
volume_ids = {filename_id(filename): filename for filename in volume_filenames}
|
|
198
|
+
|
|
199
|
+
segmentation_files_pattern = os.path.join(root, "labelsTr", "*.nii.gz")
|
|
200
|
+
segmentation_filenames = glob.glob(segmentation_files_pattern)
|
|
201
|
+
segmentation_ids = {filename_id(filename): filename for filename in segmentation_filenames}
|
|
202
|
+
|
|
203
|
+
return [
|
|
204
|
+
(volume_ids[file_id], segmentation_ids[file_id])
|
|
205
|
+
for file_id in sorted(volume_ids.keys() & segmentation_ids.keys())
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
def _make_indices(self) -> list[int]:
|
|
209
|
+
"""Builds the dataset indices for the specified split."""
|
|
210
|
+
index_ranges = self._split_index_ranges.get(self._split)
|
|
211
|
+
if index_ranges is None:
|
|
212
|
+
raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
|
|
213
|
+
|
|
214
|
+
return _data_utils.ranges_to_indices(index_ranges)
|
|
215
|
+
|
|
216
|
+
def _download_dataset(self) -> None:
|
|
217
|
+
hf_token = os.getenv("HF_TOKEN")
|
|
218
|
+
if not hf_token:
|
|
219
|
+
raise ValueError("Huggingface token required, please set the HF_TOKEN env variable.")
|
|
220
|
+
|
|
221
|
+
huggingface_hub.snapshot_download(
|
|
222
|
+
"Luffy503/VoCo_Downstream",
|
|
223
|
+
repo_type="dataset",
|
|
224
|
+
token=hf_token,
|
|
225
|
+
local_dir=self._root,
|
|
226
|
+
ignore_patterns=[".git*"],
|
|
227
|
+
allow_patterns=["BTCV.zip"],
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
zip_path = os.path.join(self._root, "BTCV.zip")
|
|
231
|
+
if not os.path.exists(zip_path):
|
|
232
|
+
raise FileNotFoundError(
|
|
233
|
+
f"BTCV.zip not found in {self._root}, something with the download went wrong."
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
data_utils.extract_archive(zip_path, self._root, remove_finished=True)
|
|
@@ -11,13 +11,13 @@ from torchvision import tv_tensors
|
|
|
11
11
|
from torchvision.transforms.v2 import functional
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
-
from eva.vision.data.datasets import _validators, wsi
|
|
15
|
-
from eva.vision.data.datasets.segmentation import _utils
|
|
14
|
+
from eva.vision.data.datasets import _validators, vision, wsi
|
|
15
|
+
from eva.vision.data.datasets.segmentation import _utils
|
|
16
16
|
from eva.vision.data.wsi.patching import samplers
|
|
17
17
|
from eva.vision.utils import io
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class CoNSeP(wsi.MultiWsiDataset,
|
|
20
|
+
class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
21
21
|
"""Dataset class for CoNSeP semantic segmentation task.
|
|
22
22
|
|
|
23
23
|
As in [1], we combine classes 3 (healthy epithelial) & 4 (dysplastic/malignant epithelial)
|
|
@@ -55,7 +55,6 @@ class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
55
55
|
width: Width of the patches to be extracted, in pixels.
|
|
56
56
|
height: Height of the patches to be extracted, in pixels.
|
|
57
57
|
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
58
|
-
backend: The backend to use for reading the whole-slide images.
|
|
59
58
|
transforms: Transforms to apply to the extracted image & mask patches.
|
|
60
59
|
"""
|
|
61
60
|
self._split = split
|
|
@@ -112,15 +111,15 @@ class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
112
111
|
|
|
113
112
|
@override
|
|
114
113
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
115
|
-
return
|
|
114
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
116
115
|
|
|
117
116
|
@override
|
|
118
|
-
def
|
|
117
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
119
118
|
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
120
119
|
return functional.to_image(image_array)
|
|
121
120
|
|
|
122
121
|
@override
|
|
123
|
-
def
|
|
122
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
124
123
|
path = self._get_mask_path(index)
|
|
125
124
|
mask = np.array(io.read_mat(path)["type_map"])
|
|
126
125
|
mask_patch = _utils.extract_mask_patch(mask, self, index)
|
|
@@ -13,12 +13,11 @@ from typing_extensions import override
|
|
|
13
13
|
|
|
14
14
|
from eva.core import utils
|
|
15
15
|
from eva.core.data import splitting
|
|
16
|
-
from eva.vision.data.datasets import _validators
|
|
17
|
-
from eva.vision.data.datasets.segmentation import base
|
|
16
|
+
from eva.vision.data.datasets import _validators, vision
|
|
18
17
|
from eva.vision.utils import io
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
class LiTS(
|
|
20
|
+
class LiTS(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
22
21
|
"""LiTS - Liver Tumor Segmentation Challenge.
|
|
23
22
|
|
|
24
23
|
Webpage: https://competitions.codalab.org/competitions/17094
|
|
@@ -110,21 +109,23 @@ class LiTS(base.ImageSegmentation):
|
|
|
110
109
|
)
|
|
111
110
|
|
|
112
111
|
@override
|
|
113
|
-
def
|
|
112
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
114
113
|
sample_index, slice_index = self._indices[index]
|
|
115
114
|
volume_path = self._volume_files[sample_index]
|
|
116
|
-
|
|
115
|
+
image_nii = io.read_nifti(volume_path, slice_index)
|
|
116
|
+
image_array = io.nifti_to_array(image_nii)
|
|
117
117
|
if self._fix_orientation:
|
|
118
118
|
image_array = self._orientation(image_array, sample_index)
|
|
119
119
|
return tv_tensors.Image(image_array.transpose(2, 0, 1))
|
|
120
120
|
|
|
121
121
|
@override
|
|
122
|
-
def
|
|
122
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
123
123
|
sample_index, slice_index = self._indices[index]
|
|
124
124
|
segmentation_path = self._segmentation_file(sample_index)
|
|
125
|
-
|
|
125
|
+
mask_nii = io.read_nifti(segmentation_path, slice_index)
|
|
126
|
+
mask_array = io.nifti_to_array(mask_nii)
|
|
126
127
|
if self._fix_orientation:
|
|
127
|
-
semantic_labels = self._orientation(
|
|
128
|
+
semantic_labels = self._orientation(mask_array, sample_index)
|
|
128
129
|
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
129
130
|
|
|
130
131
|
def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray:
|
|
@@ -64,7 +64,8 @@ class LiTSBalanced(lits.LiTS):
|
|
|
64
64
|
if sample_idx not in split_indices:
|
|
65
65
|
continue
|
|
66
66
|
|
|
67
|
-
|
|
67
|
+
segmentation_nii = io.read_nifti(self._segmentation_file(sample_idx))
|
|
68
|
+
segmentation = io.nifti_to_array(segmentation_nii)
|
|
68
69
|
tumor_filter = segmentation == 2
|
|
69
70
|
tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0
|
|
70
71
|
|
|
@@ -16,12 +16,11 @@ from torchvision.datasets import utils
|
|
|
16
16
|
from typing_extensions import override
|
|
17
17
|
|
|
18
18
|
from eva.core.utils.progress_bar import tqdm
|
|
19
|
-
from eva.vision.data.datasets import _validators, structs
|
|
20
|
-
from eva.vision.data.datasets.segmentation import base
|
|
19
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
21
20
|
from eva.vision.utils import io
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
class MoNuSAC(
|
|
23
|
+
class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
25
24
|
"""MoNuSAC2020: A Multi-organ Nuclei Segmentation and Classification Challenge.
|
|
26
25
|
|
|
27
26
|
Webpage: https://monusac-2020.grand-challenge.org/
|
|
@@ -112,13 +111,13 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
112
111
|
)
|
|
113
112
|
|
|
114
113
|
@override
|
|
115
|
-
def
|
|
114
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
116
115
|
image_path = self._image_files[index]
|
|
117
116
|
image_rgb_array = io.read_image(image_path)
|
|
118
117
|
return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
|
|
119
118
|
|
|
120
119
|
@override
|
|
121
|
-
def
|
|
120
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
122
121
|
semantic_labels = (
|
|
123
122
|
self._load_semantic_mask_file(index)
|
|
124
123
|
if self._export_masks
|
|
@@ -17,12 +17,12 @@ from typing_extensions import override
|
|
|
17
17
|
|
|
18
18
|
from eva.core.utils import io as core_io
|
|
19
19
|
from eva.core.utils import multiprocessing
|
|
20
|
-
from eva.vision.data.datasets import _validators, structs
|
|
21
|
-
from eva.vision.data.datasets.segmentation import _total_segmentator
|
|
20
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
21
|
+
from eva.vision.data.datasets.segmentation import _total_segmentator
|
|
22
22
|
from eva.vision.utils import io
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class TotalSegmentator2D(
|
|
25
|
+
class TotalSegmentator2D(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
26
26
|
"""TotalSegmentator 2D segmentation dataset."""
|
|
27
27
|
|
|
28
28
|
_expected_dataset_lengths: Dict[str, int] = {
|
|
@@ -206,19 +206,20 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
206
206
|
return len(self._indices)
|
|
207
207
|
|
|
208
208
|
@override
|
|
209
|
-
def
|
|
209
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
210
210
|
sample_index, slice_index = self._indices[index]
|
|
211
211
|
image_path = self._get_image_path(sample_index)
|
|
212
|
-
|
|
212
|
+
image_nii = io.read_nifti(image_path, slice_index)
|
|
213
|
+
image_array = io.nifti_to_array(image_nii)
|
|
213
214
|
image_array = self._fix_orientation(image_array)
|
|
214
215
|
return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))
|
|
215
216
|
|
|
216
217
|
@override
|
|
217
|
-
def
|
|
218
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
218
219
|
if self._optimize_mask_loading:
|
|
219
220
|
mask = self._load_semantic_label_mask(index)
|
|
220
221
|
else:
|
|
221
|
-
mask = self.
|
|
222
|
+
mask = self._load_target(index)
|
|
222
223
|
mask = self._fix_orientation(mask)
|
|
223
224
|
return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore
|
|
224
225
|
|
|
@@ -227,14 +228,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
227
228
|
_, slice_index = self._indices[index]
|
|
228
229
|
return {"slice_index": slice_index}
|
|
229
230
|
|
|
230
|
-
def
|
|
231
|
+
def _load_target(self, index: int) -> npt.NDArray[Any]:
|
|
231
232
|
sample_index, slice_index = self._indices[index]
|
|
232
233
|
return self._load_masks_as_semantic_label(sample_index, slice_index)
|
|
233
234
|
|
|
234
235
|
def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
|
|
235
236
|
"""Loads the segmentation mask from a semantic label NifTi file."""
|
|
236
237
|
sample_index, slice_index = self._indices[index]
|
|
237
|
-
|
|
238
|
+
nii = io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
|
|
239
|
+
return io.nifti_to_array(nii)
|
|
238
240
|
|
|
239
241
|
def _load_masks_as_semantic_label(
|
|
240
242
|
self, sample_index: int, slice_index: int | None = None
|
|
@@ -248,7 +250,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
248
250
|
masks_dir = self._get_masks_dir(sample_index)
|
|
249
251
|
classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
|
|
250
252
|
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
|
|
251
|
-
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
|
|
253
|
+
binary_masks = [io.nifti_to_array(io.read_nifti(path, slice_index)) for path in mask_paths]
|
|
252
254
|
|
|
253
255
|
if self._class_mappings:
|
|
254
256
|
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
|
|
@@ -1,17 +1,92 @@
|
|
|
1
1
|
"""Vision Dataset base class."""
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import Generic, TypeVar
|
|
4
|
+
from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar
|
|
5
5
|
|
|
6
6
|
from eva.core.data.datasets import base
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
"""The data
|
|
8
|
+
InputType = TypeVar("InputType")
|
|
9
|
+
"""The input data type."""
|
|
10
10
|
|
|
11
|
+
TargetType = TypeVar("TargetType")
|
|
12
|
+
"""The target data type."""
|
|
11
13
|
|
|
12
|
-
|
|
14
|
+
|
|
15
|
+
class VisionDataset(
|
|
16
|
+
base.MapDataset[Tuple[InputType, TargetType, Dict[str, Any]]],
|
|
17
|
+
abc.ABC,
|
|
18
|
+
Generic[InputType, TargetType],
|
|
19
|
+
):
|
|
13
20
|
"""Base dataset class for vision tasks."""
|
|
14
21
|
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
transforms: Callable | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initializes the dataset.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
transforms: A function/transform which returns a transformed
|
|
30
|
+
version of the raw data samples.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
self._transforms = transforms
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def classes(self) -> List[str] | None:
|
|
38
|
+
"""Returns the list with names of the dataset names."""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def class_to_idx(self) -> Dict[str, int] | None:
|
|
42
|
+
"""Returns a mapping of the class name to its target index."""
|
|
43
|
+
|
|
44
|
+
def __getitem__(self, index: int) -> Tuple[InputType, TargetType, Dict[str, Any]]:
|
|
45
|
+
"""Returns the `index`'th data sample.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
index: The index of the data sample to load.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A tuple with the image, the target and the metadata.
|
|
52
|
+
"""
|
|
53
|
+
image = self.load_data(index)
|
|
54
|
+
target = self.load_target(index)
|
|
55
|
+
image, target = self._apply_transforms(image, target)
|
|
56
|
+
return image, target, self.load_metadata(index) or {}
|
|
57
|
+
|
|
58
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
59
|
+
"""Returns the dataset metadata.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
index: The index of the data sample to return the metadata of.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The sample metadata.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
@abc.abstractmethod
|
|
69
|
+
def load_data(self, index: int) -> InputType:
|
|
70
|
+
"""Returns the `index`'th data sample.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
index: The index of the data sample to load.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The sample data.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
@abc.abstractmethod
|
|
80
|
+
def load_target(self, index: int) -> TargetType:
|
|
81
|
+
"""Returns the `index`'th target sample.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
index: The index of the data sample to load.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The sample target.
|
|
88
|
+
"""
|
|
89
|
+
|
|
15
90
|
@abc.abstractmethod
|
|
16
91
|
def filename(self, index: int) -> str:
|
|
17
92
|
"""Returns the filename of the `index`'th data sample.
|
|
@@ -24,3 +99,19 @@ class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
|
|
|
24
99
|
Returns:
|
|
25
100
|
The filename of the `index`'th data sample.
|
|
26
101
|
"""
|
|
102
|
+
|
|
103
|
+
def _apply_transforms(
|
|
104
|
+
self, image: InputType, target: TargetType
|
|
105
|
+
) -> Tuple[InputType, TargetType]:
|
|
106
|
+
"""Applies the transforms to the provided data and returns them.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
image: The desired image.
|
|
110
|
+
target: The target of the image.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A tuple with the image and the target transformed.
|
|
114
|
+
"""
|
|
115
|
+
if self._transforms is not None:
|
|
116
|
+
image, target = self._transforms(image, target)
|
|
117
|
+
return image, target
|