kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,184 @@
1
+ """PANDA dataset class."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import pandas as pd
9
+ import torch
10
+ from torchvision import tv_tensors
11
+ from torchvision.datasets import utils
12
+ from torchvision.transforms.v2 import functional
13
+ from typing_extensions import override
14
+
15
+ from eva.core.data import splitting
16
+ from eva.vision.data.datasets import _validators, structs, wsi
17
+ from eva.vision.data.datasets.classification import base
18
+ from eva.vision.data.wsi.patching import samplers
19
+
20
+
21
+ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
22
+ """Dataset class for PANDA images and corresponding targets."""
23
+
24
+ _train_split_ratio: float = 0.7
25
+ """Train split ratio."""
26
+
27
+ _val_split_ratio: float = 0.15
28
+ """Validation split ratio."""
29
+
30
+ _test_split_ratio: float = 0.15
31
+ """Test split ratio."""
32
+
33
+ _resources: List[structs.DownloadResource] = [
34
+ structs.DownloadResource(
35
+ filename="train_with_noisy_labels.csv",
36
+ url="https://raw.githubusercontent.com/analokmaus/kaggle-panda-challenge-public/master/train.csv",
37
+ md5="5e4bfc78bda9603d2e2faf3ed4b21dfa",
38
+ )
39
+ ]
40
+ """Download resources."""
41
+
42
+ def __init__(
43
+ self,
44
+ root: str,
45
+ sampler: samplers.Sampler,
46
+ split: Literal["train", "val", "test"] | None = None,
47
+ width: int = 224,
48
+ height: int = 224,
49
+ target_mpp: float = 0.5,
50
+ backend: str = "openslide",
51
+ image_transforms: Callable | None = None,
52
+ seed: int = 42,
53
+ ) -> None:
54
+ """Initializes the dataset.
55
+
56
+ Args:
57
+ root: Root directory of the dataset.
58
+ sampler: The sampler to use for sampling patch coordinates.
59
+ split: Dataset split to use. If `None`, the entire dataset is used.
60
+ width: Width of the patches to be extracted, in pixels.
61
+ height: Height of the patches to be extracted, in pixels.
62
+ target_mpp: Target microns per pixel (mpp) for the patches.
63
+ backend: The backend to use for reading the whole-slide images.
64
+ image_transforms: Transforms to apply to the extracted image patches.
65
+ seed: Random seed for reproducibility.
66
+ """
67
+ self._split = split
68
+ self._root = root
69
+ self._seed = seed
70
+
71
+ self._download_resources()
72
+
73
+ wsi.MultiWsiDataset.__init__(
74
+ self,
75
+ root=root,
76
+ file_paths=self._load_file_paths(split),
77
+ width=width,
78
+ height=height,
79
+ sampler=sampler,
80
+ target_mpp=target_mpp,
81
+ backend=backend,
82
+ image_transforms=image_transforms,
83
+ )
84
+
85
+ @property
86
+ @override
87
+ def classes(self) -> List[str]:
88
+ return ["0", "1", "2", "3", "4", "5"]
89
+
90
+ @functools.cached_property
91
+ def annotations(self) -> pd.DataFrame:
92
+ """Loads the dataset labels."""
93
+ path = os.path.join(self._root, "train_with_noisy_labels.csv")
94
+ return pd.read_csv(path, index_col="image_id")
95
+
96
+ @override
97
+ def prepare_data(self) -> None:
98
+ _validators.check_dataset_exists(self._root, False)
99
+
100
+ if not os.path.isdir(os.path.join(self._root, "train_images")):
101
+ raise FileNotFoundError("'train_images' directory not found in the root folder.")
102
+ if not os.path.isfile(os.path.join(self._root, "train_with_noisy_labels.csv")):
103
+ raise FileNotFoundError("'train.csv' file not found in the root folder.")
104
+
105
+ def _download_resources(self) -> None:
106
+ """Downloads the dataset resources."""
107
+ for resource in self._resources:
108
+ utils.download_url(resource.url, self._root, resource.filename, resource.md5)
109
+
110
+ @override
111
+ def validate(self) -> None:
112
+ _validators.check_dataset_integrity(
113
+ self,
114
+ length=None,
115
+ n_classes=6,
116
+ first_and_last_labels=("0", "5"),
117
+ )
118
+
119
+ @override
120
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
121
+ return base.ImageClassification.__getitem__(self, index)
122
+
123
+ @override
124
+ def load_image(self, index: int) -> tv_tensors.Image:
125
+ image_array = wsi.MultiWsiDataset.__getitem__(self, index)
126
+ return functional.to_image(image_array)
127
+
128
+ @override
129
+ def load_target(self, index: int) -> torch.Tensor:
130
+ file_path = self._file_paths[self._get_dataset_idx(index)]
131
+ return torch.tensor(self._get_target_from_path(file_path), dtype=torch.int64)
132
+
133
+ @override
134
+ def load_metadata(self, index: int) -> Dict[str, Any]:
135
+ return {"wsi_id": self.filename(index).split(".")[0]}
136
+
137
+ def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
138
+ """Loads the file paths of the corresponding dataset split."""
139
+ image_dir = os.path.join(self._root, "train_images")
140
+ file_paths = sorted(glob.glob(os.path.join(image_dir, "*.tiff")))
141
+ file_paths = [os.path.relpath(path, self._root) for path in file_paths]
142
+ if len(file_paths) != len(self.annotations):
143
+ raise ValueError(
144
+ f"Expected {len(self.annotations)} images, found {len(file_paths)} in {image_dir}."
145
+ )
146
+ file_paths = self._filter_noisy_labels(file_paths)
147
+ targets = [self._get_target_from_path(file_path) for file_path in file_paths]
148
+
149
+ train_indices, val_indices, test_indices = splitting.stratified_split(
150
+ samples=file_paths,
151
+ targets=targets,
152
+ train_ratio=self._train_split_ratio,
153
+ val_ratio=self._val_split_ratio,
154
+ test_ratio=self._test_split_ratio,
155
+ seed=self._seed,
156
+ )
157
+
158
+ match split:
159
+ case "train":
160
+ return [file_paths[i] for i in train_indices]
161
+ case "val":
162
+ return [file_paths[i] for i in val_indices]
163
+ case "test":
164
+ return [file_paths[i] for i in test_indices or []]
165
+ case None:
166
+ return file_paths
167
+ case _:
168
+ raise ValueError("Invalid split. Use 'train', 'val', 'test' or `None`.")
169
+
170
+ def _filter_noisy_labels(self, file_paths: List[str]):
171
+ is_noisy_filter = self.annotations["noise_ratio_10"] == 0
172
+ non_noisy_image_ids = set(self.annotations.loc[~is_noisy_filter].index)
173
+ filtered_file_paths = [
174
+ file_path
175
+ for file_path in file_paths
176
+ if self._get_id_from_path(file_path) in non_noisy_image_ids
177
+ ]
178
+ return filtered_file_paths
179
+
180
+ def _get_target_from_path(self, file_path: str) -> int:
181
+ return self.annotations.loc[self._get_id_from_path(file_path), "isup_grade"]
182
+
183
+ def _get_id_from_path(self, file_path: str) -> str:
184
+ return os.path.basename(file_path).replace(".tiff", "")
@@ -4,8 +4,10 @@ import os
4
4
  from typing import Callable, Dict, List, Literal
5
5
 
6
6
  import h5py
7
- import numpy as np
7
+ import torch
8
+ from torchvision import tv_tensors
8
9
  from torchvision.datasets import utils
10
+ from torchvision.transforms.v2 import functional
9
11
  from typing_extensions import override
10
12
 
11
13
  from eva.vision.data.datasets import _validators, structs
@@ -70,8 +72,7 @@ class PatchCamelyon(base.ImageClassification):
70
72
  root: str,
71
73
  split: Literal["train", "val", "test"],
72
74
  download: bool = False,
73
- image_transforms: Callable | None = None,
74
- target_transforms: Callable | None = None,
75
+ transforms: Callable | None = None,
75
76
  ) -> None:
76
77
  """Initializes the dataset.
