kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,24 +3,30 @@
3
3
  import functools
4
4
  import os
5
5
  from glob import glob
6
- from typing import Callable, Dict, List, Literal, Tuple
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
7
 
8
8
  import numpy as np
9
+ import numpy.typing as npt
10
+ import torch
11
+ import tqdm
9
12
  from torchvision import tv_tensors
10
13
  from torchvision.datasets import utils
11
14
  from typing_extensions import override
12
15
 
13
- from eva.vision.data.datasets import _utils, _validators, structs
16
+ from eva.vision.data.datasets import _validators, structs
14
17
  from eva.vision.data.datasets.segmentation import base
15
- from eva.vision.utils import convert, io
18
+ from eva.vision.utils import io
16
19
 
17
20
 
18
21
  class TotalSegmentator2D(base.ImageSegmentation):
19
22
  """TotalSegmentator 2D segmentation dataset."""
20
23
 
21
24
  _expected_dataset_lengths: Dict[str, int] = {
22
- "train_small": 29892,
23
- "val_small": 6480,
25
+ "train_small": 35089,
26
+ "val_small": 1283,
27
+ "train_full": 278190,
28
+ "val_full": 14095,
29
+ "test_full": 25578,
24
30
  }
25
31
  """Dataset version and split to the expected size."""
26
32
 
@@ -45,13 +51,20 @@ class TotalSegmentator2D(base.ImageSegmentation):
45
51
  ]
46
52
  """Resources for the small dataset version."""
47
53
 
54
+ _license: str = (
55
+ "Creative Commons Attribution 4.0 International "
56
+ "(https://creativecommons.org/licenses/by/4.0/deed.en)"
57
+ )
58
+ """Dataset license."""
59
+
48
60
  def __init__(
49
61
  self,
50
62
  root: str,
51
- split: Literal["train", "val"] | None,
52
- version: Literal["small", "full"] | None = "small",
63
+ split: Literal["train", "val", "test"] | None,
64
+ version: Literal["small", "full"] | None = "full",
53
65
  download: bool = False,
54
- as_uint8: bool = True,
66
+ classes: List[str] | None = None,
67
+ optimize_mask_loading: bool = True,
55
68
  transforms: Callable | None = None,
56
69
  ) -> None:
57
70
  """Initialize dataset.
@@ -66,7 +79,12 @@ class TotalSegmentator2D(base.ImageSegmentation):
66
79
  Note that the download will be executed only by additionally
67
80
  calling the :meth:`prepare_data` method and if the data does not
68
81
  exist yet on disk.
69
- as_uint8: Whether to convert and return the images as a 8-bit.
82
+ classes: Whether to configure the dataset with a subset of classes.
83
+ If `None`, it will use all of them.
84
+ optimize_mask_loading: Whether to pre-process the segmentation masks
85
+ in order to optimize the loading time. In the `setup` method, it
86
+ will reformat the binary one-hot masks to a semantic mask and store
87
+ it on disk.
70
88
  transforms: A function/transforms that takes in an image and a target
71
89
  mask and returns the transformed versions of both.
72
90
  """
@@ -76,7 +94,13 @@ class TotalSegmentator2D(base.ImageSegmentation):
76
94
  self._split = split
77
95
  self._version = version
78
96
  self._download = download
79
- self._as_uint8 = as_uint8
97
+ self._classes = classes
98
+ self._optimize_mask_loading = optimize_mask_loading
99
+
100
+ if self._optimize_mask_loading and self._classes is not None:
101
+ raise ValueError(
102
+ "To use customize classes please set the optimize_mask_loading to `False`."
103
+ )
80
104
 
81
105
  self._samples_dirs: List[str] = []
82
106
  self._indices: List[Tuple[int, int]] = []
@@ -91,7 +115,13 @@ class TotalSegmentator2D(base.ImageSegmentation):
91
115
  first_sample_labels = os.path.join(
92
116
  self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
93
117
  )
94
- return sorted(map(get_filename, glob(first_sample_labels)))
118
+ all_classes = sorted(map(get_filename, glob(first_sample_labels)))
119
+ if self._classes:
120
+ is_subset = all(name in all_classes for name in self._classes)
121
+ if not is_subset:
122
+ raise ValueError("Provided class names are not subset of the dataset onces.")
123
+
124
+ return all_classes if self._classes is None else self._classes
95
125
 
96
126
  @property
97
127
  @override
