kaiko-eva 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.1.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.1.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,244 @@
1
+ """Camelyon16 dataset class."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import pandas as pd
9
+ import torch
10
+ from torchvision import tv_tensors
11
+ from torchvision.transforms.v2 import functional
12
+ from typing_extensions import override
13
+
14
+ from eva.vision.data.datasets import _validators, wsi
15
+ from eva.vision.data.datasets.classification import base
16
+ from eva.vision.data.wsi.patching import samplers
17
+
18
+
19
+ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
20
+ """Dataset class for Camelyon16 images and corresponding targets."""
21
+
22
+ _val_slides = [
23
+ "normal_010",
24
+ "normal_013",
25
+ "normal_016",
26
+ "normal_017",
27
+ "normal_019",
28
+ "normal_020",
29
+ "normal_025",
30
+ "normal_030",
31
+ "normal_031",
32
+ "normal_032",
33
+ "normal_052",
34
+ "normal_056",
35
+ "normal_057",
36
+ "normal_067",
37
+ "normal_076",
38
+ "normal_079",
39
+ "normal_085",
40
+ "normal_095",
41
+ "normal_098",
42
+ "normal_099",
43
+ "normal_101",
44
+ "normal_102",
45
+ "normal_105",
46
+ "normal_106",
47
+ "normal_109",
48
+ "normal_129",
49
+ "normal_132",
50
+ "normal_137",
51
+ "normal_142",
52
+ "normal_143",
53
+ "normal_148",
54
+ "normal_152",
55
+ "tumor_001",
56
+ "tumor_005",
57
+ "tumor_011",
58
+ "tumor_012",
59
+ "tumor_013",
60
+ "tumor_019",
61
+ "tumor_031",
62
+ "tumor_037",
63
+ "tumor_043",
64
+ "tumor_046",
65
+ "tumor_057",
66
+ "tumor_065",
67
+ "tumor_069",
68
+ "tumor_071",
69
+ "tumor_073",
70
+ "tumor_079",
71
+ "tumor_080",
72
+ "tumor_081",
73
+ "tumor_082",
74
+ "tumor_085",
75
+ "tumor_097",
76
+ "tumor_109",
77
+ ]
78
+ """Validation slide names, same as the ones in patch camelyon."""
79
+
80
+ def __init__(
81
+ self,
82
+ root: str,
83
+ sampler: samplers.Sampler,
84
+ split: Literal["train", "val", "test"] | None = None,
85
+ width: int = 224,
86
+ height: int = 224,
87
+ target_mpp: float = 0.5,
88
+ backend: str = "openslide",
89
+ image_transforms: Callable | None = None,
90
+ seed: int = 42,
91
+ ) -> None:
92
+ """Initializes the dataset.
93
+
94
+ Args:
95
+ root: Root directory of the dataset.
96
+ sampler: The sampler to use for sampling patch coordinates.
97
+ split: Dataset split to use. If `None`, the entire dataset is used.
98
+ width: Width of the patches to be extracted, in pixels.
99
+ height: Height of the patches to be extracted, in pixels.
100
+ target_mpp: Target microns per pixel (mpp) for the patches.
101
+ backend: The backend to use for reading the whole-slide images.
102
+ image_transforms: Transforms to apply to the extracted image patches.
103
+ seed: Random seed for reproducibility.
104
+ """
105
+ self._split = split
106
+ self._root = root
107
+ self._width = width
108
+ self._height = height
109
+ self._target_mpp = target_mpp
110
+ self._seed = seed
111
+
112
+ wsi.MultiWsiDataset.__init__(
113
+ self,
114
+ root=root,
115
+ file_paths=self._load_file_paths(split),
116
+ width=width,
117
+ height=height,
118
+ sampler=sampler,
119
+ target_mpp=target_mpp,
120
+ backend=backend,
121
+ image_transforms=image_transforms,
122
+ )
123
+
124
+ @property
125
+ @override
126
+ def classes(self) -> List[str]:
127
+ return ["normal", "tumor"]
128
+
129
+ @property
130
+ @override
131
+ def class_to_idx(self) -> Dict[str, int]:
132
+ return {"normal": 0, "tumor": 1}
133
+
134
+ @functools.cached_property
135
+ def annotations_test_set(self) -> Dict[str, str]:
136
+ """Loads the dataset labels."""
137
+ path = os.path.join(self._root, "testing/reference.csv")
138
+ reference_df = pd.read_csv(path, header=None)
139
+ return {k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
140
+
141
+ @functools.cached_property
142
+ def annotations(self) -> Dict[str, str]:
143
+ """Loads the dataset labels."""
144
+ annotations = {}
145
+ if self._split in ["test", None]:
146
+ path = os.path.join(self._root, "testing/reference.csv")
147
+ reference_df = pd.read_csv(path, header=None)
148
+ annotations.update(
149
+ {k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
150
+ )
151
+
152
+ if self._split in ["train", "val", None]:
153
+ annotations.update(
154
+ {
155
+ self._get_id_from_path(file_path): self._get_class_from_path(file_path)
156
+ for file_path in self._file_paths
157
+ if "test" not in file_path
158
+ }
159
+ )
160
+ return annotations
161
+
162
+ @override
163
+ def prepare_data(self) -> None:
164
+ _validators.check_dataset_exists(self._root, False)
165
+
166
+ expected_directories = ["training/normal", "training/tumor", "testing/images"]
167
+ for resource in expected_directories:
168
+ if not os.path.isdir(os.path.join(self._root, resource)):
169
+ raise FileNotFoundError(f"'{resource}' not found in the root folder.")
170
+
171
+ if not os.path.isfile(os.path.join(self._root, "testing/reference.csv")):
172
+ raise FileNotFoundError("'reference.csv' file not found in the testing folder.")
173
+
174
+ @override
175
+ def validate(self) -> None:
176
+
177
+ expected_n_files = {
178
+ "train": 216,
179
+ "val": 54,
180
+ "test": 129,
181
+ None: 399,
182
+ }
183
+ _validators.check_number_of_files(
184
+ self._file_paths, expected_n_files[self._split], self._split
185
+ )
186
+ _validators.check_dataset_integrity(
187
+ self,
188
+ length=None,
189
+ n_classes=2,
190
+ first_and_last_labels=("normal", "tumor"),
191
+ )
192
+
193
+ @override
194
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
195
+ return base.ImageClassification.__getitem__(self, index)
196
+
197
+ @override
198
+ def load_image(self, index: int) -> tv_tensors.Image:
199
+ image_array = wsi.MultiWsiDataset.__getitem__(self, index)
200
+ return functional.to_image(image_array)
201
+
202
+ @override
203
+ def load_target(self, index: int) -> torch.Tensor:
204
+ file_path = self._file_paths[self._get_dataset_idx(index)]
205
+ class_name = self.annotations[self._get_id_from_path(file_path)]
206
+ return torch.tensor(self.class_to_idx[class_name], dtype=torch.int64)
207
+
208
+ @override
209
+ def load_metadata(self, index: int) -> Dict[str, Any]:
210
+ return {"wsi_id": self.filename(index).split(".")[0]}
211
+
212
+ def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
213
+ """Loads the file paths of the corresponding dataset split."""
214
+ train_paths, val_paths = [], []
215
+ for path in glob.glob(os.path.join(self._root, "training/**/*.tif")):
216
+ if self._get_id_from_path(path) in self._val_slides:
217
+ val_paths.append(path)
218
+ else:
219
+ train_paths.append(path)
220
+ test_paths = glob.glob(os.path.join(self._root, "testing/images", "*.tif"))
221
+
222
+ match split:
223
+ case "train":
224
+ paths = train_paths
225
+ case "val":
226
+ paths = val_paths
227
+ case "test":
228
+ paths = test_paths
229
+ case None:
230
+ paths = train_paths + val_paths + test_paths
231
+ case _:
232
+ raise ValueError("Invalid split. Use 'train', 'val' or `None`.")
233
+ return sorted([os.path.relpath(path, self._root) for path in paths])
234
+
235
+ def _get_id_from_path(self, file_path: str) -> str:
236
+ """Extracts the slide ID from the file path."""
237
+ return os.path.basename(file_path).replace(".tif", "")
238
+
239
+ def _get_class_from_path(self, file_path: str) -> str:
240
+ """Extracts the class name from the file path."""
241
+ class_name = self._get_id_from_path(file_path).split("_")[0]
242
+ if class_name not in self.classes:
243
+ raise ValueError(f"Invalid class name '{class_name}' in file path '{file_path}'.")
244
+ return class_name
@@ -3,7 +3,8 @@
3
3
  import os
4
4
  from typing import Callable, Dict, List, Literal, Tuple
5
5
 
6
- import numpy as np
6
+ import torch
7
+ from torchvision import tv_tensors
7
8
  from torchvision.datasets import folder, utils
8
9
  from typing_extensions import override
9
10
 
@@ -37,8 +38,7 @@ class CRC(base.ImageClassification):
37
38
  root: str,
38
39
  split: Literal["train", "val"],
39
40
  download: bool = False,
40
- image_transforms: Callable | None = None,
41
- target_transforms: Callable | None = None,
41
+ transforms: Callable | None = None,
42
42
  ) -> None:
43
43
  """Initializes the dataset.