77
78
 
@@ -82,15 +83,10 @@ class PatchCamelyon(base.ImageClassification):
82
83
  download: Whether to download the data for the specified split.
83
84
  Note that the download will be executed only by additionally
84
85
  calling the :meth:`prepare_data` method.
85
- image_transforms: A function/transform that takes in an image
86
- and returns a transformed version.
87
- target_transforms: A function/transform that takes in the target
88
- and transforms it.
86
+ transforms: A function/transform which returns a transformed
87
+ version of the raw data samples.
89
88
  """
90
- super().__init__(
91
- image_transforms=image_transforms,
92
- target_transforms=target_transforms,
93
- )
89
+ super().__init__(transforms=transforms)
94
90
 
95
91
  self._root = root
96
92
  self._split = split
@@ -131,13 +127,13 @@ class PatchCamelyon(base.ImageClassification):
131
127
  )
132
128
 
133
129
  @override
134
- def load_image(self, index: int) -> np.ndarray:
130
+ def load_image(self, index: int) -> tv_tensors.Image:
135
131
  return self._load_from_h5("x", index)
136
132
 
137
133
  @override
138
- def load_target(self, index: int) -> np.ndarray:
134
+ def load_target(self, index: int) -> torch.Tensor:
139
135
  target = self._load_from_h5("y", index).squeeze()
140
- return np.asarray(target, dtype=np.int64)
136
+ return torch.tensor(target, dtype=torch.float32)
141
137
 
142
138
  @override
143
139
  def __len__(self) -> int:
@@ -162,7 +158,7 @@ class PatchCamelyon(base.ImageClassification):
162
158
  self,
163
159
  data_key: Literal["x", "y"],
164
160
  index: int | None = None,
165
- ) -> np.ndarray:
161
+ ) -> tv_tensors.Image:
166
162
  """Load data or targets from an HDF5 file.
