kaiko-eva 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/models/modules/head.py +4 -2
  3. eva/core/models/modules/typings.py +2 -2
  4. eva/core/models/transforms/__init__.py +2 -1
  5. eva/core/models/transforms/as_discrete.py +57 -0
  6. eva/core/models/wrappers/_utils.py +121 -1
  7. eva/core/utils/suppress_logs.py +28 -0
  8. eva/vision/data/__init__.py +2 -2
  9. eva/vision/data/dataloaders/__init__.py +5 -0
  10. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  11. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  12. eva/vision/data/datasets/__init__.py +2 -2
  13. eva/vision/data/datasets/classification/bach.py +3 -4
  14. eva/vision/data/datasets/classification/bracs.py +3 -4
  15. eva/vision/data/datasets/classification/breakhis.py +3 -4
  16. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  17. eva/vision/data/datasets/classification/crc.py +3 -4
  18. eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
  19. eva/vision/data/datasets/classification/mhist.py +3 -4
  20. eva/vision/data/datasets/classification/panda.py +4 -5
  21. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  22. eva/vision/data/datasets/classification/unitopatho.py +3 -4
  23. eva/vision/data/datasets/classification/wsi.py +6 -5
  24. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  25. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  26. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  27. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  28. eva/vision/data/datasets/segmentation/consep.py +6 -7
  29. eva/vision/data/datasets/segmentation/lits.py +9 -8
  30. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  31. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  32. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  33. eva/vision/data/datasets/vision.py +95 -4
  34. eva/vision/data/datasets/wsi.py +5 -5
  35. eva/vision/data/transforms/__init__.py +22 -3
  36. eva/vision/data/transforms/common/__init__.py +1 -2
  37. eva/vision/data/transforms/croppad/__init__.py +11 -0
  38. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  39. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  40. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  41. eva/vision/data/transforms/intensity/__init__.py +11 -0
  42. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  43. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  44. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  45. eva/vision/data/transforms/spatial/__init__.py +7 -0
  46. eva/vision/data/transforms/spatial/flip.py +72 -0
  47. eva/vision/data/transforms/spatial/rotate.py +53 -0
  48. eva/vision/data/transforms/spatial/spacing.py +69 -0
  49. eva/vision/data/transforms/utility/__init__.py +5 -0
  50. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  51. eva/vision/data/tv_tensors/__init__.py +5 -0
  52. eva/vision/data/tv_tensors/volume.py +61 -0
  53. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  54. eva/vision/models/modules/semantic_segmentation.py +28 -20
  55. eva/vision/models/networks/backbones/__init__.py +9 -2
  56. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  57. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  58. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  59. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  60. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  61. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  62. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  63. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  64. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  65. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  66. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  67. eva/vision/utils/io/__init__.py +2 -0
  68. eva/vision/utils/io/nifti.py +91 -11
  69. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  70. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +73 -57
  71. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  72. eva/vision/data/datasets/classification/base.py +0 -96
  73. eva/vision/data/datasets/segmentation/base.py +0 -96
  74. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  75. eva/vision/data/transforms/normalization/__init__.py +0 -6
  76. eva/vision/data/transforms/normalization/clamp.py +0 -43
  77. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  78. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  79. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  80. eva/vision/metrics/segmentation/BUILD +0 -1
  81. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  82. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  83. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  84. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
1
1
  """Segmentation datasets API."""
2
2
 
3
- from eva.vision.data.datasets.segmentation.base import ImageSegmentation
4
3
  from eva.vision.data.datasets.segmentation.bcss import BCSS
4
+ from eva.vision.data.datasets.segmentation.btcv import BTCV
5
5
  from eva.vision.data.datasets.segmentation.consep import CoNSeP
6
6
  from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
7
7
  from eva.vision.data.datasets.segmentation.lits import LiTS
@@ -10,8 +10,8 @@ from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
10
10
  from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
11
11
 