44
44
 
@@ -56,15 +56,10 @@ class CRC(base.ImageClassification):
56
56
  Note that the download will be executed only by additionally
57
57
  calling the :meth:`prepare_data` method and if the data does
58
58
  not yet exist on disk.
59
- image_transforms: A function/transform that takes in an image
60
- and returns a transformed version.
61
- target_transforms: A function/transform that takes in the target
62
- and transforms it.
59
+ transforms: A function/transform which returns a transformed
60
+ version of the raw data samples.
63
61
  """
64
- super().__init__(
65
- image_transforms=image_transforms,
66
- target_transforms=target_transforms,
67
- )
62
+ super().__init__(transforms=transforms)
68
63
 
69
64
  self._root = root
70
65
  self._split = split
@@ -122,14 +117,14 @@ class CRC(base.ImageClassification):
122
117
  )
123
118
 
124
119
  @override
125
- def load_image(self, index: int) -> np.ndarray:
120
+ def load_image(self, index: int) -> tv_tensors.Image:
126
121
  image_path, _ = self._samples[index]
127
- return io.read_image(image_path)
122
+ return io.read_image_as_tensor(image_path)
128
123
 
129
124
  @override
130
- def load_target(self, index: int) -> np.ndarray:
125
+ def load_target(self, index: int) -> torch.Tensor:
131
126
  _, target = self._samples[index]
132
- return np.asarray(target, dtype=np.int64)
127
+ return torch.tensor(target, dtype=torch.long)
133
128
 
134
129
  @override
135
130
  def __len__(self) -> int:
@@ -3,7 +3,8 @@
3
3
  import os
4
4
  from typing import Callable, Dict, List, Literal, Tuple
5
5
 
6
- import numpy as np
6
+ import torch
7
+ from torchvision import tv_tensors
7
8
  from typing_extensions import override
8
9
 
9
10
  from eva.vision.data.datasets import _validators
@@ -18,23 +19,17 @@ class MHIST(base.ImageClassification):
18
19
  self,
19
20
  root: str,
20
21
  split: Literal["train", "test"],
21
- image_transforms: Callable | None = None,
22
- target_transforms: Callable | None = None,
22
+ transforms: Callable | None = None,
23
23
  ) -> None:
24
24
  """Initialize the dataset.