167
163
 
168
164
  Args:
@@ -176,7 +172,8 @@ class PatchCamelyon(base.ImageClassification):
176
172
  h5_file = self._h5_file(data_key)
177
173
  with h5py.File(h5_file, "r") as file:
178
174
  data = file[data_key]
179
- return data[:] if index is None else data[index] # type: ignore
175
+ image_array = data[:] if index is None else data[index] # type: ignore
176
+ return functional.to_image(image_array) # type: ignore
180
177
 
181
178
  def _fetch_dataset_length(self) -> int:
182
179
  """Fetches the dataset split length from its HDF5 file."""
@@ -0,0 +1,105 @@
1
+ """WSI classification dataset."""
2
+
3
+ import os
4
+ from typing import Any, Callable, Dict, Literal, Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from torchvision import tv_tensors
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data.datasets import wsi
13
+ from eva.vision.data.datasets.classification import base
14
+ from eva.vision.data.wsi.patching import samplers
15
+
16
+
17
+ class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
18
+ """A general dataset class for whole-slide image classification using manifest files."""
19
+
20
+ default_column_mapping: Dict[str, str] = {
21
+ "path": "path",
22
+ "target": "target",
23
+ "split": "split",
24
+ }
25
+
26
+ def __init__(
27
+ self,
28
+ root: str,
29
+ manifest_file: str,
30
+ width: int,
31
+ height: int,
32
+ target_mpp: float,
33
+ sampler: samplers.Sampler,
34
+ backend: str = "openslide",
35
+ split: Literal["train", "val", "test"] | None = None,
36
+ image_transforms: Callable | None = None,
37
+ column_mapping: Dict[str, str] = default_column_mapping,
38
+ ):
39
+ """Initializes the dataset.
40
+
41
+ Args:
42
+ root: Root directory of the dataset.
43
+ manifest_file: The path to the manifest file, relative to
44
+ the `root` argument. The `path` column is expected to contain
45
+ relative paths to the whole-slide images.
46
+ width: Width of the patches to be extracted, in pixels.
47
+ height: Height of the patches to be extracted, in pixels.
48
+ target_mpp: Target microns per pixel (mpp) for the patches.
49
+ sampler: The sampler to use for sampling patch coordinates.
50
+ backend: The backend to use for reading the whole-slide images.
51
+ split: The split of the dataset to load.
52
+ image_transforms: Transforms to apply to the extracted image patches.
53
+ column_mapping: Mapping of the columns in the manifest file.
54
+ """
55
+ self._split = split
56
+ self._column_mapping = self.default_column_mapping | column_mapping
57
+ self._manifest = self._load_manifest(os.path.join(root, manifest_file))
58
+
59
+ wsi.MultiWsiDataset.__init__(
60
+ self,
61
+ root=root,
62
+ file_paths=self._manifest[self._column_mapping["path"]].tolist(),
63
+ width=width,
64
+ height=height,
65
+ sampler=sampler,
66
+ target_mpp=target_mpp,
67
+ backend=backend,
68
+ image_transforms=image_transforms,
69
+ )
70
+
71
+ @override
72
+ def filename(self, index: int) -> str:
73
+ path = self._manifest.at[self._get_dataset_idx(index), self._column_mapping["path"]]
74
+ return os.path.basename(path) if os.path.isabs(path) else path
75
+
76
+ @override
77
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
78
+ return base.ImageClassification.__getitem__(self, index)
79
+
80
+ @override
81
+ def load_image(self, index: int) -> tv_tensors.Image:
82
+ return wsi.MultiWsiDataset.__getitem__(self, index)
83
+
84
+ @override
85
+ def load_target(self, index: int) -> np.ndarray:
86
+ target = self._manifest.at[self._get_dataset_idx(index), self._column_mapping["target"]]
87
+ return np.asarray(target)
88
+
89
+ @override
90
+ def load_metadata(self, index: int) -> Dict[str, Any]:
91
+ return {"wsi_id": self.filename(index).split(".")[0]}
92
+
93
+ def _load_manifest(self, manifest_path: str) -> pd.DataFrame:
94
+ df = pd.read_csv(manifest_path)
95
+
96
+ missing_columns = set(self._column_mapping.values()) - set(df.columns)
97
+ if self._split is None:
98
+ missing_columns = missing_columns - {self._column_mapping["split"]}
99
+ if missing_columns:
100
+ raise ValueError(f"Missing columns in the manifest file: {missing_columns}")
101
+
102
+ if self._split is not None:
103
+ df = df.loc[df[self._column_mapping["split"]] == self._split]
104
+
105
+ return df.reset_index(drop=True)
@@ -1,6 +1,19 @@
1
1
  """Segmentation datasets API."""