@@ -99,7 +129,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
99
129
  return {label: index for index, label in enumerate(self.classes)}
100
130
 
101
131
  @override
102
- def filename(self, index: int) -> str:
132
+ def filename(self, index: int, segmented: bool = True) -> str:
103
133
  sample_idx, _ = self._indices[index]
104
134
  sample_dir = self._samples_dirs[sample_idx]
105
135
  return os.path.join(sample_dir, "ct.nii.gz")
@@ -113,17 +143,23 @@ class TotalSegmentator2D(base.ImageSegmentation):
113
143
  def configure(self) -> None:
114
144
  self._samples_dirs = self._fetch_samples_dirs()
115
145
  self._indices = self._create_indices()
146
+ if self._optimize_mask_loading:
147
+ self._export_semantic_label_masks()
116
148
 
117
149
  @override
118
150
  def validate(self) -> None:
119
- if self._version is None:
151
+ if self._version is None or self._sample_every_n_slices is not None:
120
152
  return
121
153
 
122
154
  _validators.check_dataset_integrity(
123
155
  self,
124
156
  length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
125
- n_classes=117,
126
- first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"),
157
+ n_classes=len(self._classes) if self._classes else 117,
158
+ first_and_last_labels=(
159
+ (self._classes[0], self._classes[-1])
160
+ if self._classes
161
+ else ("adrenal_gland_left", "vertebrae_T9")
162
+ ),
127
163
  )
128
164
 
129
165
  @override
@@ -134,25 +170,68 @@ class TotalSegmentator2D(base.ImageSegmentation):
134
170
  def load_image(self, index: int) -> tv_tensors.Image:
135
171
  sample_index, slice_index = self._indices[index]
136
172
  image_path = self._get_image_path(sample_index)
137
- image_array = io.read_nifti_slice(image_path, slice_index)
138
- if self._as_uint8:
139
- image_array = convert.to_8bit(image_array)
173
+ image_array = io.read_nifti(image_path, slice_index)
140
174
  image_rgb_array = image_array.repeat(3, axis=2)
141
175
  return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
142
176
 
143
177
  @override
144
178
  def load_mask(self, index: int) -> tv_tensors.Mask:
179
+ if self._optimize_mask_loading:
180
+ return self._load_semantic_label_mask(index)
181
+ return self._load_mask(index)
182
+
183
+ @override
184
+ def load_metadata(self, index: int) -> Dict[str, Any]:
185
+ _, slice_index = self._indices[index]
186
+ return {"slice_index": slice_index}
187
+
188
+ def _load_mask(self, index: int) -> tv_tensors.Mask:
189
+ """Loads and builds the segmentation mask from NifTi files."""
190
+ sample_index, slice_index = self._indices[index]
191
+ semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
192
+ return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue]
193
+
194
+ def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
195
+ """Loads the segmentation mask from a semantic label NifTi file."""
145
196
  sample_index, slice_index = self._indices[index]
146
197
  masks_dir = self._get_masks_dir(sample_index)
147
- mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes)
148
- one_hot_encoded = np.concatenate(
149
- [io.read_nifti_slice(path, slice_index) for path in mask_paths],
150
- axis=2,
151
- )
152
- background_mask = one_hot_encoded.sum(axis=2, keepdims=True) == 0
153
- one_hot_encoded_with_bg = np.concatenate([background_mask, one_hot_encoded], axis=2)
154
- segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2)
155
- return tv_tensors.Mask(segmentation_label)
198
+ filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
199
+ semantic_labels = io.read_nifti(filename, slice_index)
200
+ return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
201
+
202
+ def _load_masks_as_semantic_label(
203
+ self, sample_index: int, slice_index: int | None = None
204
+ ) -> npt.NDArray[Any]:
205
+ """Loads binary masks as a semantic label mask.
206
+
207
+ Args:
208
+ sample_index: The data sample index.
209
+ slice_index: Whether to return only a specific slice.
210
+ """
211
+ masks_dir = self._get_masks_dir(sample_index)
212
+ mask_paths = [os.path.join(masks_dir, label + ".nii.gz") for label in self.classes]
213
+ binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
214
+ background_mask = np.zeros_like(binary_masks[0])
215
+ return np.argmax([background_mask] + binary_masks, axis=0)
216
+
217
+ def _export_semantic_label_masks(self) -> None:
218
+ """Exports the segmentation binary masks (one-hot) to semantic labels."""
219
+ total_samples = len(self._samples_dirs)
220
+ masks_dirs = map(self._get_masks_dir, range(total_samples))
221
+ semantic_labels = [
222
+ (index, os.path.join(directory, "semantic_labels", "masks.nii.gz"))
223
+ for index, directory in enumerate(masks_dirs)
224
+ ]
225
+ to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
226
+
227
+ for sample_index, filename in tqdm.tqdm(
228
+ list(to_export),
229
+ desc=">> Exporting optimized semantic masks",
230
+ leave=False,
231
+ ):
232
+ semantic_labels = self._load_masks_as_semantic_label(sample_index)
233
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
234
+ io.save_array_as_nifti(semantic_labels, filename)
156
235
 