25
25
 
26
26
  Args:
27
27
  root: Path to the root directory of the dataset.
28
28
  split: Dataset split to use.
29
- image_transforms: A function/transform that takes in an image
30
- and returns a transformed version.
31
- target_transforms: A function/transform that takes in the target
32
- and transforms it.
29
+ transforms: A function/transform which returns a transformed
30
+ version of the raw data samples.
33
31
  """
34
- super().__init__(
35
- image_transforms=image_transforms,
36
- target_transforms=target_transforms,
37
- )
32
+ super().__init__(transforms=transforms)
38
33
 
39
34
  self._root = root
40
35
  self._split = split
@@ -74,16 +69,16 @@ class MHIST(base.ImageClassification):
74
69
  )
75
70
 
76
71
  @override
77
- def load_image(self, index: int) -> np.ndarray:
72
+ def load_image(self, index: int) -> tv_tensors.Image:
78
73
  image_filename, _ = self._samples[index]
79
74
  image_path = os.path.join(self._dataset_path, image_filename)
80
- return io.read_image(image_path)
75
+ return io.read_image_as_tensor(image_path)
81
76
 
82
77
  @override
83
- def load_target(self, index: int) -> np.ndarray:
78
+ def load_target(self, index: int) -> torch.Tensor:
84
79
  _, label = self._samples[index]
85
80
  target = self.class_to_idx[label]
86
- return np.asarray(target, dtype=np.int64)
81
+ return torch.tensor(target, dtype=torch.float32)
87
82
 
88
83
  @override
89
84
  def __len__(self) -> int:
@@ -0,0 +1,184 @@
1
+ """PANDA dataset class."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import pandas as pd
9
+ import torch
10
+ from torchvision import tv_tensors
11
+ from torchvision.datasets import utils
12
+ from torchvision.transforms.v2 import functional
13
+ from typing_extensions import override
14
+
15
+ from eva.core.data import splitting
16
+ from eva.vision.data.datasets import _validators, structs, wsi
17
+ from eva.vision.data.datasets.classification import base
18
+ from eva.vision.data.wsi.patching import samplers
19
+
20
+
21
+ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
22
+ """Dataset class for PANDA images and corresponding targets."""
23
+
24
+ _train_split_ratio: float = 0.7
25
+ """Train split ratio."""
26
+
27
+ _val_split_ratio: float = 0.15
28
+ """Validation split ratio."""
29
+
30
+ _test_split_ratio: float = 0.15
31
+ """Test split ratio."""
32
+
33
+ _resources: List[structs.DownloadResource] = [
34
+ structs.DownloadResource(
35
+ filename="train_with_noisy_labels.csv",
36
+ url="https://raw.githubusercontent.com/analokmaus/kaggle-panda-challenge-public/master/train.csv",
37
+ md5="5e4bfc78bda9603d2e2faf3ed4b21dfa",
38
+ )
39
+ ]
40
+ """Download resources."""
41
+
42
+ def __init__(
43
+ self,
44
+ root: str,
45
+ sampler: samplers.Sampler,
46
+ split: Literal["train", "val", "test"] | None = None,
47
+ width: int = 224,
48
+ height: int = 224,
49
+ target_mpp: float = 0.5,
50
+ backend: str = "openslide",
51
+ image_transforms: Callable | None = None,
52
+ seed: int = 42,
53
+ ) -> None:
54
+ """Initializes the dataset.
55
+
56
+ Args:
57
+ root: Root directory of the dataset.
58
+ sampler: The sampler to use for sampling patch coordinates.
59
+ split: Dataset split to use. If `None`, the entire dataset is used.
60
+ width: Width of the patches to be extracted, in pixels.
61
+ height: Height of the patches to be extracted, in pixels.
62
+ target_mpp: Target microns per pixel (mpp) for the patches.
63
+ backend: The backend to use for reading the whole-slide images.
64
+ image_transforms: Transforms to apply to the extracted image patches.
65
+ seed: Random seed for reproducibility.
66
+ """
67
+ self._split = split
68
+ self._root = root
69
+ self._seed = seed
70
+
71
+ self._download_resources()
72
+
73
+ wsi.MultiWsiDataset.__init__(
74
+ self,
75
+ root=root,
76
+ file_paths=self._load_file_paths(split),
77
+ width=width,
78
+ height=height,
79
+ sampler=sampler,
80
+ target_mpp=target_mpp,
81
+ backend=backend,
82
+ image_transforms=image_transforms,
83
+ )
84
+
85
+ @property
86
+ @override
87
+ def classes(self) -> List[str]:
88
+ return ["0", "1", "2", "3", "4", "5"]
89
+
90
+ @functools.cached_property
91
+ def annotations(self) -> pd.DataFrame:
92
+ """Loads the dataset labels."""
93
+ path = os.path.join(self._root, "train_with_noisy_labels.csv")
94
+ return pd.read_csv(path, index_col="image_id")
95
+
96
+ @override
97
+ def prepare_data(self) -> None:
98
+ _validators.check_dataset_exists(self._root, False)
99
+
100
+ if not os.path.isdir(os.path.join(self._root, "train_images")):
101
+ raise FileNotFoundError("'train_images' directory not found in the root folder.")
102
+ if not os.path.isfile(os.path.join(self._root, "train_with_noisy_labels.csv")):
103
+ raise FileNotFoundError("'train.csv' file not found in the root folder.")
104
+
105
+ def _download_resources(self) -> None:
106
+ """Downloads the dataset resources."""
107
+ for resource in self._resources:
108
+ utils.download_url(resource.url, self._root, resource.filename, resource.md5)
109
+
110
+ @override
111
+ def validate(self) -> None:
112
+ _validators.check_dataset_integrity(
113
+ self,
114
+ length=None,
115
+ n_classes=6,
116
+ first_and_last_labels=("0", "5"),
117
+ )
118
+
119
+ @override
120
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
121
+ return base.ImageClassification.__getitem__(self, index)
122
+
123
+ @override
124
+ def load_image(self, index: int) -> tv_tensors.Image:
125
+ image_array = wsi.MultiWsiDataset.__getitem__(self, index)
126
+ return functional.to_image(image_array)
127
+
128
+ @override
129
+ def load_target(self, index: int) -> torch.Tensor:
130
+ file_path = self._file_paths[self._get_dataset_idx(index)]
131
+ return torch.tensor(self._get_target_from_path(file_path), dtype=torch.int64)
132
+
133
+ @override
134
+ def load_metadata(self, index: int) -> Dict[str, Any]:
135
+ return {"wsi_id": self.filename(index).split(".")[0]}
136
+
137
+ def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
138
+ """Loads the file paths of the corresponding dataset split."""
139
+ image_dir = os.path.join(self._root, "train_images")
140
+ file_paths = sorted(glob.glob(os.path.join(image_dir, "*.tiff")))
141
+ file_paths = [os.path.relpath(path, self._root) for path in file_paths]
142
+ if len(file_paths) != len(self.annotations):
143
+ raise ValueError(
144
+ f"Expected {len(self.annotations)} images, found {len(file_paths)} in {image_dir}."
145
+ )
146
+ file_paths = self._filter_noisy_labels(file_paths)
147
+ targets = [self._get_target_from_path(file_path) for file_path in file_paths]
148
+
149
+ train_indices, val_indices, test_indices = splitting.stratified_split(
150
+ samples=file_paths,
151
+ targets=targets,
152
+ train_ratio=self._train_split_ratio,
153
+ val_ratio=self._val_split_ratio,
154
+ test_ratio=self._test_split_ratio,
155
+ seed=self._seed,
156
+ )
157
+
158
+ match split:
159
+ case "train":
160
+ return [file_paths[i] for i in train_indices]
161
+ case "val":
162
+ return [file_paths[i] for i in val_indices]
163
+ case "test":
164
+ return [file_paths[i] for i in test_indices or []]
165
+ case None:
166
+ return file_paths
167
+ case _:
168
+ raise ValueError("Invalid split. Use 'train', 'val', 'test' or `None`.")
169
+
170
+ def _filter_noisy_labels(self, file_paths: List[str]):
171
+ is_noisy_filter = self.annotations["noise_ratio_10"] == 0
172
+ non_noisy_image_ids = set(self.annotations.loc[~is_noisy_filter].index)
173
+ filtered_file_paths = [
174
+ file_path
175
+ for file_path in file_paths
176
+ if self._get_id_from_path(file_path) in non_noisy_image_ids
177
+ ]
178
+ return filtered_file_paths
179
+
180
+ def _get_target_from_path(self, file_path: str) -> int:
181
+ return self.annotations.loc[self._get_id_from_path(file_path), "isup_grade"]
182
+
183
+ def _get_id_from_path(self, file_path: str) -> str:
184
+ return os.path.basename(file_path).replace(".tiff", "")
@@ -4,8 +4,10 @@ import os
4
4
  from typing import Callable, Dict, List, Literal