2
2
 
3
3
  from eva.vision.data.datasets.segmentation.base import ImageSegmentation
4
- from eva.vision.data.datasets.segmentation.total_segmentator import TotalSegmentator2D
4
+ from eva.vision.data.datasets.segmentation.bcss import BCSS
5
+ from eva.vision.data.datasets.segmentation.consep import CoNSeP
6
+ from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
7
+ from eva.vision.data.datasets.segmentation.lits import LiTS
8
+ from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
9
+ from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
5
10
 
6
- __all__ = ["ImageSegmentation", "TotalSegmentator2D"]
11
+ __all__ = [
12
+ "ImageSegmentation",
13
+ "BCSS",
14
+ "CoNSeP",
15
+ "EmbeddingsSegmentationDataset",
16
+ "LiTS",
17
+ "MoNuSAC",
18
+ "TotalSegmentator2D",
19
+ ]
@@ -0,0 +1,38 @@
1
+ from typing import Any, Tuple
2
+
3
+ import numpy.typing as npt
4
+
5
+ from eva.vision.data.datasets import wsi
6
+
7
+
8
+ def get_coords_at_index(
9
+ dataset: wsi.MultiWsiDataset, index: int
10
+ ) -> Tuple[Tuple[int, int], int, int]:
11
+ """Returns the coordinates ((x,y),width,height) of the patch at the given index.
12
+
13
+ Args:
14
+ dataset: The WSI dataset instance.
15
+ index: The sample index.
16
+ """
17
+ image_index = dataset._get_dataset_idx(index)
18
+ patch_index = index if image_index == 0 else index - dataset.cumulative_sizes[image_index - 1]
19
+ wsi_dataset = dataset.datasets[image_index]
20
+ if isinstance(wsi_dataset, wsi.WsiDataset):
21
+ coords = wsi_dataset._coords
22
+ return coords.x_y[patch_index], coords.width, coords.height
23
+ else:
24
+ raise Exception(f"Expected WsiDataset, got {type(wsi_dataset)}")
25
+
26
+
27
+ def extract_mask_patch(
28
+ mask: npt.NDArray[Any], dataset: wsi.MultiWsiDataset, index: int
29
+ ) -> npt.NDArray[Any]:
30
+ """Reads the mask patch at the coordinates corresponding to the dataset index.
31
+
32
+ Args:
33
+ mask: The mask array.
34
+ dataset: The WSI dataset instance.
35
+ index: The sample index.
36
+ """
37
+ (x, y), width, height = get_coords_at_index(dataset, index)
38
+ return mask[y : y + height, x : x + width]
@@ -3,38 +3,25 @@
3
3
  import abc
4
4
  from typing import Any, Callable, Dict, List, Tuple
5
5
 
6
- import numpy as np
6
+ from torchvision import tv_tensors
7
7
  from typing_extensions import override
8
8
 
9
9
  from eva.vision.data.datasets import vision
10
10
 
11
11
 
12
- class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
12
+ class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.Mask]], abc.ABC):
13
13
  """Image segmentation abstract dataset."""
14
14
 
