kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/data/datasets/classification/embeddings.py +2 -2
  3. eva/core/data/datasets/classification/multi_embeddings.py +2 -2
  4. eva/core/data/datasets/embeddings.py +4 -4
  5. eva/core/data/samplers/classification/balanced.py +19 -18
  6. eva/core/loggers/utils/wandb.py +33 -0
  7. eva/core/models/modules/head.py +5 -3
  8. eva/core/models/modules/typings.py +2 -2
  9. eva/core/models/transforms/__init__.py +2 -1
  10. eva/core/models/transforms/as_discrete.py +57 -0
  11. eva/core/models/wrappers/_utils.py +121 -1
  12. eva/core/trainers/functional.py +8 -5
  13. eva/core/trainers/trainer.py +32 -17
  14. eva/core/utils/suppress_logs.py +28 -0
  15. eva/vision/data/__init__.py +2 -2
  16. eva/vision/data/dataloaders/__init__.py +5 -0
  17. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  18. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  19. eva/vision/data/datasets/__init__.py +10 -2
  20. eva/vision/data/datasets/classification/__init__.py +9 -0
  21. eva/vision/data/datasets/classification/bach.py +3 -4
  22. eva/vision/data/datasets/classification/bracs.py +111 -0
  23. eva/vision/data/datasets/classification/breakhis.py +209 -0
  24. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  25. eva/vision/data/datasets/classification/crc.py +3 -4
  26. eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
  27. eva/vision/data/datasets/classification/mhist.py +3 -4
  28. eva/vision/data/datasets/classification/panda.py +4 -5
  29. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  30. eva/vision/data/datasets/classification/unitopatho.py +158 -0
  31. eva/vision/data/datasets/classification/wsi.py +6 -5
  32. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  33. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  34. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  35. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  36. eva/vision/data/datasets/segmentation/consep.py +6 -7
  37. eva/vision/data/datasets/segmentation/embeddings.py +2 -2
  38. eva/vision/data/datasets/segmentation/lits.py +9 -8
  39. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  40. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  41. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  42. eva/vision/data/datasets/vision.py +95 -4
  43. eva/vision/data/datasets/wsi.py +5 -5
  44. eva/vision/data/transforms/__init__.py +22 -3
  45. eva/vision/data/transforms/common/__init__.py +1 -2
  46. eva/vision/data/transforms/croppad/__init__.py +11 -0
  47. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  48. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  49. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  50. eva/vision/data/transforms/intensity/__init__.py +11 -0
  51. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  52. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  53. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  54. eva/vision/data/transforms/spatial/__init__.py +7 -0
  55. eva/vision/data/transforms/spatial/flip.py +72 -0
  56. eva/vision/data/transforms/spatial/rotate.py +53 -0
  57. eva/vision/data/transforms/spatial/spacing.py +69 -0
  58. eva/vision/data/transforms/utility/__init__.py +5 -0
  59. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  60. eva/vision/data/tv_tensors/__init__.py +5 -0
  61. eva/vision/data/tv_tensors/volume.py +61 -0
  62. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  63. eva/vision/models/modules/semantic_segmentation.py +28 -20
  64. eva/vision/models/networks/backbones/__init__.py +9 -2
  65. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  66. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  67. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  68. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  69. eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
  70. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  72. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  73. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  74. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  75. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  76. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  77. eva/vision/utils/io/__init__.py +2 -0
  78. eva/vision/utils/io/nifti.py +91 -11
  79. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  80. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
  81. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  82. eva/vision/data/datasets/classification/base.py +0 -96
  83. eva/vision/data/datasets/segmentation/base.py +0 -96
  84. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  85. eva/vision/data/transforms/normalization/__init__.py +0 -6
  86. eva/vision/data/transforms/normalization/clamp.py +0 -43
  87. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  88. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  89. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  90. eva/vision/metrics/segmentation/BUILD +0 -1
  91. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  92. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  93. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  94. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,18 +1,27 @@
1
1
  """Image classification datasets API."""
2
2
 
3
3
  from eva.vision.data.datasets.classification.bach import BACH
4
+ from eva.vision.data.datasets.classification.bracs import BRACS
5
+ from eva.vision.data.datasets.classification.breakhis import BreaKHis
4
6
  from eva.vision.data.datasets.classification.camelyon16 import Camelyon16