5
5
 
6
6
  import h5py
7
- import numpy as np
7
+ import torch
8
+ from torchvision import tv_tensors
8
9
  from torchvision.datasets import utils
10
+ from torchvision.transforms.v2 import functional
9
11
  from typing_extensions import override
10
12
 
11
13
  from eva.vision.data.datasets import _validators, structs
@@ -70,8 +72,7 @@ class PatchCamelyon(base.ImageClassification):
70
72
  root: str,
71
73
  split: Literal["train", "val", "test"],
72
74
  download: bool = False,
73
- image_transforms: Callable | None = None,
74
- target_transforms: Callable | None = None,
75
+ transforms: Callable | None = None,
75
76
  ) -> None:
76
77
  """Initializes the dataset.
77
78
 
@@ -82,15 +83,10 @@ class PatchCamelyon(base.ImageClassification):
82
83
  download: Whether to download the data for the specified split.
83
84
  Note that the download will be executed only by additionally
84
85
  calling the :meth:`prepare_data` method.
85
- image_transforms: A function/transform that takes in an image
86
- and returns a transformed version.
87
- target_transforms: A function/transform that takes in the target
88
- and transforms it.
86
+ transforms: A function/transform which returns a transformed
87
+ version of the raw data samples.
89
88
  """
90
- super().__init__(
91
- image_transforms=image_transforms,
92
- target_transforms=target_transforms,
93
- )
89
+ super().__init__(transforms=transforms)
94
90
 