157
236
  def _get_image_path(self, sample_index: int) -> str:
158
237
  """Returns the corresponding image path."""
@@ -164,10 +243,16 @@ class TotalSegmentator2D(base.ImageSegmentation):
164
243
  sample_dir = self._samples_dirs[sample_index]
165
244
  return os.path.join(self._root, sample_dir, "segmentations")
166
245
 
246
+ def _get_semantic_labels_filename(self, sample_index: int) -> str:
247
+ """Returns the semantic label filename."""
248
+ masks_dir = self._get_masks_dir(sample_index)
249
+ return os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
250
+
167
251
  def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
168
252
  """Returns the total amount of slices of a sample."""
169
253
  image_path = self._get_image_path(sample_index)
170
- return io.fetch_total_nifti_slices(image_path)
254
+ image_shape = io.fetch_nifti_shape(image_path)
255
+ return image_shape[-1]
171
256
 
172
257
  def _fetch_samples_dirs(self) -> List[str]:
173
258
  """Returns the name of all the samples of all the splits of the dataset."""
@@ -180,16 +265,20 @@ class TotalSegmentator2D(base.ImageSegmentation):
180
265
 
181
266
  def _get_split_indices(self) -> List[int]:
182
267
  """Returns the samples indices that corresponding the dataset split and version."""
183
- key = f"{self._split}_{self._version}"
184
- match key:
185
- case "train_small":
186
- index_ranges = [(0, 83)]
187
- case "val_small":
188
- index_ranges = [(83, 102)]
268
+ metadata_file = os.path.join(self._root, "meta.csv")
269
+ metadata = io.read_csv(metadata_file, delimiter=";", encoding="utf-8-sig")
270
+
271
+ match self._split:
272
+ case "train":
273
+ image_ids = [item["image_id"] for item in metadata if item["split"] == "train"]
274
+ case "val":
275
+ image_ids = [item["image_id"] for item in metadata if item["split"] == "val"]
276
+ case "test":
277
+ image_ids = [item["image_id"] for item in metadata if item["split"] == "test"]
189
278
  case _:
190
- index_ranges = [(0, len(self._samples_dirs))]
279
+ image_ids = self._samples_dirs
191
280
 
192
- return _utils.ranges_to_indices(index_ranges)
281
+ return sorted(map(self._samples_dirs.index, image_ids))
193
282
 
194
283
  def _create_indices(self) -> List[Tuple[int, int]]:
