kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,236 @@
1
+ """BCSS dataset."""
2
+
3
+ import glob
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ import torch
11
+ from torchvision import tv_tensors
12
+ from torchvision.transforms.v2 import functional
13
+ from typing_extensions import override
14
+
15
+ from eva.vision.data.datasets import _validators, wsi
16
+ from eva.vision.data.datasets.segmentation import _utils, base
17
+ from eva.vision.data.wsi.patching import samplers
18
+ from eva.vision.utils import io
19
+
20
+
21
+ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
22
+ """Dataset class for BCSS semantic segmentation task.
23
+
24
+ Source: https://github.com/PathologyDataScience/BCSS
25
+
26
+ We apply the the class grouping proposed by the challenge baseline:
27
+ https://bcsegmentation.grand-challenge.org/Baseline/
28
+
29
+ outside_roi: outside_roi
30
+ tumor: angioinvasion, dcis
31
+ stroma: stroma
32
+ inflammatory: lymphocytic_infiltrate, plasma_cells, other_immune_infiltrate
33
+ necrosis: necrosis_or_debris
34
+ other: remaining
35
+
36
+ Be aware that outside_roi should be assigned zero-weight during model training.
37
+ """
38
+
39
+ _train_split_ratio: float = 0.8
40
+ """Train split ratio."""
41
+
42
+ _val_split_ratio: float = 0.2
43
+ """Validation split ratio."""
44
+
45
+ _expected_length: int = 151
46
+ """Expected dataset length."""
47
+
48
+ _val_institutes = {"BH", "C8", "A8", "A1", "E9"}
49
+ """Medical institutes to use for the validation split."""
50
+
51
+ _test_institutes = {"OL", "LL", "E2", "EW", "GM", "S3"}
52
+ """Medical institutes to use for the test split."""
53
+
54
+ def __init__(
55
+ self,
56
+ root: str,
57
+ sampler: samplers.Sampler,
58
+ split: Literal["train", "val", "trainval", "test"] | None = None,
59
+ width: int = 224,
60
+ height: int = 224,
61
+ target_mpp: float = 0.5,
62
+ transforms: Callable | None = None,
63
+ ) -> None:
64
+ """Initializes the dataset.
65
+
66
+ Args:
67
+ root: Root directory of the dataset.
68
+ sampler: The sampler to use for sampling patch coordinates.
69
+ If `None`, it will use the ::class::`GridSampler` sampler.
70
+ split: Dataset split to use. If `None`, the entire dataset is used.
71
+ width: Width of the patches to be extracted, in pixels.
72
+ height: Height of the patches to be extracted, in pixels.
73
+ target_mpp: Target microns per pixel (mpp) for the patches.
74
+ backend: The backend to use for reading the whole-slide images.
75
+ transforms: Transforms to apply to the extracted image & mask patches.
76
+ """
77
+ self._split = split
78
+ self._root = root
79
+
80
+ self.datasets: List[wsi.WsiDataset] # type: ignore
81
+
82
+ wsi.MultiWsiDataset.__init__(
83
+ self,
84
+ root=root,
85
+ file_paths=self._load_file_paths(split),
86
+ width=width,
87
+ height=height,
88
+ sampler=sampler or samplers.GridSampler(max_samples=1000),
89
+ target_mpp=target_mpp,
90
+ overwrite_mpp=0.25,
91
+ backend="pil",
92
+ )
93
+ base.ImageSegmentation.__init__(self, transforms=transforms)
94
+
95
+ @property
96
+ @override
97
+ def classes(self) -> List[str]:
98
+ return list(self.class_to_idx.keys())
99
+
100
+ @property
101
+ @override
102
+ def class_to_idx(self) -> Dict[str, int]:
103
+ return {
104
+ "outside_roi": 0,
105
+ "tumor": 1,
106
+ "stroma": 2,
107
+ "inflammatory": 3,
108
+ "necrosis": 4,
109
+ "other": 5,
110
+ }
111
+
112
+ @override
113
+ def prepare_data(self) -> None:
114
+ _validators.check_dataset_exists(self._root, True)
115
+
116
+ if not os.path.isdir(os.path.join(self._root, "masks")):
117
+ raise FileNotFoundError(f"'masks' directory not found in {self._root}.")
118
+ if not os.path.isdir(os.path.join(self._root, "rgbs_colorNormalized")):
119
+ raise FileNotFoundError(f"'rgbs_colorNormalized' directory not found in {self._root}.")
120
+
121
+ @override
122
+ def validate(self) -> None:
123
+ _validators.check_dataset_integrity(
124
+ self,
125
+ length=None,
126
+ n_classes=6,
127
+ first_and_last_labels=((self.classes[0], self.classes[-1])),
128
+ )
129
+
130
+ @override
131
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
132
+ return base.ImageSegmentation.__getitem__(self, index)
133
+
134
+ @override
135
+ def load_image(self, index: int) -> tv_tensors.Image:
136
+ image_array = wsi.MultiWsiDataset.__getitem__(self, index)
137
+ return functional.to_image(image_array)
138
+
139
+ @override
140
+ def load_mask(self, index: int) -> tv_tensors.Mask:
141
+ path = self._get_mask_path(index)
142
+ mask = io.read_image_as_array(path)
143
+ mask_patch = _utils.extract_mask_patch(mask, self, index)
144
+ mask_patch = self._map_classes(mask_patch)
145
+ return tv_tensors.Mask(mask_patch, dtype=torch.int64) # type: ignore[reportCallIssue]
146
+
147
+ @override
148
+ def load_metadata(self, index: int) -> Dict[str, Any]:
149
+ (x, y), width, height = _utils.get_coords_at_index(self, index)
150
+ return {"coords": f"{x},{y},{width},{height}"}
151
+
152
+ def _load_file_paths(
153
+ self, split: Literal["train", "val", "trainval", "test"] | None = None
154
+ ) -> List[str]:
155
+ """Loads the file paths of the corresponding dataset split."""
156
+ file_paths = sorted(glob.glob(os.path.join(self._root, "rgbs_colorNormalized/*.png")))
157
+ if len(file_paths) != self._expected_length:
158
+ raise ValueError(
159
+ f"Expected {self._expected_length} images, found {len(file_paths)} in {self._root}."
160
+ )
161
+
162
+ train_indices, val_indices, test_indices = [], [], []
163
+ for i, path in enumerate(file_paths):
164
+ institute = Path(path).stem.split("-")[1]
165
+ if institute in self._test_institutes:
166
+ test_indices.append(i)
167
+ elif institute in self._val_institutes:
168
+ val_indices.append(i)
169
+ else:
170
+ train_indices.append(i)
171
+
172
+ match split:
173
+ case "train":
174
+ return [file_paths[i] for i in train_indices]
175
+ case "val":
176
+ return [file_paths[i] for i in val_indices]
177
+ case "trainval":
178
+ return [file_paths[i] for i in train_indices + val_indices]
179
+ case "test":
180
+ return [file_paths[i] for i in test_indices]
181
+ case None:
182
+ return file_paths
183
+ case _:
184
+ raise ValueError("Invalid split. Use 'train', 'val', 'test' or `None`.")
185
+
186
+ def _get_mask_path(self, index):
187
+ """Returns the path to the mask file corresponding to the patch at the given index."""
188
+ return os.path.join(self._root, "masks", self.filename(index))
189
+
190
+ def _map_classes(self, array: npt.NDArray[Any]) -> npt.NDArray[Any]:
191
+ """Maps the classes of the mask array to the grouped tissue type classes."""
192
+ original_to_grouped_class_mapping = {
193
+ "outside_roi": "outside_roi",
194
+ "angioinvasion": "tumor",
195
+ "dcis": "tumor",
196
+ "stroma": "stroma",
197
+ "lymphocytic_infiltrate": "inflammatory",
198
+ "plasma_cells": "inflammatory",
199
+ "other_immune_infiltrate": "inflammatory",
200
+ "necrosis_or_debris": "necrosis",
201
+ }
202
+
203
+ mapped_array = np.full_like(array, fill_value=self.class_to_idx["other"], dtype=int)
204
+
205
+ for original_class, grouped_class in original_to_grouped_class_mapping.items():
206
+ original_class_idx = _original_class_to_idx[original_class]
207
+ grouped_class_idx = self.class_to_idx[grouped_class]
208
+ mapped_array[array == original_class_idx] = grouped_class_idx
209
+
210
+ return mapped_array
211
+
212
+
213
+ _original_class_to_idx = {
214
+ "outside_roi": 0,
215
+ "tumor": 1,
216
+ "stroma": 2,
217
+ "lymphocytic_infiltrate": 3,
218
+ "necrosis_or_debris": 4,
219
+ "glandular_secretions": 5,
220
+ "blood": 6,
221
+ "exclude": 7,
222
+ "metaplasia_NOS": 8,
223
+ "fat": 9,
224
+ "plasma_cells": 10,
225
+ "other_immune_infiltrate": 11,
226
+ "mucoid_material": 12,
227
+ "normal_acinus_or_duct": 13,
228
+ "lymphatics": 14,
229
+ "undetermined": 15,
230
+ "nerve": 16,
231
+ "skin_adnexa": 17,
232
+ "blood_vessel": 18,
233
+ "angioinvasion": 19,
234
+ "dcis": 20,
235
+ "other": 21,
236
+ }
@@ -0,0 +1,156 @@
1
+ """CoNSeP dataset."""
2
+
3
+ import glob
4
+ import os
5
+ from typing import Any, Callable, Dict, List, Literal, Tuple
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import torch
10
+ from torchvision import tv_tensors
11
+ from torchvision.transforms.v2 import functional
12
+ from typing_extensions import override
13
+
14
+ from eva.vision.data.datasets import _validators, wsi
15
+ from eva.vision.data.datasets.segmentation import _utils, base
16
+ from eva.vision.data.wsi.patching import samplers
17
+ from eva.vision.utils import io
18
+
19
+
20
+ class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
21
+ """Dataset class for CoNSeP semantic segmentation task.
22
+
23
+ We combine classes 3 (healthy epithelial) & 4 (dysplastic/malignant epithelial)
24
+ into the epithelial class and 5 (fibroblast), 6 (muscle) & 7 (endothelial) into
25
+ the spindle-shaped class.
26
+ """
27
+
28
+ _expected_dataset_lengths: Dict[str | None, int] = {
29
+ "train": 27,
30
+ "val": 14,
31
+ None: 41,
32
+ }
33
+ """Expected dataset lengths for the splits and complete dataset."""
34
+
35
+ def __init__(
36
+ self,
37
+ root: str,
38
+ sampler: samplers.Sampler | None = None,
39
+ split: Literal["train", "val"] | None = None,
40
+ width: int = 224,
41
+ height: int = 224,
42
+ target_mpp: float = 0.25,
43
+ transforms: Callable | None = None,
44
+ ) -> None:
45
+ """Initializes the dataset.
46
+
47
+ Args:
48
+ root: Root directory of the dataset.
49
+ sampler: The sampler to use for sampling patch coordinates.
50
+ If `None`, it will use the ::class::`ForegroundGridSampler` sampler.
51
+ split: Dataset split to use. If `None`, the entire dataset is used.
52
+ width: Width of the patches to be extracted, in pixels.
53
+ height: Height of the patches to be extracted, in pixels.
54
+ target_mpp: Target microns per pixel (mpp) for the patches.
55
+ backend: The backend to use for reading the whole-slide images.
56
+ transforms: Transforms to apply to the extracted image & mask patches.
57
+ """
58
+ self._split = split
59
+ self._root = root
60
+
61
+ self.datasets: List[wsi.WsiDataset] # type: ignore
62
+
63
+ wsi.MultiWsiDataset.__init__(
64
+ self,
65
+ root=root,
66
+ file_paths=self._load_file_paths(split),
67
+ width=width,
68
+ height=height,
69
+ sampler=sampler or samplers.ForegroundGridSampler(max_samples=25),
70
+ target_mpp=target_mpp,
71
+ overwrite_mpp=0.25,
72
+ backend="pil",
73
+ image_transforms=transforms,
74
+ )
75
+
76
+ @property
77
+ @override
78
+ def classes(self) -> List[str]:
79
+ return [
80
+ "background",
81
+ "other",
82
+ "inflammatory",
83
+ "epithelial",
84
+ "spindle-shaped",
85
+ ]
86
+
87
+ @property
88
+ @override
89
+ def class_to_idx(self) -> Dict[str, int]:
90
+ return {label: index for index, label in enumerate(self.classes)}
91
+
92
+ @override
93
+ def prepare_data(self) -> None:
94
+ _validators.check_dataset_exists(self._root, True)
95
+
96
+ if not os.path.isdir(os.path.join(self._root, "Train")):
97
+ raise FileNotFoundError(f"Train directory not found in {self._root}.")
98
+ if not os.path.isdir(os.path.join(self._root, "Test")):
99
+ raise FileNotFoundError(f"Test directory not found in {self._root}.")
100
+
101
+ @override
102
+ def validate(self) -> None:
103
+ _validators.check_dataset_integrity(
104
+ self,
105
+ length=None,
106
+ n_classes=5,
107
+ first_and_last_labels=((self.classes[0], self.classes[-1])),
108
+ )
109
+
110
+ @override
111
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
112
+ return base.ImageSegmentation.__getitem__(self, index)
113
+
114
+ @override
115
+ def load_image(self, index: int) -> tv_tensors.Image:
116
+ image_array = wsi.MultiWsiDataset.__getitem__(self, index)
117
+ return functional.to_image(image_array)
118
+
119
+ @override
120
+ def load_mask(self, index: int) -> tv_tensors.Mask:
121
+ path = self._get_mask_path(index)
122
+ mask = np.array(io.read_mat(path)["type_map"])
123
+ mask_patch = _utils.extract_mask_patch(mask, self, index)
124
+ mask_patch = self._map_classes(mask_patch)
125
+ mask_tensor = tv_tensors.Mask(mask_patch, dtype=torch.int64) # type: ignore[reportCallIssue]
126
+ return self._image_transforms(mask_tensor) if self._image_transforms else mask_tensor
127
+
128
+ @override
129
+ def load_metadata(self, index: int) -> Dict[str, Any]:
130
+ (x, y), width, height = _utils.get_coords_at_index(self, index)
131
+ return {"coords": f"{x},{y},{width},{height}"}
132
+
133
+ def _load_file_paths(self, split: Literal["train", "val"] | None = None) -> List[str]:
134
+ """Loads the file paths of the corresponding dataset split."""
135
+ paths = list(glob.glob(os.path.join(self._root, "**/Images/*.png"), recursive=True))
136
+ n_expected = self._expected_dataset_lengths[None]
137
+ if len(paths) != n_expected:
138
+ raise ValueError(f"Expected {n_expected} images, found {len(paths)} in {self._root}.")
139
+
140
+ if split is not None:
141
+ split_to_folder = {"train": "Train", "val": "Test"}
142
+ paths = filter(lambda p: split_to_folder[split] == p.split("/")[-3], paths)
143
+
144
+ return sorted(paths)
145
+
146
+ def _get_mask_path(self, index: int) -> str:
147
+ """Returns the path to the mask file corresponding to the patch at the given index."""
148
+ filename = self.filename(index).split(".")[0]
149
+ mask_dir = "Train" if filename.startswith("train") else "Test"
150
+ return os.path.join(self._root, mask_dir, "Labels", f"{filename}.mat")
151
+
152
+ def _map_classes(self, array: npt.NDArray[Any]) -> npt.NDArray[Any]:
153
+ """Summarizes classes 3 & 4, and 5, 6."""
154
+ array = np.where(array == 4, 3, array)
155
+ array = np.where(array > 4, 4, array)
156
+ return array
@@ -0,0 +1,34 @@
1
+ """Embeddings based semantic segmentation dataset."""
2
+
3
+ import os
4
+ from typing import List
5
+
6
+ import torch
7
+ from torchvision import tv_tensors
8
+ from typing_extensions import override
9
+
10
+ from eva.core.data.datasets import embeddings as embeddings_base
11
+
12
+
13
+ class EmbeddingsSegmentationDataset(embeddings_base.EmbeddingsDataset[tv_tensors.Mask]):
14
+ """Embeddings segmentation dataset."""
15
+
16
+ @override
17
+ def _load_embeddings(self, index: int) -> List[torch.Tensor]:
18
+ filename = self.filename(index)
19
+ embeddings_path = os.path.join(self._root, filename)
20
+ embeddings = torch.load(embeddings_path, map_location="cpu")
21
+ if isinstance(embeddings, torch.Tensor):
22
+ embeddings = [embeddings]
23
+ return [tensor.squeeze(0) for tensor in embeddings]
24
+
25
+ @override
26
+ def _load_target(self, index: int) -> tv_tensors.Mask:
27
+ filename = self._data.at[index, self._column_mapping["target"]]
28
+ mask_path = os.path.join(self._root, filename)
29
+ semantic_labels = torch.load(mask_path, map_location="cpu")
30
+ return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue]
31
+
32
+ @override
33
+ def __len__(self) -> int:
34
+ return len(self._data)
@@ -0,0 +1,178 @@
1
+ """LiTS dataset."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import torch
9
+ from torchvision import tv_tensors
10
+ from typing_extensions import override
11
+
12
+ from eva.core import utils
13
+ from eva.vision.data.datasets import _utils as data_utils
14
+ from eva.vision.data.datasets import _validators
15
+ from eva.vision.data.datasets.segmentation import base
16
+ from eva.vision.utils import io
17
+
18
+
19
+ class LiTS(base.ImageSegmentation):
20
+ """LiTS - Liver Tumor Segmentation Challenge.
21
+
22
+ Webpage: https://competitions.codalab.org/competitions/17094
23
+
24
+ For the splits we follow: https://arxiv.org/pdf/2010.01663v2
25
+ """
26
+
27
+ _train_index_ranges: List[Tuple[int, int]] = [(0, 102)]
28
+ _val_index_ranges: List[Tuple[int, int]] = [(102, 117)]
29
+ _test_index_ranges: List[Tuple[int, int]] = [(117, 131)]
30
+ """Index ranges per split."""
31
+
32
+ _sample_every_n_slices: int | None = None
33
+ """The amount of slices to sub-sample per 3D CT scan image."""
34
+
35
+ _expected_dataset_lengths: Dict[str | None, int] = {
36
+ "train": 39307,
37
+ "val": 12045,
38
+ "test": 7286,
39
+ None: 58638,
40
+ }
41
+ """Dataset version and split to the expected size."""
42
+
43
+ _license: str = (
44
+ "Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License "
45
+ "(https://creativecommons.org/licenses/by-nc-nd/4.0/deed.en)"
46
+ )
47
+ """Dataset license."""
48
+
49
+ def __init__(
50
+ self,
51
+ root: str,
52
+ split: Literal["train", "val", "test"] | None = None,
53
+ transforms: Callable | None = None,
54
+ ) -> None:
55
+ """Initialize dataset.
56
+
57
+ Args:
58
+ root: Path to the root directory of the dataset. The dataset will
59
+ be downloaded and extracted here, if it does not already exist.
60
+ split: Dataset split to use.
61
+ transforms: A function/transforms that takes in an image and a target
62
+ mask and returns the transformed versions of both.
63
+ """
64
+ super().__init__(transforms=transforms)
65
+
66
+ self._root = root
67
+ self._split = split
68
+
69
+ self._indices: List[Tuple[int, int]] = []
70
+
71
+ @property
72
+ @override
73
+ def classes(self) -> List[str]:
74
+ return ["liver", "tumor"]
75
+
76
+ @functools.cached_property
77
+ @override
78
+ def class_to_idx(self) -> Dict[str, int]:
79
+ return {label: index for index, label in enumerate(self.classes)}
80
+
81
+ @override
82
+ def filename(self, index: int) -> str:
83
+ sample_index, _ = self._indices[index]
84
+ volume_file_path = self._volume_files[sample_index]
85
+ return os.path.relpath(volume_file_path, self._root)
86
+
87
+ @override
88
+ def configure(self) -> None:
89
+ self._indices = self._create_indices()
90
+
91
+ @override
92
+ def validate(self) -> None:
93
+ if len(self._volume_files) != len(self._segmentation_files):
94
+ raise ValueError(
95
+ "The number of volume files does not match the number of the segmentation ones."
96
+ )
97
+
98
+ _validators.check_dataset_integrity(
99
+ self,
100
+ length=self._expected_dataset_lengths.get(self._split, 0),
101
+ n_classes=2,
102
+ first_and_last_labels=("liver", "tumor"),
103
+ )
104
+
105
+ @override
106
+ def load_image(self, index: int) -> tv_tensors.Image:
107
+ sample_index, slice_index = self._indices[index]
108
+ volume_path = self._volume_files[sample_index]
109
+ image_array = io.read_nifti(volume_path, slice_index)
110
+ return tv_tensors.Image(image_array.transpose(2, 0, 1))
111
+
112
+ @override
113
+ def load_mask(self, index: int) -> tv_tensors.Mask:
114
+ sample_index, slice_index = self._indices[index]
115
+ segmentation_path = self._segmentation_files[sample_index]
116
+ semantic_labels = io.read_nifti(segmentation_path, slice_index)
117
+ return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
118
+
119
+ @override
120
+ def load_metadata(self, index: int) -> Dict[str, Any]:
121
+ _, slice_index = self._indices[index]
122
+ return {"slice_index": slice_index}
123
+
124
+ @override
125
+ def __len__(self) -> int:
126
+ return len(self._indices)
127
+
128
+ def _get_number_of_slices_per_volume(self, sample_index: int) -> int:
129
+ """Returns the total amount of slices of a volume."""
130
+ file_path = self._volume_files[sample_index]
131
+ volume_shape = io.fetch_nifti_shape(file_path)
132
+ return volume_shape[-1]
133
+
134
+ @functools.cached_property
135
+ def _volume_files(self) -> List[str]:
136
+ files_pattern = os.path.join(self._root, "**", "volume-*.nii")
137
+ files = glob.glob(files_pattern, recursive=True)
138
+ return utils.numeric_sort(files)
139
+
140
+ @functools.cached_property
141
+ def _segmentation_files(self) -> List[str]:
142
+ files_pattern = os.path.join(self._root, "**", "segmentation-*.nii")
143
+ files = glob.glob(files_pattern, recursive=True)
144
+ return utils.numeric_sort(files)
145
+
146
+ def _create_indices(self) -> List[Tuple[int, int]]:
147
+ """Builds the dataset indices for the specified split.
148
+
149
+ Returns:
150
+ A list of tuples, where the first value indicates the
151
+ sample index which the second its corresponding slice
152
+ index.
153
+ """
154
+ indices = [
155
+ (sample_idx, slide_idx)
156
+ for sample_idx in self._get_split_indices()
157
+ for slide_idx in range(self._get_number_of_slices_per_volume(sample_idx))
158
+ if slide_idx % (self._sample_every_n_slices or 1) == 0
159
+ ]
160
+ return indices
161
+
162
+ def _get_split_indices(self) -> List[int]:
163
+ """Returns the sample indices for the specified dataset split."""
164
+ split_index_ranges = {
165
+ "train": self._train_index_ranges,
166
+ "val": self._val_index_ranges,
167
+ "test": self._test_index_ranges,
168
+ None: [(0, len(self._volume_files))],
169
+ }
170
+ index_ranges = split_index_ranges.get(self._split)
171
+ if index_ranges is None:
172
+ raise ValueError("Invalid data split. Use 'train', 'val', 'test' or `None`.")
173
+
174
+ return data_utils.ranges_to_indices(index_ranges)
175
+
176
+ def _print_license(self) -> None:
177
+ """Prints the dataset license."""
178
+ print(f"Dataset license: {self._license}")