15
- def __init__(
16
- self,
17
- image_transforms: Callable | None = None,
18
- target_transforms: Callable | None = None,
19
- image_target_transforms: Callable | None = None,
20
- ) -> None:
15
+ def __init__(self, transforms: Callable | None = None) -> None:
21
16
  """Initializes the image segmentation base class.
22
17
 
23
18
  Args:
24
- image_transforms: A function/transform that takes in an image
25
- and returns a transformed version.
26
- target_transforms: A function/transform that takes in the target
27
- and transforms it.
28
- image_target_transforms: A function/transforms that takes in an
19
+ transforms: A function/transforms that takes in an
29
20
  image and a label and returns the transformed versions of both.
30
- This transform happens after the `image_transforms` and
31
- `target_transforms`.
32
21
  """
33
22
  super().__init__()
34
23
 
35
- self._image_transforms = image_transforms
36
- self._target_transforms = target_transforms
37
- self._image_target_transforms = image_target_transforms
24
+ self._transforms = transforms
38
25
 
39
26
  @property
40
27
  def classes(self) -> List[str] | None:
@@ -44,37 +31,38 @@ class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc
44
31
  def class_to_idx(self) -> Dict[str, int] | None:
45
32
  """Returns a mapping of the class name to its target index."""
46
33
 
47
- def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
48
- """Returns the dataset metadata.
34
+ @abc.abstractmethod
35
+ def load_image(self, index: int) -> tv_tensors.Image:
36
+ """Loads and returns the `index`'th image sample.
49
37
 
50
38
  Args:
51
- index: The index of the data sample to return the metadata of.
52
- If `None`, it will return the metadata of the current dataset.
39
+ index: The index of the data sample to load.
53
40
 
54
41
  Returns:
55
- The sample metadata.
42
+ An image torchvision tensor (channels, height, width).
56
43
  """
57
44
 
58
45
  @abc.abstractmethod
59
- def load_image(self, index: int) -> np.ndarray:
60
- """Loads and returns the `index`'th image sample.
46
+ def load_mask(self, index: int) -> tv_tensors.Mask:
47
+ """Returns the `index`'th target masks sample.
61
48
 
62
49
  Args:
63
- index: The index of the data sample to load.
50
+ index: The index of the data sample target masks to load.
64
51
 
65
52
  Returns:
66
- The image as a numpy array.
53
+ The semantic mask as a (H x W) shaped tensor with integer
54
+ values which represent the pixel class id.
67
55
  """
68
56
 
69
- @abc.abstractmethod
70
- def load_mask(self, index: int) -> np.ndarray:
71
- """Returns the `index`'th target mask sample.
57
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
58
+ """Returns the dataset metadata.
72
59
 
73
60
  Args:
74
- index: The index of the data sample target mask to load.
61
+ index: The index of the data sample to return the metadata of.
62
+ If `None`, it will return the metadata of the current dataset.
75
63
 
76
64
  Returns:
77
- The sample mask as a stack of binary mask arrays (label, height, width).
65
+ The sample metadata.
78
66
  """
79
67
 
80
68
  @abc.abstractmethod
@@ -83,30 +71,26 @@ class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc
83
71
  raise NotImplementedError
84
72
 
85
73
  @override
86
- def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
74
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
87
75
  image = self.load_image(index)
88
76
  mask = self.load_mask(index)
89
- return self._apply_transforms(image, mask)
77
+ metadata = self.load_metadata(index) or {}
78
+ image_tensor, mask_tensor = self._apply_transforms(image, mask)
79
+ return image_tensor, mask_tensor, metadata
90
80
 
91
81
  def _apply_transforms(
92
- self, image: np.ndarray, target: np.ndarray
93
- ) -> Tuple[np.ndarray, np.ndarray]:
82
+ self, image: tv_tensors.Image, mask: tv_tensors.Mask
83
+ ) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
94
84
  """Applies the transforms to the provided data and returns them.
95
85
 
96
86
  Args:
97
87
  image: The desired image.
98
- target: The target of the image.
88
+ mask: The target segmentation mask.
99
89
 
100
90
  Returns:
101
- A tuple with the image and the target transformed.
91
+ A tuple with the image and the masks transformed.
102
92
  """
103
- if self._image_transforms is not None:
104
- image = self._image_transforms(image)
105
-
106
- if self._target_transforms is not None:
107
- target = self._target_transforms(target)
108
-
109
- if self._image_target_transforms is not None:
110
- image, target = self._image_target_transforms(image, target)
93
+ if self._transforms is not None:
94
+ image, mask = self._transforms(image, mask)
111
95
 
112
- return image, target
96
+ return image, mask