12
12
  __all__ = [
13
- "ImageSegmentation",
14
13
  "BCSS",
14
+ "BTCV",
15
15
  "CoNSeP",
16
16
  "EmbeddingsSegmentationDataset",
17
17
  "LiTS",
@@ -1,8 +1,12 @@
1
1
  from typing import Any, Tuple
2
2
 
3
3
  import numpy.typing as npt
4
+ import torch
5
+ from torchvision import tv_tensors
4
6
 
7
+ from eva.vision.data import tv_tensors as eva_tv_tensors
5
8
  from eva.vision.data.datasets import wsi
9
+ from eva.vision.utils import io
6
10
 
7
11
 
8
12
  def get_coords_at_index(
@@ -36,3 +40,46 @@ def extract_mask_patch(
36
40
  """
37
41
  (x, y), width, height = get_coords_at_index(dataset, index)
38
42
  return mask[y : y + height, x : x + width]
43
+
44
+
45
+ def load_volume_tensor(file: str, orientation: str = "PLS") -> eva_tv_tensors.Volume:
46
+ """Load a volume from NIfTI file as :class:`eva.vision.data.tv_tensors.Volume`.
47
+
48
+ Args:
49
+ file: The path to the NIfTI file.
50
+ orientation: The orientation code to reorient the nifti image.
51
+
52
+ Returns:
53
+ Volume tensor representing of shape `[T, C, H, W]`.
54
+ """
55
+ nii = io.read_nifti(file, orientation=orientation)
56
+ array = io.nifti_to_array(nii)
57
+ array_reshaped_tchw = array[None, :, :, :].transpose(3, 0, 1, 2)
58
+
59
+ if nii.affine is None:
60
+ raise ValueError(f"Affine matrix is missing for {file}.")
61
+ affine = torch.tensor(nii.affine[:, [2, 0, 1, 3]], dtype=torch.float32)
62
+
63
+ return eva_tv_tensors.Volume(
64
+ array_reshaped_tchw, affine=affine, dtype=torch.float32
65
+ ) # type: ignore
66
+
67
+
68
+ def load_mask_tensor(
69
+ file: str, volume_file: str | None = None, orientation: str = "PLS"
70
+ ) -> tv_tensors.Mask:
71
+ """Load a volume from NIfTI file as :class:`torchvision.tv_tensors.Mask`.
72
+
73
+ Args:
74
+ file: The path to the NIfTI file containing the mask.
75
+ volume_file: The path to the volume file used as orientation reference in case
76
+ the mask file is missing the pixdim array in the NIfTI header.
77
+ orientation: The orientation code to reorient the nifti image.
78
+
79
+ Returns:
80
+ Mask tensor of shape `[T, C, H, W]`.
81
+ """
82
+ nii = io.read_nifti(file, orientation="PLS", orientation_reference=volume_file)
83
+ array = io.nifti_to_array(nii)
84
+ array_reshaped_tchw = array[None, :, :, :].transpose(3, 0, 1, 2)
85
+ return tv_tensors.Mask(array_reshaped_tchw, dtype=torch.long) # type: ignore
@@ -12,13 +12,13 @@ from torchvision import tv_tensors
12
12
  from torchvision.transforms.v2 import functional
13
13
  from typing_extensions import override
14
14
 
15
- from eva.vision.data.datasets import _validators, wsi
16
- from eva.vision.data.datasets.segmentation import _utils, base
15
+ from eva.vision.data.datasets import _validators, vision, wsi
16
+ from eva.vision.data.datasets.segmentation import _utils
17
17
  from eva.vision.data.wsi.patching import samplers
18
18
  from eva.vision.utils import io
19
19
 
20
20
 
21
- class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
21
+ class BCSS(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
22
22
  """Dataset class for BCSS semantic segmentation task.
23
23
 
24
24
  Source: https://github.com/PathologyDataScience/BCSS
@@ -71,7 +71,6 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
71
71
  width: Width of the patches to be extracted, in pixels.
72
72
  height: Height of the patches to be extracted, in pixels.
73
73
  target_mpp: Target microns per pixel (mpp) for the patches.
74
- backend: The backend to use for reading the whole-slide images.
75
74
  transforms: Transforms to apply to the extracted image & mask patches.
76
75
  """
77
76
  self._split = split
@@ -90,7 +89,7 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
90
89
  overwrite_mpp=0.25,
91
90
  backend="pil",
92
91
  )
93
- base.ImageSegmentation.__init__(self, transforms=transforms)
92
+ vision.VisionDataset.__init__(self, transforms=transforms)
94
93
 
95
94
  @property
96
95
  @override
@@ -129,15 +128,15 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
129
128
 
130
129
  @override
131
130
  def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
132
- return base.ImageSegmentation.__getitem__(self, index)
131
+ return vision.VisionDataset.__getitem__(self, index)
133
132
 
134
133
  @override
135
- def load_image(self, index: int) -> tv_tensors.Image:
134
+ def load_data(self, index: int) -> tv_tensors.Image:
136
135
  image_array = wsi.MultiWsiDataset.__getitem__(self, index)
137
136
  return functional.to_image(image_array)
138
137
 
139
138
  @override
140
- def load_mask(self, index: int) -> tv_tensors.Mask:
139
+ def load_target(self, index: int) -> tv_tensors.Mask:
141
140
  path = self._get_mask_path(index)
142
141
  mask = io.read_image_as_array(path)
143
142
  mask_patch = _utils.extract_mask_patch(mask, self, index)
@@ -0,0 +1,236 @@
1
+ """BTCV dataset."""
2
+
3
+ import glob
4
+ import os
5
+ import re
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import huggingface_hub
9
+ from torchvision import tv_tensors
10
+ from torchvision.datasets import utils as data_utils
11
+ from typing_extensions import override
12
+
13
+ from eva.vision.data import tv_tensors as eva_tv_tensors
14
+ from eva.vision.data.datasets import _utils as _data_utils
15
+ from eva.vision.data.datasets.segmentation import _utils
16
+ from eva.vision.data.datasets.vision import VisionDataset
17
+
18
+
19
+ class BTCV(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
20
+ """Beyond the Cranial Vault (BTCV) Abdomen dataset.
21
+
22
+ The BTCV dataset comprises abdominal CT acquired at the Vanderbilt
23
+ University Medical Center from metastatic liver cancer patients or
24
+ post-operative ventral hernia patients. The dataset contains one
25
+ background class and thirteen organ classes.
26
+
27
+ More info:
28
+ - Multi-organ Abdominal CT Reference Standard Segmentations
29
+ https://zenodo.org/records/1169361
30
+ - Dataset Split
31
+ https://github.com/Luffy03/Large-Scale-Medical/blob/main/Downstream/monai/BTCV/dataset/dataset_0.json
32
+ """
33
+
34
+ _split_index_ranges = {
35
+ "train": [(0, 24)],
36
+ "val": [(24, 30)],
37
+ None: [(0, 30)],
38
+ }
39
+ """Sample indices for the dataset splits."""
40
+
41
+ def __init__(
42
+ self,
43
+ root: str,
44
+ split: Literal["train", "val"] | None = None,
45
+ download: bool = False,
46
+ transforms: Callable | None = None,
47
+ ) -> None:
48
+ """Initializes the dataset.
49
+
50
+ Args:
51
+ root: Path to the dataset root directory.
52
+ split: Dataset split to use ('train' or 'val').
53
+ If None, it uses the full dataset.
54
+ download: Whether to download the dataset.
55
+ transforms: A callable object for applying data transformations.
56
+ If None, no transformations are applied.
57
+ """
58
+ super().__init__(transforms=transforms)
59
+
60
+ self._root = root
61
+ self._split = split
62
+ self._download = download
63
+
64
+ self._samples: List[Tuple[str, str]]
65
+ self._indices: List[int]
66
+
67
+ @property
68
+ @override
69
+ def classes(self) -> List[str]:
70
+ return [
71
+ "background",
72
+ "spleen",
73
+ "right_kidney",
74
+ "left_kidney",
75
+ "gallbladder",
76
+ "esophagus",
77
+ "liver",
78
+ "stomach",
79
+ "aorta",
80
+ "inferior_vena_cava",
81
+ "portal_and_splenic_vein",
82
+ "pancreas",
83
+ "right_adrenal_gland",
84
+ "left_adrenal_gland",
85
+ ]
86
+
87
+ @property
88
+ @override
89
+ def class_to_idx(self) -> Dict[str, int]:
90
+ return {label: index for index, label in enumerate(self.classes)}
91
+
92
+ @override
93
+ def filename(self, index: int) -> str:
94
+ return os.path.basename(self._samples[self._indices[index]][0])
95
+
96
+ @override
97
+ def prepare_data(self) -> None:
98
+ if self._download:
99
+ self._download_dataset()
100
+
101
+ @override
102
+ def configure(self) -> None:
103
+ self._samples = self._find_samples()
104
+ self._indices = self._make_indices()
105
+
106
+ @override
107
+ def validate(self) -> None:
108
+ def _valid_sample(index: int) -> bool:
109
+ """Indicates if the sample files exist and are reachable."""
110
+ volume_file, segmentation_file = self._samples[self._indices[index]]
111
+ return os.path.isfile(volume_file) and os.path.isfile(segmentation_file)
112
+
113
+ if len(self._samples) < len(self._indices):
114
+ raise OSError(f"Dataset is missing {len(self._indices) - len(self._samples)} files.")
115
+
116
+ invalid_samples = [self._samples[i] for i in range(len(self)) if not _valid_sample(i)]
117
+ if invalid_samples:
118
+ raise OSError(
119
+ f"Dataset '{self.__class__.__qualname__}' contains missing or "
120
+ f"corrupted samples ({len(invalid_samples)} in total). "
121
+ f"Examples of missing folders: {str(invalid_samples[:10])[:-1]}, ...]. "
122
+ )
123
+
124
+ @override
125
+ def __getitem__(
126
+ self, index: int
127
+ ) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask, dict[str, Any]]:
128
+ volume = self.load_data(index)
129
+ mask = self.load_target(index)
130
+ metadata = self.load_metadata(index) or {}
131
+ volume_tensor, mask_tensor = self._apply_transforms(volume, mask)
132
+ return volume_tensor, mask_tensor, metadata
133
+
134
+ @override
135
+ def __len__(self) -> int:
136
+ return len(self._indices)
137
+
138
+ @override
139
+ def load_data(self, index: int) -> eva_tv_tensors.Volume:
140
+ """Loads the CT volume for a given sample.
141
+
142
+ Args:
143
+ index: The index of the desired sample.
144
+
145
+ Returns:
146
+ Tensor representing the CT volume of shape `[T, C, H, W]`.
147
+ """
148
+ ct_scan_file, _ = self._samples[self._indices[index]]
149
+ return _utils.load_volume_tensor(ct_scan_file)
150
+
151
+ @override
152
+ def load_target(self, index: int) -> tv_tensors.Mask:
153
+ """Loads the segmentation mask for a given sample.
154
+
155
+ Args:
156
+ index: The index of the desired sample.
157
+
158
+ Returns:
159
+ Tensor representing the segmentation mask of shape `[T, C, H, W]`.
160
+ """
161
+ ct_scan_file, mask_file = self._samples[self._indices[index]]
162
+ return _utils.load_mask_tensor(mask_file, ct_scan_file)
163
+
164
+ def _apply_transforms(
165
+ self, ct_scan: eva_tv_tensors.Volume, mask: tv_tensors.Mask
166
+ ) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask]:
167
+ """Applies transformations to the provided data.
168
+
169
+ Args:
170
+ ct_scan: The CT volume tensor.
171
+ mask: The segmentation mask tensor.
172
+
173
+ Returns:
174
+ A tuple containing the transformed CT and mask tensors.
175
+ """
176
+ return self._transforms(ct_scan, mask) if self._transforms else (ct_scan, mask)
177
+
178
+ def _find_samples(self) -> list[tuple[str, str]]:
179
+ """Retrieves the file paths for the CT volumes and segmentation.
180
+
181
+ Returns:
182
+ The a list of file path to the CT volumes and segmentation.
183
+ """
184
+
185
+ def filename_id(filename: str) -> int:
186
+ matches = re.match(r".*(?:\D|^)(\d+)", filename)
187
+ if matches is None:
188
+ raise ValueError(f"Filename '{filename}' is not valid.")
189
+
190
+ return int(matches.group(1))
191
+
192
+ subdir = os.path.join(self._root, "BTCV")
193
+ root = subdir if os.path.isdir(subdir) else self._root
194
+
195
+ volume_files_pattern = os.path.join(root, "imagesTr", "*.nii.gz")
196
+ volume_filenames = glob.glob(volume_files_pattern)
197
+ volume_ids = {filename_id(filename): filename for filename in volume_filenames}
198
+
199
+ segmentation_files_pattern = os.path.join(root, "labelsTr", "*.nii.gz")
200
+ segmentation_filenames = glob.glob(segmentation_files_pattern)
201
+ segmentation_ids = {filename_id(filename): filename for filename in segmentation_filenames}
202
+
203
+ return [
204
+ (volume_ids[file_id], segmentation_ids[file_id])
205
+ for file_id in sorted(volume_ids.keys() & segmentation_ids.keys())
206
+ ]
207
+
208
+ def _make_indices(self) -> list[int]:
209
+ """Builds the dataset indices for the specified split."""
210
+ index_ranges = self._split_index_ranges.get(self._split)
211
+ if index_ranges is None:
212
+ raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
213
+
214
+ return _data_utils.ranges_to_indices(index_ranges)
215
+
216
+ def _download_dataset(self) -> None:
217
+ hf_token = os.getenv("HF_TOKEN")
218
+ if not hf_token:
219
+ raise ValueError("Huggingface token required, please set the HF_TOKEN env variable.")
220
+
221
+ huggingface_hub.snapshot_download(
222
+ "Luffy503/VoCo_Downstream",
223
+ repo_type="dataset",
224
+ token=hf_token,
225
+ local_dir=self._root,
226
+ ignore_patterns=[".git*"],
227
+ allow_patterns=["BTCV.zip"],
228
+ )
229
+
230
+ zip_path = os.path.join(self._root, "BTCV.zip")
231
+ if not os.path.exists(zip_path):
232
+ raise FileNotFoundError(
233
+ f"BTCV.zip not found in {self._root}, something with the download went wrong."
234
+ )
235
+
236
+ data_utils.extract_archive(zip_path, self._root, remove_finished=True)
@@ -11,13 +11,13 @@ from torchvision import tv_tensors
11
11
  from torchvision.transforms.v2 import functional
12
12
  from typing_extensions import override
13
13
 
14
- from eva.vision.data.datasets import _validators, wsi
15
- from eva.vision.data.datasets.segmentation import _utils, base
14
+ from eva.vision.data.datasets import _validators, vision, wsi
15
+ from eva.vision.data.datasets.segmentation import _utils
16
16
  from eva.vision.data.wsi.patching import samplers
17
17
  from eva.vision.utils import io
18
18
 
19
19
 
20
- class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
20
+ class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
21
21
  """Dataset class for CoNSeP semantic segmentation task.
22
22
 
23
23
  As in [1], we combine classes 3 (healthy epithelial) & 4 (dysplastic/malignant epithelial)
@@ -55,7 +55,6 @@ class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
55
55
  width: Width of the patches to be extracted, in pixels.
56
56
  height: Height of the patches to be extracted, in pixels.
57
57
  target_mpp: Target microns per pixel (mpp) for the patches.
58
- backend: The backend to use for reading the whole-slide images.
59
58
  transforms: Transforms to apply to the extracted image & mask patches.
60
59
  """
61
60
  self._split = split
@@ -112,15 +111,15 @@ class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
112
111
 
113
112
  @override
114
113
  def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
115
- return base.ImageSegmentation.__getitem__(self, index)
114
+ return vision.VisionDataset.__getitem__(self, index)
116
115
 
117
116
  @override
118
- def load_image(self, index: int) -> tv_tensors.Image:
117
+ def load_data(self, index: int) -> tv_tensors.Image:
119
118
  image_array = wsi.MultiWsiDataset.__getitem__(self, index)
120
119
  return functional.to_image(image_array)
121
120
 
122
121
  @override
123
- def load_mask(self, index: int) -> tv_tensors.Mask:
122
+ def load_target(self, index: int) -> tv_tensors.Mask:
124
123
  path = self._get_mask_path(index)
125
124
  mask = np.array(io.read_mat(path)["type_map"])
126
125
  mask_patch = _utils.extract_mask_patch(mask, self, index)
@@ -13,12 +13,11 @@ from typing_extensions import override
13
13
 
14
14
  from eva.core import utils
15
15
  from eva.core.data import splitting
16
- from eva.vision.data.datasets import _validators
17
- from eva.vision.data.datasets.segmentation import base
16
+ from eva.vision.data.datasets import _validators, vision
18
17
  from eva.vision.utils import io
19
18
 
20
19
 
21
- class LiTS(base.ImageSegmentation):
20
+ class LiTS(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
22
21
  """LiTS - Liver Tumor Segmentation Challenge.
23
22
 
24
23
  Webpage: https://competitions.codalab.org/competitions/17094
@@ -110,21 +109,23 @@ class LiTS(base.ImageSegmentation):
110
109
  )
111
110
 
112
111
  @override
113
- def load_image(self, index: int) -> tv_tensors.Image:
112
+ def load_data(self, index: int) -> tv_tensors.Image:
114
113
  sample_index, slice_index = self._indices[index]
115
114
  volume_path = self._volume_files[sample_index]
116
- image_array = io.read_nifti(volume_path, slice_index)
115
+ image_nii = io.read_nifti(volume_path, slice_index)
116
+ image_array = io.nifti_to_array(image_nii)
117
117
  if self._fix_orientation:
118
118
  image_array = self._orientation(image_array, sample_index)
119
119
  return tv_tensors.Image(image_array.transpose(2, 0, 1))
120
120
 
121
121
  @override
122
- def load_mask(self, index: int) -> tv_tensors.Mask:
122
+ def load_target(self, index: int) -> tv_tensors.Mask:
123
123
  sample_index, slice_index = self._indices[index]
124
124
  segmentation_path = self._segmentation_file(sample_index)
125
- semantic_labels = io.read_nifti(segmentation_path, slice_index)
125
+ mask_nii = io.read_nifti(segmentation_path, slice_index)
126
+ mask_array = io.nifti_to_array(mask_nii)
126
127
  if self._fix_orientation:
127
- semantic_labels = self._orientation(semantic_labels, sample_index)
128
+ semantic_labels = self._orientation(mask_array, sample_index)
128
129
  return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
129
130
 
130
131
  def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray:
@@ -64,7 +64,8 @@ class LiTSBalanced(lits.LiTS):
64
64
  if sample_idx not in split_indices:
65
65
  continue
66
66
 
67
- segmentation = io.read_nifti(self._segmentation_file(sample_idx))
67
+ segmentation_nii = io.read_nifti(self._segmentation_file(sample_idx))
68
+ segmentation = io.nifti_to_array(segmentation_nii)
68
69
  tumor_filter = segmentation == 2
69
70
  tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0
70
71
 
@@ -16,12 +16,11 @@ from torchvision.datasets import utils
16
16
  from typing_extensions import override
17
17
 
18
18
  from eva.core.utils.progress_bar import tqdm
19
- from eva.vision.data.datasets import _validators, structs
20
- from eva.vision.data.datasets.segmentation import base
19
+ from eva.vision.data.datasets import _validators, structs, vision
21
20
  from eva.vision.utils import io
22
21
 
23
22
 
24
- class MoNuSAC(base.ImageSegmentation):
23
+ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
25
24
  """MoNuSAC2020: A Multi-organ Nuclei Segmentation and Classification Challenge.
26
25
 
27
26
  Webpage: https://monusac-2020.grand-challenge.org/
@@ -112,13 +111,13 @@ class MoNuSAC(base.ImageSegmentation):
112
111
  )
113
112
 
114
113
  @override
115
- def load_image(self, index: int) -> tv_tensors.Image:
114
+ def load_data(self, index: int) -> tv_tensors.Image:
116
115
  image_path = self._image_files[index]
117
116
  image_rgb_array = io.read_image(image_path)
118
117
  return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
119
118
 
120
119
  @override
121
- def load_mask(self, index: int) -> tv_tensors.Mask:
120
+ def load_target(self, index: int) -> tv_tensors.Mask:
122
121
  semantic_labels = (
123
122
  self._load_semantic_mask_file(index)
124
123
  if self._export_masks
@@ -17,12 +17,12 @@ from typing_extensions import override
17
17
 
18
18
  from eva.core.utils import io as core_io
19
19
  from eva.core.utils import multiprocessing
20
- from eva.vision.data.datasets import _validators, structs
21
- from eva.vision.data.datasets.segmentation import _total_segmentator, base
20
+ from eva.vision.data.datasets import _validators, structs, vision
21
+ from eva.vision.data.datasets.segmentation import _total_segmentator
22
22
  from eva.vision.utils import io
23
23
 
24
24
 
25
- class TotalSegmentator2D(base.ImageSegmentation):
25
+ class TotalSegmentator2D(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
26
26
  """TotalSegmentator 2D segmentation dataset."""
27
27
 
28
28
  _expected_dataset_lengths: Dict[str, int] = {
@@ -206,19 +206,20 @@ class TotalSegmentator2D(base.ImageSegmentation):
206
206
  return len(self._indices)
207
207
 
208
208
  @override
209
- def load_image(self, index: int) -> tv_tensors.Image:
209
+ def load_data(self, index: int) -> tv_tensors.Image:
210
210
  sample_index, slice_index = self._indices[index]
211
211
  image_path = self._get_image_path(sample_index)
212
- image_array = io.read_nifti(image_path, slice_index)
212
+ image_nii = io.read_nifti(image_path, slice_index)
213
+ image_array = io.nifti_to_array(image_nii)
213
214
  image_array = self._fix_orientation(image_array)
214
215
  return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))
215
216
 
216
217
  @override
217
- def load_mask(self, index: int) -> tv_tensors.Mask:
218
+ def load_target(self, index: int) -> tv_tensors.Mask:
218
219
  if self._optimize_mask_loading:
219
220
  mask = self._load_semantic_label_mask(index)
220
221
  else:
221
- mask = self._load_mask(index)
222
+ mask = self._load_target(index)
222
223
  mask = self._fix_orientation(mask)
223
224
  return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore
224
225
 
@@ -227,14 +228,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
227
228
  _, slice_index = self._indices[index]
228
229
  return {"slice_index": slice_index}
229
230
 
230
- def _load_mask(self, index: int) -> npt.NDArray[Any]:
231
+ def _load_target(self, index: int) -> npt.NDArray[Any]:
231
232
  sample_index, slice_index = self._indices[index]
232
233
  return self._load_masks_as_semantic_label(sample_index, slice_index)
233
234
 
234
235
  def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
235
236
  """Loads the segmentation mask from a semantic label NifTi file."""
236
237
  sample_index, slice_index = self._indices[index]
237
- return io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
238
+ nii = io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
239
+ return io.nifti_to_array(nii)
238
240
 
239
241
  def _load_masks_as_semantic_label(
240
242
  self, sample_index: int, slice_index: int | None = None
@@ -248,7 +250,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
248
250
  masks_dir = self._get_masks_dir(sample_index)
249
251
  classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
250
252
  mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
251
- binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
253
+ binary_masks = [io.nifti_to_array(io.read_nifti(path, slice_index)) for path in mask_paths]
252
254
 
253
255
  if self._class_mappings:
254
256
  mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
@@ -1,17 +1,92 @@
1
1
  """Vision Dataset base class."""
2
2
 
3
3
  import abc
4
- from typing import Generic, TypeVar
4
+ from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar
5
5
 
6
6
  from eva.core.data.datasets import base
7
7
 
8
- DataSample = TypeVar("DataSample")
9
- """The data sample type."""
8
+ InputType = TypeVar("InputType")
9
+ """The input data type."""
10
10
 
11
+ TargetType = TypeVar("TargetType")
12
+ """The target data type."""
11
13
 
12
- class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
14
+
15
+ class VisionDataset(
16
+ base.MapDataset[Tuple[InputType, TargetType, Dict[str, Any]]],
17
+ abc.ABC,
18
+ Generic[InputType, TargetType],
19
+ ):
13
20
  """Base dataset class for vision tasks."""
14
21
 
22
+ def __init__(
23
+ self,
24
+ transforms: Callable | None = None,
25
+ ) -> None:
26
+ """Initializes the dataset.
27
+
28
+ Args:
29
+ transforms: A function/transform which returns a transformed
30
+ version of the raw data samples.
31
+ """
32
+ super().__init__()
33
+
34
+ self._transforms = transforms
35
+
36
+ @property
37
+ def classes(self) -> List[str] | None:
38
+ """Returns the list with names of the dataset names."""
39
+
40
+ @property
41
+ def class_to_idx(self) -> Dict[str, int] | None:
42
+ """Returns a mapping of the class name to its target index."""
43
+
44
+ def __getitem__(self, index: int) -> Tuple[InputType, TargetType, Dict[str, Any]]:
45
+ """Returns the `index`'th data sample.
46
+
47
+ Args:
48
+ index: The index of the data sample to load.
49
+
50
+ Returns:
51
+ A tuple with the image, the target and the metadata.
52
+ """
53
+ image = self.load_data(index)
54
+ target = self.load_target(index)
55
+ image, target = self._apply_transforms(image, target)
56
+ return image, target, self.load_metadata(index) or {}
57
+
58
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
59
+ """Returns the dataset metadata.
60
+
61
+ Args:
62
+ index: The index of the data sample to return the metadata of.
63
+
64
+ Returns:
65
+ The sample metadata.
66
+ """
67
+
68
+ @abc.abstractmethod
69
+ def load_data(self, index: int) -> InputType:
70
+ """Returns the `index`'th data sample.
71
+
72
+ Args:
73
+ index: The index of the data sample to load.
74
+
75
+ Returns:
76
+ The sample data.
77
+ """
78
+
79
+ @abc.abstractmethod
80
+ def load_target(self, index: int) -> TargetType:
81
+ """Returns the `index`'th target sample.
82
+
83
+ Args:
84
+ index: The index of the data sample to load.
85
+
86
+ Returns:
87
+ The sample target.
88
+ """
89
+
15
90
  @abc.abstractmethod
16
91
  def filename(self, index: int) -> str:
17
92
  """Returns the filename of the `index`'th data sample.
@@ -24,3 +99,19 @@ class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
24
99
  Returns:
25
100
  The filename of the `index`'th data sample.
26
101
  """
102
+
103
+ def _apply_transforms(
104
+ self, image: InputType, target: TargetType
105
+ ) -> Tuple[InputType, TargetType]:
106
+ """Applies the transforms to the provided data and returns them.
107
+
108
+ Args:
109
+ image: The desired image.
110
+ target: The target of the image.
111
+
112
+ Returns:
113
+ A tuple with the image and the target transformed.
114
+ """
115
+ if self._transforms is not None:
116
+ image, target = self._transforms(image, target)
117
+ return image, target