5
7
  from eva.vision.data.datasets.classification.crc import CRC
8
+ from eva.vision.data.datasets.classification.gleason_arvaniti import GleasonArvaniti
6
9
  from eva.vision.data.datasets.classification.mhist import MHIST
7
10
  from eva.vision.data.datasets.classification.panda import PANDA, PANDASmall
8
11
  from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
12
+ from eva.vision.data.datasets.classification.unitopatho import UniToPatho
9
13
  from eva.vision.data.datasets.classification.wsi import WsiClassificationDataset
10
14
 
11
15
  __all__ = [
12
16
  "BACH",
17
+ "BreaKHis",
18
+ "BRACS",
19
+ "Camelyon16",
13
20
  "CRC",
21
+ "GleasonArvaniti",
14
22
  "MHIST",
15
23
  "PatchCamelyon",
24
+ "UniToPatho",
16
25
  "WsiClassificationDataset",
17
26
  "PANDA",
18
27
  "PANDASmall",
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
8
8
  from torchvision.datasets import folder, utils
9
9
  from typing_extensions import override
10
10
 
11
- from eva.vision.data.datasets import _utils, _validators, structs
12
- from eva.vision.data.datasets.classification import base
11
+ from eva.vision.data.datasets import _utils, _validators, structs, vision
13
12
  from eva.vision.utils import io
14
13
 
15
14
 
16
- class BACH(base.ImageClassification):
15
+ class BACH(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
17
16
  """Dataset class for BACH images and corresponding targets."""
18
17
 
19
18
  _train_index_ranges: List[Tuple[int, int]] = [
@@ -125,7 +124,7 @@ class BACH(base.ImageClassification):
125
124
  )
126
125
 
127
126
  @override
128
- def load_image(self, index: int) -> tv_tensors.Image:
127
+ def load_data(self, index: int) -> tv_tensors.Image:
129
128
  image_path, _ = self._samples[self._indices[index]]
130
129
  return io.read_image_as_tensor(image_path)
131
130
 
@@ -0,0 +1,111 @@
1
+ """BRACS dataset class."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, List, Literal, Tuple
5
+
6
+ import torch
7
+ from torchvision import tv_tensors
8
+ from torchvision.datasets import folder
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data.datasets import _validators, vision
12
+ from eva.vision.utils import io
13
+
14
+
15
+ class BRACS(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
16
+ """Dataset class for BRACS images and corresponding targets."""
17
+
18
+ _expected_dataset_lengths: Dict[str, int] = {
19
+ "train": 3657,
20
+ "val": 312,
21
+ "test": 570,
22
+ }
23
+ """Expected dataset lengths for the splits and complete dataset."""
24
+
25
+ _license: str = "CC BY-NC 4.0 (https://creativecommons.org/licenses/by-nc/4.0/)"
26
+ """Dataset license."""
27
+
28
+ def __init__(
29
+ self,
30
+ root: str,
31
+ split: Literal["train", "val", "test"],
32
+ transforms: Callable | None = None,
33
+ ) -> None:
34
+ """Initializes the dataset.
35
+
36
+ Args:
37
+ root: Path to the root directory of the dataset.
38
+ split: Dataset split to use.
39
+ transforms: A function/transform which returns a transformed
40
+ version of the raw data samples.
41
+ """
42
+ super().__init__(transforms=transforms)
43
+
44
+ self._root = root
45
+ self._split = split
46
+
47
+ self._samples: List[Tuple[str, int]] = []
48
+
49
+ @property
50
+ @override
51
+ def classes(self) -> List[str]:
52
+ return ["0_N", "1_PB", "2_UDH", "3_FEA", "4_ADH", "5_DCIS", "6_IC"]
53
+
54
+ @property
55
+ @override
56
+ def class_to_idx(self) -> Dict[str, int]:
57
+ return {name: index for index, name in enumerate(self.classes)}
58
+
59
+ @override
60
+ def filename(self, index: int) -> str:
61
+ image_path, *_ = self._samples[index]
62
+ return os.path.relpath(image_path, self._dataset_path)
63
+
64
+ @override
65
+ def prepare_data(self) -> None:
66
+ _validators.check_dataset_exists(self._root, True)
67
+
68
+ @override
69
+ def configure(self) -> None:
70
+ self._samples = self._make_dataset()
71
+
72
+ @override
73
+ def validate(self) -> None:
74
+ _validators.check_dataset_integrity(
75
+ self,
76
+ length=self._expected_dataset_lengths[self._split],
77
+ n_classes=7,
78
+ first_and_last_labels=("0_N", "6_IC"),
79
+ )
80
+
81
+ @override
82
+ def load_data(self, index: int) -> tv_tensors.Image:
83
+ image_path, _ = self._samples[index]
84
+ return io.read_image_as_tensor(image_path)
85
+
86
+ @override
87
+ def load_target(self, index: int) -> torch.Tensor:
88
+ _, target = self._samples[index]
89
+ return torch.tensor(target, dtype=torch.long)
90
+
91
+ @override
92
+ def __len__(self) -> int:
93
+ return len(self._samples)
94
+
95
+ @property
96
+ def _dataset_path(self) -> str:
97
+ """Returns the full path of dataset directory."""
98
+ return os.path.join(self._root, "BRACS_RoI/latest_version")
99
+
100
+ def _make_dataset(self) -> List[Tuple[str, int]]:
101
+ """Builds the dataset for the specified split."""
102
+ dataset = folder.make_dataset(
103
+ directory=os.path.join(self._dataset_path, self._split),
104
+ class_to_idx=self.class_to_idx,
105
+ extensions=(".png"),
106
+ )
107
+ return dataset
108
+
109
+ def _print_license(self) -> None:
110
+ """Prints the dataset license."""
111
+ print(f"Dataset license: {self._license}")
@@ -0,0 +1,209 @@
1
+ """BreaKHis dataset class."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Literal, Set
7
+
8
+ import torch
9
+ from torchvision import tv_tensors
10
+ from torchvision.datasets import utils
11
+ from typing_extensions import override
12
+
13
+ from eva.vision.data.datasets import _validators, structs, vision
14
+ from eva.vision.utils import io
15
+
16
+
17
+ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
18
+ """Dataset class for BreaKHis images and corresponding targets."""
19
+
20
+ _resources: List[structs.DownloadResource] = [
21
+ structs.DownloadResource(
22
+ filename="BreaKHis_v1.tar.gz",
23
+ url="http://www.inf.ufpr.br/vri/databases/BreaKHis_v1.tar.gz",
24
+ ),
25
+ ]
26
+ """Dataset resources."""
27
+
28
+ _val_patient_ids: Set[str] = {
29
+ "18842D",
30
+ "19979",
31
+ "15275",
32
+ "15792",
33
+ "16875",
34
+ "3909",
35
+ "5287",
36
+ "16716",
37
+ "2773",
38
+ "5695",
39
+ "16184CD",
40
+ "23060CD",
41
+ "21998CD",
42
+ "21998EF",
43
+ }
44
+ """Patient IDs to use for dataset splits."""
45
+
46
+ _expected_dataset_lengths: Dict[str | None, int] = {
47
+ "train": 1132,
48
+ "val": 339,
49
+ None: 1471,
50
+ }
51
+ """Expected dataset lengths for the splits and complete dataset."""
52
+
53
+ _default_magnifications = ["40X"]
54
+ """Default magnification to use for images in train/val datasets."""
55
+
56
+ _license: str = "CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/)"
57
+ """Dataset license."""
58
+
59
+ def __init__(
60
+ self,
61
+ root: str,
62
+ split: Literal["train", "val"] | None = None,
63
+ magnifications: List[Literal["40X", "100X", "200X", "400X"]] | None = None,
64
+ download: bool = False,
65
+ transforms: Callable | None = None,
66
+ ) -> None:
67
+ """Initialize the dataset.
68
+
69
+ The dataset is split into train and validation by taking into account
70
+ the patient IDs to avoid any data leakage.
71
+
72
+ Args:
73
+ root: Path to the root directory of the dataset. The dataset will
74
+ be downloaded and extracted here, if it does not already exist.
75
+ split: Dataset split to use. If `None`, the entire dataset is used.
76
+ magnifications: A list of the WSI magnifications to select. By default
77
+ only 40X images are used.
78
+ download: Whether to download the data for the specified split.
79
+ Note that the download will be executed only by additionally
80
+ calling the :meth:`prepare_data` method and if the data does
81
+ not yet exist on disk.
82
+ transforms: A function/transform which returns a transformed
83
+ version of the raw data samples.
84
+ """
85
+ super().__init__(transforms=transforms)
86
+
87
+ self._root = root
88
+ self._split = split
89
+ self._download = download
90
+
91
+ self._magnifications = magnifications or self._default_magnifications
92
+ self._indices: List[int] = []
93
+
94
+ @property
95
+ @override
96
+ def classes(self) -> List[str]:
97
+ return ["TA", "MC", "F", "DC"]
98
+
99
+ @property
100
+ @override
101
+ def class_to_idx(self) -> Dict[str, int]:
102
+ return {label: index for index, label in enumerate(self.classes)}
103
+
104
+ @property
105
+ def _dataset_path(self) -> str:
106
+ """Returns the path of the image data of the dataset."""
107
+ return os.path.join(self._root, "BreaKHis_v1", "histology_slides")
108
+
109
+ @functools.cached_property
110
+ def _image_files(self) -> List[str]:
111
+ """Return the list of image files in the dataset.
112
+
113
+ Returns:
114
+ List of image file paths.
115
+ """
116
+ image_files = []
117
+ for magnification in self._magnifications:
118
+ files_pattern = os.path.join(self._dataset_path, f"**/{magnification}", "*.png")
119
+ image_files.extend(list(glob.glob(files_pattern, recursive=True)))
120
+ return sorted(image_files)
121
+
122
+ @override
123
+ def filename(self, index: int) -> str:
124
+ image_path = self._image_files[self._indices[index]]
125
+ return os.path.relpath(image_path, self._dataset_path)
126
+
127
+ @override
128
+ def prepare_data(self) -> None:
129
+ if self._download:
130
+ self._download_dataset()
131
+ _validators.check_dataset_exists(self._root, True)
132
+
133
+ @override
134
+ def configure(self) -> None:
135
+ self._indices = self._make_indices()
136
+
137
+ @override
138
+ def validate(self) -> None:
139
+ _validators.check_dataset_integrity(
140
+ self,
141
+ length=self._expected_dataset_lengths[self._split],
142
+ n_classes=4,
143
+ first_and_last_labels=("TA", "DC"),
144
+ )
145
+
146
+ @override
147
+ def load_data(self, index: int) -> tv_tensors.Image:
148
+ image_path = self._image_files[self._indices[index]]
149
+ return io.read_image_as_tensor(image_path)
150
+
151
+ @override
152
+ def load_target(self, index: int) -> torch.Tensor:
153
+ class_name = self._extract_class(self._image_files[self._indices[index]])
154
+ return torch.tensor(self.class_to_idx[class_name], dtype=torch.long)
155
+
156
+ @override
157
+ def load_metadata(self, index: int) -> Dict[str, Any]:
158
+ return {"patient_id": self._extract_patient_id(self._image_files[self._indices[index]])}
159
+
160
+ @override
161
+ def __len__(self) -> int:
162
+ return len(self._indices)
163
+
164
+ def _download_dataset(self) -> None:
165
+ """Downloads the dataset."""
166
+ for resource in self._resources:
167
+ if os.path.isdir(self._dataset_path):
168
+ continue
169
+
170
+ self._print_license()
171
+ utils.download_and_extract_archive(
172
+ resource.url,
173
+ download_root=self._root,
174
+ filename=resource.filename,
175
+ remove_finished=True,
176
+ )
177
+
178
+ def _print_license(self) -> None:
179
+ """Prints the dataset license."""
180
+ print(f"Dataset license: {self._license}")
181
+
182
+ def _extract_patient_id(self, image_file: str) -> str:
183
+ """Extracts the patient ID from the image file name."""
184
+ return os.path.basename(image_file).split("-")[2]
185
+
186
+ def _extract_class(self, file: str) -> str:
187
+ return os.path.basename(file).split("-")[0].split("_")[-1]
188
+
189
+ def _make_indices(self) -> List[int]:
190
+ """Builds the dataset indices for the specified split."""
191
+ train_indices = []
192
+ val_indices = []
193
+
194
+ for index, image_file in enumerate(self._image_files):
195
+ if self._extract_class(image_file) not in self.classes:
196
+ continue
197
+ patient_id = self._extract_patient_id(image_file)
198
+ if patient_id in self._val_patient_ids:
199
+ val_indices.append(index)
200
+ else:
201
+ train_indices.append(index)
202
+
203
+ split_indices = {
204
+ "train": train_indices,
205
+ "val": val_indices,
206
+ None: train_indices + val_indices,
207
+ }
208
+
209
+ return split_indices[self._split]
@@ -11,12 +11,11 @@ from torchvision import tv_tensors
11
11
  from torchvision.transforms.v2 import functional
12
12
  from typing_extensions import override
13
13
 
14
- from eva.vision.data.datasets import _validators, wsi
15
- from eva.vision.data.datasets.classification import base
14
+ from eva.vision.data.datasets import _validators, vision, wsi
16
15
  from eva.vision.data.wsi.patching import samplers
17
16
 
18
17
 
19
- class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
18
+ class Camelyon16(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
20
19
  """Dataset class for Camelyon16 images and corresponding targets."""
21
20
 
22
21
  _val_slides = [
@@ -195,10 +194,10 @@ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
195
194
 
196
195
  @override
197
196
  def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
198
- return base.ImageClassification.__getitem__(self, index)
197
+ return vision.VisionDataset.__getitem__(self, index)
199
198
 
200
199
  @override
201
- def load_image(self, index: int) -> tv_tensors.Image:
200
+ def load_data(self, index: int) -> tv_tensors.Image:
202
201
  image_array = wsi.MultiWsiDataset.__getitem__(self, index)
203
202
  return functional.to_image(image_array)
204
203
 
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
8
8
  from torchvision.datasets import folder, utils
9
9
  from typing_extensions import override
10
10
 
11
- from eva.vision.data.datasets import _validators, structs
12
- from eva.vision.data.datasets.classification import base
11
+ from eva.vision.data.datasets import _validators, structs, vision
13
12
  from eva.vision.utils import io
14
13
 
15
14
 
16
- class CRC(base.ImageClassification):
15
+ class CRC(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
17
16
  """Dataset class for CRC images and corresponding targets."""
18
17
 
19
18
  _train_resource: structs.DownloadResource = structs.DownloadResource(
@@ -117,7 +116,7 @@ class CRC(base.ImageClassification):
117
116
  )
118
117
 
119
118
  @override
120
- def load_image(self, index: int) -> tv_tensors.Image:
119
+ def load_data(self, index: int) -> tv_tensors.Image:
121
120
  image_path, _ = self._samples[index]
122
121
  return io.read_image_as_tensor(image_path)
123
122
 
@@ -0,0 +1,171 @@
1
+ """GleasonArvaniti dataset class."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Callable, Dict, List, Literal
8
+
9
+ import pandas as pd
10
+ import torch
11
+ from loguru import logger
12
+ from torchvision import tv_tensors
13
+ from typing_extensions import override
14
+
15
+ from eva.vision.data.datasets import _validators, vision
16
+ from eva.vision.utils import io
17
+
18
+
19
+ class GleasonArvaniti(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
20
+ """Dataset class for GleasonArvaniti images and corresponding targets."""
21
+
22
+ _expected_dataset_lengths: Dict[str | None, int] = {
23
+ "train": 15303,
24
+ "val": 2482,
25
+ "test": 4967,
26
+ None: 22752,
27
+ }
28
+ """Expected dataset lengths for the splits and complete dataset."""
29
+
30
+ _license: str = "CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)"
31
+ """Dataset license."""
32
+
33
+ def __init__(
34
+ self,
35
+ root: str,
36
+ split: Literal["train", "val", "test"] | None = None,
37
+ transforms: Callable | None = None,
38
+ ) -> None:
39
+ """Initialize the dataset.
40
+
41
+ Args:
42
+ root: Path to the root directory of the dataset.
43
+ split: Dataset split to use. If `None`, the entire dataset is used.
44
+ transforms: A function/transform which returns a transformed
45
+ version of the raw data samples.
46
+ """
47
+ super().__init__(transforms=transforms)
48
+
49
+ self._root = root
50
+ self._split = split
51
+
52
+ self._indices: List[int] = []
53
+
54
+ @property
55
+ @override
56
+ def classes(self) -> List[str]:
57
+ return ["benign", "gleason_3", "gleason_4", "gleason_5"]
58
+
59
+ @property
60
+ @override
61
+ def class_to_idx(self) -> Dict[str, int]:
62
+ return {name: index for index, name in enumerate(self.classes)}
63
+
64
+ @functools.cached_property
65
+ def _image_files(self) -> List[str]:
66
+ """Return the list of image files in the dataset.
67
+
68
+ Returns:
69
+ List of image file paths.
70
+ """
71
+ subdirs = ["train_validation_patches_750", "test_patches_750/patho_1"]
72
+
73
+ image_files = []
74
+ for subdir in subdirs:
75
+ files_pattern = os.path.join(self._root, subdir, "**/*.jpg")
76
+ image_files += list(glob.glob(files_pattern, recursive=True))
77
+
78
+ return sorted(image_files)
79
+
80
+ @functools.cached_property
81
+ def _manifest(self) -> pd.DataFrame:
82
+ """Returns the train.csv & test.csv files as dataframe."""
83
+ df_train = pd.read_csv(os.path.join(self._root, "train.csv"))
84
+ df_val = pd.read_csv(os.path.join(self._root, "test.csv"))
85
+ df_train["split"], df_val["split"] = "train", "val"
86
+ return pd.concat([df_train, df_val], axis=0).set_index("image_id")
87
+
88
+ @override
89
+ def filename(self, index: int) -> str:
90
+ image_path = self._image_files[self._indices[index]]
91
+ return os.path.relpath(image_path, self._root)
92
+
93
+ @override
94
+ def prepare_data(self) -> None:
95
+ _validators.check_dataset_exists(self._root, download_available=False)
96
+ if not os.path.isdir(os.path.join(self._root, "train_validation_patches_750")):
97
+ raise FileNotFoundError(
98
+ f"`train_validation_patches_750` directory not found in {self._root}"
99
+ )
100
+ if not os.path.isdir(os.path.join(self._root, "test_patches_750")):
101
+ raise FileNotFoundError(f"`test_patches_750` directory not found in {self._root}")
102
+
103
+ if self._split == "test":
104
+ logger.warning(
105
+ "The test split currently leads to unstable evaluation results. "
106
+ "We recommend using the validation split instead."
107
+ )
108
+
109
+ @override
110
+ def configure(self) -> None:
111
+ self._indices = self._make_indices()
112
+
113
+ @override
114
+ def validate(self) -> None:
115
+ _validators.check_dataset_integrity(
116
+ self,
117
+ length=self._expected_dataset_lengths[self._split],
118
+ n_classes=4,
119
+ first_and_last_labels=("benign", "gleason_5"),
120
+ )
121
+
122
+ @override
123
+ def load_data(self, index: int) -> tv_tensors.Image:
124
+ image_path = self._image_files[self._indices[index]]
125
+ return io.read_image_as_tensor(image_path)
126
+
127
+ @override
128
+ def load_target(self, index: int) -> torch.Tensor:
129
+ target = self._extract_class(self._image_files[self._indices[index]])
130
+ return torch.tensor(target, dtype=torch.long)
131
+
132
+ @override
133
+ def __len__(self) -> int:
134
+ return len(self._indices)
135
+
136
+ def _print_license(self) -> None:
137
+ """Prints the dataset license."""
138
+ print(f"Dataset license: {self._license}")
139
+
140
+ def _extract_micro_array_id(self, file: str) -> str:
141
+ """Extracts the ID of the tissue micro array from the file name."""
142
+ return Path(file).stem.split("_")[0]
143
+
144
+ def _extract_class(self, file: str) -> int:
145
+ """Extracts the class label from the file name."""
146
+ return int(Path(file).stem.split("_")[-1])
147
+
148
+ def _make_indices(self) -> List[int]:
149
+ """Builds the dataset indices for the specified split."""
150
+ train_indices, val_indices, test_indices = [], [], []
151
+
152
+ for index, image_file in enumerate(self._image_files):
153
+ array_id = self._extract_micro_array_id(image_file)
154
+
155
+ if array_id == "ZT76":
156
+ val_indices.append(index)
157
+ elif array_id in {"ZT111", "ZT199", "ZT204"}:
158
+ train_indices.append(index)
159
+ elif "test_patches_750" in image_file:
160
+ test_indices.append(index)
161
+ else:
162
+ raise ValueError(f"Invalid microarray value found for file {image_file}")
163
+
164
+ split_indices = {
165
+ "train": train_indices,
166
+ "val": val_indices,
167
+ "test": test_indices,
168
+ None: train_indices + val_indices + test_indices,
169
+ }
170
+
171
+ return split_indices[self._split]
@@ -7,12 +7,11 @@ import torch
7
7
  from torchvision import tv_tensors
8
8
  from typing_extensions import override
9
9
 
10
- from eva.vision.data.datasets import _validators
11
- from eva.vision.data.datasets.classification import base
10
+ from eva.vision.data.datasets import _validators, vision
12
11
  from eva.vision.utils import io
13
12
 
14
13
 
15
- class MHIST(base.ImageClassification):
14
+ class MHIST(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
16
15
  """MHIST dataset."""
17
16
 
18
17
  def __init__(
@@ -69,7 +68,7 @@ class MHIST(base.ImageClassification):
69
68
  )
70
69
 
71
70
  @override
72
- def load_image(self, index: int) -> tv_tensors.Image:
71
+ def load_data(self, index: int) -> tv_tensors.Image:
73
72
  image_filename, _ = self._samples[index]
74
73
  image_path = os.path.join(self._dataset_path, image_filename)
75
74
  return io.read_image_as_tensor(image_path)
@@ -13,12 +13,11 @@ from torchvision.transforms.v2 import functional
13
13
  from typing_extensions import override
14
14
 
15
15
  from eva.core.data import splitting
16
- from eva.vision.data.datasets import _validators, structs, wsi
17
- from eva.vision.data.datasets.classification import base
16
+ from eva.vision.data.datasets import _validators, structs, vision, wsi
18
17
  from eva.vision.data.wsi.patching import samplers
19
18
 
20
19
 
21
- class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
20
+ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
22
21
  """Dataset class for PANDA images and corresponding targets."""
23
22
 
24
23
  _train_split_ratio: float = 0.7
@@ -121,10 +120,10 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
121
120
 
122
121
  @override
123
122
  def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
124
- return base.ImageClassification.__getitem__(self, index)
123
+ return vision.VisionDataset.__getitem__(self, index)
125
124
 
126
125
  @override
127
- def load_image(self, index: int) -> tv_tensors.Image:
126
+ def load_data(self, index: int) -> tv_tensors.Image:
128
127
  image_array = wsi.MultiWsiDataset.__getitem__(self, index)
129
128
  return functional.to_image(image_array)
130
129
 
@@ -10,14 +10,13 @@ from torchvision.datasets import utils
10
10
  from torchvision.transforms.v2 import functional
11
11
  from typing_extensions import override
12
12
 
13
- from eva.vision.data.datasets import _validators, structs
14
- from eva.vision.data.datasets.classification import base
13
+ from eva.vision.data.datasets import _validators, structs, vision
15
14
 
16
15
  _URL_TEMPLATE = "https://zenodo.org/records/2546921/files/{filename}.gz?download=1"
17
16
  """PatchCamelyon URL files templates."""
18
17
 
19
18
 
20
- class PatchCamelyon(base.ImageClassification):
19
+ class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
21
20
  """Dataset class for PatchCamelyon images and corresponding targets."""
22
21
 
23
22
  _train_resources: List[structs.DownloadResource] = [
@@ -127,7 +126,7 @@ class PatchCamelyon(base.ImageClassification):
127
126
  )
128
127
 
129
128
  @override
130
- def load_image(self, index: int) -> tv_tensors.Image:
129
+ def load_data(self, index: int) -> tv_tensors.Image:
131
130
  return self._load_from_h5("x", index)
132
131
 
133
132
  @override