195
284
  """Builds the dataset indices for the specified split.
@@ -219,6 +308,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
219
308
  f"Can't download data version '{self._version}'. Use 'small' or 'full'."
220
309
  )
221
310
 
311
+ self._print_license()
222
312
  for resource in resources:
223
313
  if os.path.isdir(self._root):
224
314
  continue
@@ -229,3 +319,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
229
319
  filename=resource.filename,
230
320
  remove_finished=True,
231
321
  )
322
+
323
+ def _print_license(self) -> None:
324
+ """Prints the dataset license."""
325
+ print(f"Dataset license: {self._license}")
@@ -0,0 +1,187 @@
1
+ """Dataset classes for whole-slide images."""
2
+
3
+ import bisect
4
+ import os
5
+ from typing import Callable, List
6
+
7
+ from loguru import logger
8
+ from torch.utils.data import dataset as torch_datasets
9
+ from torchvision import tv_tensors
10
+ from torchvision.transforms.v2 import functional
11
+ from typing_extensions import override
12
+
13
+ from eva.vision.data import wsi
14
+ from eva.vision.data.datasets import vision
15
+ from eva.vision.data.wsi.patching import samplers
16
+
17
+
18
+ class WsiDataset(vision.VisionDataset):
19
+ """Dataset class for reading patches from whole-slide images."""
20
+
21
+ def __init__(
22
+ self,
23
+ file_path: str,
24
+ width: int,
25
+ height: int,
26
+ sampler: samplers.Sampler,
27
+ target_mpp: float,
28
+ overwrite_mpp: float | None = None,
29
+ backend: str = "openslide",
30
+ image_transforms: Callable | None = None,
31
+ ):
32
+ """Initializes a new dataset instance.
33
+
34
+ Args:
35
+ file_path: Path to the whole-slide image file.
36
+ width: Width of the patches to be extracted, in pixels.
37
+ height: Height of the patches to be extracted, in pixels.
38
+ sampler: The sampler to use for sampling patch coordinates.
39
+ target_mpp: Target microns per pixel (mpp) for the patches.
40
+ overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
41
+ backend: The backend to use for reading the whole-slide images.
42
+ image_transforms: Transforms to apply to the extracted image patches.
43
+ """
44
+ super().__init__()
45
+
46
+ self._file_path = file_path
47
+ self._width = width
48
+ self._height = height
49
+ self._sampler = sampler
50
+ self._target_mpp = target_mpp
51
+ self._overwrite_mpp = overwrite_mpp
52
+ self._backend = backend
53
+ self._image_transforms = image_transforms
54
+
55
+ @override
56
+ def __len__(self):
57
+ return len(self._coords.x_y)
58
+
59
+ @override
60
+ def filename(self, index: int) -> str:
61
+ return f"{self._file_path}_{index}"
62
+
63
+ @property
64
+ def _wsi(self) -> wsi.Wsi:
65
+ return wsi.get_cached_wsi(self._file_path, self._backend, self._overwrite_mpp)
66
+
67
+ @property
68
+ def _coords(self) -> wsi.PatchCoordinates:
69
+ return wsi.get_cached_coords(
70
+ file_path=self._file_path,
71
+ width=self._width,
72
+ height=self._height,
73
+ target_mpp=self._target_mpp,
74
+ overwrite_mpp=self._overwrite_mpp,
75
+ sampler=self._sampler,
76
+ backend=self._backend,
77
+ )
78
+
79
+ @override
80
+ def __getitem__(self, index: int) -> tv_tensors.Image:
81
+ x, y = self._coords.x_y[index]
82
+ width, height, level_idx = self._coords.width, self._coords.height, self._coords.level_idx
83
+ patch = self._wsi.read_region((x, y), level_idx, (width, height))
84
+ patch = functional.to_image(patch)
85
+ patch = self._apply_transforms(patch)
86
+ return patch
87
+
88
+ def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
89
+ if self._image_transforms is not None:
90
+ image = self._image_transforms(image)
91
+ return image
92
+
93
+
94
+ class MultiWsiDataset(vision.VisionDataset):
95
+ """Dataset class for reading patches from multiple whole-slide images."""
96
+
97
+ def __init__(
98
+ self,
99
+ root: str,
100
+ file_paths: List[str],
101
+ width: int,
102
+ height: int,
103
+ sampler: samplers.Sampler,
104
+ target_mpp: float,
105
+ overwrite_mpp: float | None = None,
106
+ backend: str = "openslide",
107
+ image_transforms: Callable | None = None,
108
+ ):
109
+ """Initializes a new dataset instance.
110
+
111
+ Args:
112
+ root: Root directory of the dataset.
113
+ file_paths: List of paths to the whole-slide image files, relative to the root.
114
+ width: Width of the patches to be extracted, in pixels.
115
+ height: Height of the patches to be extracted, in pixels.
116
+ target_mpp: Target microns per pixel (mpp) for the patches.
117
+ overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
118
+ sampler: The sampler to use for sampling patch coordinates.
119
+ backend: The backend to use for reading the whole-slide images.
120
+ image_transforms: Transforms to apply to the extracted image patches.
121
+ """
122
+ super().__init__()
123
+
124
+ self._root = root
125
+ self._file_paths = file_paths
126
+ self._width = width
127
+ self._height = height
128
+ self._target_mpp = target_mpp
129
+ self._overwrite_mpp = overwrite_mpp
130
+ self._sampler = sampler
131
+ self._backend = backend
132
+ self._image_transforms = image_transforms
133
+
134
+ self._concat_dataset: torch_datasets.ConcatDataset
135
+
136
+ @property
137
+ def datasets(self) -> List[WsiDataset]:
138
+ """Returns the list of WSI datasets."""
139
+ return self._concat_dataset.datasets # type: ignore
140
+
141
+ @property
142
+ def cumulative_sizes(self) -> List[int]:
143
+ """Returns the cumulative sizes of the WSI datasets."""
144
+ return self._concat_dataset.cumulative_sizes
145
+
146
+ @override
147
+ def configure(self) -> None:
148
+ self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
149
+
150
+ @override
151
+ def __len__(self) -> int:
152
+ return len(self._concat_dataset)
153
+
154
+ @override
155
+ def __getitem__(self, index: int) -> tv_tensors.Image:
156
+ return self._concat_dataset[index]
157
+
158
+ @override
159
+ def filename(self, index: int) -> str:
160
+ return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
161
+
162
+ def _load_datasets(self) -> list[WsiDataset]:
163
+ logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
164
+ wsi_datasets = []
165
+ for file_path in self._file_paths:
166
+ file_path = (
167
+ os.path.join(self._root, file_path) if self._root not in file_path else file_path
168
+ )
169
+ if not os.path.exists(file_path):
170
+ raise FileNotFoundError(f"File not found: {file_path}")
171
+
172
+ wsi_datasets.append(
173
+ WsiDataset(
174
+ file_path=file_path,
175
+ width=self._width,
176
+ height=self._height,
177
+ sampler=self._sampler,
178
+ target_mpp=self._target_mpp,
179
+ overwrite_mpp=self._overwrite_mpp,
180
+ backend=self._backend,
181
+ image_transforms=self._image_transforms,
182
+ )
183
+ )
184
+ return wsi_datasets
185
+
186
+ def _get_dataset_idx(self, index: int) -> int:
187
+ return bisect.bisect_right(self.cumulative_sizes, index)
@@ -1,5 +1,6 @@
1
1
  """Vision data transforms."""
2
2
 
3
- from eva.vision.data.transforms.common import ResizeAndCrop
3
+ from eva.vision.data.transforms.common import ResizeAndClamp, ResizeAndCrop
4
+ from eva.vision.data.transforms.normalization import Clamp, RescaleIntensity
4
5
 
5
- __all__ = ["ResizeAndCrop"]
6
+ __all__ = ["ResizeAndCrop", "ResizeAndClamp", "Clamp", "RescaleIntensity"]
@@ -1,5 +1,6 @@
1
1
  """Common vision transforms."""
2
2
 
3
+ from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp
3
4
  from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
4
5
 
5
- __all__ = ["ResizeAndCrop"]
6
+ __all__ = ["ResizeAndClamp", "ResizeAndCrop"]
@@ -0,0 +1,51 @@
1
+ """Specialized transforms for resizing, clamping and range normalizing."""
2
+
3
+ from typing import Callable, Sequence, Tuple
4
+
5
+ from torchvision.transforms import v2
6
+
7
+ from eva.vision.data.transforms import normalization
8
+
9
+
10
+ class ResizeAndClamp(v2.Compose):
11
+ """Resizes, crops, clamps and normalizes an input image."""
12
+
13
+ def __init__(
14
+ self,
15
+ size: int | Sequence[int] = 224,
16
+ clamp_range: Tuple[int, int] = (-1024, 1024),
17
+ mean: Sequence[float] = (0.0, 0.0, 0.0),
18
+ std: Sequence[float] = (1.0, 1.0, 1.0),
19
+ ) -> None:
20
+ """Initializes the transform object.
21
+
22
+ Args:
23
+ size: Desired output size of the crop. If size is an `int` instead
24
+ of sequence like (h, w), a square crop (size, size) is made.
25
+ clamp_range: The lower and upper bound to clamp the pixel values.
26
+ mean: Sequence of means for each image channel.
27
+ std: Sequence of standard deviations for each image channel.
28
+ """
29
+ self._size = size
30
+ self._clamp_range = clamp_range
31
+ self._mean = mean
32
+ self._std = std
33
+
34
+ super().__init__(transforms=self._build_transforms())
35
+
36
+ def _build_transforms(self) -> Sequence[Callable]:
37
+ """Builds and returns the list of transforms."""
38
+ transforms = [
39
+ v2.Resize(size=self._size),
40
+ v2.CenterCrop(size=self._size),
41
+ normalization.Clamp(out_range=self._clamp_range),
42
+ normalization.RescaleIntensity(
43
+ in_range=self._clamp_range,
44
+ out_range=(0.0, 1.0),
45
+ ),
46
+ v2.Normalize(
47
+ mean=self._mean,
48
+ std=self._std,
49
+ ),
50
+ ]
51
+ return transforms
@@ -3,10 +3,10 @@
3
3
  from typing import Callable, Sequence
4
4
 
5
5
  import torch
6
- import torchvision.transforms.v2 as torch_transforms
6
+ from torchvision.transforms import v2
7
7
 
8
8
 
9
- class ResizeAndCrop(torch_transforms.Compose):
9
+ class ResizeAndCrop(v2.Compose):
10
10
  """Resizes, crops and normalizes an input image while preserving its aspect ratio."""
11
11
 
12
12
  def __init__(
@@ -32,11 +32,10 @@ class ResizeAndCrop(torch_transforms.Compose):
32
32
  def _build_transforms(self) -> Sequence[Callable]:
33
33
  """Builds and returns the list of transforms."""
34
34
  transforms = [
35
- torch_transforms.ToImage(),
36
- torch_transforms.Resize(size=self._size),
37
- torch_transforms.CenterCrop(size=self._size),
38
- torch_transforms.ToDtype(torch.float32, scale=True),
39
- torch_transforms.Normalize(
35
+ v2.Resize(size=self._size),
36
+ v2.CenterCrop(size=self._size),
37
+ v2.ToDtype(torch.float32, scale=True),
38
+ v2.Normalize(
40
39
  mean=self._mean,
41
40
  std=self._std,
42
41
  ),
@@ -0,0 +1,6 @@
1
+ """Normalization related transformations."""
2
+
3
+ from eva.vision.data.transforms.normalization.clamp import Clamp
4
+ from eva.vision.data.transforms.normalization.rescale_intensity import RescaleIntensity
5
+
6
+ __all__ = ["Clamp", "RescaleIntensity"]
@@ -0,0 +1,43 @@
1
+ """Image clamp transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Tuple
5
+
6
+ import torch
7
+ import torchvision.transforms.v2 as torch_transforms
8
+ from torchvision import tv_tensors
9
+ from typing_extensions import override
10
+
11
+
12
+ class Clamp(torch_transforms.Transform):
13
+ """Clamps all elements in input into a specific range."""
14
+
15
+ def __init__(self, out_range: Tuple[int, int]) -> None:
16
+ """Initializes the transform.
17
+
18
+ Args:
19
+ out_range: The lower and upper bound of the range to
20
+ be clamped to.
21
+ """
22
+ super().__init__()
23
+
24
+ self._out_range = out_range
25
+
26
+ @functools.singledispatchmethod
27
+ @override
28
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
29
+ return inpt
30
+
31
+ @_transform.register(torch.Tensor)
32
+ def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
33
+ return torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
34
+
35
+ @_transform.register(tv_tensors.Image)
36
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
37
+ inpt_clamp = torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
38
+ return tv_tensors.wrap(inpt_clamp, like=inpt)
39
+
40
+ @_transform.register(tv_tensors.BoundingBoxes)
41
+ @_transform.register(tv_tensors.Mask)
42
+ def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
43
+ return inpt
@@ -0,0 +1,5 @@
1
+ """Functional normalization related transformations API."""
2
+
3
+ from eva.vision.data.transforms.normalization.functional.rescale_intensity import rescale_intensity
4
+
5
+ __all__ = ["rescale_intensity"]
@@ -0,0 +1,28 @@
1
+ """Intensity level functions."""
2
+
3
+ import sys
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+
9
+ def rescale_intensity(
10
+ image: torch.Tensor,
11
+ in_range: Tuple[float, float] | None = None,
12
+ out_range: Tuple[float, float] = (0.0, 1.0),
13
+ ) -> torch.Tensor:
14
+ """Stretches or shrinks the image intensity levels.
15
+
16
+ Args:
17
+ image: The image tensor as float-type.
18
+ in_range: The input data range. If `None`, it will
19
+ fetch the min and max of the input image.
20
+ out_range: The desired intensity range of the output.
21
+
22
+ Returns:
23
+ The image tensor after stretching or shrinking its intensity levels.
24
+ """
25
+ imin, imax = in_range or (image.min(), image.max())
26
+ omin, omax = out_range
27
+ image_scaled = (image - imin) / (imax - imin + sys.float_info.epsilon)
28
+ return image_scaled * (omax - omin) + omin