95
91
  self._root = root
96
92
  self._split = split
@@ -131,13 +127,13 @@ class PatchCamelyon(base.ImageClassification):
131
127
  )
132
128
 
133
129
  @override
134
- def load_image(self, index: int) -> np.ndarray:
130
+ def load_image(self, index: int) -> tv_tensors.Image:
135
131
  return self._load_from_h5("x", index)
136
132
 
137
133
  @override
138
- def load_target(self, index: int) -> np.ndarray:
134
+ def load_target(self, index: int) -> torch.Tensor:
139
135
  target = self._load_from_h5("y", index).squeeze()
140
- return np.asarray(target, dtype=np.int64)
136
+ return torch.tensor(target, dtype=torch.float32)
141
137
 
142
138
  @override
143
139
  def __len__(self) -> int:
@@ -162,7 +158,7 @@ class PatchCamelyon(base.ImageClassification):
162
158
  self,
163
159
  data_key: Literal["x", "y"],
164
160
  index: int | None = None,
165
- ) -> np.ndarray:
161
+ ) -> tv_tensors.Image:
166
162
  """Load data or targets from an HDF5 file.
167
163
 
168
164
  Args:
@@ -176,7 +172,8 @@ class PatchCamelyon(base.ImageClassification):
176
172
  h5_file = self._h5_file(data_key)
177
173
  with h5py.File(h5_file, "r") as file:
178
174
  data = file[data_key]
179
- return data[:] if index is None else data[index] # type: ignore
175
+ image_array = data[:] if index is None else data[index] # type: ignore
176
+ return functional.to_image(image_array) # type: ignore
180
177
 
181
178
  def _fetch_dataset_length(self) -> int:
182
179
  """Fetches the dataset split length from its HDF5 file."""